After training and evaluating the model, let’s first test its inference performance in Python.
To verify the model’s generalization ability, it’s best to write the characters slightly more sloppily. Here’s an example of the character “知” (zhi):
![[zhi.png]]
Python inference code:
def get_labels():
labels = []
with open("data/label.txt", "r", encoding="utf-8") as f:
for line in f:
# line: ! 0
line = line.strip()
label = line.split("\t")[0]
labels.append(label)
return labels
if __name__ == "__main__":
model = HandwritingTrainer.load_from_checkpoint("logs/version_0/checkpoint-epoch=32-val_loss=0.156.ckpt")
model.eval()
model = model.to("cuda")
img = Image.open("./testdata/hui.png")
img = img.convert("RGB")
img = img.resize((96,96))
rans = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.95], std=[0.2])
])
img = trans(img)
img = img.unsqueeze(0)
img = img.to("cuda")
labels = get_labels()
with torch.no_grad():
output = model(img)
output = torch.nn.functional.softmax(output,dim=1)
# Get the top 5 predictions
top5_prob, top5_idx = torch.topk(output, 5)
top5_prob = top5_prob.cpu().numpy()
top5_idx = top5_idx.cpu().numpy()
for i in range(5):
idx = top5_idx[0][i]
print(f"Top {i+1} predicted label: {labels[idx]}, probability: {top5_prob[0][i]:.4f}")
The results are as follows:
Top 1 predicted label: 知, probability: 0.9505
Top 2 predicted label: 勉, probability: 0.0095
Top 3 predicted label: 贮, probability: 0.0025
Top 4 predicted label: 处, probability: 0.0025
Top 5 predicted label: ‰, probability: 0.0025
Once confirmed, we proceed to export the model in safetensors format:
model = HandwritingTrainer.load_from_checkpoint("logs/version_0/checkpoint-epoch=32-val_loss=0.156.ckpt")
model = model.model
model.eval()
print(model,file=open("model.txt","w"))
save_model(model, "ochw_mobilenetv2.safetensors")
We can also save the model structure to model.txt for easier implementation in Rust.
For the MobileNetV2 model, I’ve previously written about how to implement it using the Rust Candle framework, so I won’t go into detail here.
One thing to note is that since I removed the downsampling during training, the stride in the Candle implementation should also be changed to 1:
// Change the stride of features[0][0] from 2 to 1,
let cbr1 = Conv2dNormActivation::new(vb.pp(0), 3, c_in, 3, 1, 1)?;
(After ablation experiments, this operation was removed.)
The full implementation can be found here: https://github.com/ximeiorg/ochw/blob/main/ochw-wasm/src/models/mobilenetv2.rs#L169
Now, let’s switch to Rust for inference:
fn predict()->anyhow::Result<()> {
let model_path =
Path::new(env!("CARGO_MANIFEST_DIR")).join("ochw_mobilenetv2.safetensors");
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(
&[model_path.as_path()],
candle_core::DType::F32,
&candle_core::Device::Cpu,
)
.unwrap()
};
let nclasses = 4037;
let model = mobilenetv2::Mobilenetv2::new(vb, nclasses).unwrap();
let image_data = include_bytes!("../../../testdata/zhi.png");
let device = &Device::Cpu;
let image = load_image_from_buffer(image_data, device).unwrap();
let image = image.unsqueeze(0).unwrap();
let output = model.forward(&image).unwrap();
//softmax
// top 5, Candle doesn't seem to have a function like torch.topk, so we have to implement it ourselves
let output = softmax(&output, 1).unwrap();
println!("{output}");
// Get the top 5 predictions (including indices and probabilities)
let mut predictions = output
.flatten_all()?
.to_vec1::<f32>()?
.into_iter()
.enumerate()
.collect::<Vec<_>>();
predictions.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
let label_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("testdata/label.txt");
let labels = get_labels(label_path).unwrap();
let top5 = predictions.iter().take(5).collect::<Vec<_>>();
for (i, (class_idx, prob)) in top5.iter().enumerate() {
println!("{}. Class {}: {:.2}%", i+1, labels[*class_idx], prob * 100.0);
}
Ok(())
}
load_image_from_buffer will first load the image, resize it to 96x96, and then apply normalization:
/// Load and preprocess a 96x96 RGB image, converting it into a tensor suitable for model input.
///
/// This function takes a vector containing raw image byte data and converts it into a normalized tensor.
/// The normalization process includes scaling pixel values from [0, 255] to [0, 1], then subtracting the mean and dividing by the standard deviation.
///
/// # Parameters
/// - `raw`: A vector containing raw image byte data, with a length of 96x96x3 (i.e., a 96x96 RGB image).
/// - `device`: The device (e.g., CPU or GPU) used to create the tensor.
///
/// # Returns
/// Returns a `Result<Tensor>`, representing the processed tensor. If an error occurs, it returns `Err`.
fn load_image_raw(raw: Vec<u8>, device: &Device) -> Result<Tensor> {
// Convert raw byte data into a tensor of shape (96, 96, 3), then permute the dimensions to place the channel dimension first.
let data = Tensor::from_vec(raw, (96, 96, 3), device)?.permute((2, 0, 1))?;
// Create tensors for mean and standard deviation, and reshape them to match the input tensor's dimensions.
let mean_array = [0.95f32, 0.95, 0.95];//[0.485f32, 0.456, 0.406]
let std_array = [0.2f32, 0.2, 0.2];//[0.229f32, 0.224, 0.225]
let mean = Tensor::new(&mean_array, device)?.reshape((3, 1, 1))?;
let std = Tensor::new(&std_array, device)?.reshape((3, 1, 1))?;
// Scale pixel values from [0, 255] to [0, 1], then perform normalization.
(data.to_dtype(DType::F32)? / 255.0)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
/// Load an image from a memory buffer and convert it into a tensor of the specified size.
///
/// This function first loads an image from a byte buffer in memory, then resizes it to 96x96 pixels,
/// and finally converts the image to RGB format and loads it as a tensor.
///
/// # Parameters
/// - `buffer`: A byte slice containing the image data.
/// - `device`: The device used to create the tensor.
///
/// # Returns
/// - `Result<Tensor>`: If successful, returns a tensor containing the image data; if failed, returns an error.
pub fn load_image_from_buffer(buffer: &[u8], device: &Device) -> Result<Tensor> {
// Load the image from memory
let img = image::load_from_memory(buffer)
.map_err(Error::wrap)?
.resize_to_fill(96, 96, image::imageops::FilterType::Triangle);
// Convert the image to RGB format
let img = img.to_rgb8();
// Load the image data as a tensor
load_image_raw(img.into_raw(), device)
}
The results are as follows:
Tensor[[1, 4037], f32]
1. Class 知: 96.17%
2. Class 勉: 0.62%
3. Class 处: 0.28%
4. Class 贮: 0.23%
5. Class 矩: 0.12%