Next, we will cover normalization layers, recurrent layers, Transformer layers, linear layers, Dropout layers, and sparse layers.
Table Preview
- ✅ Indicates implemented
- 🚫 Indicates not implemented
- ☢️ Indicates alternative implementation
| Function | PyTorch | Candle | Implemented |
|---|---|---|---|
| Apply batch normalization to 2D or 3D input | nn.BatchNorm1d | batch_norm | ✅ |
| Apply batch normalization to 4D input | nn.BatchNorm2d | batch_norm | ✅ |
| Apply batch normalization to 5D input | nn.BatchNorm3d | Not implemented | 🚫 |
| Apply layer normalization | nn.LayerNorm | LayerNorm | ✅ |
| Apply root mean square layer normalization | nn.RMSNorm | RMSNorm | ✅ |
| RNN module | nn.RNN | Not implemented | |
| Multi-layer Long Short-Term Memory (LSTM) RNN applied to input sequences | nn.LSTM | LSTM | ✅ |
| Multi-layer Gated Recurrent Unit (GRU) RNN applied to input sequences | nn.GRU | GRU | ✅ |
| Transformer model | nn.Transformer | See various models | ☢️ |
| Placeholder layer | nn.Identity | Not needed | 🚫 |
| Linear layer | nn.Linear | Linear,linear_no_bias | ✅ |
| Dropout layer | nn.Dropout/nn.Dropout1d/nn.Dropout2d | Dropout | ✅ |
| Embedding | nn.Embedding | Embedding | ✅ |
| Upsampling | nn.Upsample | upsample_nearest1d interpolate1d upsample_nearest2d interpolate2d | ✅ |
Applying Batch Normalization
The normalization interface in Candle supports both BatchNorm1d and BatchNorm2d. PyTorch:
b = nn.BatchNorm1d(3,affine=False)
input = torch.ones(1, 3, 224,dtype=torch.float32)
output = b(input)
b = nn.BatchNorm2d(3,affine=False)
input = torch.ones(1, 3,224, 224,dtype=torch.float32)
output = b(input)
Candle:
let cfg = candle_nn::BatchNormConfig{
affine:false,
..Default::default()
};
let b = candle_nn::batch_norm(3, cfg, vb)?;
//let x = Tensor::ones((1,3,224), DType::F32, &Device::Cpu)?;
let x = Tensor::ones((1,3,224,224), DType::F32, &Device::Cpu)?;
let y = b.forward_train(&x)?;
In the source code, it explicitly requires at least three dimensions, where the second dimension (i.e., 3 above) is the number of features, and other dimensions are flattened, which happens to support both BatchNorm1d and BatchNorm2d. Although BatchNorm3d(1,3,10,224,224) does not throw an error in Candle, the result is already different from PyTorch, indicating that it is not supported.
Applying Layer Normalization
Since applying layer normalization (LayerNorm) is one of the most common layers in large model architectures, and Candle is actually a framework specifically designed for large model inference, Candle definitely supports it.
The formula for applying layer normalization is roughly as follows:
PyTorch:
batch, sentence_length, embedding_dim = 20, 5, 10
embedding = torch.randn(batch, sentence_length, embedding_dim)
layer_norm = nn.LayerNorm(embedding_dim)
output = layer_norm(embedding)
Candle:
let cfg = candle_nn::LayerNormConfig{
affine:false,
..Default::default()
};
let b = candle_nn::layer_norm(10, cfg, vb)?;
let x = Tensor::ones((20, 5, 10), DType::F32, &Device::Cpu)?;
let y = b.forward(&x)?;
However, Candle’s layer_norm does not support image dimensions like nn.LayerNorm([C, H, W]). Of course, CNN networks generally do not use layer normalization (LayerNorm).
Applying Root Mean Square Layer Normalization
Candle’s RMSNorm does not support higher normalization dimensional inputs.
PyTorch:
rms_norm = nn.RMSNorm([3]) # or 3
input = torch.randn(2, 2, 3)
output = rms_norm(input)
print(output)
Candle:
let b = candle_nn::rms_norm(3, 1e-5, vb)?;
let x = Tensor::ones((2,3), DType::F32, &Device::Cpu)?;
let y = b.forward(&x)?;
Recurrent Layers
In RNNs, Candle supports GRU and LSTM. However, I am not familiar with RNNs, so I won’t expand on this. What can be confirmed is that the inference code for RNN layers in Candle is different from PyTorch.
Transformer Model
The nn.Transformer in PyTorch should be the original version, but none of the current large models use the original version; instead, they have made extensive modifications internally. Therefore, in the original Candle repository code, we see a lot of XXXTransformer.
Since Candle has the highest support for large models, you can directly check this address, which contains the structural implementations of various major models:
https://github.com/huggingface/candle/tree/cb02b389d53a1cf5547dfa69b5168bdc1a50d325/candle-transformers
Linear Layer
The linear layer is the most commonly used and common layer.
PyTorch:
m = nn.Linear(512, 512,bias=False)
input = torch.ones(1, 512)
output = m(input)
print(output)
Candle:
let m = linear_no_bias(512, 512, vb)?;
let xs = Tensor::ones((1,512), DType::F32, &Device::Cpu)?;
let y = m.forward(&xs)?;
Dropout Layer
The Dropout layer is usually not needed during the inference phase because it is generally used during the training phase to prevent overfitting. However, there is also a Dropout layer API in Candle.
let d = candle_nn::Dropout::new(0.5);
//or
let y = candle_nn::ops::dropout(xs, 0.5)?;
Embedding Layer
PyTorch:
m = nn.Embedding(10, 3)
input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
output = m(input)
print(output)#Tensor[[2, 4, 3], f32]
Candle:
let m = candle_nn::embedding(10, 3, vb)?;
let data = vec![1u32, 2, 4, 5,4, 3, 2, 9];
let xs = Tensor::from_vec(data, (2,4), &Device::Cpu)?;
let y = m.forward(&xs)?;
Upsampling
The APIs on both sides are somewhat different. PyTorch’s upsampling parameters use scale_factor for magnification, while Candle’s interpolate1d and interpolate2d parameters are target sizes (such as target_size or target_h, target_w).
Additionally, PyTorch’s upsampling supports other algorithms, such as bilinear. However, Candle only supports nearest, so interpolate1d and interpolate2d each have aliases: upsample_nearest1d and upsample_nearest2d.
PyTorch:
m = nn.Upsample(scale_factor=2, mode='nearest')
# torch.Size([1, 1, 2, 4])
input = torch.Tensor([[[[1, 2, 4, 5], [4, 3, 2, 9]]]])
output = m(input)
print(output)
# tensor([[[[1., 1., 2., 2., 4., 4., 5., 5.],
# [1., 1., 2., 2., 4., 4., 5., 5.],
# [4., 4., 3., 3., 2., 2., 9., 9.],
# [4., 4., 3., 3., 2., 2., 9., 9.]]]])
Candle:
let data = vec![1u32, 2, 4, 5,4, 3, 2, 9];
let input = Tensor::from_vec(data, (2, 4), &Device::Cpu)?.unsqueeze(0)?.unsqueeze(0)?;
let input_height = input.dim(2)?; // Get input height
let input_width = input.dim(3)?; // Get input width
let y = input.interpolate2d(input_height * 2, input_width * 2)?;
println!("{y}");
// [[[[1, 1, 2, 2, 4, 4, 5, 5],
// [1, 1, 2, 2, 4, 4, 5, 5],
// [4, 4, 3, 3, 2, 2, 9, 9],
// [4, 4, 3, 3, 2, 2, 9, 9]]]]