Dec 22, 2024
8 min read
Rust,
Candle,

Advanced Rust AI (Part 1): Converting PyTorch Models to Rust Candle

A guide on converting PyTorch models to Rust Candle models

Candle is a minimalist machine learning framework developed by Hugging Face, specifically tailored for the Rust programming language. It aims to combine high performance with ease of use.

One of Candle’s core objectives is to enable serverless inference, making it easier for developers to deploy machine learning models to the cloud without worrying about managing the underlying infrastructure.

This article provides a detailed explanation of how to convert PyTorch model code into Rust Candle code.

Key Points for Model Code Conversion

To successfully convert model code, you need to:

  1. Understand the corresponding APIs between PyTorch and Candle.
  2. Comprehend the structure of the model.
  3. Validate that the output of the converted model matches the original model.

Pytorch vs Candle

Below is a comparison of common equivalent code snippets. This table is not exhaustive but serves as a reference. For more equivalent implementations, you may need to consult the source code or official examples. A challenge is that the Candle team has not yet planned to create comprehensive documentation for this.

Using PyTorchUsing Candle
Creationtorch.Tensor([[1, 2], [3, 4]])Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)?
Creationtorch.zeros((2, 2))Tensor::zeros((2, 2), DType::F32, &Device::Cpu)?
Indexingtensor[:, :4]tensor.i((.., ..4))?
Operationstensor.view((2, 2))tensor.reshape((2, 2))?
Operationsa.matmul(b)a.matmul(&b)?
Arithmetica + b&a + &b
Devicetensor.to(device="cuda")tensor.to_device(&Device::new_cuda(0)?)?
Dtypetensor.to(dtype=torch.float16)tensor.to_dtype(&DType::F16)?
Savingtorch.save({"A": A}, "model.bin")candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")?
Loadingweights = torch.load("model.bin")candle::safetensors::load("model.safetensors", &device)

More equivalent code will be provided in the following examples.

Understanding the Model Structure

To understand a model’s architecture, we can print its structure. In PyTorch, this can be done simply by printing the model:

model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
print(model)

In previous articles, we used mobilenetv2 as an example. This time, we will use ResNet.

ResNet Overview

ResNet, short for “Residual Network,” is a deep convolutional neural network architecture introduced by Kaiming He at Microsoft Research in 2015. Its key feature is the introduction of “residual blocks” that incorporate shortcut connections or skip connections, allowing the network to learn the residual mapping between inputs and outputs. This makes it possible to effectively train very deep networks.

Similar to how the Transformer is foundational to large models, the design philosophy of ResNet has influenced almost all modern traditional CNN architectures. You can find traces of ResNet in many CNN models.

The official PyTorch implementation of ResNet can be found here:

https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py

Unlike MobileNet, ResNet has an official implementation in Candle, which can be found here:

https://github.com/huggingface/candle/blob/main/candle-transformers/src/models/resnet.rs

Interestingly, the official Candle implementation is not based on the PyTorch code but rather on another work by the same author: a Rust binding for the C++ version of libtorch called tch-rs:

https://github.com/LaurentMazare/tch-rs/blob/main/src/vision/resnet.rs

The official Candle ResNet implementation extensively uses closures for code encapsulation. My implementation predates the official version and references both the official PyTorch implementation and the tch version, leading to some differences in the implementation approach.

ResNet Model Structure

ResNet has several versions. Here, we will focus on ResNet18 and ResNet50. Let’s take a look at the structure of ResNet18:

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

The difference between ResNet50 and ResNet18 lies in the structure of the BasicBlock. ResNet50 uses a more complex Bottleneck structure, where each residual block contains three convolutional layers: a 1x1 convolution to reduce the number of channels, a 3x3 convolution for feature extraction, and another 1x1 convolution to restore the original number of channels.

Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )

By correctly distinguishing between these two structures, you can implement ResNet18 and ResNet50 separately.

Special Considerations in Candle

VarBuilder

VarBuilder is a tool in the candle library used to help build and manage variables (such as weights and biases) for neural network models. It is particularly useful when creating layers like convolutional and fully connected layers, as it facilitates the initialization of these layers’ parameters and supports loading weights from pre-trained models.

For example, when initializing the maxpool layer in Candle, you would use VarBuilder to bind this variable:

vb.pp("conv1")

If there is an issue with binding, you might encounter an error like this:

Error: WithBacktrace { inner: UnexpectedShape { msg: "shape mismatch for layer2.0.conv2.weight", expected: [128, 64, 3, 3], got: [128, 128, 3, 3] },

You also need to use VarBuilder to load the weights:

let model_file = "./testdata/resnet18.safetensors";
let device = candle_core::Device::Cpu;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };

This method must be called within an unsafe block because VarBuilder::from_mmaped_safetensors uses memmap2::MmapOptions, and all file-based memory mapping constructors in memmap2::MmapOptions are marked as unsafe. This is due to the potential for undefined behavior (UB) if the underlying file is modified after mapping (either by internal or external processes). Applications must consider this risk and take appropriate precautions, such as using file permissions, locks, or process-private (e.g., unlinked) files.

Due to space limitations, I will provide detailed explanations of the specific implementations of each module in “Advanced Rust AI (Part 2): Converting PyTorch Models to Rust Candle.”