Apr 27, 2025
4 min read
Rust,
pytorch,
candle,
typescript,
wasm,

从零开始构建手写输入法:wasm 应用篇

之前的文章我们已经实现了在 rust 层的手写模型推理。现在我希望这个模型直接在浏览器内运行,而不用占服务器资源。

刚好 rust 对 wasm 支持非常的好,而 candle 框架是支持 wasm 环境使用 cpu 推理的。

wasm-pack

wasm-pack 是构建和使用 Rust 生成的 WebAssembly 的一站式解决方案,您可以在浏览器中或与 Node.js 配合使用,使其与 JavaScript 进行互操作。

我们先安装wasm-pack :

cargo install wasm-pack

创建一个 wasm 项目:

cargo generate --git https://github.com/rustwasm/wasm-pack-template

我们需要为项目加上下面这些依赖:

[dependencies]
//...
candle-nn = "0.8.4"
candle-core = "0.8.4"
image = "0.25.6"
anyhow = "1.0.98"
serde = "1.0.219"
serde_json = "1.0.140"
getrandom = { version = "0.3", features = ["wasm_js"] }
serde-wasm-bindgen = "0.6.0"

基于之前的 rust 代码整理如下:


#[derive(Serialize, Deserialize)]
pub struct Top5 {
    pub label: String,
    pub score: f32,
    pub class_idx:usize,
}

pub struct Worker {
    model: crate::models::mobilenetv2::Mobilenetv2,
}

impl Worker {

    /// 加载并初始化一个预训练的 MobileNetV2 模型。
    pub fn load_model() -> Result<Self> {
        let dev = &Device::Cpu;
        let weights = include_bytes!("../ochw_mobilenetv2.safetensors");
        let vb = VarBuilder::from_buffered_safetensors(weights.to_vec(), DType::F32, dev)?;
        let model = crate::models::mobilenetv2::Mobilenetv2::new(vb, 4037)?;
        Ok(Self { model })
    }

        /// 从指定的标签文件中获取所有标签,并将其作为字符串向量返回。
        /// ## 返回值
        /// - `Result<Vec<String>>`: 如果成功读取并解析文件,返回包含所有标签的 `Vec<String>`;
        ///   如果过程中发生错误(如文件读取失败或解析错误),返回相应的错误信息。
        pub fn get_labels(&self) -> Result<Vec<String>> {
            // 读取标签文件内容
            let label_text = include_str!("../../training/data/train/label.txt");
            let reader = BufReader::new(label_text.as_bytes());
    
            let mut labels = Vec::new();
            // 逐行读取文件内容
            for line in reader.lines() {
                let line = line?;
                let line = line.trim();
                // 提取每行的第一个字段作为标签
                if let Some(label) = line.split('\t').next() {
                    labels.push(label.to_string());
                }
            }
            Ok(labels)
        }

    /// 使用预训练的模型对输入的图像进行预测,并返回概率最高的前5个类别及其对应的概率。
    ///
    /// # 参数
    /// - `image`: 输入的图像数据,以 `Vec<u8>` 形式表示,通常为图像的二进制数据。
    ///
    /// # 返回值
    /// - `Result<Vec<Top5>>`: 返回一个包含前5个预测结果的 `Vec<Top5>`,每个 `Top5` 结构体包含类别标签、概率值和类别索引。
    ///   如果过程中出现错误,则返回 `Err`。
    pub fn predict(&self, image: Vec<u8>) -> Result<Vec<Top5>> {
        // 从缓冲区加载图像并将其转换为模型所需的张量格式
        let image = load_image_from_buffer(&image, &Device::Cpu)?;
        let image = image.unsqueeze(0)?;
    
        // 使用模型对图像进行前向传播,获取输出结果
        let output = self.model.forward(&image)?;
    
        // 对输出结果进行 softmax 处理,将其转换为概率分布
        let output = candle_nn::ops::softmax(&output, 1)?;
    
        // 将输出结果展平并转换为包含索引和概率值的向量
        let mut predictions = output
            .flatten_all()?
            .to_vec1::<f32>()?
            .into_iter()
            .enumerate()
            .collect::<Vec<_>>();
    
        // 根据概率值对预测结果进行排序,从高到低
        predictions.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
    
        // 获取类别标签
        let labels = self.get_labels()?;
    
        // 取概率最高的前5个预测结果
        let top5 = predictions.iter().take(5).collect::<Vec<_>>();
    
        // 将前5个预测结果转换为 `Top5` 结构体,并打印结果
        let mut top5_data = Vec::with_capacity(5);
        for (i, (class_idx, prob)) in top5.iter().enumerate() {
            println!(
                "{}. Class {}: {:.2}%",
                i + 1,
                labels[*class_idx],
                prob * 100.0
            );
            top5_data.push(Top5 {
                label: labels[*class_idx].clone(),
                score: *prob,
                class_idx: *class_idx,
            })
        }
    
        Ok(top5_data)
    }
}

但是上面的代码在 js 层是无法调用的,我们还需要导出给 js 层使用:


#[wasm_bindgen]
pub struct Model {
    worker: Worker,
    labels: Vec<String>,
}

#[wasm_bindgen]
impl Model {
    pub fn new() -> Result<Self, JsError> {
        let worker = Worker::load_model()?;
        let labels = worker.get_labels()?;
        Ok(Self { worker, labels })
    }

    /// 获取标签
    pub fn get_label(&self) -> Result<String, JsError> {
        let json = serde_json::to_string(&self.labels)?;
        Ok(json)
    }

    /// 推理
    pub fn predict(&self, image: Vec<u8>) -> Result<String, JsError> {
        let output = self.worker.predict(image)?;
        let json = serde_json::to_string(&output)?;
        Ok(json)
    }
}

正常情况下,使用下面命令就可以构建:

wasm-pack build --target web

但是这个项目你会得到下面这个错误:

error: The wasm32-unknown-unknown targets are not supported by default; you may need to enable the "wasm_js" configuration flag. Note that enabling the `wasm_js` feature flag alone is insufficient. For more information see: https://docs.rs/getrandom/#webassembly-support

原因是 getrandom 这个依赖,虽然支持 wasm,但是由于 wasm-pack的编译目标是wasm32-unknown-unknown, getrandom 无法基于上面的目标(wasm32-unknown-unknown)名称中,推断出应该使用哪个 JavaScript 接口(或者 JavaScript 是否可用),因此我们需要手动编译参数:

RUSTFLAGS='--cfg getrandom_backend="wasm_js"' wasm-pack build --target web

编译成功后,你就得到了一个受前端支持的 npm 包。并且模型权重也包涵在 wasm 文件里了。

 tree pkg
 pkg
├── ochw_wasm.d.ts
├── ochw_wasm.js
├── ochw_wasm_bg.wasm
├── ochw_wasm_bg.wasm.d.ts
└── package.json

到了这一步, wasm 的实现算是完成了。