Apr 27, 2025
5 min read
Rust,
pytorch,
candle,
typescript,
wasm,

Building a Handwriting Input Method from Scratch: WASM Application

In previous articles, we implemented handwriting model inference in Rust. Now, I want this model to run directly in the browser without consuming server resources.

Fortunately, Rust has excellent support for WebAssembly (WASM), and the Candle framework supports CPU inference in a WASM environment.

wasm-pack

wasm-pack is a one-stop solution for building and using Rust-generated WebAssembly, allowing it to interoperate with JavaScript in the browser or with Node.js.

First, let’s install wasm-pack:

cargo install wasm-pack

Create a WASM project:

cargo generate --git https://github.com/rustwasm/wasm-pack-template

We need to add the following dependencies to the project:

[dependencies]
//...
candle-nn = "0.8.4"
candle-core = "0.8.4"
image = "0.25.6"
anyhow = "1.0.98"
serde = "1.0.219"
serde_json = "1.0.140"
getrandom = { version = "0.3", features = ["wasm_js"] }
serde-wasm-bindgen = "0.6.0"

Based on the previous Rust code, we organize it as follows:


#[derive(Serialize, Deserialize)]
pub struct Top5 {
    pub label: String,
    pub score: f32,
    pub class_idx:usize,
}

pub struct Worker {
    model: crate::models::mobilenetv2::Mobilenetv2,
}

impl Worker {

    /// Load and initialize a pre-trained MobileNetV2 model.
    pub fn load_model() -> Result<Self> {
        let dev = &Device::Cpu;
        let weights = include_bytes!("../ochw_mobilenetv2.safetensors");
        let vb = VarBuilder::from_buffered_safetensors(weights.to_vec(), DType::F32, dev)?;
        let model = crate::models::mobilenetv2::Mobilenetv2::new(vb, 4037)?;
        Ok(Self { model })
    }

        /// Retrieve all labels from the specified label file and return them as a vector of strings.
        /// ## Returns
        /// - `Result<Vec<String>>`: If the file is successfully read and parsed, returns a `Vec<String>` containing all labels;
        ///   If an error occurs (e.g., file read failure or parsing error), returns the corresponding error message.
        pub fn get_labels(&self) -> Result<Vec<String>> {
            // Read the label file content
            let label_text = include_str!("../../training/data/train/label.txt");
            let reader = BufReader::new(label_text.as_bytes());
    
            let mut labels = Vec::new();
            // Read the file content line by line
            for line in reader.lines() {
                let line = line?;
                let line = line.trim();
                // Extract the first field of each line as the label
                if let Some(label) = line.split('\t').next() {
                    labels.push(label.to_string());
                }
            }
            Ok(labels)
        }

    /// Use the pre-trained model to predict the input image and return the top 5 categories with their corresponding probabilities.
    ///
    /// # Parameters
    /// - `image`: The input image data, represented as a `Vec<u8>`, typically the binary data of the image.
    ///
    /// # Returns
    /// - `Result<Vec<Top5>>`: Returns a `Vec<Top5>` containing the top 5 predictions, where each `Top5` struct includes the category label, probability value, and class index.
    ///   If an error occurs, returns `Err`.
    pub fn predict(&self, image: Vec<u8>) -> Result<Vec<Top5>> {
        // Load the image from the buffer and convert it into the tensor format required by the model
        let image = load_image_from_buffer(&image, &Device::Cpu)?;
        let image = image.unsqueeze(0)?;
    
        // Perform forward propagation on the image using the model to obtain the output
        let output = self.model.forward(&image)?;
    
        // Apply softmax to the output to convert it into a probability distribution
        let output = candle_nn::ops::softmax(&output, 1)?;
    
        // Flatten the output and convert it into a vector containing indices and probability values
        let mut predictions = output
            .flatten_all()?
            .to_vec1::<f32>()?
            .into_iter()
            .enumerate()
            .collect::<Vec<_>>();
    
        // Sort the predictions by probability value, from highest to lowest
        predictions.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
    
        // Retrieve the category labels
        let labels = self.get_labels()?;
    
        // Take the top 5 predictions with the highest probabilities
        let top5 = predictions.iter().take(5).collect::<Vec<_>>();
    
        // Convert the top 5 predictions into `Top5` structs and print the results
        let mut top5_data = Vec::with_capacity(5);
        for (i, (class_idx, prob)) in top5.iter().enumerate() {
            println!(
                "{}. Class {}: {:.2}%",
                i + 1,
                labels[*class_idx],
                prob * 100.0
            );
            top5_data.push(Top5 {
                label: labels[*class_idx].clone(),
                score: *prob,
                class_idx: *class_idx,
            })
        }
    
        Ok(top5_data)
    }
}

However, the above code cannot be called from the JavaScript layer, so we need to expose it for use in JavaScript:


#[wasm_bindgen]
pub struct Model {
    worker: Worker,
    labels: Vec<String>,
}

#[wasm_bindgen]
impl Model {
    pub fn new() -> Result<Self, JsError> {
        let worker = Worker::load_model()?;
        let labels = worker.get_labels()?;
        Ok(Self { worker, labels })
    }

    /// Retrieve labels
    pub fn get_label(&self) -> Result<String, JsError> {
        let json = serde_json::to_string(&self.labels)?;
        Ok(json)
    }

    /// Perform inference
    pub fn predict(&self, image: Vec<u8>) -> Result<String, JsError> {
        let output = self.worker.predict(image)?;
        let json = serde_json::to_string(&output)?;
        Ok(json)
    }
}

Normally, you can build it using the following command:

wasm-pack build --target web

However, in this project, you will encounter the following error:

error: The wasm32-unknown-unknown targets are not supported by default; you may need to enable the "wasm_js" configuration flag. Note that enabling the `wasm_js` feature flag alone is insufficient. For more information see: https://docs.rs/getrandom/#webassembly-support

The reason is that although the getrandom dependency supports WASM, since the build target of wasm-pack is wasm32-unknown-unknown, getrandom cannot infer which JavaScript interface to use (or whether JavaScript is available) based on the target name (wasm32-unknown-unknown). Therefore, we need to manually specify the build flag:

RUSTFLAGS='--cfg getrandom_backend="wasm_js"' wasm-pack build --target web

After successful compilation, you will have an npm package supported by the frontend, and the model weights are also included in the WASM file.

 tree pkg
 pkg
├── ochw_wasm.d.ts
├── ochw_wasm.js
├── ochw_wasm_bg.wasm
├── ochw_wasm_bg.wasm.d.ts
└── package.json

At this point, the WASM implementation is complete.