【技术基础】使用ONNX使模型通用化
bigegpt 2025-01-23 15:28 9 浏览
什么是ONNX?
ONNX(Open Neural Network Exchange)- 开放神经网络交换格式,作为框架共用的一种模型交换格式,使用protobuf 二进制格式来序列化模型,可以提供更好的传输性能我们可能会在某一任务中Pytorch或者TensorFlow模型转化为ONNX模型(ONNX模型一般用于中间部署阶段),然后再拿转化后的ONNX模型进而转化为我们使用不同框架部署需要的类型,ONNX相当于一个翻译的作用。
为什么要用ONNX?
深度学习算法大多通过计算数据流图来完成神经网络的深度学习过程。一些框架(例如CNTK,Caffe2,Theano和TensorFlow)使用静态图形,而其他框架(例如PyTorch和Chainer)使用动态图形。但是这些框架都提供了接口,使开发人员可以轻松构建计算图和运行时,以优化的方式处理图。这些图用作中间表示(IR),捕获开发人员源代码的特定意图,有助于优化和转换在特定设备(CPU,GPU,FPGA等)上运行。假设一个场景:现在某组织因为主要开发用TensorFlow为基础的框架,现在有一个深度算法,需要将其部署在移动设备上,以观测变现。传统地我们需要用Caffe2重新将模型写好,然后再训练参数;试想下这将是一个多么耗时耗力的过程。此时,ONNX便应运而生,Caffe2,PyTorch,Microsoft Cognitive Toolkit,Apache MXNet等主流框架都对ONNX有着不同程度的支持。这就便于我们的算法及模型在不同框架之间的迁移。
ONNX结构分析
ONNX将每一个网络的每一层或者说是每一个算子当作节点Node,再由这些Node去构建一个Graph,相当于是一个网络。最后将Graph和这个ONNX模型的其他信息结合在一起,生成一个Model,也就是最终的.onnx的模型。构建一个简单的ONNX模型,实质上,只要构建好每一个node,然后将它们和输入输出超参数一起塞到Graph,最后转成Model就可以了。
graph{ node{ input: "1" input: "2" output: "12" op_type: "Conv" } attribute{ name: "strides" ints: 1 ints: 1 } attribute{ name: "pads" ints: 2 ints: 2 } ...}
我们查看ONNX网络结构和参数(查看网址:https://netron.app/)
ONNX安装、使用
安装ONNX环境,在终端中执行以下命令,环境中需要提前准本 python3.6. 以下流程以ubunt 20.04 为例。
模型转换流程
超分辨率是一种提高图像、视频分辨率的算法,广泛用于图像处理或视频编辑。首先,让我们在PyTorch中创建一个SuperResolution 模型。该模型使用描述的高效子像素卷积层将图像的分辨率提高了一个放大因子。该模型将图像的YCbCr的Y分量作为输入,并以超分辨率输出放大的Y分量。
# Some standard importsimport ioimport numpy as npfrom torch import nnimport torch.utils.model_zoo as model_zooimport torch.onnx# Super Resolution model definition in PyTorchimport torch.nn as nnimport torch.nn.init as initclass SuperResolutionNet(nn.Module): def __init__(self, upscale_factor, inplace=False): super(SuperResolutionNet, self).__init__() self.relu = nn.ReLU(inplace=inplace) self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2)) self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)) self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1)) self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1)) self.pixel_shuffle = nn.PixelShuffle(upscale_factor) self._initialize_weights() def forward(self, x): x = self.relu(self.conv1(x)) x = self.relu(self.conv2(x)) x = self.relu(self.conv3(x)) x = self.pixel_shuffle(self.conv4(x)) return x def _initialize_weights(self): init.orthogonal_(self.conv1.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv2.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv3.weight, init.calculate_gain('relu')) init.orthogonal_(self.conv4.weight)# Create the super-resolution model by using the above model definition.torch_model = SuperResolutionNet(upscale_factor=3)
模型下载
由于本教程以演示为目的,因此采用下载预先训练好的权重。在导出模型之前调用torch_model.eval()或torch_model.train(False)将模型转换为推理模式很重要。因为dropout或batchnorm等运算符在推理和训练模式下的行为不同。
# Load pretrained model weightsmodel_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'batch_size = 1 # just a random number# Initialize model with the pretrained weightsmap_location = lambda storage, loc: storageif torch.cuda.is_available(): map_location = Nonetorch_model.load_state_dict(model_zoo.load_url(model_url, map_location=map_location))# set the model to inference modetorch_model.eval()
模型导出
要导出模型,我们调用该torch.onnx.export() 函数。这将执行模型,记录用于计算输出的运算符。因为export运行模型,我们需要提供一个输入张量x。只要它是正确的类型和大小,其中的值可以是随机的。请注意,除非指定为动态轴,否则所有输入维度的导出ONNX图中的输入大小将是固定的。在此示例中,我们使用batch_size 1的输入导出模型,但随后在dynamic_axes参数中将第一个维度指定为动态 torch.onnx.export() . 因此,导出的模型将接受大小为[batch_size, 1, 224, 224]的输入,其中batch_size可以是可变的。
# Input to the modelx = torch.randn(batch_size, 1, 224, 224, requires_grad=True)torch_out = torch_model(x)# Export the modeltorch.onnx.export(torch_model, # model being run x, # model input (or a tuple for multiple inputs) "super_resolution.onnx", # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=10, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names = ['input'], # the model's input names output_names = ['output'], # the model's output names dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes 'output' : {0 : 'batch_size'}})
导出模型测试
在使用ONNX Runtime验证模型的输出之前,我们将使用ONNX的 API检查ONNX 模型。首先,onnx.load("super_resolution.onnx") 将加载保存的模型并输出 onnx.ModelProto结构(用于捆绑 ML 模型的顶级文件/容器格式)。然后,onnx.checker.check_model(onnx_model) 将验证模型的结构并确认模型具有有效的架构。ONNX 图的有效性通过检查模型的版本、图的结构以及节点及其输入和输出来验证。
import onnxonnx_model = onnx.load("super_resolution.onnx")onnx.checker.check_model(onnx_model)import onnxruntimeort_session = onnxruntime.InferenceSession("super_resolution.onnx")def to_numpy(tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# compute ONNX Runtime output predictionort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}ort_outs = ort_session.run(None, ort_inputs)# compare ONNX Runtime and PyTorch resultsnp.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)print("Exported model has been tested with ONNXRuntime, and the result looks good!")
处理前图片:
1.加载处理前图片,使用标准PIL python库对其进行预处理。
2.调整图像大小以适应模型输入的大小 (224x224)。
处理后结果:
注:文章仅代表作者个人的观点,欢迎大家留言交流。
作者介绍
塔超,海云捷迅研发工程师。本科毕业于内蒙古科技大学并获得计算机主修学士学位。拥有丰富的项目经验,开发过AI、K8S相关的项目。
相关推荐
- 【Docker 新手入门指南】第十章:Dockerfile
-
Dockerfile是Docker镜像构建的核心配置文件,通过预定义的指令集实现镜像的自动化构建。以下从核心概念、指令详解、最佳实践三方面展开说明,帮助你系统掌握Dockerfile的使用逻...
- Windows下最简单的ESP8266_ROTS_ESP-IDF环境搭建与腾讯云SDK编译
-
前言其实也没啥可说的,只是我感觉ESP-IDF对新手来说很不友好,很容易踩坑,尤其是对业余DIY爱好者搭建环境非常困难,即使有官方文档,或者网上的其他文档,但是还是很容易踩坑,多研究,记住两点就行了,...
- python虚拟环境迁移(python虚拟环境conda)
-
主机A的虚拟环境向主机B迁移。前提条件:主机A和主机B已经安装了virtualenv1.主机A操作如下虚拟环境目录:venv进入虚拟环境:sourcevenv/bin/active(1)记录虚拟环...
- Python爬虫进阶教程(二):线程、协程
-
简介线程线程也叫轻量级进程,它是一个基本的CPU执行单元,也是程序执行过程中的最小单元,由线程ID、程序计数器、寄存器集合和堆栈共同组成。线程的引入减小了程序并发执行时的开销,提高了操作系统的并发性能...
- 基于网络安全的Docker逃逸(docker)
-
如何判断当前机器是否为Docker容器环境Metasploit中的checkcontainer模块、(判断是否为虚拟机,checkvm模块)搭配学习教程1.检查根目录下是否存在.dockerenv文...
- Python编程语言被纳入浙江高考,小学生都开始学了
-
今年9月份开始的新学期,浙江省三到九年级信息技术课将同步替换新教材。其中,新初二将新增Python编程课程内容。新高一信息技术编程语言由VB替换为Python,大数据、人工智能、程序设计与算法按照教材...
- CentOS 7下安装Python 3.10的完整过程
-
1.安装相应的编译工具yum-ygroupinstall"Developmenttools"yum-yinstallzlib-develbzip2-develope...
- 如何在Ubuntu 20.04上部署Odoo 14
-
Odoo是世界上最受欢迎的多合一商务软件。它提供了一系列业务应用程序,包括CRM,网站,电子商务,计费,会计,制造,仓库,项目管理,库存等等,所有这些都无缝集成在一起。Odoo可以通过几种不同的方式进...
- Ubuntu 系统安装 PyTorch 全流程指南
-
当前环境:Ubuntu22.04,显卡为GeForceRTX3080Ti1、下载显卡驱动驱动网站:https://www.nvidia.com/en-us/drivers/根据自己的显卡型号和...
- spark+python环境搭建(python 环境搭建)
-
最近项目需要用到spark大数据相关技术,周末有空spark环境搭起来...目标spark,python运行环境部署在linux服务器个人通过vscode开发通过远程python解释器执行代码准备...
- centos7.9安装最新python-3.11.1(centos安装python环境)
-
centos7.9安装最新python-3.11.1centos7.9默认安装的是python-2.7.5版本,安全扫描时会有很多漏洞,比如:Python命令注入漏洞(CVE-2015-2010...
- Linux系统下,五大步骤安装Python
-
一、下载Python包网上教程大多是通过官方地址进行下载Python的,但由于国内网络环境问题,会导致下载很慢,所以这里建议通过国内镜像进行下载例如:淘宝镜像http://npm.taobao.or...
- centos7上安装python3(centos7安装python3.7.2一键脚本)
-
centos7上默认安装的是python2,要使用python3则需要自行下载源码编译安装。1.安装依赖yum-ygroupinstall"Developmenttools"...
- 利用本地数据通过微调方式训练 本地DeepSeek-R1 蒸馏模型
-
网络上相应的教程基本都基于LLaMA-Factory进行,本文章主要顺着相应的教程一步步实现大模型的微调和训练。训练环境:可自行定义,mac、linux或者window之类的均可以,本文以ma...
- 【法器篇】天啦噜,库崩了没备份(天啦噜是什么意思?)
-
背景数据库没有做备份,一天突然由于断电或其他原因导致无法启动了,且设置了innodb_force_recovery=6都无法启动,里面的数据怎么才能恢复出来?本例采用解析建表语句+表空间传输的方式进行...
- 一周热门
- 最近发表
- 标签列表
-
- mybatiscollection (79)
- mqtt服务器 (88)
- keyerror (78)
- c#map (65)
- resize函数 (64)
- xftp6 (83)
- bt搜索 (75)
- c#var (76)
- mybatis大于等于 (64)
- xcode-select (66)
- mysql授权 (74)
- 下载测试 (70)
- skip-name-resolve (63)
- linuxlink (65)
- pythonwget (67)
- logstashinput (65)
- hadoop端口 (65)
- vue阻止冒泡 (67)
- oracle时间戳转换日期 (64)
- jquery跨域 (68)
- php写入文件 (73)
- kafkatools (66)
- mysql导出数据库 (66)
- jquery鼠标移入移出 (71)
- 取小数点后两位的函数 (73)