Mar 18, 2025
3 min read
Rust,
ONNX,
onnxruntime,

Rust 高效推理新选择:基于 ONNX Runtime 的模型部署实践

本文介绍如何使用Rust和ONNX Runtime进行模型部署,以ResNet50为例,详细说明模型导出、加载、数据前处理、推理及后处理的全流程,并分析其优缺点。

Rust 除了可以使用 tch-rscandle 框架进行模型部署之外,还可以使用 onnxruntime 进行部署。ort 就是 onnxruntime 在 rust 中的绑定库。

[dependencies]
ort = "=2.0.0-rc.9"

ONNX(Open Neural Network Exchange)是微软公司推出的与训练框架无关的通过模型权重格式,它是一种开放式的文件格式。它的设计目标是让不同的深度学习框架(如 PyTorch、谷歌的TensorFlow、 百度的PaddlePaddle 等)之间可以互相转换和共享模型,从而实现跨平台的模型部署和推理。有一些冷门的推理框架甚至会使用 ONNX 作为中间格式来达到支持大部分模型的目地。

ONNX 使用的是 Protobuf 格式进行存储,同时它也存储了计算图结构,因此使用 ONNX 格式并不需要定义模型结构。

当然,ONNX 格式本身也是就是部署格式,它的运行框架就叫 onnxruntime。

导出模型

这里以 Pytorch 框架为例子,可以通过 torch.onnx.export api 直接导出 onnx 模型。

这里导出一个常见的 cnn 模型 - resnet50:

import torch
from torchvision.models import resnet50

if __name__ == "__main__":
    model = resnet50(pretrained=True)
    model.eval()
    print(model)
    
    input_tensor = torch.rand((1, 3, 224, 224), dtype=torch.float32)
    torch.onnx.export(
        model,(input_tensor,),
        "resnet50.onnx", 
        input_names=["input"],
    )

推理模型

实现前,我们先看看在 Pytorch 里是怎么实现 Resnet 推理的:

from PIL import Image
from torchvision import transforms

#加载模型
model = resnet50(pretrained=True)
model.eval()

#数据前处理
input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) 

#模型推理
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)

# 数据后处理
print(output[0])
print(torch.nn.functional.softmax(output[0], dim=0))

模型推理一般分 4 步:

  1. 加载模型
  2. 数据前处理
  3. 模型推理
  4. 数据后处理

加载模型

let session = Session::builder()?
        .with_optimization_level(GraphOptimizationLevel::Level3)?
        .with_execution_providers([
	        CPUExecutionProvider::default().build(),
            CUDAExecutionProvider::default().build(),
        ])?
        .with_intra_threads(4)?
        .commit_from_file("resnet50.onnx")?;

with_optimization_level 用于设置会话的优化级别,不同级别对应不同的图优化策略。

with_execution_providers 主要用于设置执行提供者列表, ONNX Runtime 抽象了不同的执行提供者(EP)来实现硬件加速执行 ONNX 图。常见的EP 有 CUDACPUTensorRTROCm、等,甚至支持国产NPU - RKNPU

with_intra_threads 用于设置会话中节点内并行执行的线程数。如果ONNX Runtime使用OpenMP构建,则线程数由环境变量OMP_NUM_THREADS控制,此函数无效。函数通过调用SetIntraOpNumThreads设置线程数。

前处理

通过上面 Python 的实现我们发现,数据前处理需要先把图片 resize 到 224x224(这是因为 resnet50 的输入尺寸是 224x224),然后转为张量,并为张量进行标准化。

以这张图片为例子 ![[cat 1.jpeg]]

加载待推理图片,并 resize 到224x224:

let image_buffer: ImageBuffer<Rgb<u8>, Vec<u8>> = image::open(
        Path::new(env!("CARGO_MANIFEST_DIR"))
            .join("tests")
            .join("cat.jpeg"),
    )
    .unwrap()
    .resize(224, 224, FilterType::Nearest)
    .to_rgb8();

ort 使用 ndarray 作为张量处理格式:

let mut array = ndarray::Array::from_shape_fn((1, 3, 224, 224), |(_, c, j, i)| {
        let pixel = image_buffer.get_pixel(i as u32, j as u32);
        let channels = pixel.channels();
        // range [0, 255] -> range [0, 1]
        (channels[c] as f32) / 255.0
    });

标准化:

let mean = [0.485, 0.456, 0.406];
    let std = [0.229, 0.224, 0.225];
    for c in 0..3 {
        let mut channel_array = array.slice_mut(s![0, c, .., ..]);
        channel_array -= mean[c];
        channel_array /= std[c];
    }

转成 Tensor:

let input = Tensor::from_array(array)?;

模型推理

剩下的就简单了,直接调用 session 进行模型推理即可。

let outputs = session.run(inputs![input]?)?;

根据模型 resnet 50 的结构可知,resnet50 的最后一层是全连接层,模型默认输出的是 imagenet 数据集的分类。因此我们需要找到 imagenet 的分类列表:

candle 框架项目里有一份 imagenet 的分类列表,我们可以直接拿来使用:

https://github.com/huggingface/candle/blob/0b24f7f0a41d369942bfcadac3a3cf494167f8a6/candle-examples/src/imagenet.rs

因此 output 的输出默认是 1000 维的概率。我们需要使用 softmax 转换一下:

let mut probabilities: Vec<(usize, f32)> = outputs[0]
		.try_extract_tensor()?
		.softmax(ndarray::Axis(1))
		.iter()
		.copied()
		.enumerate()
		.collect::<Vec<_>>();

让其每一项格式如下:

(
    281, // 分类索引
    0.92174786, // 概率
)

最后按概率排序,取最大的概率:

probabilities.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap());

// 第一个就是最大概率结果
dbg!(probabilities[0]);
// 获取最大概率结果的标签
let label = CLASSES[probabilities[0].0];
dbg!(label); //结果

结果打印如下:

[src/main.rs:56:5] probabilities[0] = (
    281,
    0.92174786,
)
[src/main.rs:58:5] label = "tabby, tabby cat"

到此,完成了使用 onnxruntime 对 resnet 的图像分类的推理。

ort / onnxruntime 的优点与缺点

onnxruntime 的推理代码是相对来说比较简单的。它的最大优点是通用性很强,几乎支持大部分模型与常见的训练框架。并且,onnxruntime 的依赖很少,集成非常的方便。

不过 使用 onnx 也有可能碰到不支持的算子,它自己本身可能会做一些优化。另外,有一些模型使用 onnxruntime gpu 推理时可能速度反而要慢,已知的比较典型的是一些 OCR 模型。