Dec 06, 2024
11 min read
Rust,
Mobilenet,
Candle,

使用 Rust 如何实现 CNN 神经网络?

如何使用 Rust 和 Candle 框架实现 MobileNetV2 神经网络,包括模型结构的搭建和前向传播的实现。

当前最热门的领域非人工智能 (AI) 莫属。几乎所有编程语言都在尝试在 AI 领域占据一席之地,但 Python 和 C++ 仍然牢牢占据主导地位。C++ 以其高效的性能被广泛用于底层实现,而 Python 则因其易用性和丰富的库支持成为上层开发的首选。

然而,Rust 在 AI 领域的发展也不容忽视,尤其是在 Hugging Face 的推动下。Hugging Face 推出了多个基于 Rust 的项目,如大模型部署框架 TGI(Text Generation Inference)、Tokenizers 以及 Candle 等。

Candle 是一个用 Rust 编写的极简机器学习框架,专注于高性能(包括 GPU 支持)和易用性。其 API 设计与 PyTorch 类似,使得开发者能够快速上手。


# pytorch
torch.Tensor([[1, 2], [3, 4]])
// rust candle
Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?

使用过 PyTorch 的人都知道,PyTorch 的依赖非常庞大。如果使用 PyTorch Docker 镜像进行推理部署,整个镜像大小至少需要 15GB。Candle 的核心目标是实现无服务器推理,允许部署轻量级的二进制文件。更重要的是,Candle 可以完全消除对 Python 的依赖,因为 Python 的性能相对较差。

实际上,Candle 已经实现了许多现有的大型模型,如 GPT、Stable Diffusion 和 LLaMA 等。

然而,一些较旧的模型尚未实现,例如 MobileNet。

MobileNet 是一个非常轻量级的神经网络,常用于移动设备,如智能手机和平板电脑。它主要用于图像分类任务,例如识别图像中的物体。

我们先从python 中导出模型权重:

import torch
import torchvision
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from safetensors import safe_open
from safetensors.torch import save_file

model = torchvision.models.mobilenet_v2(weights = MobileNet_V2_Weights.DEFAULT)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("mobilenetv2.pt")

print(model, file=open("mobilenetv2.txt", "w"))
weights = model.state_dict()
for key, value in weights.items():
    print(key)
save_file(model.state_dict(), "mobilenetv2.safetensors")

candle 使用safetensors 格式,顺便把权重各个层都打印出来,保存到 mobilenetv2.txt

我们看一下 mobilenetv2.txt 网络层:

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False)
          (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (3): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (4): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=144, bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (5): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (6): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (7): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=192, bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (8): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (9): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (10): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (11): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (12): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (13): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (14): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=576, bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (15): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (16): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (17): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (18): Conv2dNormActivation(
      (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
  )
  (classifier): Sequential(
    (0): Dropout(p=0.2, inplace=False)
    (1): Linear(in_features=1280, out_features=1000, bias=True)
  )
)

看起来有很多层,其实大部分是重复层。一般我们从最深的层开始实现

先实现 Conv2D + BatchNorm2D + ReLU6 三个模块。

    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )

这一块代码如下所示:

/// Conv2D + BatchNorm2D + ReLU6
#[derive(Debug, Clone)]
pub struct Conv2dNormActivation {
    conv2d: Conv2d,
    batch_norm2d: BatchNorm,
}

impl Conv2dNormActivation {
    pub fn new(
        vb: VarBuilder,
        in_channels: usize,
        out_channels: usize,
        kernel_size: usize,
        stride: usize,
        groups: usize
    ) -> Result<Self> {
        let cfg = candle_nn::Conv2dConfig {
            stride,
            padding: (kernel_size - 1) / 2,
            groups,
            ..Default::default()
        };
        let conv2d = conv2d_no_bias(in_channels, out_channels, kernel_size, cfg, vb.pp(0))?;

        let batch_norm2d = batch_norm(out_channels, 1e-5, vb.pp(1))?;

        Ok(Self { conv2d, batch_norm2d })
    }
}

impl Module for Conv2dNormActivation {
    fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
        
        let ys = xs.apply(&self.conv2d)?;
        ys
        .apply_t(&self.batch_norm2d,false)?
        .relu()?
        .clamp(0.0, 6.0)
    }
}

在实例化时,vb 是权重参数的绑定器,用于绑定参数名称。例如,vb.pp(“conv”) 表示当前模块的权重参数名称为 conv,即 (conv): Sequential 中的 conv。

接下来,我们将实现 InvertedResidual(倒残差模块)。

InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
#[derive(Debug,Clone)]
struct InvertedResidual{
    conv: ConvSequential
}

impl InvertedResidual {
    fn new(
        vb: VarBuilder,
        in_channels: usize,
        out_channels: usize,
        stride: usize,
        expand_ratio: usize
    ) -> Result<Self> {
        Ok(
            Self { conv: ConvSequential::new(vb.pp("conv"), in_channels, out_channels, stride, expand_ratio)? }
        )
    }
}

impl Module for InvertedResidual{
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        xs.apply(&self.conv)
    }
}

这里有一个 ConvSequential 模块

#[derive(Debug, Clone)]
struct ConvSequential {
    cbr1: Option<Conv2dNormActivation>,
    cbr2: Conv2dNormActivation,
    conv2d: Conv2d,
    batch_norm2d: BatchNorm,
    use_res_connect: bool,
}

impl ConvSequential {
    fn new(
        vb: VarBuilder,
        in_channels: usize,
        out_channels: usize,
        stride: usize,
        expand_ratio: usize
    ) -> Result<Self> {
        let c_hidden = expand_ratio * in_channels;
        let mut id = 0;
        let cbr1 = if expand_ratio != 1 {
            // conv = conv.add(cbr(&p / id, c_in, c_hidden, 1, 1, 1));
            let cbr = Conv2dNormActivation::new(vb.pp(id), in_channels, c_hidden, 1, 1, 1)?;
            id += 1;
            Some(cbr)
        } else {
            None
        };
        let cbr2 = Conv2dNormActivation::new(vb.pp(id), c_hidden, c_hidden, 3, stride, c_hidden)?;
        let cfg = candle_nn::Conv2dConfig {
            stride: 1,
            ..Default::default()
        };
        let conv2d = conv2d_no_bias(c_hidden, out_channels, 1, cfg, vb.pp(id + 1))?;

        let batch_norm2d = batch_norm(out_channels, 1e-5, vb.pp(id + 2))?;
        let use_res_connect = stride == 1 && in_channels == out_channels;
        Ok(Self {
            cbr1,
            cbr2,
            conv2d,
            batch_norm2d,
            use_res_connect,
        })
    }
}

impl Module for ConvSequential {
    fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
        let mut ys = xs.clone();
        if let Some(cbr1) = &self.cbr1 {
            ys = ys.apply(cbr1)?;
        }

        let ys = ys.apply(&self.cbr2)?.apply(&self.conv2d)?.apply_t(&self.batch_norm2d, false)?;

        if self.use_res_connect {
            xs + ys
        } else {
            Ok(ys)
        }
    }
}

到这里,深层的模块就实现了,接下就是最外层的 FeaturesClassifier 模块了,

Features 模块

#[derive(Debug, Clone)]
pub struct Features {
    cbr1: Conv2dNormActivation,
    invs: Sequential<InvertedResidual>,
    cbr2: Conv2dNormActivation,
}

impl Features {
    fn new(vb: VarBuilder) -> Result<Self> {
        let mut c_in = 32;
        let cbr1 = Conv2dNormActivation::new(vb.pp(0), 3, c_in, 3, 2, 1)?;
        let mut layer_id = 1;
        let mut invs = seq(0);
        for &(er, c_out, n, stride) in INVERTED_RESIDUAL_SETTINGS.iter() {
            for i in 0..n {
                let stride = if i == 0 { stride } else { 1 };
                let inv = InvertedResidual::new(vb.pp(layer_id), c_in, c_out, stride, er)?;
                invs.add(inv);
                c_in = c_out;
                layer_id += 1;
            }
        }
        let cbr2 = Conv2dNormActivation::new(vb.pp(layer_id),  c_in,1280, 1, 1, 1)?;

        Ok(Self {
            cbr1,
            invs,
            cbr2,
        })
    }
}

impl Module for Features {
    fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
        let ys = xs.apply(&self.cbr1)?;
        
        let ys = ys.apply(&self.invs)?;
        ys.apply(&self.cbr2)
    }
}

Classifier 模块

#[derive(Debug, Clone)]
struct Classifier {
    linear: Linear,
}

impl Classifier {
    fn new(vb: VarBuilder, nclasses: usize) -> Result<Self> {
        let linear = candle_nn::linear(1280, nclasses, vb.pp(1))?;
        Ok(Self { linear })
    }
}

impl Module for Classifier {
    fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
        let ys = dropout(xs, 0.2)?;
        ys.apply(&self.linear)
    }
}

我们可以看到,写模型结构基本想搭积木一样。最定义 Mobilenetv2

#[derive(Debug, Clone)]
pub struct Mobilenetv2 {
    features: Features,
    classifier: Classifier,
}

impl Mobilenetv2 {
    pub fn new(vb: VarBuilder, nclasses: usize) -> Result<Self> {
        let features = Features::new(vb.pp("features"))?;
        let classifier = Classifier::new(vb.pp("classifier"), nclasses)?;
        Ok(Self { features, classifier })
    }
}

impl Module for Mobilenetv2 {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        xs.apply(&self.features)?.mean(D::Minus1)?.mean(D::Minus1)?.apply(&self.classifier)
    }
}

上面的代码,最重要的是 forward 的实现,我们需要参考 python 版本实现。

torchvision 里有完整的实现,可以参考:https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py

完整实现:https://github.com/kingzcheung/candle-models