Mar 24, 2025
9 min read
Rust,
Candle,
Pytorch,

Rust Candle Framework and PyTorch nn Module Network Layer Conversion (1)

This article compares the similarities and differences between Rust Candle and PyTorch in the implementation of neural network layers, covering sequential containers, convolutional layers (1D/2D and transposed convolution), pooling layers (max pooling and average pooling), and common activation functions. It focuses on analyzing the correspondence between the two in terms of functionality, parameter configuration, and usage, and points out that Candle does not yet support some features such as 3D convolution.

Previously, we roughly understood some tensor operations between PyTorch and Candle, and now we have come to the building blocks of graphs: some similarities and differences in the nn module.

The torch.nn module is the core toolkit for constructing and training neural networks, containing a large number of network layers, loss functions, optimizers, and other definitions.

Since Candle emphasizes inference rather than training, this article focuses on discussing the mutual conversion of network layers between Candle and PyTorch.

Table Preview

  • ✅ Indicates implemented
  • 🚫 Indicates not implemented
  • ☢️ Indicates alternative implementation
FunctionPyTorchCandleImplemented
Sequential ContainerSequentialSequential
1D Convolutionnn.Conv1dconv1d/conv1d_no_bias
2D Convolutionnn.Conv2dconv2d/conv2d_no_bias
3D Convolutionnn.Conv3dNot implemented🚫
1D Transposed Convolutionnn.ConvTranspose1dconv_transpose1d
conv_transpose1d_no_bias
2D Transposed Convolutionnn.ConvTranspose2dconv_transpose2d
conv_transpose2d_no_bias
3D Transposed Convolutionnn.ConvTranspose3dNot implemented🚫
1D Max Poolingnn.MaxPool1dNot implemented🚫
2D Max Poolingnn.MaxPool2dmax_pool2d
max_pool2d_with_stride
3D Max Poolingnn.MaxPool3dNot implemented🚫
1D Average Poolingnn.AvgPool1dNot implemented🚫
2D Average Poolingnn.AvgPool2davg_pool2d
avg_pool2d_with_stride
3D Average Poolingnn.AvgPool3dNot implemented🚫
Apply Rectified Linear Unit Function Element-wisenn.ReLUrelu
Apply ReLU6 Function Element-wisenn.ReLU6relu6
Apply Randomized Leaky Rectified Linear Unit Function Element-wisenn.RReLUNot implemented
Apply Gaussian Error Linear Unit Functionnn.GELUgelu
Sigmoid Functionnn.Sigmoidsigmoid
Sigmoid Linear Unit (SiLU) Functionnn.SiLUsilu
Apply Hyperbolic Tangent (Tanh) Function Element-wisenn.Tanhtanh
Apply Exponential Linear Unit (ELU) Function Element-wise.nn.ELUelu

Before We Start

To compare if the operations on both sides are equivalent, we need to find a way to fix and share the test weights between PyTorch and Candle. Because if some parameters are not initialized with weights, they will be assigned random values by default. For example, the following code:

m = nn.Conv1d(3, 16, 3)
input = torch.ones(1, 3, 224,dtype=torch.float32)
output = m(input)

In the above code, since no weights are initialized, the result of output is actually different every time. Therefore, to facilitate reproduction, we need to save the weights and fix the random parameters.

The best way is to save the Python weights to a file and then reproduce them in Candle. PyTorch saves weights using the safetensors format:

from safetensors.torch import save_model
m = nn.Conv1d(3, 16, 3, stride=2,bias=False)
save_model(m, "model.safetensors")

In Candle, directly read the model.safetensors weights to reproduce, and this is also the closest way to implement a PyTorch model using Candle:

let vb = unsafe {
        VarBuilder::from_mmaped_safetensors(&["./model.safetensors"], DType::F32, &Device::Cpu)?
    };

Sequential

Sequential is just a container, it has no actual function itself, and it won’t affect the loading of weights either. Therefore, we can use the official Candle version or define our own. PyTorch

seq = torch.nn.Sequential(
        nn.Conv2d(3, 64, 3, 1, 1),
        nn.ReLU(),
        nn.Conv2d()
    )

The official Candle definition essentially maintains a vec internally:

    let mut seq = candle_nn::seq();
    seq.add(line);

Therefore, if the official one doesn’t meet the needs (Why doesn’t it meet? Due to Rust’s type system, it’s not that convenient, for instance, the official implementation type is: Vec<Box<dyn Module>>), we can implement our own:

use candle_core::{ Module, Tensor, Result };

#[derive(Debug, Clone)]
pub struct Sequential<T: Module> {
    layers: Vec<T>,
}

pub fn seq<T: Module>(cnt: usize) -> Sequential<T> {
    let v = if cnt == 0 { vec![] } else { Vec::with_capacity(cnt) };
    Sequential { layers: v }
}

impl<T: Module> Sequential<T> {
    pub fn len(&self) -> usize {
        self.layers.len()
    }

    pub fn is_empty(&self) -> bool {
        self.layers.is_empty()
    }

    pub fn push(&mut self, layer: T) {
        self.layers.push(layer);
    }

    pub fn add(&mut self, layer: T) {
        self.layers.push(layer);
    }
}
impl<T: Module> Module for Sequential<T> {
    fn forward(&self, xs: &candle_core::Tensor) -> Result<Tensor> {
        let mut xs = xs.clone();
        for layer in self.layers.iter() {
            xs = xs.apply(layer)?;
        }
        Ok(xs)
    }
}

1D Convolution

Conv1d is mainly used to process sequence data, its core is to extract local features through a sliding window in a single dimension (usually the time or sequence dimension). For example, sensor data (such as temperature, stock prices), medical signals (such as ECG), industrial monitoring, etc.

The formula for Conv1d is roughly as follows:

out(Ni,Coutj​​)=bias(Coutj)+k=0Cin1weight(Coutj​​​​,k)input(Ni,k)out(N_i​,C_{out_j}​​)=bias(C_{out_j})+ \sum_{k=0}^{C_{in}​−1}​weight(C{out_j}​​​​,k)⋆input(N_i​,k)

Input (batch, channels, length):

(N,Cin,Lin) (N,C_{in},L_{in})

Output (batch, channels, length)

(N,Cout,Lout) (N,C_{out},L_{out})

Where LoutL_{out} is calculated as follows:

Lout=Lin+2paddingdilation(kernel_size1)1stride+1L_{out} = \frac{L_{in}+2*padding-dilation*(kernel\_size - 1) - 1}{stride} + 1

PyTorch

    # Input channels 3, output channels 16, kernel size 3, stride 2
    # Conv1d(3, 16, kernel_size=(3,), stride=(2,))
    m = nn.Conv1d(3, 16, 3, stride=2,bias=False)
    input = torch.ones(1, 3, 224,dtype=torch.float32)
    output = m(input)
    print(output)
    print(output.size()) #Tensor[dims 1, 16, 111; f32]

Candle

    // Conv1d(3, 16, kernel_size=(3,), stride=(2,))
    let cfg = candle_nn::Conv1dConfig{
        stride: 2,
        padding:0,
        dilation:1,
        groups:1,
    };
    let conv1d = candle_nn::conv1d_no_bias(3, 16, 3, cfg, vb)?;
    let x = Tensor::ones((1,3,224), DType::F32, &Device::Cpu)?;
    let y = conv1d.forward(&x)?;
    println!("{y}"); //Tensor[[1, 16, 111], f32]

It should be noted here that in PyTorch, bias=False corresponds to conv1d_no_bias in Candle, while bias=True corresponds to conv1d in Candle.

2D Convolution

Conv2d is mainly used to process image data, extracting local features by sliding filters (kernel) in two-dimensional space (height and width).

The formula is as follows:

out(Ni,Coutj​​)=bias(Coutj)+k=0Cin1weight(Coutj,k)input(Ni,k)out(N_i​,C_{out_j}​​) = bias(C_{out_j}) + \sum^{C_{in}-1}_{k=0} weight(C_{out_j},k)⋆input(N_i,k)

Both input and output are (N,C,H,W). The output dimensions HoutH_{out} and WoutW_{out} are respectively as follows:

Hout=Hin+2padding[0]dilation[0](kernel_size[0]1)1stride[0]+1H_{out} = \frac{H_{in}+2*padding[0]-dilation[0]*(kernel\_size[0]-1)-1}{stride[0]} + 1 Wout=Win+2padding[1]dilation[1](kernel_size[1]1)1stride[1]+1W_{out} = \frac{W_{in}+2*padding[1]-dilation[1]*(kernel\_size[1]-1)-1}{stride[1]} + 1

PyTorch

    # Input channels 3, output channels 16, kernel size 3, stride 2
    # Conv1d(3, 16, kernel_size=(3,), stride=(2,))
    m = nn.Conv2d(3, 16, 3, stride=2,bias=False)
    # Input parameters: (N,C,H,W)
    input = torch.ones(1, 3, 224,224,dtype=torch.float32)
    output = m(input)
    print(output)
    print(output.size()) #torch.Size([1, 16, 111, 111])

Candle:

    // Conv1d(3, 16, kernel_size=(3,), stride=(2,))
    let cfg = candle_nn::Conv2dConfig{
        stride: 2,
        padding:0,
        dilation:1,
        groups:1,
    };
    let conv1d = candle_nn::conv2d_no_bias(3, 16, 3, cfg, vb)?;
    let x = Tensor::ones((1,3,224,224), DType::F32, &Device::Cpu)?;
    let y = conv1d.forward(&x)?;
    println!("{y}"); // Tensor[[1, 16, 111, 111], f32]

Bias works the same way as conv1d.

3D Convolution

Conv3d is mainly used to process 4-dimensional spatiotemporal data or volumetric data, such as videos. Unfortunately, Candle does not support 3D convolution, and even does not support any 3D operations [as of March 2025].

1D Transposed Convolution / Deconvolution

One-dimensional transposed convolutional layers (also known as deconvolution layers) are mainly used for upsampling or restoring feature map sizes.

Inputs and outputs are the same as 1D convolution. Input (batch, channels, length):

(N,Cin,Lin) (N,C_{in},L_{in})

Output (batch, channels, length)

(N,Cout,Lout) (N,C_{out},L_{out})

The calculation formula for length LoutL_{out} is as follows:

Lout=(Lin1)×stride2×padding+dilation×(kernel_size1)+output_padding+1L_{out​}=(L_{in}​−1)×stride−2×padding+dilation×(kernel\_size−1)+output\_padding+1

PyTorch:

    m = nn.ConvTranspose1d(3, 16, 3, stride=2,bias=False)
    input = torch.ones(1, 3, 224,dtype=torch.float32)
    output = m(input)
    print(output)
    print(output.size()) #torch.Size([1, 16, 449])

Candle:

    let cfg = candle_nn::ConvTranspose1dConfig{
        stride: 2,
        padding:0,
        dilation:1,
        groups:1,
        ..Default::default()
    };
    let conv1d = candle_nn::conv_transpose1d_no_bias(3, 16, 3, cfg, vb)?;
    let x = Tensor::ones((1,3,224), DType::F32, &Device::Cpu)?;
    let y = conv1d.forward(&x)?;
    println!("{y}"); // Tensor[[1, 16, 449], f32]

2D Transposed Convolution / Deconvolution

Two-dimensional transposed convolutional layers (also known as deconvolution layers) are mainly used for upsampling or restoring image resolution. They map low-resolution feature maps to high-resolution space through transposed convolution operations, commonly used in tasks such as image generation (e.g., GANs) and semantic segmentation. Like 2D convolution, inputs and outputs are (N,C,H,W).

The formulas for output height H and width W are as follows:

Hout=(Hin1)×stride[0]2×padding[0]+dilation[0]×(kernel_size[0]1)+output_padding[0]+1H_{out​}=(H_{in}​−1)×stride[0]−2×padding[0]+dilation[0]×(kernel\_size[0]−1)+output\_padding[0]+1 Wout=(Win1)×stride[1]2×padding[1]+dilation[1]×(kernel_size[1]1)+output_padding[1]+1W_{out}​=(W_{in}​−1)×stride[1]−2×padding[1]+dilation[1]×(kernel\_size[1]−1)+output\_padding[1]+1

PyTorch:

    m = nn.ConvTranspose2d(3, 16, 3, stride=2,bias=False)
    input = torch.ones(1, 3, 224,224,dtype=torch.float32)
    output = m(input)
    print(output)
    print(output.size()) #torch.Size([1, 16, 449, 449])

Candle:

    let cfg = candle_nn::ConvTranspose2dConfig{
        stride: 2,
        padding:0,
        dilation:1,
        ..Default::default()
    };
    let conv1d = candle_nn::conv_transpose2d_no_bias(3, 16, 3, cfg, vb)?;
    let x = Tensor::ones((1,3,224,224), DType::F32, &Device::Cpu)?;
    let y = conv1d.forward(&x)?;
    println!("{y}"); // Tensor[[1, 16, 449, 449], f32]

2D Max Pooling

Max pooling can exist as a layer in PyTorch, but in Candle, it is an operation directly performed on tensors. PyTorch

    p = nn.MaxPool2d(3, stride=2)
    input = torch.ones(1, 3, 224,224,dtype=torch.float32)
    output = p(input)

Candle

    let x = Tensor::ones((1,3,224,224), DType::F32, &Device::Cpu)?;
    let y = x.max_pool2d_with_stride(3,2)?;

2D Average Pooling

2D average pooling operations are similar to the above 2D max pooling, except that max_pool2d becomes avg_pool2d, and max_pool2d_with_stride becomes avg_pool2d_with_stride.

Activation Layers

Activation layers mainly introduce non-linear transformations, enabling models to learn complex patterns.

PyTorch also has functional calling methods, which we won’t expand upon here.

PyTorch

    p = nn.ReLU()
    p = nn.GELU()
    p = nn.SiLU()
    p = nn.Tanh()
    p = nn.ELU()
    p = nn.Sigmoid()
    p = nn.ReLU6()
    p = nn.LeakyReLU(0.01)

Both methods are available in Candle. Essentially, candle_nn::Activation implements Module, encapsulating the former method.

    let x = Tensor::ones((1,3,224,224), DType::F32, &Device::Cpu)?;
    let y = x.relu()?; //relu
    let y = x.gelu()?;//gelu
    let y = x.silu()?;//gelu
    let y = x.tanh()?;//tanh
    let y = x.elu(1f64)?;//elu
    let y = candle_nn::ops::sigmoid(&x)?;//sigmoid
    let y = x.clamp(0f32, 6f32)?; //relu6
    let y = x.relu()?.sqr()?;//Relu2
    let y = candle_nn::ops::leaky_relu(&x, 0.01)?;
	// Or the following way.
    let activation = candle_nn::Activation::Relu;
    let activation = candle_nn::Activation::Gelu;
    let activation = candle_nn::Activation::Silu;
    let activation = candle_nn::Activation::Elu(0.01);
    let activation = candle_nn::Activation::Sigmoid;
    let activation = candle_nn::Activation::Relu6;
    //...

    let y = activation.forward(&x)?;

More layers will be covered in the next article.