请先关注 [低调大师] 公众号 优秀的自媒体个人博客,低调大师,许军

低调大师

您现在的位置是:首页>文章详情

文章详情

PyTorch 1.4 发布:支持 Java 和分布式模型并行训练

2020-01-26 48热度

PyTorch 团队上周发布了最新的 PyTorch 1.4 版本更新日志显示,此版本包含了 1500 多次提交,并在 JIT、ONNX、分布式、性能和 Eager 前端等方面进行了改进,以及对于移动版本和量化方面的实验领域也进行了改进。1.4 还增加了新的实验性功能,其中包括基于 RPC 的分布式模型并行训练以及对 Java 的语言绑定。

此外,PyTorch 1.4 是支持 Python 2 的最后一个版本,同时也是支持 C++11 的最后一个版本。因此官方建议从 1.4 开始迁移到 Python 3,并使用 C++14 进行构建,以方便将来从 1.4 过渡到 1.5。

更新亮点

为 PyTorch Mobile 提供 Build 级别自定义的支持

在 1.3 中推出处于实验性阶段的 Pytorch Mobile 之后,1.4 版本增加了更多对移动端的支持,包括以细粒度级别( fine-grain level)自定义构建脚本的功能。此功能使得移动端开发者能够优化库的大小 —— 仅在库的模型中包括它们使用的 operators,同时在此过程中显著减少了其设备占用的空间。早期结果显示,定制的 MobileNetV2 比预构建的 PyTorch 移动端库小 40% 至 50%。

用于选择性地仅编译 MobileNetV2 所需的 operators 的示例代码:

 # Dump list of operators used by MobileNetV2: import torch, yaml model = torch.jit.load('MobileNetV2.pt') ops = torch.jit.export_opnames(model) with open('MobileNetV2.yaml', 'w') as output: yaml.dump(ops, output)
 # Build PyTorch Android library customized for MobileNetV2: SELECTED_OP_LIST=MobileNetV2.yaml scripts/build_pytorch_android.sh arm64-v8a # Build PyTorch iOS library customized for MobileNetV2: SELECTED_OP_LIST=MobileNetV2.yaml BUILD_PYTORCH_MOBILE=1 IOS_ARCH=arm64 scripts/build_ios.sh

分布式模型并行训练(实验性)

随着诸如 RoBERTa 等万亿级别参数的大型模型出现,模型并行训练对于帮助研究人员突破极限变得越来越重要。此版本提供了分布式 RPC 框架,以支持分布式模型并行训练。此框架支持远程运行函数,以及在无需复制真实数据的前提下引用远程对象。PyTorch 还提供了 autograd 和 Optimizer API,能够透明地在后台运行并跨 RPC 边界更新参数。

Java bindings(实验性)

除了支持 Python 和 C++ 外,此版本还增加了对 Java bindings 的实验性支持。基于 PyTorch Mobile 中为 Android 开发的接口,我们可通过新的 Java bindings 从任何 Java 程序中调用 TorchScript 模型。不过要注意的是,此版本的 Java bindings 仅支持 Linux 平台,且只能用于进行模型推理。开发团队表示会在后续的版本中扩展支持。

有关如何在 Java 中使用 PyTorch 请查看以下代码片段:

 Module mod = Module.load("demo-model.pt1"); Tensor data = Tensor.fromBlob( new int[] {1, 2, 3, 4, 5, 6}, // data new long[] {2, 3} // shape ); IValue result = mod.forward(IValue.from(data), IValue.from(3.0)); Tensor output = result.toTensor(); System.out.println("shape: " + Arrays.toString(output.shape())); System.out.println("data: " + Arrays.toString(output.getDataAsFloatArray()));

下载地址:https://github.com/pytorch/pytorch/releases/tag/v1.4.0

收藏 (0)

相关文章

    文章评论

    共有0条评论来说两句吧...