Apr 26, 2025
3 min read
Rust,
pytorch,
candle,
typescript,
wasm,

从零开始构建手写输入法:模型推理篇

训练并评估好模型后,我们先试试在 python 下推理的效果。

为了验证模型的泛化能力,我们最好手写的时候稍微潦草一些。以下面这张“知”字图片为例子:

![[zhi.png]]

python 推理代码:

def get_labels():
    labels = []
    with open("data/label.txt", "r", encoding="utf-8") as f:
        for line in f:
            # line: !	0
            line = line.strip()
            label = line.split("\t")[0]
            labels.append(label)
    return labels

if __name__ == "__main__":
    model = HandwritingTrainer.load_from_checkpoint("logs/version_0/checkpoint-epoch=32-val_loss=0.156.ckpt")
    model.eval()
    model = model.to("cuda")
    img = Image.open("./testdata/hui.png")
    img = img.convert("RGB")
    img = img.resize((96,96))
    rans = transforms.Compose([
        transforms.Resize((96, 96)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.95], std=[0.2])
    ])
    img = trans(img)
    img = img.unsqueeze(0)
    img = img.to("cuda")
    labels = get_labels()
    with torch.no_grad():
        output = model(img)
        output = torch.nn.functional.softmax(output,dim=1)
        # 获取top5的预测结果
        top5_prob, top5_idx = torch.topk(output, 5)
        top5_prob = top5_prob.cpu().numpy()
        top5_idx = top5_idx.cpu().numpy()
        for i in range(5):
            idx = top5_idx[0][i]
            print(f"Top {i+1} 预测标签: {labels[idx]}, 概率: {top5_prob[0][i]:.4f}")
        

得到的结果如下:

Top 1 预测标签: 知, 概率: 0.9505
Top 2 预测标签: 勉, 概率: 0.0095
Top 3 预测标签: 贮, 概率: 0.0025
Top 4 预测标签: 处, 概率: 0.0025
Top 5 预测标签: ‰, 概率: 0.0025

确定没问题后,就开始导出模型,采用 safetensors 格式:

    model = HandwritingTrainer.load_from_checkpoint("logs/version_0/checkpoint-epoch=32-val_loss=0.156.ckpt")
    model = model.model
    model.eval()
    print(model,file=open("model.txt","w"))
    save_model(model, "ochw_mobilenetv2.safetensors")

我们可以顺便把模型结构保存到 model.txt, 方便在 rust 层实现。

对于mobilenetv2 模型,之前有写过关于如何使用 rust candle 框架实现的文章,这里不展开。

需要注意的一点是,由于之前训练的时候,我在 mobilenetv2 开关移除了下采样,同样地,在 candle 实现上,也要把 stride 改为 1:

// 需要把 features[0][0] 的 stride 改为 2=>1,
let cbr1 = Conv2dNormActivation::new(vb.pp(0), 3, c_in, 3, 1, 1)?;

(后面通过消融实现,移除了这个操作。)

完整实现见: https://github.com/ximeiorg/ochw/blob/main/ochw-wasm/src/models/mobilenetv2.rs#L169

改用 rust 推理:

fn predict()->anyhow::Result<()> {
        let model_path =
            Path::new(env!("CARGO_MANIFEST_DIR")).join("ochw_mobilenetv2.safetensors");
        
        let vb = unsafe {
            VarBuilder::from_mmaped_safetensors(
                &[model_path.as_path()],
                candle_core::DType::F32,
                &candle_core::Device::Cpu,
            )
            .unwrap()
        };
        
        let nclasses = 4037;
        let model = mobilenetv2::Mobilenetv2::new(vb, nclasses).unwrap();

        let image_data = include_bytes!("../../../testdata/zhi.png");
        let device = &Device::Cpu;
        let image = load_image_from_buffer(image_data, device).unwrap();
        let image = image.unsqueeze(0).unwrap();
        let output = model.forward(&image).unwrap();
        //softmax
        // top 5, candle 好像没有类似的 torch.topk 的函数,只能自己实现
        let output = softmax(&output, 1).unwrap();
        println!("{output}");
        // 获取 top 5 预测结果(包含索引和概率值)
        let mut predictions = output
            .flatten_all()?
            .to_vec1::<f32>()?
            .into_iter()
            .enumerate()
            .collect::<Vec<_>>();
            
        predictions.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());

        let label_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("testdata/label.txt");
        let labels = get_labels(label_path).unwrap();
        
        let top5 = predictions.iter().take(5).collect::<Vec<_>>();
        for (i, (class_idx, prob)) in top5.iter().enumerate() {
            println!("{}. Class {}: {:.2}%", i+1, labels[*class_idx], prob * 100.0);
        }

        Ok(())
        
        
    }

load_image_from_buffer 会先加载图片,并 resize 到 96x96,再做标准化处理:

/// 加载并预处理一个96x96的RGB图像,将其转换为适合模型输入的张量。
///
/// 该函数接受一个包含图像原始字节数据的向量,并将其转换为一个经过标准化处理的张量。
/// 标准化过程包括将像素值从[0, 255]范围缩放到[0, 1],然后减去均值并除以标准差。
///
/// # 参数
/// - `raw`: 包含图像原始字节数据的向量,长度为96x96x3(即96x96的RGB图像)。
/// - `device`: 用于创建张量的设备(如CPU或GPU)。
///
/// # 返回值
/// 返回一个`Result<Tensor>`,表示处理后的张量。如果过程中出现错误,则返回`Err`。
fn load_image_raw(raw: Vec<u8>, device: &Device) -> Result<Tensor> {
    // 将原始字节数据转换为形状为(96, 96, 3)的张量,并进行维度置换,将通道维度放在最前面。
    let data = Tensor::from_vec(raw, (96, 96, 3), device)?.permute((2, 0, 1))?;

    // 创建均值和标准差的张量,并调整其形状以匹配输入张量的维度。
    let mean_array = [0.95f32, 0.95, 0.95];//[0.485f32, 0.456, 0.406]
    let std_array = [0.2f32, 0.2, 0.2];//[0.229f32, 0.224, 0.225]
    let mean = Tensor::new(&mean_array, device)?.reshape((3, 1, 1))?;
    let std = Tensor::new(&std_array, device)?.reshape((3, 1, 1))?;

    // 将像素值从[0, 255]缩放到[0, 1],然后进行标准化处理。
    (data.to_dtype(DType::F32)? / 255.0)?
        .broadcast_sub(&mean)?
        .broadcast_div(&std)
}

/// 从内存中的图像缓冲区加载图像,并将其转换为指定大小的张量。
///
/// 该函数首先从内存中的字节缓冲区加载图像,然后将其调整为96x96像素大小,
/// 最后将图像转换为RGB格式并加载为张量。
///
/// # 参数
/// - `buffer`: 包含图像数据的字节切片。
/// - `device`: 用于创建张量的设备。
///
/// # 返回值
/// - `Result<Tensor>`: 如果成功,返回包含图像数据的张量;如果失败,返回错误。
pub fn load_image_from_buffer(buffer: &[u8], device: &Device) -> Result<Tensor> {
    // 从内存中加载图像
    let img = image::load_from_memory(buffer)
        .map_err(Error::wrap)?
        .resize_to_fill(96, 96, image::imageops::FilterType::Triangle);
    
    // 将图像转换为RGB格式
    let img = img.to_rgb8();
    
    // 将图像数据加载为张量
    load_image_raw(img.into_raw(), device)
}

结果如下:

Tensor[[1, 4037], f32]
1. Class: 96.17%
2. Class: 0.62%
3. Class: 0.28%
4. Class: 0.23%
5. Class: 0.12%