您现在的位置是:首页 > 文章详情

ONNX整理

日期:2022-10-20点击:587

ONNX(Open Neural Network Exchange)——开放神经网络交换格式,作为框架共用的一种模型交换格式,使用protobuf二进制格式来序列化模型(protobuf序列化可以参考Netty整合Protobuffer ),可以提供更好的传输性能。官方github:GitHub - onnx/onnx at f2daca5e9b9315a2034da61c662d2a7ac28a9488

ONNX将每一个网络的每一层或者说是每一个算子当作节点Node,再由这些Node去构建一个Graph,相当于是一个网络。最后将Graph和这个onnx模型的其他信息结合在一起,生成一个model,也就是最终的onnx模型。实例如下

在这里插入图片描述

创建ONNX模型

创建onnx模型有两种方法,一种是其他框架转换过来,如Pytorch、PaddlePaddle等,从Pytorch转换onnx可以参考模型部署篇 的Pytorch 权重 pth 转换 onnx;PaddlePaddle转换onnx可以参考PaddleOCR使用指南 中的Paddle2ONNX

我们先来生成一个onnx文件

 import torch import torch.nn as nn from torch.autograd import Variable class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.conv = nn.Conv2d(1, 1, 1) self.act = nn.ReLU() def forward(self, x): return self.act(self.conv(x)) if __name__ == '__main__': net = Network() input = Variable(torch.randn([1, 1, 1, 1])) torch.onnx.export(net, input, 'net.onnx', opset_version=10)

然后来打印这个onnx文件的结构

 import torch import torch.nn as nn from torch.autograd import Variable import onnx class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.conv = nn.Conv2d(1, 1, 1) self.act = nn.ReLU() def forward(self, x): return self.act(self.conv(x)) if __name__ == '__main__': # net = Network()  # input = Variable(torch.randn([1, 1, 1, 1]))  # torch.onnx.export(net, input, 'net.onnx', opset_version=10)   print(onnx.load("./net.onnx"))

运行结果

 ir_version: 5 producer_name: "pytorch" producer_version: "1.12.1" graph { node { input: "input.1" input: "conv.weight" input: "conv.bias" output: "input" name: "Conv_0" op_type: "Conv" attribute { name: "dilations" ints: 1 ints: 1 type: INTS } attribute { name: "group" i: 1 type: INT } attribute { name: "kernel_shape" ints: 1 ints: 1 type: INTS } attribute { name: "pads" ints: 0 ints: 0 ints: 0 ints: 0 type: INTS } attribute { name: "strides" ints: 1 ints: 1 type: INTS } } node { input: "input" output: "4" name: "Relu_1" op_type: "Relu" } name: "torch_jit" initializer { dims: 1 dims: 1 dims: 1 dims: 1 data_type: 1 name: "conv.weight" raw_data: "\014\317B?" } initializer { dims: 1 data_type: 1 name: "conv.bias" raw_data: "\344\n\026\277" } input { name: "input.1" type { tensor_type { elem_type: 1 shape { dim { dim_value: 1 } dim { dim_value: 1 } dim { dim_value: 1 } dim { dim_value: 1 } } } } } output { name: "4" type { tensor_type { elem_type: 1 shape { dim { dim_value: 1 } dim { dim_value: 1 } dim { dim_value: 1 } dim { dim_value: 1 } } } } } } opset_import { version: 10 }

首先是onnx版本,我们这里为ir_version: 5,然后是从什么框架转换过来的,这里是从Pytorch转换过来的producer_name: "pytorch",版本号是producer_version: "1.12.1"。

然后是graph->node,第一个node是2D卷积核,第二个node是Relu激活函数。node中的op_type是节点类型,所有类型可以参考https://github.com/onnx/onnx/blob/f2daca5e9b9315a2034da61c662d2a7ac28a9488/docs/Operators.md。name是节点名称,它跟op_type是不同的。attribute是节点属性,在Conv_0中就是2D卷积的各种属性,比如"group"是分组卷积,"kernel_shape"是卷积核尺寸等等。initializer是初始化,包含了权重初始化和偏置初始化。input是输入,包含输入的形状,output是输出,包含输出的形状。opset_import为当前的模型文件所依赖的算子domain和版本。

最后我们来检查该模型,运行是没有问题的。

 import torch import torch.nn as nn from torch.autograd import Variable import onnx class Network(nn.Module): def __init__(self): super(Network, self).__init__() self.conv = nn.Conv2d(1, 1, 1) self.act = nn.ReLU() def forward(self, x): return self.act(self.conv(x)) if __name__ == '__main__': # net = Network()  # input = Variable(torch.randn([1, 1, 1, 1]))  # torch.onnx.export(net, input, 'net.onnx', opset_version=10)   # print(onnx.load("./net.onnx"))  model = onnx.load("./net.onnx") onnx.checker.check_model(model)

另外一种就是用onnx自己的方法创建onnx模型。

 import onnx import onnx.helper as helper import numpy as np if __name__ == '__main__': input = helper.make_tensor_value_info(name='input', elem_type=onnx.TensorProto.FLOAT, shape=[1, 3, 244, 244]) output = helper.make_tensor_value_info(name='output', elem_type=onnx.TensorProto.FLOAT, shape=[1, 3, 244, 244]) weight = helper.make_tensor(name='weight', data_type=onnx.TensorProto.FLOAT, dims=[3, 3, 1, 1], vals=np.random.randn(3, 3, 1, 1)) bias = helper.make_tensor(name='bias', data_type=onnx.TensorProto.FLOAT, dims=[3], vals=np.random.randn(3)) node = helper.make_node(op_type='Conv', inputs=['input', 'weight', 'bias'], outputs=['output'], kernel_shape=[1, 1], strides=[1, 1],  group=1, pads=[0, 0, 0, 0]) graph = helper.make_graph(name='graph', nodes=[node], inputs=[input], outputs=[output], initializer=[weight, bias]) model = helper.make_model(graph) onnx.checker.check_model(model) print(model) onnx.save_model(model, 'model.onnx')

运行结果

 ir_version: 8 graph { node { input: "input" input: "weight" input: "bias" output: "output" op_type: "Conv" attribute { name: "group" i: 1 type: INT } attribute { name: "kernel_shape" ints: 1 ints: 1 type: INTS } attribute { name: "pads" ints: 0 ints: 0 ints: 0 ints: 0 type: INTS } attribute { name: "strides" ints: 1 ints: 1 type: INTS } } name: "graph" initializer { dims: 3 dims: 3 dims: 1 dims: 1 data_type: 1 float_data: 0.45837152004241943 float_data: 0.10209446400403976 float_data: 1.0382566452026367 float_data: -0.09292714297771454 float_data: 1.58871591091156 float_data: 0.3746287226676941 float_data: -0.35588690638542175 float_data: 0.7165427207946777 float_data: 0.10244251787662506 name: "weight" } initializer { dims: 3 data_type: 1 float_data: -0.36782845854759216 float_data: 2.305680513381958 float_data: -0.13051341474056244 name: "bias" } input { name: "input" type { tensor_type { elem_type: 1 shape { dim { dim_value: 1 } dim { dim_value: 3 } dim { dim_value: 244 } dim { dim_value: 244 } } } } } output { name: "output" type { tensor_type { elem_type: 1 shape { dim { dim_value: 1 } dim { dim_value: 3 } dim { dim_value: 244 } dim { dim_value: 244 } } } } } } opset_import { version: 17 }

动态设置batch_size

在上面的结果中,我们可以看到input的维度都是固定值[1,3,244,244],现在我们要来改变这个固定值为可以动态输入的值。我们先将模型给运行起来。

 import onnx import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load("./model.onnx") sess = onnxruntime.InferenceSession('./model.onnx') input = np.random.randn(1, 3, 244, 244).astype(np.float32) print(sess.run(['output'], {'input': input}))

运行结果

 [array([[[[-7.4062514e-01, 2.5951520e-01, -3.5876265e-01, ..., -2.0852795e+00, -1.0078001e-01, -4.9386564e-01], [-6.0379845e-01, 9.2830718e-01, -4.2096943e-02, ..., -1.9139317e-01, 1.6547061e+00, 1.4468774e+00], [ 2.6494553e+00, -9.6209788e-01, 8.2099646e-02, ..., -1.5899204e+00, -1.3295431e+00, 1.1512205e-01], ..., [ 1.4135087e+00, 6.4077592e-01, -5.6514746e-01, ..., 2.1367333e+00, 2.6012421e+00, -1.3565271e+00], [ 6.9879985e-01, 1.2454928e+00, 6.0045028e-01, ..., -6.1302024e-01, -4.3026954e-02, -7.2975445e-01], [-2.1020520e+00, -1.2499222e+00, -9.3896770e-01, ..., -4.6129468e-01, 5.4580927e-01, -7.4599540e-01]], [[ 5.6230574e+00, 2.6218858e+00, 7.1071947e-01, ..., 3.6510468e-02, 2.5771899e+00, 2.0060635e+00], [ 4.2759910e+00, 2.5261867e+00, 1.0787441e+00, ..., 3.3373690e+00, 4.5090003e+00, 3.5535808e+00], [ 1.6522924e+00, 1.5206050e+00, 3.6905313e+00, ..., 1.5963824e+00, 5.1875353e-02, 3.4248161e+00], ..., [ 1.0295208e+00, 4.5397396e+00, 4.3366423e+00, ..., 1.2408195e+00, 3.1239326e+00, 1.7476916e+00], [ 9.7080982e-01, 1.9692242e+00, 3.7690439e+00, ..., -1.6770840e-01, 1.1871569e+00, 4.2690439e+00], [ 4.4730301e+00, 1.5573008e+00, 7.2707558e+00, ..., 4.7898588e+00, 2.9080591e+00, 7.2294927e-01]], [[ 1.3509388e+00, -1.9160898e-01, -1.3318433e+00, ..., -1.0562456e+00, 1.0652192e-01, -4.4993240e-01], [ 7.3106253e-01, -4.0714890e-03, -5.3625894e-01, ..., -6.2385768e-02, 3.3464909e-01, 2.7667671e-01], [-7.8517151e-01, -7.1918708e-01, 5.5366117e-01, ..., -4.7982591e-01, -1.0322813e+00, 8.0901492e-01], ..., [-1.0904443e+00, 4.7577775e-01, 9.5288980e-01, ..., -9.8435390e-01, -5.1632053e-01, 2.4581529e-01], [-6.4627886e-01, -9.8449951e-01, 1.6146483e-01, ..., -1.2009792e+00, -7.3006052e-01, 7.0891309e-01], [ 1.3855783e+00, -8.9338100e-01, 2.4704218e+00, ..., 6.8950468e-01, 1.7709453e-01, -7.6678610e-01]]]], dtype=float32)]

现在我们来把输入的batch_size调整成2

 import onnx import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load("./model.onnx") inputs = model.graph.input outputs = model.graph.output for i in inputs: i.type.tensor_type.shape.dim[0].dim_value = 2  for o in outputs: o.type.tensor_type.shape.dim[0].dim_value = 2  onnx.checker.check_model(model) onnx.save_model(model, 'dynamic_model.onnx') sess = onnxruntime.InferenceSession('./dynamic_model.onnx') input = np.random.randn(2, 3, 244, 244).astype(np.float32) print(sess.run(['output'], {'input': input}))

运行结果

 [array([[[[-2.10871696e-02, -1.32871771e+00, -1.22335061e-01, ..., 4.77721721e-01, -4.10815179e-01, -1.37511027e+00], [-1.09181249e+00, -2.02204657e+00, 1.54176390e+00, ..., -1.88722742e+00, -2.00726366e+00, 4.24929589e-01], [-7.14685619e-01, 3.82802397e-01, -2.30412316e+00, ..., 7.06834435e-01, -2.36892438e+00, -2.11947155e+00], ..., [-9.51929450e-01, -1.22408187e+00, -1.35213524e-01, ..., 5.55669367e-02, -5.95110297e-01, -2.15206313e+00], [ 8.90325904e-01, -1.89442956e+00, 8.34725618e-01, ..., -2.34860206e+00, -1.09965193e+00, -4.96994108e-01], [ 1.56639183e+00, 5.97145438e-01, -5.28750658e-01, ..., 5.77995658e-01, -1.46205699e+00, 2.80693078e+00]], [[ 3.09728765e+00, -1.42589498e+00, 7.58970022e-01, ..., 3.48910093e+00, 2.95971513e+00, 1.96736765e+00], [ 2.76622701e+00, 1.58350587e+00, 2.41761374e+00, ..., 3.68322372e-01, 3.05963039e-01, 2.99718475e+00], [-1.75151324e+00, 2.79870439e+00, -3.03543806e-01, ..., 2.86027908e+00, 1.78771615e+00, 4.79569674e+00], ..., [ 1.30739605e+00, 1.83714139e+00, 4.55001736e+00, ..., 1.44066858e+00, 4.87037659e+00, 2.10291076e+00], [ 9.44083452e-01, -8.11131001e-02, 2.89160919e+00, ..., 2.34788847e+00, 1.95467031e+00, 3.87145948e+00], [ 2.71238947e+00, 1.46723819e+00, 7.61192560e-01, ..., 2.69581342e+00, 2.11386037e+00, 4.08577728e+00]], [[ 3.52043629e-01, -1.83945060e+00, -9.97831583e-01, ..., -2.60245442e-01, 3.69277894e-01, 1.17505208e-01], [ 2.62015522e-01, -6.50106370e-01, -7.36498535e-01, ..., -3.72626394e-01, -9.92001474e-01, 1.87904552e-01], [-2.00427341e+00, -2.67415404e-01, -1.00334084e+00, ..., 8.22970718e-02, 1.41485706e-01, 1.49001801e+00], ..., [-9.85595703e-01, -9.74879414e-03, 1.27501774e+00, ..., -7.10564435e-01, 1.17551017e+00, -4.15902734e-01], [-9.80473995e-01, -1.07735765e+00, 2.39617974e-02, ..., -1.93872005e-02, 2.48361230e-02, 7.19040394e-01], [-6.61614537e-02, -4.85614896e-01, -7.31452227e-01, ..., -9.65917259e-02, -2.94267178e-01, 1.87805906e-01]]], [[[-1.05941308e+00, 8.10959578e-01, -9.29054856e-01, ..., -1.33419132e+00, -5.62950134e-01, 3.15277368e-01], [-2.45844007e+00, -5.31174302e-01, 8.06264520e-01, ..., -1.37343729e+00, -1.26287377e+00, -1.79255664e+00], [ 5.01155496e-01, 2.53203034e+00, -9.11398768e-01, ..., -2.61194611e+00, -6.27550602e-01, -1.04612875e+00], ..., [ 5.64767838e-01, 1.82380235e+00, -9.87865806e-01, ..., -1.48546624e+00, 5.00284791e-01, -1.14099467e+00], [-1.48488015e-01, -3.75306606e-03, 2.05217457e+00, ..., -4.82964367e-01, 6.37757182e-01, 5.87742925e-01], [-7.62285709e-01, 5.78535438e-01, -9.07517672e-01, ..., -1.40203249e+00, 3.13063234e-01, 9.46564317e-01]], [[ 2.21778965e+00, 1.17825162e+00, 1.17773283e+00, ..., 4.21785736e+00, 1.93207061e+00, 6.90674305e+00], [ 5.16840172e+00, 4.03573513e-02, 3.72957373e+00, ..., 2.57324958e+00, 3.23857665e-01, 8.98278236e-01], [ 1.18916261e+00, 4.03137350e+00, 1.54717636e+00, ..., 5.73142242e+00, 2.54209590e+00, 3.02691102e+00], ..., [ 2.02949071e+00, 4.00444984e+00, 3.55739307e+00, ..., 5.54533482e-01, 3.57894540e+00, 7.03547835e-01], [ 2.57975435e+00, 2.32062602e+00, 4.18669128e+00, ..., 2.15663671e+00, 2.39567637e+00, 7.93485880e-01], [ 3.32399893e+00, 3.12817383e+00, 3.60134292e+00, ..., 1.70791423e+00, 7.71586776e-01, 3.58140349e+00]], [[-6.90246701e-01, -8.55753422e-01, -1.35433823e-01, ..., 9.99482393e-01, -2.96287388e-01, 2.49611807e+00], [ 1.56937921e+00, -9.95752215e-01, 5.38442284e-02, ..., 1.63274094e-01, -9.27845955e-01, -6.64922059e-01], [-5.40241778e-01, 2.26666585e-01, -2.95405626e-01, ..., 1.90356636e+00, 4.94795978e-01, 1.35599896e-01], ..., [-4.09579694e-01, 1.26961544e-01, 5.97525239e-01, ..., -9.00853217e-01, 8.11160445e-01, -8.88532698e-01], [-5.75763881e-01, -1.15364529e-01, 2.42510274e-01, ..., 1.83168098e-01, -3.83193374e-01, -1.10992551e+00], [ 9.75027233e-02, 1.07848495e-02, 2.93477297e-01, ..., -2.67393768e-01, -8.09366763e-01, 6.60410523e-03]]]], dtype=float32)]

但是现在batch_size依然是一个固定值,如果我们修改input的第一个维度,是会报错的。则我们需要修改成以下的方式才能输入任意的batch_size。

 import onnx import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load("./model.onnx") inputs = model.graph.input outputs = model.graph.output for i in inputs: i.type.tensor_type.shape.dim[0].dim_param = 'batchsize'  for o in outputs: o.type.tensor_type.shape.dim[0].dim_param = 'batchsize'  onnx.checker.check_model(model) onnx.save_model(model, 'dynamic_model.onnx') sess = onnxruntime.InferenceSession('./dynamic_model.onnx') input = np.random.randn(3, 3, 244, 244).astype(np.float32) print(sess.run(['output'], {'input': input}))

运行结果

 [array([[[[-5.6308472e-01, -2.8269453e+00, -2.7103744e+00, ..., 4.2550400e-01, 6.5147376e-01, -4.7779888e-02], [-2.5536952e+00, 1.1469245e-01, 3.4514198e-01, ..., -1.8919052e+00, -5.7445437e-01, -1.5864235e+00], [-1.7443299e-02, -8.9739335e-01, -2.9766396e-01, ..., 2.7872375e-01, -8.8234627e-01, -2.3681331e+00], ..., [-1.3148707e+00, -5.4888296e-01, 4.1061863e-01, ..., -1.0763314e+00, -9.6379507e-01, 1.3077673e+00], [-6.3514382e-02, -5.1493609e-01, -1.5793841e+00, ..., -2.2589236e-02, -2.2170777e+00, 1.2437304e+00], [ 7.4394345e-01, 7.8581774e-01, 2.0062235e-01, ..., -1.4014708e+00, 5.5377036e-02, 3.6608991e-01]], [[ 5.4129419e+00, 1.7448205e+00, 3.4165416e+00, ..., 1.0320716e+00, 1.6988618e+00, 5.1501741e+00], [-1.3918903e+00, 1.7199724e+00, 2.1343894e+00, ..., 8.0553353e-01, 4.7985373e+00, 2.5783958e+00], [ 2.3555427e+00, 6.3222194e-01, 2.9314611e+00, ..., 4.3459427e-01, 1.3417060e+00, 1.6852837e+00], ..., [ 5.7537341e-01, 3.0654173e+00, -5.7629395e-01, ..., 1.0968879e+00, 3.7861698e+00, 1.4928346e+00], [ 3.1267416e+00, 2.0358701e+00, 2.2204084e+00, ..., 5.2084265e+00, 3.9166064e+00, 6.4575119e+00], [-7.2486067e-01, 2.3311584e+00, 2.0912974e+00, ..., 1.8693907e+00, 3.2796674e+00, 3.8991761e+00]], [[ 1.1898929e+00, 8.9648962e-03, 8.5148907e-01, ..., -4.7057205e-01, -7.9108685e-01, 1.0573645e+00], [-1.5732453e+00, -3.8554335e-01, -1.6086581e-01, ..., -8.1125468e-01, 1.2085729e+00, 5.6812420e-02], [-3.0767348e-01, -8.5083431e-01, 4.9003422e-02, ..., -7.8210533e-01, -5.2408022e-01, -2.3199841e-02], ..., [-7.3540843e-01, -2.9384446e-01, -1.6465921e+00, ..., -3.8980949e-01, 7.1137357e-01, -8.0783540e-01], [ 2.3953258e-01, -1.7050017e-01, -1.3933203e-01, ..., 1.6591790e+00, 1.0759927e+00, 1.7683787e+00], [-1.6956003e+00, -4.6602386e-01, -3.4259117e-01, ..., -1.0014131e-01, 2.6990986e-01, 8.6363363e-01]]], [[[ 6.4478827e-01, -8.1067204e-01, -1.2237258e+00, ..., -1.2951733e+00, -6.2070227e-01, -1.2906476e+00], [-6.6038930e-01, -2.8674665e-01, -1.0612940e+00, ..., 4.6769258e-01, 4.8500946e-01, -5.6188315e-01], [ 1.0600269e-02, -1.4934481e+00, 9.1430867e-01, ..., -6.1285675e-01, -3.0706315e+00, -9.9033105e-01], ..., [ 1.7771789e+00, -1.3830042e+00, -1.4351614e+00, ..., -2.6786397e+00, 3.7956804e-02, 6.7189908e-01], [-2.1517308e+00, -5.8123243e-01, -7.7163374e-01, ..., 1.6774191e+00, 7.2239363e-01, 1.3373801e+00], [-8.6465418e-01, -1.3932706e+00, -2.2982714e+00, ..., 1.9587449e+00, -6.2718022e-01, -1.1754386e+00]], [[ 5.2605295e+00, 6.8119764e-01, 1.6433215e+00, ..., 1.4899890e+00, 7.7494907e-01, 1.0885936e+00], [ 1.7135508e+00, 1.7890544e+00, 1.5538380e+00, ..., 4.2714515e+00, 3.4532502e+00, 4.0540075e+00], [ 3.2757509e-01, 2.8093519e+00, 4.4473543e+00, ..., 1.6302650e+00, 2.0791094e+00, -2.7314346e+00], ..., [ 3.1872306e+00, 2.1063502e+00, 4.4839258e+00, ..., 8.6034179e-01, 3.7707591e+00, 3.9809742e+00], [-2.0055294e-02, -4.3134212e-02, 2.1313593e+00, ..., 3.0318618e+00, 2.2852294e+00, 3.9968524e+00], [ 2.1781492e+00, 3.6937137e+00, 1.5003638e+00, ..., 3.5955300e+00, 1.7056749e+00, 1.9585730e+00]], [[ 1.1795213e+00, -7.1754062e-01, -1.3523299e-01, ..., -5.6350648e-01, -1.3417213e+00, -5.0127864e-02], [-5.1167816e-01, -7.7823803e-02, -3.1461412e-01, ..., 7.8631788e-01, 5.9256524e-01, 6.9275266e-01], [-1.3142396e+00, 7.9331988e-01, 5.0062788e-01, ..., -6.4525604e-03, -3.3234254e-02, -2.1546085e+00], ..., [-2.0651843e-01, 1.1771068e-02, 1.3835690e+00, ..., 2.8883666e-03, 4.5511311e-01, 2.9804629e-01], [-9.0822458e-01, -1.3634090e+00, -4.2348909e-01, ..., 2.9903316e-01, -5.9180021e-01, 5.1938176e-01], [-3.7974668e-01, 6.5785772e-01, -4.8025602e-01, ..., 1.5578230e-01, -8.5666311e-01, -8.2990326e-02]]], [[[ 2.7270940e-01, 1.6803369e-01, 6.4784336e-01, ..., -8.6817765e-01, 2.4317000e+00, 9.9560642e-01], [-1.0902294e+00, -1.5418210e+00, -6.4213789e-01, ..., 3.8346985e-01, -2.2009264e-01, -1.4083362e+00], [-1.2999996e+00, -1.0029310e+00, -8.0927563e-01, ..., -9.6844232e-01, 4.7647089e-02, -1.7528368e+00], ..., [ 7.9181468e-01, -7.1245348e-01, -1.2355906e+00, ..., -4.4910422e-01, 7.0296872e-01, -1.8157486e+00], [ 8.5229218e-01, -3.9036795e-01, 3.7029549e-01, ..., -2.0579123e+00, 9.2259049e-03, -1.2485095e+00], [-1.0421257e+00, 9.6360290e-01, -1.9165359e+00, ..., -1.5525728e+00, -2.7757692e+00, 5.9844279e-01]], [[ 2.0120070e+00, 2.5763493e+00, 2.5311258e+00, ..., 2.0375581e+00, 1.6430848e+00, 4.5296006e+00], [-6.4119029e-01, 3.2270002e-01, 2.7286339e+00, ..., 3.4792902e+00, 4.8433290e+00, 1.8760866e+00], [ 5.2160606e+00, 5.8354855e-01, 1.9910555e+00, ..., 3.8761294e-01, 3.4568546e+00, 2.2840927e+00], ..., [ 2.4697292e+00, 3.1099756e+00, 4.5984769e+00, ..., 3.1638999e+00, 1.7895203e+00, 5.1426482e-01], [ 2.0174649e+00, 3.7343421e+00, 1.3838698e+00, ..., 6.8948352e-01, 1.9830887e+00, -1.2911747e+00], [ 2.2970469e+00, 2.8243198e+00, 8.7906146e-01, ..., 3.2837601e+00, 1.0420291e+00, 4.1244802e+00]], [[-4.5218289e-01, -1.1248827e-02, -3.9010030e-01, ..., -2.1441557e-01, -8.8925439e-01, 1.0432711e+00], [-1.5277631e+00, -6.0763943e-01, 8.2450414e-01, ..., 5.1565582e-01, 9.1227055e-01, -4.1257131e-01], [ 1.1678007e+00, -8.4806198e-01, -4.1370481e-01, ..., -9.4888353e-01, 2.4556525e-01, 2.7058780e-02], ..., [-2.6444227e-01, 6.4803612e-01, 1.4935874e+00, ..., 1.9097075e-02, -6.0670388e-01, -3.2186458e-01], [-5.2368152e-01, 6.9923353e-01, -6.0641676e-01, ..., -3.2536793e-01, -3.0933461e-01, -1.7596698e+00], [-5.7884902e-01, -9.0141267e-02, -4.4471401e-01, ..., 3.5021925e-01, -1.7998603e-01, 6.4285696e-01]]]], dtype=float32)]

这里可以把input的第一个维度,也就是batch_size修改成任意数值,程序都可以运行。此时我们打印下model的信息。

 import onnx import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load("./model.onnx") inputs = model.graph.input outputs = model.graph.output for i in inputs: i.type.tensor_type.shape.dim[0].dim_param = 'batchsize'  for o in outputs: o.type.tensor_type.shape.dim[0].dim_param = 'batchsize'  print(model) # onnx.checker.check_model(model)  # onnx.save_model(model, 'dynamic_model.onnx')  # sess = onnxruntime.InferenceSession('./dynamic_model.onnx')  # input = np.random.randn(3, 3, 244, 244).astype(np.float32)  # print(sess.run(['output'], {'input': input}))

运行结果

 ir_version: 8 graph { node { input: "input" input: "weight" input: "bias" output: "output" op_type: "Conv" attribute { name: "group" i: 1 type: INT } attribute { name: "kernel_shape" ints: 1 ints: 1 type: INTS } attribute { name: "pads" ints: 0 ints: 0 ints: 0 ints: 0 type: INTS } attribute { name: "strides" ints: 1 ints: 1 type: INTS } } name: "graph" initializer { dims: 3 dims: 3 dims: 1 dims: 1 data_type: 1 float_data: 0.45837152004241943 float_data: 0.10209446400403976 float_data: 1.0382566452026367 float_data: -0.09292714297771454 float_data: 1.58871591091156 float_data: 0.3746287226676941 float_data: -0.35588690638542175 float_data: 0.7165427207946777 float_data: 0.10244251787662506 name: "weight" } initializer { dims: 3 data_type: 1 float_data: -0.36782845854759216 float_data: 2.305680513381958 float_data: -0.13051341474056244 name: "bias" } input { name: "input" type { tensor_type { elem_type: 1 shape { dim { dim_param: "batchsize" } dim { dim_value: 3 } dim { dim_value: 244 } dim { dim_value: 244 } } } } } output { name: "output" type { tensor_type { elem_type: 1 shape { dim { dim_param: "batchsize" } dim { dim_value: 3 } dim { dim_value: 244 } dim { dim_value: 244 } } } } } } opset_import { version: 17 }

这里我们可以看到在input中的第一个dim中变成了dim_param: "batchsize"

节点的增加和删除

  • 增加节点
 import onnx import onnx.helper as helper import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load('./model.onnx') nodes = model.graph.node new_node = helper.make_node(op_type='Relu', name='relu1', inputs=['conv1'], outputs=['output']) nodes.append(new_node) nodes[0].output[0] = 'conv1'  onnx.checker.check_model(model) onnx.save_model(model, 'add_model.onnx') input = np.random.randn(1, 3, 244, 244).astype(np.float32) sess = onnxruntime.InferenceSession('./add_model.onnx') print(sess.run(['output'], {'input': input}))

运行结果

 [array([[[[1.5453527 , 0. , 0. , ..., 0.04255658, 0. , 0.40214583], [0. , 0.5019511 , 0. , ..., 0.34235588, 0.36859825, 0. ], [0. , 0. , 0.34334645, ..., 0. , 0. , 0. ], ..., [1.1857387 , 1.0710502 , 0. , ..., 0. , 1.8497316 , 0. ], [0.37889728, 0. , 0. , ..., 0. , 0. , 0. ], [0.73697627, 0. , 0.4978644 , ..., 0. , 0. , 0.32394186]], [[1.2723072 , 0. , 0.66669345, ..., 5.6399436 , 1.4827138 , 2.7300682 ], [4.5705633 , 2.9856906 , 2.9005556 , ..., 3.505543 , 4.7502317 , 0. ], [1.5251542 , 3.3182473 , 3.8036246 , ..., 0. , 1.6024959 , 1.4051957 ], ..., [1.7204559 , 4.551407 , 4.172427 , ..., 0.9121852 , 3.3593512 , 4.6163626 ], [0.2845726 , 0.13289118, 3.3601975 , ..., 3.9331636 , 0.3700601 , 1.5711328 ], [3.3283763 , 2.128338 , 2.1621299 , ..., 1.7635765 , 0. , 2.1479769 ]], [[0. , 0. , 0. , ..., 1.4292918 , 0. , 0.46683455], [1.0534286 , 0. , 0.02258705, ..., 0.4342987 , 1.1339298 , 0. ], [0. , 0.50237906, 0.20627443, ..., 0. , 0. , 0. ], ..., [0. , 0.78040606, 1.003104 , ..., 0. , 0. , 1.0389903 ], [0. , 0. , 0. , ..., 0.74816215, 0. , 0.02678718], [0.26068228, 0. , 0. , ..., 0. , 0. , 0. ]]]], dtype=float32)]

这里我们再来打印下model的信息

 import onnx import onnx.helper as helper import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load('./model.onnx') nodes = model.graph.node new_node = helper.make_node(op_type='Relu', name='relu1', inputs=['conv1'], outputs=['output']) nodes.append(new_node) nodes[0].output[0] = 'conv1'  print(model) # onnx.checker.check_model(model)  # onnx.save_model(model, 'add_model.onnx')  #  # input = np.random.randn(1, 3, 244, 244).astype(np.float32)  # sess = onnxruntime.InferenceSession('./add_model.onnx')  # print(sess.run(['output'], {'input': input}))

运行结果

 ir_version: 8 graph { node { input: "input" input: "weight" input: "bias" output: "conv1" op_type: "Conv" attribute { name: "group" i: 1 type: INT } attribute { name: "kernel_shape" ints: 1 ints: 1 type: INTS } attribute { name: "pads" ints: 0 ints: 0 ints: 0 ints: 0 type: INTS } attribute { name: "strides" ints: 1 ints: 1 type: INTS } } node { input: "conv1" output: "output" name: "relu1" op_type: "Relu" } name: "graph" initializer { dims: 3 dims: 3 dims: 1 dims: 1 data_type: 1 float_data: 0.45837152004241943 float_data: 0.10209446400403976 float_data: 1.0382566452026367 float_data: -0.09292714297771454 float_data: 1.58871591091156 float_data: 0.3746287226676941 float_data: -0.35588690638542175 float_data: 0.7165427207946777 float_data: 0.10244251787662506 name: "weight" } initializer { dims: 3 data_type: 1 float_data: -0.36782845854759216 float_data: 2.305680513381958 float_data: -0.13051341474056244 name: "bias" } input { name: "input" type { tensor_type { elem_type: 1 shape { dim { dim_value: 1 } dim { dim_value: 3 } dim { dim_value: 244 } dim { dim_value: 244 } } } } } output { name: "output" type { tensor_type { elem_type: 1 shape { dim { dim_value: 1 } dim { dim_value: 3 } dim { dim_value: 244 } dim { dim_value: 244 } } } } } } opset_import { version: 17 }

这里我们可以看到增加了一个relu1的节点,并且第一个节点的output是conv1,第二个节点的input是conv1,output是output。

  • 删除节点
 import onnx import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load('./add_model.onnx') nodes = model.graph.node for node in nodes: if node.name == 'relu1': nodes.remove(node) nodes[0].output[0] = 'output'  onnx.checker.check_model(model) onnx.save_model(model, 'del_model.onnx') input = np.random.randn(1, 3, 244, 244).astype(np.float32) sess = onnxruntime.InferenceSession('./del_model.onnx') print(sess.run(['output'], {'input': input}))

运行结果

 [array([[[[-8.5923064e-01, -4.2249173e-01, 3.8687822e-01, ..., -4.8348337e-02, 3.1652334e-01, -5.7166600e-01], [ 3.1469372e-01, -9.4796360e-01, -2.4245100e+00, ..., 4.1007617e-01, -1.4098099e+00, 6.7472184e-01], [-1.2910874e+00, 1.6070822e-01, -1.0217074e+00, ..., 7.1467435e-01, 1.5835044e-01, -6.4228356e-01], ..., [-2.5442154e+00, -8.8969648e-01, 1.1389736e+00, ..., 1.7202379e+00, -1.1968368e+00, -3.3861694e-01], [-9.0216339e-01, 4.8469666e-01, -9.5050204e-01, ..., 4.0511075e-01, -1.0113320e-01, 1.8743831e+00], [ 3.2901958e-01, 4.3780953e-02, 1.4250931e+00, ..., -1.4544667e+00, 9.0659869e-01, 1.7170597e+00]], [[ 1.3439684e+00, 3.0856354e+00, 2.7811766e+00, ..., 4.1714394e-01, -3.3547878e-02, 1.1771207e+00], [ 2.1574910e+00, 2.1122241e+00, -5.8333945e-01, ..., 1.9629711e+00, 3.4840956e+00, 6.1747317e+00], [ 5.2136226e+00, 4.8688288e+00, 1.4613919e+00, ..., 4.1095753e+00, 1.4553337e+00, 3.5171165e+00], ..., [ 7.3736429e-02, 8.4109855e-01, 5.7113109e+00, ..., 3.6336284e+00, 4.4551125e+00, 3.4602299e+00], [ 1.1054695e+00, 2.7417006e+00, 4.9065466e+00, ..., 2.1775680e+00, 4.4132576e+00, 2.3781679e+00], [-1.2788355e+00, 2.5300267e+00, 3.2560487e+00, ..., 2.2025514e+00, 4.2551570e+00, 3.5148311e+00]], [[-8.5124874e-01, 3.1858414e-01, 3.3686757e-03, ..., -1.1497847e+00, -1.1996644e+00, -9.6176589e-01], [-4.2057925e-01, -1.8098265e-01, -7.4302059e-01, ..., -3.5920531e-01, 7.0454830e-01, 1.8304255e+00], [ 1.4177717e+00, 8.4456313e-01, -1.6396353e-01, ..., 4.2133337e-01, -4.6482396e-01, 6.6906375e-01], ..., [-1.0060047e+00, -1.2088763e+00, 1.2608007e+00, ..., 5.1739502e-01, 8.9526463e-01, 7.2866821e-01], [-3.5698372e-01, -5.9943002e-01, 1.0040566e+00, ..., 3.1322885e-01, 3.4513384e-01, -6.2404698e-01], [-2.0622578e+00, 3.9633280e-01, 2.1701033e-01, ..., 2.6992482e-01, 4.4787437e-01, 2.1187775e-01]]]], dtype=float32)]

替换节点

现在我们将add_model.onnx中的Conv节点替换成Squeeze节点(压缩维度)

 import onnx import onnx.helper as helper import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load('./add_model.onnx') new_node = helper.make_node(op_type='Squeeze', inputs=['input'], outputs=['conv1'], name='squeeze1') nodes = model.graph.node nodes.append(new_node) for node in nodes: if node.op_type == 'Conv': nodes.remove(node) # onnx.checker.check_model(model)  onnx.save_model(model, 'replace.onnx') input = np.random.randn(1, 3, 244, 244).astype(np.float32) sess = onnxruntime.InferenceSession('./replace.onnx') print(sess.run(['output'], {'input': input}))

运行结果

 [array([[[0. , 1.1258854 , 0. , ..., 0.54984987, 0.19069785, 0. ], [1.1481465 , 0. , 1.9025986 , ..., 0. , 0.11273875, 0. ], [1.57912 , 0. , 0. , ..., 0. , 1.7471381 , 0. ], ..., [0.42386332, 0.30908984, 0. , ..., 0. , 0. , 1.8173866 ], [0.07642962, 0.31224537, 0. , ..., 1.6805407 , 2.0282576 , 0. ], [0. , 0.2521538 , 0. , ..., 0. , 0.6431213 , 0.5844705 ]], [[0. , 0. , 0. , ..., 0.23725364, 0.22994171, 0.316093 ], [0.85044146, 1.2757416 , 0.28854838, ..., 0. , 0. , 0. ], [0. , 0. , 1.1362596 , ..., 1.8543358 , 1.1296074 , 0.5114057 ], ..., [0. , 0.00810617, 0. , ..., 1.0819261 , 1.707781 , 0. ], [0. , 0.6385371 , 0. , ..., 0.6565783 , 1.457183 , 0. ], [0. , 0.8315589 , 1.4111192 , ..., 1.0682058 , 0.17328343, 2.3547616 ]], [[0.2426068 , 0. , 0. , ..., 0.89054537, 0.98760164, 0. ], [1.1344411 , 0.8732987 , 0. , ..., 0. , 0. , 0. ], [0. , 0.3664789 , 1.4099371 , ..., 0. , 0.0588427 , 0.5932818 ], ..., [0. , 0.68438137, 0.8869638 , ..., 0. , 0. , 1.4681839 ], [0. , 0. , 0. , ..., 0.16630006, 1.9389246 , 0. ], [0. , 0. , 0.03726726, ..., 0.86296386, 0. , 0. ]]], dtype=float32)]

这里需要注意的是,如果我们将# onnx.checker.check_model(model)的注释打开,运行是会报错的,因为我们添加的新节点squeeze1是在relu1之后的,虽然无法通过检查,但是是可以使用运行时来运行的。那如何才能即能运行又让检查也可以通过。

 import onnx import onnx.helper as helper import onnxruntime import numpy as np if __name__ == '__main__': model = onnx.load('./add_model.onnx') new_node = helper.make_node(op_type='Squeeze', inputs=['input'], outputs=['conv1'], name='squeeze1') nodes = model.graph.node # nodes.append(new_node)  for idx, node in enumerate(nodes): if node.op_type == 'Conv': nodes.remove(node) nodes.insert(idx, new_node) onnx.checker.check_model(model) onnx.save_model(model, 'replace.onnx') input = np.random.randn(1, 3, 244, 244).astype(np.float32) sess = onnxruntime.InferenceSession('./replace.onnx') print(sess.run(['output'], {'input': input}))

这里主要就是调换一下新节点的位置就好了。

ONNXRuntime介绍

ONNXRuntime是微软推出的一个推理框架,可以非常方便的运行ONNX模型,官方GitHub:https://github.com/microsoft/onnxruntime

 

原文链接:https://my.oschina.net/u/3768341/blog/5585440
关注公众号

低调大师中文资讯倾力打造互联网数据资讯、行业资源、电子商务、移动互联网、网络营销平台。

持续更新报道IT业界、互联网、市场资讯、驱动更新,是最及时权威的产业资讯及硬件资讯报道平台。

转载内容版权归作者及来源网站所有,本站原创内容转载请注明来源。

文章评论

共有0条评论来说两句吧...

文章二维码

扫描即可查看该文章

点击排行

推荐阅读

最新文章