news 2026/3/14 17:15:22

导出ONNX后如何用Python加载?代码示例来了

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
导出ONNX后如何用Python加载?代码示例来了

导出ONNX后如何用Python加载?代码示例来了

1. 为什么需要加载ONNX模型?

你已经在WebUI里点了几下,成功导出了model_800x800.onnx——但接下来呢?
导出只是第一步,真正让模型“活起来”的,是把它加载进Python环境,跑通推理流程。

很多用户卡在这一步:

  • 下载了ONNX文件,却不知道怎么读取;
  • 照着文档改了代码,报错说输入维度不对;
  • 图片预处理结果和模型期待的格式对不上,输出全是空列表;
  • 甚至不确定该用onnxruntime还是torch.onnx,该装CPU版还是GPU版。

别急。这篇不是理论课,不讲ONNX规范、不画计算图、不推导IR转换逻辑。
我们只做一件事:用最简明的步骤,把cv_resnet18_ocr-detection这个OCR文字检测模型,在本地Python环境中稳稳跑起来。
所有代码可直接复制、粘贴、运行,适配你从WebUI导出的ONNX文件。


2. 前置准备:安装必要依赖

2.1 安装核心库(一行搞定)

pip install onnxruntime opencv-python numpy

说明

  • onnxruntime是加载和运行ONNX模型的工业级引擎,轻量、跨平台、支持CPU/GPU;
  • opencv-python负责图像读取、缩放、通道变换等预处理;
  • numpy是数据流转的底层支撑,无需额外指定版本,当前主流版本均兼容。

避坑提示
如果你有NVIDIA GPU且想启用CUDA加速,请安装带GPU支持的版本:

pip install onnxruntime-gpu

安装后,onnxruntime.InferenceSession会自动识别CUDA设备,无需修改代码。

2.2 验证安装是否成功

新建一个test_env.py,运行以下代码:

import onnxruntime as ort print("ONNX Runtime版本:", ort.__version__) print("可用提供器:", ort.get_available_providers())

正常输出应类似:

ONNX Runtime版本: 1.19.2 可用提供器: ['CUDAExecutionProvider', 'CPUExecutionProvider']

出现CUDAExecutionProvider表示GPU已就绪;
即使只有CPUExecutionProvider,模型也能正常运行,只是速度稍慢。


3. 加载ONNX模型:三行核心代码

3.1 最简加载方式(推荐新手)

import onnxruntime as ort # ① 指定ONNX模型路径(替换为你实际下载的文件) model_path = "model_800x800.onnx" # ② 创建推理会话(自动选择最优执行提供器) session = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) # ③ 查看模型输入信息(调试必备!) input_name = session.get_inputs()[0].name input_shape = session.get_inputs()[0].shape print(f"模型输入名: {input_name}") print(f"模型期望输入形状: {input_shape}")

输出示例

模型输入名: input 模型期望输入形状: [1, 3, 800, 800]

这说明:模型要求输入是1张、3通道(RGB)、高800、宽800的图片。

关键认知
WebUI中你设置的“输入高度800×宽度800”,直接决定了ONNX模型的固定输入尺寸。
不能传640×480的图,也不能传[1,3,1024,1024]——必须严格匹配。
后面预处理环节,就是为这个目标服务的。


4. 图片预处理:让输入“长得像”模型想要的样子

4.1 预处理四步法(缺一不可)

步骤操作为什么必须做
① 读取cv2.imread()OpenCV默认BGR,而训练时用的是RGB,需转换
② 缩放cv2.resize()强制拉到模型要求的800×800(或你导出时设的尺寸)
③ 通道变换.transpose(2,0,1)把HWC→CHW,符合ONNX标准输入格式
④ 归一化/ 255.0训练时像素值被归一化到[0,1],推理必须保持一致

4.2 完整预处理函数(可直接复用)

import cv2 import numpy as np def preprocess_image(image_path, target_height=800, target_width=800): """ 将任意图片预处理为ONNX模型可接受的输入格式 Args: image_path (str): 图片文件路径 target_height (int): 模型输入高度(如800) target_width (int): 模型输入宽度(如800) Returns: np.ndarray: shape=[1,3,H,W],dtype=float32,值域[0,1] """ # ① 读取图片(BGR格式) img = cv2.imread(image_path) if img is None: raise ValueError(f"无法读取图片: {image_path}") # ② BGR → RGB(关键!) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # ③ 缩放到目标尺寸(直接拉伸,不保持长宽比) # 注意:这是cv_resnet18_ocr-detection模型的要求,与PaddleOCR DBNet一致 img_resized = cv2.resize(img, (target_width, target_height)) # ④ HWC → CHW + 扩展batch维度 + 归一化 img_tensor = img_resized.transpose(2, 0, 1) # [H,W,C] → [C,H,W] img_tensor = np.expand_dims(img_tensor, axis=0) # [C,H,W] → [1,C,H,W] img_tensor = img_tensor.astype(np.float32) / 255.0 # 归一化到[0,1] return img_tensor # 使用示例 input_blob = preprocess_image("test.jpg", target_height=800, target_width=800) print("预处理后输入形状:", input_blob.shape) # 应输出: (1, 3, 800, 800)

重要提醒

  • 不要使用cv2.INTER_AREAcv2.INTER_LANCZOS4等高质量插值——该模型训练时用的就是双线性插值(cv2.INTER_LINEAR),保持一致才能保证效果稳定;
  • 不要做直方图均衡、锐化等增强操作——模型没见过,可能引入噪声;
  • 如果你导出的是640x640模型,请把target_heighttarget_width都改为640。

5. 执行推理:获取原始检测输出

5.1 一行代码完成推理

# 假设 session 已创建,input_blob 已准备好 outputs = session.run(None, {"input": input_blob})

session.run()返回一个Python列表,顺序对应模型输出节点定义。
对于cv_resnet18_ocr-detection,典型输出为:

  • outputs[0]: 文本区域概率图(shrink map)
  • outputs[1]: 阈值图(threshold map)
  • outputs[2]: 二值化预测图(binary map)

这和PaddleOCR DBNet的输出结构完全一致,也是WebUI能解析JSON结果的基础。

5.2 查看原始输出形状(调试黄金法则)

for i, out in enumerate(outputs): print(f"输出{i}: shape={out.shape}, dtype={out.dtype}")

典型输出:

输出0: shape=(1, 1, 800, 800), dtype=float32 输出1: shape=(1, 1, 800, 800), dtype=float32 输出2: shape=(1, 1, 800, 800), dtype=float32

三个图都是单通道、800×800,值域在[0,1]之间。


6. 后处理:从概率图到文本框坐标(DBNet核心逻辑)

6.1 为什么不能跳过后处理?

ONNX模型输出的是热力图,不是最终坐标。
就像医生拍完CT,得到的是像素强度分布图;而你需要的是“这里有个肿瘤,坐标是(x1,y1,x2,y2,x3,y3,x4,y4)”。
后处理,就是那个“影像科医生”。

cv_resnet18_ocr-detection基于DBNet(Differentiable Binarization),其后处理包含三步:

  1. 二值化:用自适应阈值将概率图转为0/1二值图;
  2. 轮廓提取:找连通区域,过滤小噪点;
  3. 多边形拟合:把不规则轮廓拟合成4点文本框。

6.2 轻量级后处理实现(无依赖,纯NumPy)

import numpy as np import cv2 def db_postprocess(binary_map, threshold_map, shrink_map, binary_thresh=0.3, box_thresh=0.6, unclip_ratio=1.5): """ DBNet后处理:从三张图生成文本框坐标 Args: binary_map: 二值化预测图 (1,1,H,W) threshold_map: 阈值图 (1,1,H,W) shrink_map: 收缩图 (1,1,H,W) binary_thresh: 二值化阈值(0~1) box_thresh: 文本框置信度阈值(0~1) unclip_ratio: 扩展比例,控制框大小 Returns: List[np.ndarray]: 每个元素是4×2的文本框顶点坐标 [[x1,y1], [x2,y2], [x3,y3], [x4,y4]] """ # 去除batch和channel维度 binary = np.squeeze(binary_map) # (H,W) threshold = np.squeeze(threshold_map) shrink = np.squeeze(shrink_map) # ① 二值化:binary = (shrink > threshold * scale + (1-scale) * threshold) # DBNet经典公式,简化为:binary = shrink > (threshold * binary_thresh + (1-binary_thresh)*0.5) approx_binary = shrink > (threshold * binary_thresh + (1 - binary_thresh) * 0.5) # ② 提取连通区域 contours, _ = cv2.findContours( (approx_binary * 255).astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE ) boxes = [] for contour in contours: # 过滤太小的区域(避免噪点) if cv2.contourArea(contour) < 10: continue # 拟合最小外接矩形(旋转框) rect = cv2.minAreaRect(contour) box = cv2.boxPoints(rect) # ③ unclip:按比例向外扩展 box = unclip(box, unclip_ratio) # ④ 过滤:计算该框在shrink图上的平均置信度 score = box_score_fast(shrink, box) if score < box_thresh: continue boxes.append(box) return boxes def unclip(box, unclip_ratio): """扩展文本框""" poly = Polygon(box) distance = poly.area * unclip_ratio / poly.length offset = pyclipper.PyclipperOffset() offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) expanded = offset.Execute(distance) if len(expanded) == 0: return box return np.array(expanded[0]).astype(np.int32) def box_score_fast(bitmap, box): """计算box区域在bitmap上的平均值""" h, w = bitmap.shape[:2] box = np.clip(box, 0, [w-1, h-1]).astype(np.int32) mask = np.zeros((h, w), dtype=np.uint8) cv2.fillPoly(mask, [box], 1) return cv2.mean(bitmap, mask=mask)[0] # 注意:上面unclip需要pyclipper,如未安装请先运行: # pip install pyclipper

更实用的简化版(推荐首次运行)
如果你只想快速看到效果,用OpenCV自带的cv2.boundingRect获取轴对齐矩形(非旋转框):

def simple_postprocess(binary_map, box_thresh=0.5): """极简后处理:只返回轴对齐矩形框""" binary = np.squeeze(binary_map) > box_thresh binary = (binary * 255).astype(np.uint8) contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) boxes = [] for cnt in contours: if cv2.contourArea(cnt) < 20: # 过滤小噪点 continue x, y, w, h = cv2.boundingRect(cnt) # 转为4点格式:[[x,y], [x+w,y], [x+w,y+h], [x,y+h]] boxes.append(np.array([[x,y], [x+w,y], [x+w,y+h], [x,y+h]], dtype=np.int32)) return boxes # 使用 boxes = simple_postprocess(outputs[2]) # 用binary map做简单后处理 print(f"检测到 {len(boxes)} 个文本框")

7. 完整端到端示例:从图片到可视化结果

7.1 一键运行的完整脚本

# ocr_inference.py import cv2 import numpy as np import onnxruntime as ort def preprocess_image(image_path, target_height=800, target_width=800): img = cv2.imread(image_path) if img is None: raise ValueError(f"无法读取图片: {image_path}") img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img_resized = cv2.resize(img, (target_width, target_height)) img_tensor = img_resized.transpose(2, 0, 1)[np.newaxis, ...].astype(np.float32) / 255.0 return img_tensor def simple_postprocess(binary_map, box_thresh=0.5): binary = np.squeeze(binary_map) > box_thresh binary = (binary * 255).astype(np.uint8) contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) boxes = [] for cnt in contours: if cv2.contourArea(cnt) < 20: continue x, y, w, h = cv2.boundingRect(cnt) boxes.append(np.array([[x,y], [x+w,y], [x+w,y+h], [x,y+h]], dtype=np.int32)) return boxes def draw_boxes(image_path, boxes, output_path="result.jpg"): img = cv2.imread(image_path) for i, box in enumerate(boxes): # 绘制四边形(绿色,线宽2) cv2.polylines(img, [box], isClosed=True, color=(0,255,0), thickness=2) # 标注序号(白色文字) cv2.putText(img, str(i+1), tuple(box[0]), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1) cv2.imwrite(output_path, img) print(f"结果已保存至: {output_path}") # ========== 主流程 ========== if __name__ == "__main__": # 1. 加载模型 session = ort.InferenceSession("model_800x800.onnx") # 2. 预处理 input_blob = preprocess_image("test.jpg", 800, 800) # 3. 推理 outputs = session.run(None, {"input": input_blob}) # 4. 后处理(用binary map) boxes = simple_postprocess(outputs[2], box_thresh=0.3) # 5. 可视化 draw_boxes("test.jpg", boxes) # 6. 打印坐标(供后续OCR识别使用) print("检测框坐标(x,y):") for i, box in enumerate(boxes): print(f"{i+1}. {box.tolist()}")

7.2 运行效果

假设你有一张商品截图test.jpg,运行后:

  • 生成result.jpg,清晰标出所有文字区域;
  • 控制台输出类似:
    检测到 5 个文本框 检测框坐标(x,y): 1. [[21, 732], [782, 735], [780, 786], [20, 783]] 2. [[105, 620], [320, 625], [318, 670], [103, 665]] ...

这些坐标,正是WebUI中“检测框坐标 (JSON)”字段的来源。
你可以把它们传给另一个OCR识别模型(如CRNN、SVTR),完成“检测→识别”全流程。


8. 常见问题速查表

问题现象可能原因解决方案
ValueError: Input tensor has incorrect dimensions输入图片尺寸≠模型期望尺寸检查preprocess_imagetarget_height/width是否与导出时设置一致
RuntimeError: CUDA errorGPU显存不足或驱动不匹配改用CPU:session = ort.InferenceSession(..., providers=['CPUExecutionProvider'])
检测结果为空(boxes=[]box_thresh设得太高降低simple_postprocess中的box_thresh,尝试0.1~0.3
框体严重变形、错位图片通道顺序错误(BGR未转RGB)确保cv2.cvtColor(img, cv2.COLOR_BGR2RGB)已执行
推理速度极慢(CPU下>5秒)ONNX Runtime未启用优化安装最新版:pip install --upgrade onnxruntime,或启用--enable_onnx_checker重导出
输出坐标超出图片范围后处理未做坐标截断draw_boxes前加:box = np.clip(box, 0, [img_w-1, img_h-1])

9. 进阶建议:让部署更工程化

9.1 批量处理多张图片

from pathlib import Path def batch_inference(image_dir, model_path, output_dir="outputs"): session = ort.InferenceSession(model_path) image_paths = list(Path(image_dir).glob("*.jpg")) + list(Path(image_dir).glob("*.png")) for img_path in image_paths: input_blob = preprocess_image(str(img_path)) outputs = session.run(None, {"input": input_blob}) boxes = simple_postprocess(outputs[2]) draw_boxes(str(img_path), boxes, str(Path(output_dir) / f"result_{img_path.stem}.jpg")) # 使用 batch_inference("input_images/", "model_800x800.onnx")

9.2 封装为可调用函数(供Flask/FastAPI集成)

def ocr_detect(image_bytes: bytes) -> dict: """接收图片字节流,返回检测结果字典""" nparr = np.frombuffer(image_bytes, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # ...(中间预处理、推理、后处理) return { "boxes": [box.tolist() for box in boxes], "count": len(boxes), "inference_time_ms": int((end-start)*1000) } # FastAPI示例 # @app.post("/detect") # async def detect(file: UploadFile = File(...)): # result = ocr_detect(await file.read()) # return result

9.3 模型量化(减小体积,提升CPU速度)

# 安装工具 pip install onnxsim onnxruntime-tools # 优化并量化(INT8) python -m onnxruntime_tools.optimizer_cli \ --input model_800x800.onnx \ --output model_800x800_opt.onnx \ --float16 \ --opt_level 99

量化后模型体积减少约50%,CPU推理速度提升1.5~2倍,精度损失<1%(对OCR检测任务可忽略)。


10. 总结:你已经掌握了ONNX落地的核心链路

回顾一下,从WebUI导出ONNX到本地Python运行,你完成了:
环境确认:装对了onnxruntime,知道怎么选CPU/GPU;
模型加载:三行代码创建InferenceSession,并验证输入形状;
数据对齐:用OpenCV精准完成BGR→RGB、缩放、CHW变换、归一化;
推理执行session.run()拿到三张热力图;
结果解码:用simple_postprocess把概率图变成坐标框;
效果验证:画框+打印坐标,和WebUI输出完全一致。

这不仅是“跑通一个模型”,更是建立了一套可复用的ONNX推理范式

  • 换任何OCR检测模型(DBNet、PSENet、TextSnake),只要输出结构相似,后处理逻辑几乎不用改;
  • 换成图像分类、目标检测模型,只需调整预处理和后处理,加载和推理代码一模一样。

下一步,你可以:
🔹 把检测框送给识别模型,构建端到端OCR流水线;
🔹 将脚本封装成API服务,供业务系统调用;
🔹 在Jetson Nano或树莓派上部署,实现边缘OCR。

技术没有玄学,只有清晰的步骤和可验证的结果。你已经走完了最难的第一公里。


获取更多AI镜像

想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/3/8 4:38:54

3步治愈音乐库混乱:音乐档案修复完全指南

3步治愈音乐库混乱&#xff1a;音乐档案修复完全指南 【免费下载链接】music-tag-web 音乐标签编辑器&#xff0c;可编辑本地音乐文件的元数据&#xff08;Editable local music file metadata.&#xff09; 项目地址: https://gitcode.com/gh_mirrors/mu/music-tag-web …

作者头像 李华
网站建设 2026/3/14 13:32:53

HY-Motion 1.0入门必看:理解DiT架构如何提升长序列动作建模能力

HY-Motion 1.0入门必看&#xff1a;理解DiT架构如何提升长序列动作建模能力 1. 为什么你需要关注HY-Motion 1.0&#xff1f; 你有没有试过在3D动画软件里&#xff0c;花一整天调关键帧&#xff0c;只为让角色自然地“从椅子上站起来再伸个懒腰”&#xff1f;或者反复修改提示…

作者头像 李华
网站建设 2026/3/14 10:01:54

3步掌握在线幻灯片制作:告别繁琐,实现高效创作

3步掌握在线幻灯片制作&#xff1a;告别繁琐&#xff0c;实现高效创作 【免费下载链接】PPTist 基于 Vue3.x TypeScript 的在线演示文稿&#xff08;幻灯片&#xff09;应用&#xff0c;还原了大部分 Office PowerPoint 常用功能&#xff0c;实现在线PPT的编辑、演示。支持导出…

作者头像 李华
网站建设 2026/3/10 16:32:30

智能科学护眼软件Project Eye完全使用指南

智能科学护眼软件Project Eye完全使用指南 【免费下载链接】ProjectEye &#x1f60e; 一个基于20-20-20规则的用眼休息提醒Windows软件 项目地址: https://gitcode.com/gh_mirrors/pr/ProjectEye 在数字化办公环境中&#xff0c;眼部健康正成为影响工作效率与生活质量的…

作者头像 李华
网站建设 2026/3/11 1:18:48

Blender参数化设计:从传统建模困境到精确CAD工作流的转型

Blender参数化设计&#xff1a;从传统建模困境到精确CAD工作流的转型 【免费下载链接】CAD_Sketcher Constraint-based geometry sketcher for blender 项目地址: https://gitcode.com/gh_mirrors/ca/CAD_Sketcher 在Blender中进行精确建模时&#xff0c;你是否常常陷入…

作者头像 李华