In previous articles, we implemented handwriting model inference in Rust. Now, I want this model to run directly in the browser without consuming server resources.
Fortunately, Rust has excellent support for WebAssembly (WASM), and the Candle framework supports CPU inference in a WASM environment.
wasm-pack
wasm-pack is a one-stop solution for building and using Rust-generated WebAssembly, allowing it to interoperate with JavaScript in the browser or with Node.js.
First, let’s install wasm-pack:
cargo install wasm-pack
Create a WASM project:
cargo generate --git https://github.com/rustwasm/wasm-pack-template
We need to add the following dependencies to the project:
[dependencies]
//...
candle-nn = "0.8.4"
candle-core = "0.8.4"
image = "0.25.6"
anyhow = "1.0.98"
serde = "1.0.219"
serde_json = "1.0.140"
getrandom = { version = "0.3", features = ["wasm_js"] }
serde-wasm-bindgen = "0.6.0"
Based on the previous Rust code, we organize it as follows:
#[derive(Serialize, Deserialize)]
pub struct Top5 {
pub label: String,
pub score: f32,
pub class_idx:usize,
}
pub struct Worker {
model: crate::models::mobilenetv2::Mobilenetv2,
}
impl Worker {
/// Load and initialize a pre-trained MobileNetV2 model.
pub fn load_model() -> Result<Self> {
let dev = &Device::Cpu;
let weights = include_bytes!("../ochw_mobilenetv2.safetensors");
let vb = VarBuilder::from_buffered_safetensors(weights.to_vec(), DType::F32, dev)?;
let model = crate::models::mobilenetv2::Mobilenetv2::new(vb, 4037)?;
Ok(Self { model })
}
/// Retrieve all labels from the specified label file and return them as a vector of strings.
/// ## Returns
/// - `Result<Vec<String>>`: If the file is successfully read and parsed, returns a `Vec<String>` containing all labels;
/// If an error occurs (e.g., file read failure or parsing error), returns the corresponding error message.
pub fn get_labels(&self) -> Result<Vec<String>> {
// Read the label file content
let label_text = include_str!("../../training/data/train/label.txt");
let reader = BufReader::new(label_text.as_bytes());
let mut labels = Vec::new();
// Read the file content line by line
for line in reader.lines() {
let line = line?;
let line = line.trim();
// Extract the first field of each line as the label
if let Some(label) = line.split('\t').next() {
labels.push(label.to_string());
}
}
Ok(labels)
}
/// Use the pre-trained model to predict the input image and return the top 5 categories with their corresponding probabilities.
///
/// # Parameters
/// - `image`: The input image data, represented as a `Vec<u8>`, typically the binary data of the image.
///
/// # Returns
/// - `Result<Vec<Top5>>`: Returns a `Vec<Top5>` containing the top 5 predictions, where each `Top5` struct includes the category label, probability value, and class index.
/// If an error occurs, returns `Err`.
pub fn predict(&self, image: Vec<u8>) -> Result<Vec<Top5>> {
// Load the image from the buffer and convert it into the tensor format required by the model
let image = load_image_from_buffer(&image, &Device::Cpu)?;
let image = image.unsqueeze(0)?;
// Perform forward propagation on the image using the model to obtain the output
let output = self.model.forward(&image)?;
// Apply softmax to the output to convert it into a probability distribution
let output = candle_nn::ops::softmax(&output, 1)?;
// Flatten the output and convert it into a vector containing indices and probability values
let mut predictions = output
.flatten_all()?
.to_vec1::<f32>()?
.into_iter()
.enumerate()
.collect::<Vec<_>>();
// Sort the predictions by probability value, from highest to lowest
predictions.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
// Retrieve the category labels
let labels = self.get_labels()?;
// Take the top 5 predictions with the highest probabilities
let top5 = predictions.iter().take(5).collect::<Vec<_>>();
// Convert the top 5 predictions into `Top5` structs and print the results
let mut top5_data = Vec::with_capacity(5);
for (i, (class_idx, prob)) in top5.iter().enumerate() {
println!(
"{}. Class {}: {:.2}%",
i + 1,
labels[*class_idx],
prob * 100.0
);
top5_data.push(Top5 {
label: labels[*class_idx].clone(),
score: *prob,
class_idx: *class_idx,
})
}
Ok(top5_data)
}
}
However, the above code cannot be called from the JavaScript layer, so we need to expose it for use in JavaScript:
#[wasm_bindgen]
pub struct Model {
worker: Worker,
labels: Vec<String>,
}
#[wasm_bindgen]
impl Model {
pub fn new() -> Result<Self, JsError> {
let worker = Worker::load_model()?;
let labels = worker.get_labels()?;
Ok(Self { worker, labels })
}
/// Retrieve labels
pub fn get_label(&self) -> Result<String, JsError> {
let json = serde_json::to_string(&self.labels)?;
Ok(json)
}
/// Perform inference
pub fn predict(&self, image: Vec<u8>) -> Result<String, JsError> {
let output = self.worker.predict(image)?;
let json = serde_json::to_string(&output)?;
Ok(json)
}
}
Normally, you can build it using the following command:
wasm-pack build --target web
However, in this project, you will encounter the following error:
error: The wasm32-unknown-unknown targets are not supported by default; you may need to enable the "wasm_js" configuration flag. Note that enabling the `wasm_js` feature flag alone is insufficient. For more information see: https://docs.rs/getrandom/#webassembly-support
The reason is that although the getrandom dependency supports WASM, since the build target of wasm-pack is wasm32-unknown-unknown, getrandom cannot infer which JavaScript interface to use (or whether JavaScript is available) based on the target name (wasm32-unknown-unknown). Therefore, we need to manually specify the build flag:
RUSTFLAGS='--cfg getrandom_backend="wasm_js"' wasm-pack build --target web
After successful compilation, you will have an npm package supported by the frontend, and the model weights are also included in the WASM file.
❯ tree pkg
pkg
├── ochw_wasm.d.ts
├── ochw_wasm.js
├── ochw_wasm_bg.wasm
├── ochw_wasm_bg.wasm.d.ts
└── package.json
At this point, the WASM implementation is complete.