Dec 06, 2024
12 min read
Rust,
Mobilenet,
Candle,

How to Implement a CNN Neural Network with Rust?

How to implement the MobileNetV2 neural network using Rust and the Candle framework, including the construction of the model structure and the implementation of the forward pass.

Currently, the most popular field is undoubtedly artificial intelligence (AI). Almost all programming languages are trying to gain a foothold in the AI domain, but Python and C++ still dominate. C++ is widely used for low-level implementations due to its high performance, while Python is the preferred choice for high-level development because of its ease of use and rich library support.

However, the development of Rust in the AI domain cannot be ignored, especially with the push from Hugging Face. Hugging Face has launched several Rust-based projects, such as the large model deployment framework TGI (Text Generation Inference), Tokenizers, and Candle, among others.

Candle is a minimal machine learning framework written in Rust, focusing on high performance (including GPU support) and ease of use. Its API design is similar to PyTorch, making it easy for developers to get started.

# pytorch
torch.Tensor([[1, 2], [3, 4]])
// rust candle
Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?

Anyone who has used PyTorch knows that its dependencies are quite large. If you use a PyTorch Docker image for inference deployment, the entire image size is at least 15GB. The core goal of Candle is to enable serverless inference, allowing the deployment of lightweight binaries. More importantly, Candle can completely eliminate the dependency on Python, which has relatively poor performance.

In fact, Candle has already implemented many existing large models, such as GPT, Stable Diffusion, and LLaMA, among others.

However, some older models have not yet been implemented, such as MobileNet.

MobileNet is a very lightweight neural network, commonly used in mobile devices like smartphones and tablets. It is primarily used for image classification tasks, such as recognizing objects in images.

We start by exporting the model weights from Python:

import torch
import torchvision
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
from safetensors import safe_open
from safetensors.torch import save_file

model = torchvision.models.mobilenet_v2(weights = MobileNet_V2_Weights.DEFAULT)
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("mobilenetv2.pt")

print(model, file=open("mobilenetv2.txt", "w"))
weights = model.state_dict()
for key, value in weights.items():
    print(key)
save_file(model.state_dict(), "mobilenetv2.safetensors")

MobileNetV2 uses the safetensors format. Let’s print out the layers of the model and save them to mobilenetv2.txt.

Looking at the contents of mobilenetv2.txt, we can see that there are many layers, but most of them are repetitive. Generally, we start implementing from the deepest layer.

MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=96, bias=False)
          (1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (3): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (4): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=144, bias=False)
          (1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (5): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (6): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (7): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(192, 192, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=192, bias=False)
          (1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (8): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (9): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (10): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(384, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (11): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(64, 384, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=384, bias=False)
          (1): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(384, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (12): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (13): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (14): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(576, 576, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=576, bias=False)
          (1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(576, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (15): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (16): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(960, 160, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(160, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (17): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(160, 960, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2dNormActivation(
          (0): Conv2d(960, 960, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=960, bias=False)
          (1): BatchNorm2d(960, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (2): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (3): BatchNorm2d(320, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (18): Conv2dNormActivation(
      (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (1): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
  )
  (classifier): Sequential(
    (0): Dropout(p=0.2, inplace=False)
    (1): Linear(in_features=1280, out_features=1000, bias=True)
  )
)

First, let’s implement the Conv2D + BatchNorm2D + ReLU6 modules.

    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )

The code for this block is as follows:

/// Conv2D + BatchNorm2D + ReLU6
#[derive(Debug, Clone)]
pub struct Conv2dNormActivation {
    conv2d: Conv2d,
    batch_norm2d: BatchNorm,
}

impl Conv2dNormActivation {
    pub fn new(
        vb: VarBuilder,
        in_channels: usize,
        out_channels: usize,
        kernel_size: usize,
        stride: usize,
        groups: usize
    ) -> Result<Self> {
        let cfg = candle_nn::Conv2dConfig {
            stride,
            padding: (kernel_size - 1) / 2,
            groups,
            ..Default::default()
        };
        let conv2d = conv2d_no_bias(in_channels, out_channels, kernel_size, cfg, vb.pp(0))?;

        let batch_norm2d = batch_norm(out_channels, 1e-5, vb.pp(1))?;

        Ok(Self { conv2d, batch_norm2d })
    }
}

impl Module for Conv2dNormActivation {
    fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
        let ys = xs.apply(&self.conv2d)?;
        ys
        .apply_t(&self.batch_norm2d, false)?
        .relu()?
        .clamp(0.0, 6.0)
    }
}

When instantiating, vb is a weight parameter binder used to bind parameter names. For example, vb.pp("conv") indicates that the weight parameter name for the current module is conv, i.e., (conv): Sequential in the model.

Next, we will implement the InvertedResidual (inverted residual module).

InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
#[derive(Debug, Clone)]
struct InvertedResidual {
    conv: ConvSequential
}

impl InvertedResidual {
    fn new(
        vb: VarBuilder,
        in_channels: usize,
        out_channels: usize,
        stride: usize,
        expand_ratio: usize
    ) -> Result<Self> {
        Ok(
            Self { conv: ConvSequential::new(vb.pp("conv"), in_channels, out_channels, stride, expand_ratio)? }
        )
    }
}

impl Module for InvertedResidual {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        xs.apply(&self.conv)
    }
}

Here is the ConvSequential module:

#[derive(Debug, Clone)]
struct ConvSequential {
    cbr1: Option<Conv2dNormActivation>,
    cbr2: Conv2dNormActivation,
    conv2d: Conv2d,
    batch_norm2d: BatchNorm,
    use_res_connect: bool,
}

impl ConvSequential {
    fn new(
        vb: VarBuilder,
        in_channels: usize,
        out_channels: usize,
        stride: usize,
        expand_ratio: usize
    ) -> Result<Self> {
        let c_hidden = expand_ratio * in_channels;
        let mut id = 0;
        let cbr1 = if expand_ratio != 1 {
            let cbr = Conv2dNormActivation::new(vb.pp(id), in_channels, c_hidden, 1, 1, 1)?;
            id += 1;
            Some(cbr)
        } else {
            None
        };
        let cbr2 = Conv2dNormActivation::new(vb.pp(id), c_hidden, c_hidden, 3, stride, c_hidden)?;
        let cfg = candle_nn::Conv2dConfig {
            stride: 1,
            ..Default::default()
        };
        let conv2d = conv2d_no_bias(c_hidden, out_channels, 1, cfg, vb.pp(id + 1))?;

        let batch_norm2d = batch_norm(out_channels, 1e-5, vb.pp(id + 2))?;
        let use_res_connect = stride == 1 && in_channels == out_channels;
        Ok(Self {
            cbr1,
            cbr2,
            conv2d,
            batch_norm2d,
            use_res_connect,
        })
    }
}

impl Module for ConvSequential {
    fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
        let mut ys = xs.clone();
        if let Some(cbr1) = &self.cbr1 {
            ys = ys.apply(cbr1)?;
        }

        let ys = ys.apply(&self.cbr2)?.apply(&self.conv2d)?.apply_t(&self.batch_norm2d, false)?;

        if self.use_res_connect {
            xs + ys
        } else {
            Ok(ys)
        }
    }
}

With the deeper modules implemented, the next step is to implement the outermost Features and Classifier modules.

Features Module

#[derive(Debug, Clone)]
pub struct Features {
    cbr1: Conv2dNormActivation,
    invs: Sequential<InvertedResidual>,
    cbr2: Conv2dNormActivation,
}

impl Features {
    fn new(vb: VarBuilder) -> Result<Self> {
        let mut c_in = 32;
        let cbr1 = Conv2dNormActivation::new(vb.pp(0), 3, c_in, 3, 2, 1)?;
        let mut layer_id = 1;
        let mut invs = seq(0);
        for &(er, c_out, n, stride) in INVERTED_RESIDUAL_SETTINGS.iter() {
            for i in 0..n {
                let stride = if i == 0 { stride } else { 1 };
                let inv = InvertedResidual::new(vb.pp(layer_id), c_in, c_out, stride, er)?;
                invs.add(inv);
                c_in = c_out;
                layer_id += 1;
            }
        }
        let cbr2 = Conv2dNormActivation::new(vb.pp(layer_id), c_in, 1280, 1, 1, 1)?;

        Ok(Self {
            cbr1,
            invs,
            cbr2,
        })
    }
}

impl Module for Features {
    fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
        let ys = xs.apply(&self.cbr1)?;

        let ys = ys.apply(&self.invs)?;
        ys.apply(&self.cbr2)
    }
}

Classifier Module

#[derive(Debug, Clone)]
struct Classifier {
    linear: Linear,
}

impl Classifier {
    fn new(vb: VarBuilder, nclasses: usize) -> Result<Self> {
        let linear = candle_nn::linear(1280, nclasses, vb.pp(1))?;
        Ok(Self { linear })
    }
}

impl Module for Classifier {
    fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
        let ys = dropout(xs, 0.2)?;
        ys.apply(&self.linear)
    }
}

We can see that building the model structure is like stacking blocks. Finally, let’s define Mobilenetv2.

#[derive(Debug, Clone)]
pub struct Mobilenetv2 {
    features: Features,
    classifier: Classifier,
}

impl Mobilenetv2 {
    pub fn new(vb: VarBuilder, nclasses: usize) -> Result<Self> {
        let features = Features::new(vb.pp("features"))?;
        let classifier = Classifier::new(vb.pp("classifier"), nclasses)?;
        Ok(Self { features, classifier })
    }
}

impl Module for Mobilenetv2 {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        xs.apply(&self.features)?.mean(D::Minus1)?.mean(D::Minus1)?.apply(&self.classifier)
    }
}

The most important part is the implementation of forward. We need to refer to the Python version for the implementation.

The complete implementation can be found in torchvision: https://github.com/pytorch/vision/blob/main/torchvision/models/mobilenetv2.py