Mar 24, 2025
6 min read
Rust,
Candle,
Pytorch,

Rust Candle Framework and PyTorch nn Module Network Layer Conversion (2)

This article compares the implementation of common neural network layers in Rust Candle and PyTorch, including normalization layers (BatchNorm, LayerNorm, RMSNorm), recurrent layers (LSTM, GRU), Transformer, linear layers, Dropout layers, embedding layers, and upsampling layers. It summarizes the functional correspondence and differences between the two frameworks and highlights Candle's support for large model inference scenarios.

Next, we will cover normalization layers, recurrent layers, Transformer layers, linear layers, Dropout layers, and sparse layers.

Table Preview

  • ✅ Indicates implemented
  • 🚫 Indicates not implemented
  • ☢️ Indicates alternative implementation
FunctionPyTorchCandleImplemented
Apply batch normalization to 2D or 3D inputnn.BatchNorm1dbatch_norm
Apply batch normalization to 4D inputnn.BatchNorm2dbatch_norm
Apply batch normalization to 5D inputnn.BatchNorm3dNot implemented🚫
Apply layer normalizationnn.LayerNormLayerNorm
Apply root mean square layer normalizationnn.RMSNormRMSNorm
RNN modulenn.RNNNot implemented
Multi-layer Long Short-Term Memory (LSTM) RNN applied to input sequencesnn.LSTMLSTM
Multi-layer Gated Recurrent Unit (GRU) RNN applied to input sequencesnn.GRUGRU
Transformer modelnn.TransformerSee various models☢️
Placeholder layernn.IdentityNot needed🚫
Linear layernn.LinearLinear,linear_no_bias
Dropout layernn.Dropout/nn.Dropout1d/nn.Dropout2dDropout
Embeddingnn.EmbeddingEmbedding
Upsamplingnn.Upsampleupsample_nearest1d
interpolate1d
upsample_nearest2d
interpolate2d

Applying Batch Normalization

The normalization interface in Candle supports both BatchNorm1d and BatchNorm2d. PyTorch:


	b = nn.BatchNorm1d(3,affine=False)
    input = torch.ones(1, 3, 224,dtype=torch.float32)
    output = b(input)
    
	b = nn.BatchNorm2d(3,affine=False)
    input = torch.ones(1, 3,224, 224,dtype=torch.float32)
    output = b(input)

Candle:

	let cfg = candle_nn::BatchNormConfig{
        affine:false,
        ..Default::default()
    };
    let b = candle_nn::batch_norm(3, cfg, vb)?;
    //let x = Tensor::ones((1,3,224), DType::F32, &Device::Cpu)?;
    let x = Tensor::ones((1,3,224,224), DType::F32, &Device::Cpu)?;
    let y = b.forward_train(&x)?;

In the source code, it explicitly requires at least three dimensions, where the second dimension (i.e., 3 above) is the number of features, and other dimensions are flattened, which happens to support both BatchNorm1d and BatchNorm2d. Although BatchNorm3d(1,3,10,224,224) does not throw an error in Candle, the result is already different from PyTorch, indicating that it is not supported.

Applying Layer Normalization

Since applying layer normalization (LayerNorm) is one of the most common layers in large model architectures, and Candle is actually a framework specifically designed for large model inference, Candle definitely supports it.

The formula for applying layer normalization is roughly as follows:

y=xE[x]Var[x]+ϵγ+βy=\frac{​x−E[x]}{Var[x]+ϵ}​∗γ+β

PyTorch:

	batch, sentence_length, embedding_dim = 20, 5, 10
    embedding = torch.randn(batch, sentence_length, embedding_dim)
    layer_norm = nn.LayerNorm(embedding_dim)
    output = layer_norm(embedding)

Candle:

let cfg = candle_nn::LayerNormConfig{
        affine:false,
        ..Default::default()
    };
    let b = candle_nn::layer_norm(10, cfg, vb)?;
    let x = Tensor::ones((20, 5, 10), DType::F32, &Device::Cpu)?;
    let y = b.forward(&x)?;

However, Candle’s layer_norm does not support image dimensions like nn.LayerNorm([C, H, W]). Of course, CNN networks generally do not use layer normalization (LayerNorm).

Applying Root Mean Square Layer Normalization

Candle’s RMSNorm does not support higher normalization dimensional inputs.

PyTorch:

    rms_norm = nn.RMSNorm([3]) # or 3
    input = torch.randn(2, 2, 3)
    output = rms_norm(input)
    print(output)

Candle:

	let b = candle_nn::rms_norm(3, 1e-5, vb)?;
    let x = Tensor::ones((2,3), DType::F32, &Device::Cpu)?;
    let y = b.forward(&x)?;

Recurrent Layers

In RNNs, Candle supports GRU and LSTM. However, I am not familiar with RNNs, so I won’t expand on this. What can be confirmed is that the inference code for RNN layers in Candle is different from PyTorch.

Transformer Model

The nn.Transformer in PyTorch should be the original version, but none of the current large models use the original version; instead, they have made extensive modifications internally. Therefore, in the original Candle repository code, we see a lot of XXXTransformer. Since Candle has the highest support for large models, you can directly check this address, which contains the structural implementations of various major models:

https://github.com/huggingface/candle/tree/cb02b389d53a1cf5547dfa69b5168bdc1a50d325/candle-transformers

Linear Layer

The linear layer is the most commonly used and common layer.

PyTorch:

	 m = nn.Linear(512, 512,bias=False)
    input = torch.ones(1, 512)
    output = m(input)
    print(output)

Candle:

	let m = linear_no_bias(512, 512, vb)?;
    let xs = Tensor::ones((1,512), DType::F32, &Device::Cpu)?;
    let y = m.forward(&xs)?;

Dropout Layer

The Dropout layer is usually not needed during the inference phase because it is generally used during the training phase to prevent overfitting. However, there is also a Dropout layer API in Candle.

let d = candle_nn::Dropout::new(0.5);

//or 
let y = candle_nn::ops::dropout(xs, 0.5)?;

Embedding Layer

PyTorch:

	m =  nn.Embedding(10, 3)
    input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
    output = m(input)
    print(output)#Tensor[[2, 4, 3], f32]

Candle:

	let m = candle_nn::embedding(10, 3, vb)?;
    let data = vec![1u32, 2, 4, 5,4, 3, 2, 9];
    let xs = Tensor::from_vec(data, (2,4), &Device::Cpu)?;
    let y = m.forward(&xs)?;

Upsampling

The APIs on both sides are somewhat different. PyTorch’s upsampling parameters use scale_factor for magnification, while Candle’s interpolate1d and interpolate2d parameters are target sizes (such as target_size or target_h, target_w).

Additionally, PyTorch’s upsampling supports other algorithms, such as bilinear. However, Candle only supports nearest, so interpolate1d and interpolate2d each have aliases: upsample_nearest1d and upsample_nearest2d.

PyTorch:

	m = nn.Upsample(scale_factor=2, mode='nearest')
    # torch.Size([1, 1, 2, 4])
    input = torch.Tensor([[[[1, 2, 4, 5], [4, 3, 2, 9]]]])
    output = m(input)
    print(output)
    # tensor([[[[1., 1., 2., 2., 4., 4., 5., 5.],
    #       [1., 1., 2., 2., 4., 4., 5., 5.],
    #       [4., 4., 3., 3., 2., 2., 9., 9.],
    #       [4., 4., 3., 3., 2., 2., 9., 9.]]]])

Candle:

    let data = vec![1u32, 2, 4, 5,4, 3, 2, 9];
    let input = Tensor::from_vec(data, (2, 4), &Device::Cpu)?.unsqueeze(0)?.unsqueeze(0)?;
    let input_height = input.dim(2)?; // Get input height
    let input_width = input.dim(3)?; // Get input width
    let y = input.interpolate2d(input_height * 2, input_width * 2)?;
    
    println!("{y}");
//    [[[[1, 1, 2, 2, 4, 4, 5, 5],
//    [1, 1, 2, 2, 4, 4, 5, 5],
//    [4, 4, 3, 3, 2, 2, 9, 9],
//    [4, 4, 3, 3, 2, 2, 9, 9]]]]