PyTorch 的首要编程接口是 Python 言语。虽然关于许多需求动态和快速迭代的场景来说,Python 是一种合适且首选的言语,但同样有许多情况下,Python 的这些特点恰恰是不利的。后者常常运用的一个环境是出产环境——要求低延迟和严格布置。关于出产场景,C++ 通常是首选言语,即便仅仅将其绑定到另一种言语,如 Java、Rust 或 Go。以下内容将概述怎么利用 PyTorch 提供C++ 库加载现有Python序列化模型,完全不依靠于Python的在C++环境中履行。
第 1 步:将 PyTorch 模型转换为 Torch 脚本
PyTorch 模型从 Python 迁移到 C++ 的前言是由Torch Script实现的,Torch Script是 PyTorch 模型的一种中间表明,能够被 Torch Script 编译器了解、编译和序列化。假如是从 vanilla “eager” API 编写的现有 PyTorch 模型开端,则必须首先将模型转换为 Torch 脚本。在下面讨论的最常见的情况下,这只需求很少的作业。假如现已有一个 Torch 脚本模块,能够跳到本教程的下一部分。
将 PyTorch 模型转换为 Torch 脚本有两种办法。第一种称为盯梢,这是一种经过运用示例输入对其进行一次评价来捕获模型结构的机制,并记载这些输入在模型中的活动。这适用于束缚运用
操控流的模型(即模型中不存在 if)。第二种办法是向模型增加显式注释,通知 Torch 脚本编译器它能够直接解析和编译模型代码,受 Torch 脚本言语施加的束缚。
提示:
能够在官方的Torch 脚本参阅中找到这两种办法的完好文档,以及运用的进一步指导。
1.1 经过盯梢转换为 Torch 脚本
要经过盯梢将 PyTorch 模型转换为 Torch 脚本,必须将模型实例
连同示例输入
一同传递给torch.jit.trace
函数。这将生成一个torch.jit.ScriptModule
目标,其间包括嵌入模块forward
办法中的模型评价盯梢:
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
# 把模型 model 和 样例数据 example 传入 torch.jit.trace
traced_script_module = torch.jit.trace(model, example)
和惯例 PyTorch 模块相同,现在能够对盯梢后的ScriptModule
进行相同的推理猜测:
In[1]: output = traced_script_module(torch.ones(1, 3, 224, 224))
In[2]: output[0, :5]
Out[2]: tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
1.2 经过注释转换为 Torch 脚本
在某些情况下,假如模型采用了特定形式的操控流,能够直接在 Torch 脚本中编写模型并相应地注释模型。例如,假定有以下 vanilla Pytorch 模型:
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0: ### <<<<<======这里有if 判断
output = self.weight.mv(input)
else:
output = self.weight + input
return output
因为模块的forward
办法运用依靠于输入的操控流,所以不适合追寻。为了将模块转换为ScriptModule
,需求torch.jit.script
编译模块,如下所示:
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
my_module = MyModule(10,20)
sm = torch.jit.script(my_module)
假如需求排除nn.Module
中的某些办法,因为它们运用了 TorchScript 尚不支撑的 Python 功用,能够运用@torch.jit.ignore
sm
是一个ScriptModule
准备好序列化的实例。
第 2 步:将脚本模块序列化为文件
经过 盯梢或注释 PyTorch模型的办法 获取ScriptModule
后,就能够将其序列化为文件。稍后,将能够在 C++ 中从此文件加载模块并履行它,而不依靠于 Python。假定要序列化前面在盯梢示例中显现的模型ResNet18
。要履行序列化,只需在模块上调用[save] 并将文件名传递给它:(pytorch.org/docs/master…)
traced_script_module.save("traced_resnet_model.pt")
这将在作业目录中生成一个文件traced_resnet_model.pt
。假如想序列化sm
,请调用sm.save("my_module_model.pt")
现在能够离开 Python 范畴,准备跨入 C++ 范畴。
第 3 步:在 C++ 中加载脚本模块
要在 C++ 中加载方才序列化 PyTorch 模型,运用程序必须依靠于 PyTorch C++ API——也称为LibTorch。LibTorch 发行版包括一组共享库、头文件和 CMake 构建装备文件。虽然 CMake 不是依靠 LibTorch 的必要条件,但推荐运用它,并且会在未来得到很好的支撑。在本文中,将运用 CMake 和 LibTorch 构建一个最小的 C++ 运用程序,它仅仅加载和履行序列化的 PyTorch 模型。
3.1 一个最小的 C++ 运用程序
从加载模块的代码开端。以下内容现已完成:
#include <torch/script.h> // One-stop header.
#include <iostream>
#include <memory>
int main(int argc, const char* argv[]) {
if (argc != 2) {
std::cerr << "usage: example-app <path-to-exported-script-module>\n";
return -1;
}
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load(argv[1]);
}
catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
return -1;
}
std::cout << "ok\n";
}
标<torch/script.h>
头包括运转示例所需的 LibTorch 库中的一切相关内容。运用程序接受序列化 PyTorch ScriptModule
的文件途径作为其仅有的命令行参数,然后运用该函数c处理反序列化模块,该torch::jit::load()
函数将此文件途径作为输入。它回来一个torch::jit::script::Module
目标。稍后将研究怎么履行它。
3.2 依靠 LibTorch 并构建运用程序
将上述代码存储到一个名为example-app.cpp
.CMakeLists.txt
构建它的最小化或许看起来很简单:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(custom_ops)
find_package(Torch REQUIRED)
add_executable(example-app example-app.cpp)
target_link_libraries(example-app "${TORCH_LIBRARIES}")
set_property(TARGET example-app PROPERTY CXX_STANDARD 14)
构建示例运用程序所需的最终一件事是 LibTorch 发行版。能够从PyTorch 网站上的下载页面获取最新的稳定版别。下载并解压缩最新的存档,会看到一个具有以下目录结构的文件夹:
libtorch/
bin/
include/
lib/
share/
- 该
lib/
文件夹包括必须链接的共享库, - 该
include/
文件夹包括程序需求包括的头文件, - 该文件夹包括启用上述简单命令
share/
所需的 CMake 装备。find_package(Torch)
提示:
在 Windows 上,debug和release版别 ABI 不兼容。假如在调试形式下构建项目,请测验 LibTorch 的调试版别。此外,请确保在cmake--build.
下面的行中指定正确的装备。
最终一步是构建运用程序。为此,假定示例目录布局如下:
example-app/
CMakeLists.txt
example-app.cpp
现在能够运转以下命令从example-app/
文件夹中构建运用程序:
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
cmake --build . --config Release
/path/to/libtorch
是解压后的 LibTorch 发行版的完好途径。假如一切顺利,它是这样:
root@4b5a67132e81:/example-app# mkdir build
root@4b5a67132e81:/example-app# cd build
root@4b5a67132e81:/example-app/build# cmake -DCMAKE_PREFIX_PATH=/path/to/libtorch ..
-- The C compiler identification is GNU 5.4.0
-- The CXX compiler identification is GNU 5.4.0
-- Check for working C compiler: /usr/bin/cc
-- Check for working C compiler: /usr/bin/cc -- works
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Detecting C compile features
-- Detecting C compile features - done
-- Check for working CXX compiler: /usr/bin/c++
-- Check for working CXX compiler: /usr/bin/c++ -- works
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Looking for pthread.h
-- Looking for pthread.h - found
-- Looking for pthread_create
-- Looking for pthread_create - not found
-- Looking for pthread_create in pthreads
-- Looking for pthread_create in pthreads - not found
-- Looking for pthread_create in pthread
-- Looking for pthread_create in pthread - found
-- Found Threads: TRUE
-- Configuring done
-- Generating done
-- Build files have been written to: /example-app/build
root@4b5a67132e81:/example-app/build# make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app
将之前创立的盯梢ResNet18
模型 traced_resnet_model.pt
的途径提供给生成的example-app
二进制文件,会得到一个友好的“ok”奖励。请注意,假如测验与my_module_model.pt
您一同运转此示例,将收到一条过错消息,指出输入形状不兼容。my_module_model.pt
希望 1D 而不是 4D。
root@4b5a67132e81:/example-app/build# ./example-app <path_to_model>/traced_resnet_model.pt
ok
第 4 步:在 C++ 中履行脚本模块
在 C++ 中成功加载了序列化ResNet18
后,现在只需几行代码即可履行它!将这些行增加到 C++ 运用程序的main()
函数中:
// Create a vector of inputs.
std::vector<torch::jit::IValue> inputs;
inputs.push_back(torch::ones({1, 3, 224, 224}));
// Execute the model and turn its output into a tensor.
at::Tensor output = module.forward(inputs).toTensor();
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
前两行设置了模型的输入。创立一个torch::jit::IValue
(一个类型擦除的值类型script::Module
办法接受和回来)的向量,并增加一个输入。运用torch::ones()
创立输入张量, 。然后运转script::Module
‘forward
办法,将创立的输入向量传递给它。回来一个IValue
,并调用toTensor()
将其转换为张量。
提示:
要了解更多关于函数torch::ones
和 PyTorch C++ API 的更多信息,请参阅pytorch.org/cppdocs上的文档。PyTorch C++ API 提供与 Python API 挨近的特性,答应您像在 Python 中相同进一步操作和处理张量。
在最终一行,打印输出的前五个条目。因为在文前面的 Python 中为模型提供了相同的输入,因此理想情况下应该看到相同的输出。经过重新编译运用程序并运用相同的序列化模型运转它来测验一下:
root@4b5a67132e81:/example-app/build# make
Scanning dependencies of target example-app
[ 50%] Building CXX object CMakeFiles/example-app.dir/example-app.cpp.o
[100%] Linking CXX executable example-app
[100%] Built target example-app
root@4b5a67132e81:/example-app/build# ./example-app traced_resnet_model.pt
-0.2698 -0.0381 0.4023 -0.3010 -0.0448
[ Variable[CPUFloatType]{1,5} ]
作为参阅,之前 Python 中的输出是:
tensor([-0.2698, -0.0381, 0.4023, -0.3010, -0.0448], grad_fn=<SliceBackward>)
看起来相同!
提示:
能够运用model.to(at::kCUDA);
将模型移动到 GPU 内存.经过调用tensor.to(at::kCUDA)
确保模型的输入也存在于 CUDA 内存中,这将在 CUDA 内存中回来一个新的张量。
第 5 步:获取协助和探究 API
本文有望使您对 PyTorch 模型从 Python 迁移到 C++ 的进程有一个大致的了解。运用本文中描述的概念,应该能够从一个一般的、“eager” PyTorch 模型,到用Python 编译模型ScriptModule
,再到磁盘上的序列化文件,然后到C++ script::Module
中的可履行文件。
当然,还有许多概念没有涉及。例如,您或许会发现自己想要运用 C++ 或 CUDA 实现的自定义运算符来扩展ScriptModule
,并在纯 C++ 出产环境中加载并履行ScriptModule
自定义运算符。这是可行的,并且得到了很好的支撑!能够阅读此文件夹中的示例,咱们将很快更新文档。现在,以下链接通常或许会有所协助:
- Torch 脚本参阅:https ://pytorch.org/docs/master/jit.html
- PyTorch C++ API 文档:https ://pytorch.org/cppdocs/
- PyTorch Python API 文档:https ://pytorch.org/docs/
与往常相同,假如遇到任何问题或有疑问,能够运用论坛或GitHub 问题进行联络。
原文地址