训练并评估好模型后,我们先试试在 python 下推理的效果。
为了验证模型的泛化能力,我们最好手写的时候稍微潦草一些。以下面这张“知”字图片为例子:
![[zhi.png]]
python 推理代码:
def get_labels():
labels = []
with open("data/label.txt", "r", encoding="utf-8") as f:
for line in f:
# line: ! 0
line = line.strip()
label = line.split("\t")[0]
labels.append(label)
return labels
if __name__ == "__main__":
model = HandwritingTrainer.load_from_checkpoint("logs/version_0/checkpoint-epoch=32-val_loss=0.156.ckpt")
model.eval()
model = model.to("cuda")
img = Image.open("./testdata/hui.png")
img = img.convert("RGB")
img = img.resize((96,96))
rans = transforms.Compose([
transforms.Resize((96, 96)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.95], std=[0.2])
])
img = trans(img)
img = img.unsqueeze(0)
img = img.to("cuda")
labels = get_labels()
with torch.no_grad():
output = model(img)
output = torch.nn.functional.softmax(output,dim=1)
# 获取top5的预测结果
top5_prob, top5_idx = torch.topk(output, 5)
top5_prob = top5_prob.cpu().numpy()
top5_idx = top5_idx.cpu().numpy()
for i in range(5):
idx = top5_idx[0][i]
print(f"Top {i+1} 预测标签: {labels[idx]}, 概率: {top5_prob[0][i]:.4f}")
得到的结果如下:
Top 1 预测标签: 知, 概率: 0.9505
Top 2 预测标签: 勉, 概率: 0.0095
Top 3 预测标签: 贮, 概率: 0.0025
Top 4 预测标签: 处, 概率: 0.0025
Top 5 预测标签: ‰, 概率: 0.0025
确定没问题后,就开始导出模型,采用 safetensors 格式:
model = HandwritingTrainer.load_from_checkpoint("logs/version_0/checkpoint-epoch=32-val_loss=0.156.ckpt")
model = model.model
model.eval()
print(model,file=open("model.txt","w"))
save_model(model, "ochw_mobilenetv2.safetensors")
我们可以顺便把模型结构保存到 model.txt, 方便在 rust 层实现。
对于mobilenetv2 模型,之前有写过关于如何使用 rust candle 框架实现的文章,这里不展开。
需要注意的一点是,由于之前训练的时候,我在 mobilenetv2 开关移除了下采样,同样地,在 candle 实现上,也要把 stride 改为 1:
// 需要把 features[0][0] 的 stride 改为 2=>1,
let cbr1 = Conv2dNormActivation::new(vb.pp(0), 3, c_in, 3, 1, 1)?;
(后面通过消融实现,移除了这个操作。)
完整实现见: https://github.com/ximeiorg/ochw/blob/main/ochw-wasm/src/models/mobilenetv2.rs#L169
改用 rust 推理:
fn predict()->anyhow::Result<()> {
let model_path =
Path::new(env!("CARGO_MANIFEST_DIR")).join("ochw_mobilenetv2.safetensors");
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(
&[model_path.as_path()],
candle_core::DType::F32,
&candle_core::Device::Cpu,
)
.unwrap()
};
let nclasses = 4037;
let model = mobilenetv2::Mobilenetv2::new(vb, nclasses).unwrap();
let image_data = include_bytes!("../../../testdata/zhi.png");
let device = &Device::Cpu;
let image = load_image_from_buffer(image_data, device).unwrap();
let image = image.unsqueeze(0).unwrap();
let output = model.forward(&image).unwrap();
//softmax
// top 5, candle 好像没有类似的 torch.topk 的函数,只能自己实现
let output = softmax(&output, 1).unwrap();
println!("{output}");
// 获取 top 5 预测结果(包含索引和概率值)
let mut predictions = output
.flatten_all()?
.to_vec1::<f32>()?
.into_iter()
.enumerate()
.collect::<Vec<_>>();
predictions.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
let label_path = Path::new(env!("CARGO_MANIFEST_DIR")).join("testdata/label.txt");
let labels = get_labels(label_path).unwrap();
let top5 = predictions.iter().take(5).collect::<Vec<_>>();
for (i, (class_idx, prob)) in top5.iter().enumerate() {
println!("{}. Class {}: {:.2}%", i+1, labels[*class_idx], prob * 100.0);
}
Ok(())
}
load_image_from_buffer 会先加载图片,并 resize 到 96x96,再做标准化处理:
/// 加载并预处理一个96x96的RGB图像,将其转换为适合模型输入的张量。
///
/// 该函数接受一个包含图像原始字节数据的向量,并将其转换为一个经过标准化处理的张量。
/// 标准化过程包括将像素值从[0, 255]范围缩放到[0, 1],然后减去均值并除以标准差。
///
/// # 参数
/// - `raw`: 包含图像原始字节数据的向量,长度为96x96x3(即96x96的RGB图像)。
/// - `device`: 用于创建张量的设备(如CPU或GPU)。
///
/// # 返回值
/// 返回一个`Result<Tensor>`,表示处理后的张量。如果过程中出现错误,则返回`Err`。
fn load_image_raw(raw: Vec<u8>, device: &Device) -> Result<Tensor> {
// 将原始字节数据转换为形状为(96, 96, 3)的张量,并进行维度置换,将通道维度放在最前面。
let data = Tensor::from_vec(raw, (96, 96, 3), device)?.permute((2, 0, 1))?;
// 创建均值和标准差的张量,并调整其形状以匹配输入张量的维度。
let mean_array = [0.95f32, 0.95, 0.95];//[0.485f32, 0.456, 0.406]
let std_array = [0.2f32, 0.2, 0.2];//[0.229f32, 0.224, 0.225]
let mean = Tensor::new(&mean_array, device)?.reshape((3, 1, 1))?;
let std = Tensor::new(&std_array, device)?.reshape((3, 1, 1))?;
// 将像素值从[0, 255]缩放到[0, 1],然后进行标准化处理。
(data.to_dtype(DType::F32)? / 255.0)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
/// 从内存中的图像缓冲区加载图像,并将其转换为指定大小的张量。
///
/// 该函数首先从内存中的字节缓冲区加载图像,然后将其调整为96x96像素大小,
/// 最后将图像转换为RGB格式并加载为张量。
///
/// # 参数
/// - `buffer`: 包含图像数据的字节切片。
/// - `device`: 用于创建张量的设备。
///
/// # 返回值
/// - `Result<Tensor>`: 如果成功,返回包含图像数据的张量;如果失败,返回错误。
pub fn load_image_from_buffer(buffer: &[u8], device: &Device) -> Result<Tensor> {
// 从内存中加载图像
let img = image::load_from_memory(buffer)
.map_err(Error::wrap)?
.resize_to_fill(96, 96, image::imageops::FilterType::Triangle);
// 将图像转换为RGB格式
let img = img.to_rgb8();
// 将图像数据加载为张量
load_image_raw(img.into_raw(), device)
}
结果如下:
Tensor[[1, 4037], f32]
1. Class 知: 96.17%
2. Class 勉: 0.62%
3. Class 处: 0.28%
4. Class 贮: 0.23%
5. Class 矩: 0.12%