之前的文章我们已经实现了在 rust 层的手写模型推理。现在我希望这个模型直接在浏览器内运行,而不用占服务器资源。
刚好 rust 对 wasm 支持非常的好,而 candle 框架是支持 wasm 环境使用 cpu 推理的。
wasm-pack
wasm-pack 是构建和使用 Rust 生成的 WebAssembly 的一站式解决方案,您可以在浏览器中或与 Node.js 配合使用,使其与 JavaScript 进行互操作。
我们先安装wasm-pack :
cargo install wasm-pack
创建一个 wasm 项目:
cargo generate --git https://github.com/rustwasm/wasm-pack-template
我们需要为项目加上下面这些依赖:
[dependencies]
//...
candle-nn = "0.8.4"
candle-core = "0.8.4"
image = "0.25.6"
anyhow = "1.0.98"
serde = "1.0.219"
serde_json = "1.0.140"
getrandom = { version = "0.3", features = ["wasm_js"] }
serde-wasm-bindgen = "0.6.0"
基于之前的 rust 代码整理如下:
#[derive(Serialize, Deserialize)]
pub struct Top5 {
pub label: String,
pub score: f32,
pub class_idx:usize,
}
pub struct Worker {
model: crate::models::mobilenetv2::Mobilenetv2,
}
impl Worker {
/// 加载并初始化一个预训练的 MobileNetV2 模型。
pub fn load_model() -> Result<Self> {
let dev = &Device::Cpu;
let weights = include_bytes!("../ochw_mobilenetv2.safetensors");
let vb = VarBuilder::from_buffered_safetensors(weights.to_vec(), DType::F32, dev)?;
let model = crate::models::mobilenetv2::Mobilenetv2::new(vb, 4037)?;
Ok(Self { model })
}
/// 从指定的标签文件中获取所有标签,并将其作为字符串向量返回。
/// ## 返回值
/// - `Result<Vec<String>>`: 如果成功读取并解析文件,返回包含所有标签的 `Vec<String>`;
/// 如果过程中发生错误(如文件读取失败或解析错误),返回相应的错误信息。
pub fn get_labels(&self) -> Result<Vec<String>> {
// 读取标签文件内容
let label_text = include_str!("../../training/data/train/label.txt");
let reader = BufReader::new(label_text.as_bytes());
let mut labels = Vec::new();
// 逐行读取文件内容
for line in reader.lines() {
let line = line?;
let line = line.trim();
// 提取每行的第一个字段作为标签
if let Some(label) = line.split('\t').next() {
labels.push(label.to_string());
}
}
Ok(labels)
}
/// 使用预训练的模型对输入的图像进行预测,并返回概率最高的前5个类别及其对应的概率。
///
/// # 参数
/// - `image`: 输入的图像数据,以 `Vec<u8>` 形式表示,通常为图像的二进制数据。
///
/// # 返回值
/// - `Result<Vec<Top5>>`: 返回一个包含前5个预测结果的 `Vec<Top5>`,每个 `Top5` 结构体包含类别标签、概率值和类别索引。
/// 如果过程中出现错误,则返回 `Err`。
pub fn predict(&self, image: Vec<u8>) -> Result<Vec<Top5>> {
// 从缓冲区加载图像并将其转换为模型所需的张量格式
let image = load_image_from_buffer(&image, &Device::Cpu)?;
let image = image.unsqueeze(0)?;
// 使用模型对图像进行前向传播,获取输出结果
let output = self.model.forward(&image)?;
// 对输出结果进行 softmax 处理,将其转换为概率分布
let output = candle_nn::ops::softmax(&output, 1)?;
// 将输出结果展平并转换为包含索引和概率值的向量
let mut predictions = output
.flatten_all()?
.to_vec1::<f32>()?
.into_iter()
.enumerate()
.collect::<Vec<_>>();
// 根据概率值对预测结果进行排序,从高到低
predictions.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
// 获取类别标签
let labels = self.get_labels()?;
// 取概率最高的前5个预测结果
let top5 = predictions.iter().take(5).collect::<Vec<_>>();
// 将前5个预测结果转换为 `Top5` 结构体,并打印结果
let mut top5_data = Vec::with_capacity(5);
for (i, (class_idx, prob)) in top5.iter().enumerate() {
println!(
"{}. Class {}: {:.2}%",
i + 1,
labels[*class_idx],
prob * 100.0
);
top5_data.push(Top5 {
label: labels[*class_idx].clone(),
score: *prob,
class_idx: *class_idx,
})
}
Ok(top5_data)
}
}
但是上面的代码在 js 层是无法调用的,我们还需要导出给 js 层使用:
#[wasm_bindgen]
pub struct Model {
worker: Worker,
labels: Vec<String>,
}
#[wasm_bindgen]
impl Model {
pub fn new() -> Result<Self, JsError> {
let worker = Worker::load_model()?;
let labels = worker.get_labels()?;
Ok(Self { worker, labels })
}
/// 获取标签
pub fn get_label(&self) -> Result<String, JsError> {
let json = serde_json::to_string(&self.labels)?;
Ok(json)
}
/// 推理
pub fn predict(&self, image: Vec<u8>) -> Result<String, JsError> {
let output = self.worker.predict(image)?;
let json = serde_json::to_string(&output)?;
Ok(json)
}
}
正常情况下,使用下面命令就可以构建:
wasm-pack build --target web
但是这个项目你会得到下面这个错误:
error: The wasm32-unknown-unknown targets are not supported by default; you may need to enable the "wasm_js" configuration flag. Note that enabling the `wasm_js` feature flag alone is insufficient. For more information see: https://docs.rs/getrandom/#webassembly-support
原因是 getrandom 这个依赖,虽然支持 wasm,但是由于 wasm-pack的编译目标是wasm32-unknown-unknown, getrandom 无法基于上面的目标(wasm32-unknown-unknown)名称中,推断出应该使用哪个 JavaScript 接口(或者 JavaScript 是否可用),因此我们需要手动编译参数:
RUSTFLAGS='--cfg getrandom_backend="wasm_js"' wasm-pack build --target web
编译成功后,你就得到了一个受前端支持的 npm 包。并且模型权重也包涵在 wasm 文件里了。
❯ tree pkg
pkg
├── ochw_wasm.d.ts
├── ochw_wasm.js
├── ochw_wasm_bg.wasm
├── ochw_wasm_bg.wasm.d.ts
└── package.json
到了这一步, wasm 的实现算是完成了。