当前最热门的领域非人工智能 (AI) 莫属。几乎所有编程语言都在尝试在 AI 领域占据一席之地,但 Python 和 C++ 仍然牢牢占据主导地位。C++ 以其高效的性能被广泛用于底层实现,而 Python 则因其易用性和丰富的库支持成为上层开发的首选。
然而,Rust 在 AI 领域的发展也不容忽视,尤其是在 Hugging Face 的推动下。Hugging Face 推出了多个基于 Rust 的项目,如大模型部署框架 TGI(Text Generation Inference)、Tokenizers 以及 Candle 等。
Candle 是一个用 Rust 编写的极简机器学习框架,专注于高性能(包括 GPU 支持)和易用性。其 API 设计与 PyTorch 类似,使得开发者能够快速上手。
# pytorch
torch.Tensor([[1, 2], [3, 4]])
// rust candle
Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?
使用过 PyTorch 的人都知道,PyTorch 的依赖非常庞大。如果使用 PyTorch Docker 镜像进行推理部署,整个镜像大小至少需要 15GB。Candle 的核心目标是实现无服务器推理,允许部署轻量级的二进制文件。更重要的是,Candle 可以完全消除对 Python 的依赖,因为 Python 的性能相对较差。
实际上,Candle 已经实现了许多现有的大型模型,如 GPT、Stable Diffusion 和 LLaMA 等。
然而,一些较旧的模型尚未实现,例如 MobileNet。
MobileNet 是一个非常轻量级的神经网络,常用于移动设备,如智能手机和平板电脑。它主要用于图像分类任务,例如识别图像中的物体。
我们先从python 中导出模型权重:
import torch
import torchvision
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from safetensors import safe_open
from safetensors.torch import save_file
model = torchvision.models.mobilenet_v2(weights = MobileNet_V2_Weights.DEFAULT)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("mobilenetv2.pt")
print(model, file=open("mobilenetv2.txt", "w"))
weights = model.state_dict()
for key, value in weights.items():
print(key)
save_file(model.state_dict(), "mobilenetv2.safetensors")
candle 使用safetensors 格式,顺便把权重各个层都打印出来,保存到 mobilenetv2.txt 。
我们看一下 mobilenetv2.txt 网络层:
MobileNetV2(
(features): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(2): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False)
(1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(3): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
(1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(4): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=144, bias=False)
(1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(5): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(6): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(7): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=192, bias=False)
(1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(8): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(9): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(10): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(11): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
(1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(12): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
(1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(13): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
(1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(14): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=576, bias=False)
(1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(15): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
(1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(16): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
(1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(17): InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2dNormActivation(
(0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
(1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(2): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
(3): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(18): Conv2dNormActivation(
(0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
)
(classifier): Sequential(
(0): Dropout(p=0.2, inplace=False)
(1): Linear(in_features=1280, out_features=1000, bias=True)
)
)
看起来有很多层,其实大部分是重复层。一般我们从最深的层开始实现
先实现 Conv2D + BatchNorm2D + ReLU6 三个模块。
(0): Conv2dNormActivation(
(0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
这一块代码如下所示:
/// Conv2D + BatchNorm2D + ReLU6
#[derive(Debug, Clone)]
pub struct Conv2dNormActivation {
conv2d: Conv2d,
batch_norm2d: BatchNorm,
}
impl Conv2dNormActivation {
pub fn new(
vb: VarBuilder,
in_channels: usize,
out_channels: usize,
kernel_size: usize,
stride: usize,
groups: usize
) -> Result<Self> {
let cfg = candle_nn::Conv2dConfig {
stride,
padding: (kernel_size - 1) / 2,
groups,
..Default::default()
};
let conv2d = conv2d_no_bias(in_channels, out_channels, kernel_size, cfg, vb.pp(0))?;
let batch_norm2d = batch_norm(out_channels, 1e-5, vb.pp(1))?;
Ok(Self { conv2d, batch_norm2d })
}
}
impl Module for Conv2dNormActivation {
fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
let ys = xs.apply(&self.conv2d)?;
ys
.apply_t(&self.batch_norm2d,false)?
.relu()?
.clamp(0.0, 6.0)
}
}
在实例化时,vb 是权重参数的绑定器,用于绑定参数名称。例如,vb.pp(“conv”) 表示当前模块的权重参数名称为 conv,即 (conv): Sequential 中的 conv。
接下来,我们将实现 InvertedResidual(倒残差模块)。
InvertedResidual(
(conv): Sequential(
(0): Conv2dNormActivation(
(0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
(1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU6(inplace=True)
)
(1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
(2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
#[derive(Debug,Clone)]
struct InvertedResidual{
conv: ConvSequential
}
impl InvertedResidual {
fn new(
vb: VarBuilder,
in_channels: usize,
out_channels: usize,
stride: usize,
expand_ratio: usize
) -> Result<Self> {
Ok(
Self { conv: ConvSequential::new(vb.pp("conv"), in_channels, out_channels, stride, expand_ratio)? }
)
}
}
impl Module for InvertedResidual{
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.conv)
}
}
这里有一个 ConvSequential 模块
#[derive(Debug, Clone)]
struct ConvSequential {
cbr1: Option<Conv2dNormActivation>,
cbr2: Conv2dNormActivation,
conv2d: Conv2d,
batch_norm2d: BatchNorm,
use_res_connect: bool,
}
impl ConvSequential {
fn new(
vb: VarBuilder,
in_channels: usize,
out_channels: usize,
stride: usize,
expand_ratio: usize
) -> Result<Self> {
let c_hidden = expand_ratio * in_channels;
let mut id = 0;
let cbr1 = if expand_ratio != 1 {
// conv = conv.add(cbr(&p / id, c_in, c_hidden, 1, 1, 1));
let cbr = Conv2dNormActivation::new(vb.pp(id), in_channels, c_hidden, 1, 1, 1)?;
id += 1;
Some(cbr)
} else {
None
};
let cbr2 = Conv2dNormActivation::new(vb.pp(id), c_hidden, c_hidden, 3, stride, c_hidden)?;
let cfg = candle_nn::Conv2dConfig {
stride: 1,
..Default::default()
};
let conv2d = conv2d_no_bias(c_hidden, out_channels, 1, cfg, vb.pp(id + 1))?;
let batch_norm2d = batch_norm(out_channels, 1e-5, vb.pp(id + 2))?;
let use_res_connect = stride == 1 && in_channels == out_channels;
Ok(Self {
cbr1,
cbr2,
conv2d,
batch_norm2d,
use_res_connect,
})
}
}
impl Module for ConvSequential {
fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
let mut ys = xs.clone();
if let Some(cbr1) = &self.cbr1 {
ys = ys.apply(cbr1)?;
}
let ys = ys.apply(&self.cbr2)?.apply(&self.conv2d)?.apply_t(&self.batch_norm2d, false)?;
if self.use_res_connect {
xs + ys
} else {
Ok(ys)
}
}
}
到这里,深层的模块就实现了,接下就是最外层的 Features、Classifier 模块了,
Features 模块
#[derive(Debug, Clone)]
pub struct Features {
cbr1: Conv2dNormActivation,
invs: Sequential<InvertedResidual>,
cbr2: Conv2dNormActivation,
}
impl Features {
fn new(vb: VarBuilder) -> Result<Self> {
let mut c_in = 32;
let cbr1 = Conv2dNormActivation::new(vb.pp(0), 3, c_in, 3, 2, 1)?;
let mut layer_id = 1;
let mut invs = seq(0);
for &(er, c_out, n, stride) in INVERTED_RESIDUAL_SETTINGS.iter() {
for i in 0..n {
let stride = if i == 0 { stride } else { 1 };
let inv = InvertedResidual::new(vb.pp(layer_id), c_in, c_out, stride, er)?;
invs.add(inv);
c_in = c_out;
layer_id += 1;
}
}
let cbr2 = Conv2dNormActivation::new(vb.pp(layer_id), c_in,1280, 1, 1, 1)?;
Ok(Self {
cbr1,
invs,
cbr2,
})
}
}
impl Module for Features {
fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
let ys = xs.apply(&self.cbr1)?;
let ys = ys.apply(&self.invs)?;
ys.apply(&self.cbr2)
}
}
Classifier 模块
#[derive(Debug, Clone)]
struct Classifier {
linear: Linear,
}
impl Classifier {
fn new(vb: VarBuilder, nclasses: usize) -> Result<Self> {
let linear = candle_nn::linear(1280, nclasses, vb.pp(1))?;
Ok(Self { linear })
}
}
impl Module for Classifier {
fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
let ys = dropout(xs, 0.2)?;
ys.apply(&self.linear)
}
}
我们可以看到,写模型结构基本想搭积木一样。最定义 Mobilenetv2
#[derive(Debug, Clone)]
pub struct Mobilenetv2 {
features: Features,
classifier: Classifier,
}
impl Mobilenetv2 {
pub fn new(vb: VarBuilder, nclasses: usize) -> Result<Self> {
let features = Features::new(vb.pp("features"))?;
let classifier = Classifier::new(vb.pp("classifier"), nclasses)?;
Ok(Self { features, classifier })
}
}
impl Module for Mobilenetv2 {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.apply(&self.features)?.mean(D::Minus1)?.mean(D::Minus1)?.apply(&self.classifier)
}
}
上面的代码,最重要的是 forward 的实现,我们需要参考 python 版本实现。
torchvision 里有完整的实现,可以参考:https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py