Mar 24, 2025
4 min read
Rust,
Candle,
Pytorch,

Rust Candle 框架与 Pytorch nn 模块网络层转换(2)

本文对比了 Rust Candle 框架与 PyTorch 中常见神经网络层的实现,包括归一化层(BatchNorm、LayerNorm、RMSNorm)、循环层(LSTM、GRU)、Transformer、线性层、Dropout 层、嵌入层及上采样层。总结了两者的功能对应关系及差异,并指出 Candle 在大模型推理场景下的支持情况。

接下来是归一化层、循环层、Transformer 层、线性层、Dropout 层、稀疏层。

表格预览

  • ✅ 表示有对应实现
  • 🚫 表示无对应实现
  • ☢️ 表示有代替实现
功能pytorchcandle是否实现
对 2D 或 3D 输入应用批归一化nn.BatchNorm1dbatch_norm
对 4D 输入应用批归一化nn.BatchNorm2dbatch_norm
对 5D 输入应用批归一化nn.BatchNorm3d未实现🚫
应用层归一化nn.LayerNormLayerNorm
应用均方根层归一化nn.RMSNormRMSNorm
RNN 模块nn.RNN未实现
多层长短期记忆 (LSTM) RNN 应用于输入序列nn.LSTMLSTM
多层门控循环单元 (GRU) RNN 应用于输入序列nn.GRUGRU
Transformer 模型nn.Transformer见各大模型☢️
占位层nn.Identity不需要实现🚫
连接层nn.LinearLinear,linear_no_bias
Dropout 层nn.Dropout/nn.Dropout1d/nn.Dropout2dDropout
嵌入nn.EmbeddingEmbedding
上采样nn.Upsampleupsample_nearest1d
interpolate1d
upsample_nearest2d
interpolate2d

应用批归一化

candle 的归一化接口同时支持 BatchNorm1d 和 BatchNorm2d。 pytorch:


	b = nn.BatchNorm1d(3,affine=False)
    input = torch.ones(1, 3, 224,dtype=torch.float32)
    output = b(input)
    
	b = nn.BatchNorm2d(3,affine=False)
    input = torch.ones(1, 3,224, 224,dtype=torch.float32)
    output = b(input)

candle:

	let cfg = candle_nn::BatchNormConfig{
        affine:false,
        ..Default::default()
    };
    let b = candle_nn::batch_norm(3, cfg, vb)?;
    //let x = Tensor::ones((1,3,224), DType::F32, &Device::Cpu)?;
    let x = Tensor::ones((1,3,224,224), DType::F32, &Device::Cpu)?;
    let y = b.forward_train(&x)?;

源代码中,明确要求至少需要三个维度,其中第二个维度(也就是上面的 3)为特征数,其他维度被展平,刚好 BatchNorm1d 和 BatchNorm2d 都支持。candle 中虽然 BatchNorm3d(1,3,10,224,224) 也不会报错,但是结果已经和 pytorch 不太一样,说明不支持。

应用层归一化

由于应用层归一化(LayerNorm)是大模型结构中最常见的一个层,而 candle 其实是专门用于做大模型推理的框架,因此 candle 是肯定支持的。

应用层归一化的公式大概如下:

y=xE[x]Var[x]+ϵγ+βy=\frac{​x−E[x]}{Var[x]+ϵ}​∗γ+β

pytorch

	batch, sentence_length, embedding_dim = 20, 5, 10
    embedding = torch.randn(batch, sentence_length, embedding_dim)
    layer_norm = nn.LayerNorm(embedding_dim)
    output = layer_norm(embedding)

candle:

let cfg = candle_nn::LayerNormConfig{
        affine:false,
        ..Default::default()
    };
    let b = candle_nn::layer_norm(10, cfg, vb)?;
    let x = Tensor::ones((20, 5, 10), DType::F32, &Device::Cpu)?;
    let y = b.forward(&x)?;

不过, candle 的 layer_norm 并不支持图片上的维度nn.LayerNorm([C, H, W])。当然,一般 cnn 网络也不太会使用 应用层归一化(LayerNorm)。

应用均方根层归一化

candle 中 RMSNorm 不支持更高的归一化维度输入。

pytorch

    rms_norm = nn.RMSNorm([3]) # or 3
    input = torch.randn(2, 2, 3)
    output = rms_norm(input)
    print(output)

candle:

	let b = candle_nn::rms_norm(3, 1e-5, vb)?;
    let x = Tensor::ones((2,3), DType::F32, &Device::Cpu)?;
    let y = b.forward(&x)?;

循环层

RNN 中,candle 支持 GRU 和 LSTM。不过我对 RNN 不了解,这里不展开。可以确定的是,candle 中的 RNN 层推理代码和 pytorch 写法不一样。

Transformer 模型

pytorch 中的 nn.Transformer 应该是原始版本,但是现在的所有大模型应该都没有使用原始版本,而是在内部做了大量的改动。因此我们在 candle 仓库原代码中会见到大量的 XXXTransformer。 由于 candle 对大模型的支持度最高,因此可以直接查看这个地址,里面有各大模型的结构实现:

https://github.com/huggingface/candle/tree/cb02b389d53a1cf5547dfa69b5168bdc1a50d325/candle-transformers

连接层

连接层是最觉用和最常见的层。

pytorch

	 m = nn.Linear(512, 512,bias=False)
    input = torch.ones(1, 512)
    output = m(input)
    print(output)

candle:

	let m = linear_no_bias(512, 512, vb)?;
    let xs = Tensor::ones((1,512), DType::F32, &Device::Cpu)?;
    let y = m.forward(&xs)?;

Dropout 层

推理阶段通常不需要 Dropout 层,因为Dropout 层一般用于训练阶段防止过拟合使用。不过,candle 中也有 Dropout 层 api。

let d = candle_nn::Dropout::new(0.5);

//or 
let y = candle_nn::ops::dropout(xs, 0.5)?;

嵌入层

pytorch

	m =  nn.Embedding(10, 3)
    input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
    output = m(input)
    print(output)#Tensor[[2, 4, 3], f32]

candle:

	let m = candle_nn::embedding(10, 3, vb)?;
    let data = vec![1u32, 2, 4, 5,4, 3, 2, 9];
    let xs = Tensor::from_vec(data, (2,4), &Device::Cpu)?;
    let y = m.forward(&xs)?;

上采样

两边 api 不太一样,pytorch 的上采样参数用的是 scale_factor 倍数放大,而 candle interpolate1d 和 interpolate2d 的参数是目标尺寸(如 target_size 或 target_h, target_w)。

还有一个是 pytorch 的上采样支持其他算法,比如:bilinear。而 candle 中只支持 nearest,所以,interpolate1d 和 interpolate2d 分别都有别名:upsample_nearest1dupsample_nearest2d

pytorch:

	m = nn.Upsample(scale_factor=2, mode='nearest')
    # torch.Size([1, 1, 2, 4])
    input = torch.Tensor([[[[1, 2, 4, 5], [4, 3, 2, 9]]]])
    output = m(input)
    print(output)
    # tensor([[[[1., 1., 2., 2., 4., 4., 5., 5.],
    #       [1., 1., 2., 2., 4., 4., 5., 5.],
    #       [4., 4., 3., 3., 2., 2., 9., 9.],
    #       [4., 4., 3., 3., 2., 2., 9., 9.]]]])

candle:

    let data = vec![1u32, 2, 4, 5,4, 3, 2, 9];
    let input = Tensor::from_vec(data, (2, 4), &Device::Cpu)?.unsqueeze(0)?.unsqueeze(0)?;
    let input_height = input.dim(2)?; // 获取输入的高度
    let input_width = input.dim(3)?; // 获取输入的宽度
    let y = input.interpolate2d(input_height * 2, input_width * 2)?;
    
    println!("{y}");
//    [[[[1, 1, 2, 2, 4, 4, 5, 5],
//    [1, 1, 2, 2, 4, 4, 5, 5],
//    [4, 4, 3, 3, 2, 2, 9, 9],
//    [4, 4, 3, 3, 2, 2, 9, 9]]]]