模型部署之 ONNX ONNXRuntime

通常我們?cè)谟?xùn)練模型時(shí)可以使用很多不同的框架,比如有的同學(xué)喜歡用 Pytorch,有的同學(xué)喜歡使用 TensorFLow,也有的喜歡 MXNet,以及深度學(xué)習(xí)最開(kāi)始流行的 Caffe等等,這樣不同的訓(xùn)練框架就導(dǎo)致了產(chǎn)生不同的模型結(jié)果包,在模型進(jìn)行部署推理時(shí)就需要不同的依賴庫(kù),而且同一個(gè)框架比如tensorflow 不同的版本之間的差異較大, 為了解決這個(gè)混亂問(wèn)題,LF AI 這個(gè)組織聯(lián)合 Facebook, MicroSoft等公司制定了機(jī)器學(xué)習(xí)模型的標(biāo)準(zhǔn),這個(gè)標(biāo)準(zhǔn)叫做ONNX, Open Neural Network Exchage,所有其他框架產(chǎn)生的模型包 (.pth, .pb) 都可以轉(zhuǎn)換成這個(gè)標(biāo)準(zhǔn)格式,轉(zhuǎn)換成這個(gè)標(biāo)準(zhǔn)格式后,就可以使用統(tǒng)一的 ONNX Runtime等工具進(jìn)行統(tǒng)一部署。

這其實(shí)可以和 JVM 對(duì)比,
A Java virtual machine (JVM) is a virtual machine that enables a computer to run Java programs as well as programs written in other languages that are also compiled to Java bytecode. The JVM is detailed by a specification that formally describes what is required in a JVM implementation. Having a specification ensures interoperability of Java programs across different implementations so that program authors using the Java Development Kit (JDK) need not worry about idiosyncrasies of the underlying hardware platform.

JAVA中有 JAVA 語(yǔ)言 + .jar 包 + JVM,同時(shí)還有其他的語(yǔ)言比如 Scala等也是建立在 JVM上運(yùn)行的,因此不同的語(yǔ)言只要都最后將程序轉(zhuǎn)換成 JVM可以統(tǒng)一識(shí)別的格式,就可以在統(tǒng)一的跨平臺(tái) JVM JAVA 虛擬機(jī)上運(yùn)行。這里JVM使用的 包是二進(jìn)制包,因此里面的內(nèi)容是不可知的,人類難以直觀理解的。

這里 ONNX 標(biāo)準(zhǔn)采取了谷歌開(kāi)發(fā) protocal buffers 作為格式標(biāo)準(zhǔn),這個(gè)格式是在 XML, json的基礎(chǔ)上發(fā)展的,是一個(gè)人類易理解的格式。ONNX 官網(wǎng)對(duì)ONNX的介紹如下:
ONNX defines a common set of operators - the building blocks of machine learning and deep learning models - and a common file format to enable AI developers to use models with a variety of frameworks, tools, runtimes, and compilers.
ONNX支持的模型來(lái)源,基本上囊括了我們?nèi)粘J褂玫乃锌蚣埽?br>

ONNX支持的模型來(lái)源

ONNX的文件格式,采用的是谷歌的 protocal buffers,和 caffe采用的一致。


ONNX定義的數(shù)據(jù)類包括了我們常用的數(shù)據(jù)類型,用來(lái)定義模型中的輸出輸出格式

ONNX中定義了很多我們常用的節(jié)點(diǎn),比如 Conv,ReLU,BN, maxpool等等約124種,同時(shí)也在不停地更新中,當(dāng)遇到自帶節(jié)點(diǎn)庫(kù)中沒(méi)有的節(jié)點(diǎn)時(shí),我們也可以自己寫一個(gè)節(jié)點(diǎn)

  • 有了輸入輸出,以及計(jì)算節(jié)點(diǎn),就可以根據(jù) pytorch框架中的 forward 記錄一張模型從輸入圖片到輸出的計(jì)算圖,ONNX 就是將這張計(jì)算圖用標(biāo)準(zhǔn)的格式存儲(chǔ)下來(lái)了,可以通過(guò)一個(gè)工具 Netron對(duì) ONNX 進(jìn)行可視化,如第一張圖右側(cè)所示;
  • 保存成統(tǒng)一的 ONNX 格式后,就可以使用統(tǒng)一的運(yùn)行平臺(tái)來(lái)進(jìn)行 inference。

pytorch原生支持 ONNX 格式轉(zhuǎn)碼,下面是實(shí)例:

1. 將pytorch模型轉(zhuǎn)換為onnx格式,直接傻瓜式調(diào)用 torch.onnx.export(model, input, output_name)

import torch
from torchvision import models

net = models.resnet.resnet18(pretrained=True)
dummpy_input = torch.randn(1,3,224,224)
torch.onnx.export(net, dummpy_input, 'resnet18.onnx')

2. 對(duì)生成的 onnx 進(jìn)行查看

import onnx

# Load the ONNX model
model = onnx.load("resnet18.onnx")

# Check that the IR is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

輸出:
可以看到其輸出有3個(gè)dict,一個(gè)是 input, 一個(gè)是 initializers,以及最后一個(gè)是operators把輸入和權(quán)重 initialization 進(jìn)行類似于 forward操作,在最后一個(gè)dict operators中其返回是 %191,也就是 gemm 最后一個(gè)全連接的輸出

graph torch-jit-export (
  %input.1[FLOAT, 1x3x224x224]
) initializers (
  %193[FLOAT, 64x3x7x7]
  %194[FLOAT, 64]
  %196[FLOAT, 64x64x3x3]
  %197[FLOAT, 64]
  %199[FLOAT, 64x64x3x3]
  %200[FLOAT, 64]
  %202[FLOAT, 64x64x3x3]
  %203[FLOAT, 64]
  %205[FLOAT, 64x64x3x3]
  %206[FLOAT, 64]
  %208[FLOAT, 128x64x3x3]
  %209[FLOAT, 128]
  %211[FLOAT, 128x128x3x3]
  %212[FLOAT, 128]
  %214[FLOAT, 128x64x1x1]
  %215[FLOAT, 128]
  %217[FLOAT, 128x128x3x3]
  %218[FLOAT, 128]
  %220[FLOAT, 128x128x3x3]
  %221[FLOAT, 128]
  %223[FLOAT, 256x128x3x3]
  %224[FLOAT, 256]
  %226[FLOAT, 256x256x3x3]
  %227[FLOAT, 256]
  %229[FLOAT, 256x128x1x1]
  %230[FLOAT, 256]
  %232[FLOAT, 256x256x3x3]
  %233[FLOAT, 256]
  %235[FLOAT, 256x256x3x3]
  %236[FLOAT, 256]
  %238[FLOAT, 512x256x3x3]
  %239[FLOAT, 512]
  %241[FLOAT, 512x512x3x3]
  %242[FLOAT, 512]
  %244[FLOAT, 512x256x1x1]
  %245[FLOAT, 512]
  %247[FLOAT, 512x512x3x3]
  %248[FLOAT, 512]
  %250[FLOAT, 512x512x3x3]
  %251[FLOAT, 512]
  %fc.bias[FLOAT, 1000]
  %fc.weight[FLOAT, 1000x512]
) {
  %192 = Conv[dilations = [1, 1], group = 1, kernel_shape = [7, 7], pads = [3, 3, 3, 3], strides = [2, 2]](%input.1, %193, %194)
  %125 = Relu(%192)
  %126 = MaxPool[kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]](%125)
  %195 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%126, %196, %197)
  %129 = Relu(%195)
  %198 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%129, %199, %200)
  %132 = Add(%198, %126)
  %133 = Relu(%132)
  %201 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%133, %202, %203)
  %136 = Relu(%201)
  %204 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%136, %205, %206)
  %139 = Add(%204, %133)
  %140 = Relu(%139)
  %207 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]](%140, %208, %209)
  %143 = Relu(%207)
  %210 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%143, %211, %212)
  %213 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [2, 2]](%140, %214, %215)
  %148 = Add(%210, %213)
  %149 = Relu(%148)
  %216 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%149, %217, %218)
  %152 = Relu(%216)
  %219 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%152, %220, %221)
  %155 = Add(%219, %149)
  %156 = Relu(%155)
  %222 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]](%156, %223, %224)
  %159 = Relu(%222)
  %225 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%159, %226, %227)
  %228 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [2, 2]](%156, %229, %230)
  %164 = Add(%225, %228)
  %165 = Relu(%164)
  %231 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%165, %232, %233)
  %168 = Relu(%231)
  %234 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%168, %235, %236)
  %171 = Add(%234, %165)
  %172 = Relu(%171)
  %237 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [2, 2]](%172, %238, %239)
  %175 = Relu(%237)
  %240 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%175, %241, %242)
  %243 = Conv[dilations = [1, 1], group = 1, kernel_shape = [1, 1], pads = [0, 0, 0, 0], strides = [2, 2]](%172, %244, %245)
  %180 = Add(%240, %243)
  %181 = Relu(%180)
  %246 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%181, %247, %248)
  %184 = Relu(%246)
  %249 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%184, %250, %251)
  %187 = Add(%249, %181)
  %188 = Relu(%187)
  %189 = GlobalAveragePool(%188)
  %190 = Flatten[axis = 1](%189)
  %191 = Gemm[alpha = 1, beta = 1, transB = 1](%190, %fc.weight, %fc.bias)
  return %191
}

3. 對(duì)生成的ONNX進(jìn)行可視化:

onnx的可是支持有兩個(gè),一個(gè)是 netron, 一個(gè)是百度飛槳開(kāi)發(fā)的visualDL
這里介紹 netron的下載安裝:https://github.com/lutzroeder/Netron,對(duì)于 mac用戶可以安裝成功直接打開(kāi)軟件進(jìn)行圖形化選取onnx地址就可以打開(kāi)

||
netron可視化圖

4. ONNX Runtime

支持ONNX的runtime就是類似于JVM將統(tǒng)一的ONNX格式的模型包運(yùn)行起來(lái),包括對(duì)ONNX 模型進(jìn)行解讀,優(yōu)化(融合conv-bn等操作),運(yùn)行。

支持ONNX格式的runtime

這里介紹 microsoft 開(kāi)發(fā)的 ONNX Runtime

4.1 ONNXRuntime的安裝

https://github.com/microsoft/onnxruntime
對(duì)于使用cpu來(lái)進(jìn)行推理的 mac os 可以使用

brew install libomp
pip install onnxruntime

推理

import onnxruntime as rt
import numpy as  np
data = np.array(np.random.randn(1,3,224,224))
sess = rt.InferenceSession('resnet18.onnx')
input_name = sess.get_inputs()[0].name
label_name = sess.get_outputs()[0].name

pred_onx = sess.run([label_name], {input_name:data.astype(np.float32)})[0]
print(pred_onx)
print(np.argmax(pred_onx)

可以看到,這樣推理就不需要其他各種各樣的pytorch等依賴,方便部署。

推薦兩個(gè)易懂的視頻講解:
Everything You Want to Know About ONNX
MicroSoft onnx and onnx runtim

?著作權(quán)歸作者所有,轉(zhuǎn)載或內(nèi)容合作請(qǐng)聯(lián)系作者
平臺(tái)聲明:文章內(nèi)容(如有圖片或視頻亦包括在內(nèi))由作者上傳并發(fā)布,文章內(nèi)容僅代表作者本人觀點(diǎn),簡(jiǎn)書(shū)系信息發(fā)布平臺(tái),僅提供信息存儲(chǔ)服務(wù)。

推薦閱讀更多精彩內(nèi)容