Mar 24, 2025
6 min read
Rust,
Candle,
Pytorch,

Rust Candle 框架与 Pytorch nn 模块网络层转换(1)

本文对比了 Rust Candle 框架与 PyTorch 在神经网络层实现上的异同,涵盖顺序容器、卷积层(1D/2D 及转置卷积)、池化层(最大池化与平均池化)及常见激活函数。重点分析了两者在功能实现、参数配置及使用方式上的对应关系,并指出 Candle 暂不支持 3D 卷积等部分功能。

之前大体上了解完 pytorch 与 candle 之间张量的一些操作,现在来到了图的构建块: nn 模块的一些相同或者不同的地方。

torch.nn 模块是构建和训练神经网络的核心工具包,里面有大量的网络层、损失函数、优化器等的定义。

由于 candle 重推理而不是训练,因此这里重点讲讲 candle 和 pytorch 网络层之间的相互转换。

表格预览

  • ✅ 表示有对应实现
  • 🚫 表示无对应实现
  • ☢️ 表示有代替实现
功能pytorchcandle是否实现
顺序容器SequentialSequential
1D 卷积nn.Conv1dconv1d/conv1d_no_bias
2D 卷积nn.Conv2dconv2d/conv2d_no_bias
3D 卷积nn.Conv3d未实现🚫
1D 转置卷积nn.ConvTranspose1dconv_transpose1d
conv_transpose1d_no_bias
2D 转置卷积nn.ConvTranspose2dconv_transpose2d
conv_transpose2d_no_bias
3D 转置卷积nn.ConvTranspose3d未实现🚫
1D 最大池化nn.MaxPool1d未实现🚫
2D 最大池化nn.MaxPool2dmax_pool2d
max_pool2d_with_stride
3D 最大池化nn.MaxPool3d未实现🚫
1D 平均池化nn.AvgPool1d未实现🚫
2D 平均池化nn.AvgPool2davg_pool2d
avg_pool2d_with_stride
3D 平均池化nn.AvgPool3d未实现🚫
逐元素应用修正线性单元函数nn.ReLUrelu
逐元素应用 ReLU6 函数nn.ReLU6relu6
逐元素应用随机泄漏修正线性单元函数nn.RReLU未实现
应用高斯误差线性单元函数nn.GELUgelu
Sigmoid 函数nn.Sigmoidsigmoid
Sigmoid 线性单元 (SiLU) 函数nn.SiLUsilu
逐元素应用双曲正切 (Tanh) 函数nn.Tanhtanh
逐元素应用指数线性单元 (ELU) 函数。nn.ELUelu

开始之前

为了对比两边的操作是否等效,我们需要找到一个方法固定并共享 pytorch 和 candle 两个框架的测试权重。 因为有一些参数如果初始化时不设定权重,会默认随机分配权重值。比如下面这个代码:

m = nn.Conv1d(3, 16, 3)
input = torch.ones(1, 3, 224,dtype=torch.float32)
output = m(input)

上面的代码由于没有初始化权重,事实上 output 的结果每次都不一样,因此为了方便复现,我们需要保存权重,把随机的参数都固定下来。

最好的办法是把python 的权重保存到文件,然后在 candle 中复现。 pytorch 保存权重,这里使用 safetensors 格式:

from safetensors.torch import save_model
m = nn.Conv1d(3, 16, 3, stride=2,bias=False)
save_model(m, "model.safetensors")

candle 中,直接读取model.safetensors 权重即可复现,同时这也是最接近使用 candle 实现 pytorch 模型的方式:

let vb = unsafe {
        VarBuilder::from_mmaped_safetensors(&["./model.safetensors"], DType::F32, &Device::Cpu)?
    };

Sequential

Sequential 只是容器,本身并无任何功能,而且也不会影响权重的加载。因此我们可以使用 candle 官方的,也可以自己定义。 pytorch

seq = torch.nn.Sequential(
        nn.Conv2d(3, 64, 3, 1, 1),
        nn.ReLU(),
        nn.Conv2d()
    )

candle 官方的定义本质上也是内部维护一个 vec:

    let mut seq = candle_nn::seq();
    seq.add(line);

因此如果官方的不满足(为什么不满足? rust 由于类型系统存在中没有那么方便,比如官方的实现类型是:Vec<Box<dyn Module>> ),可以自己实现:

use candle_core::{ Module, Tensor, Result };

#[derive(Debug, Clone)]
pub struct Sequential<T: Module> {
    layers: Vec<T>,
}

pub fn seq<T: Module>(cnt: usize) -> Sequential<T> {
    let v = if cnt == 0 { vec![] } else { Vec::with_capacity(cnt) };
    Sequential { layers: v }
}

impl<T: Module> Sequential<T> {
    pub fn len(&self) -> usize {
        self.layers.len()
    }

    pub fn is_empty(&self) -> bool {
        self.layers.is_empty()
    }

    pub fn push(&mut self, layer: T) {
        self.layers.push(layer);
    }

    pub fn add(&mut self, layer: T) {
        self.layers.push(layer);
    }
}
impl<T: Module> Module for Sequential<T> {
    fn forward(&self, xs: &candle_core::Tensor) -> Result<Tensor> {
        let mut xs = xs.clone();
        for layer in self.layers.iter() {
            xs = xs.apply(layer)?;
        }
        Ok(xs)
    }
}

1D 卷积

conv1d 主要用于处理 序列数据 ,其核心是通过滑动窗口在单个维度(通常是时间或序列维度)上提取局部特征。比如传感器数据(如温度、股票价格)、医疗信号(如 ECG)、工业监控等。

Conv1d 的公式大概如下:

out(Ni,Coutj​​)=bias(Coutj)+k=0Cin1weight(Coutj​​​​,k)input(Ni,k)out(N_i​,C_{out_j}​​)=bias(C_{out_j})+ \sum_{k=0}^{C_{in}​−1}​weight(C{out_j}​​​​,k)⋆input(N_i​,k)

输入(批次,通道,长度):

 (N,Cin,Lin) (N,C_{in},L_{in})

输出(批次,通道,长度)

 (N,Cout,Lout) (N,C_{out},L_{out})

其中 LoutL_{out} 公式如下:

Lout=Lin+2paddingdilation(kernel_size1)1stride+1L_{out} = \frac{L_{in}+2*padding-dilation*(kernel\_size - 1) - 1}{stride} + 1

pytorch

    # 输入通道为3,输出通道为16,卷积核大小为3,步长为2
    # Conv1d(3, 16, kernel_size=(3,), stride=(2,))
    m = nn.Conv1d(3, 16, 3, stride=2,bias=False)
    input = torch.ones(1, 3, 224,dtype=torch.float32)
    output = m(input)
    print(output)
    print(output.size()) #Tensor[dims 1, 16, 111; f32]

candle

    // Conv1d(3, 16, kernel_size=(3,), stride=(2,))
    let cfg = candle_nn::Conv1dConfig{
        stride: 2,
        padding:0,
        dilation:1,
        groups:1,
    };
    let conv1d = candle_nn::conv1d_no_bias(3, 16, 3, cfg, vb)?;
    let x = Tensor::ones((1,3,224), DType::F32, &Device::Cpu)?;
    let y = conv1d.forward(&x)?;
    println!("{y}"); //Tensor[[1, 16, 111], f32]

这里需要注意的是,pytorch 里 bias =False 时,对应的是 candle 中的 conv1d_no_bias,当 bias = True 时,才对应 candle 中的 conv1d。

2D 卷积

conv2d 主要用于主处理 图像数据 ,通过在二维空间(高度和宽度)上滑动滤波器(kernel)提取局部特征。

公式如下:

out(Ni,Coutj​​)=bias(Coutj)+k=0Cin1weight(Coutj,k)input(Ni,k)out(N_i​,C_{out_j}​​) = bias(C_{out_j}) + \sum^{C_{in}-1}_{k=0} weight(C_{out_j},k)⋆input(N_i,k)

输入和输出都是 (N,C,H,W)。 其中输出的尺寸 HoutH_{out}WoutW_{out} 分别如下:

Hout=Hin+2padding[0]dilation[0](kernel_size[0]1)1stride[0]+1H_{out} = \frac{H_{in}+2*padding[0]-dilation[0]*(kernel\_size[0]-1)-1}{stride[0]} + 1 Wout=Win+2padding[1]dilation[1](kernel_size[1]1)1stride[1]+1W_{out} = \frac{W_{in}+2*padding[1]-dilation[1]*(kernel\_size[1]-1)-1}{stride[1]} + 1

pytorch

    # 输入通道为3,输出通道为16,卷积核大小为3,步长为2
    # Conv1d(3, 16, kernel_size=(3,), stride=(2,))
    m = nn.Conv2d(3, 16, 3, stride=2,bias=False)
    # 输入参数: (N,C,H,W)
    input = torch.ones(1, 3, 224,224,dtype=torch.float32)
    output = m(input)
    print(output)
    print(output.size()) #torch.Size([1, 16, 111, 111])

candle:

    // Conv1d(3, 16, kernel_size=(3,), stride=(2,))
    let cfg = candle_nn::Conv2dConfig{
        stride: 2,
        padding:0,
        dilation:1,
        groups:1,
    };
    let conv1d = candle_nn::conv2d_no_bias(3, 16, 3, cfg, vb)?;
    let x = Tensor::ones((1,3,224,224), DType::F32, &Device::Cpu)?;
    let y = conv1d.forward(&x)?;
    println!("{y}"); // Tensor[[1, 16, 111, 111], f32]

bias 上和 conv1d 同理。

3D 卷积

conv3d 主要用于处理具有 4 维的 时空数据体积数据,比如视频。很遗憾的是,candle 不支持3D 卷积, 甚至不支持任何3D 的操作[2025年 3 月前]。

1D 转置卷积/反卷积

一维转置卷积层(也称为反卷积层),主要用于 上采样 或 特征图尺寸恢复 。

输入输出同 1D 卷积。输入(批次,通道,长度):

 (N,Cin,Lin) (N,C_{in},L_{in})

输出(批次,通道,长度)

 (N,Cout,Lout) (N,C_{out},L_{out})

其中长度LoutL_{out}的计算公式如下:

Lout=(Lin1)×stride2×padding+dilation×(kernel_size1)+output_padding+1L_{out​}=(L_{in}​−1)×stride−2×padding+dilation×(kernel\_size−1)+output\_padding+1

pytorch:

    m = nn.ConvTranspose1d(3, 16, 3, stride=2,bias=False)
    input = torch.ones(1, 3, 224,dtype=torch.float32)
    output = m(input)
    print(output)
    print(output.size()) #torch.Size([1, 16, 449])

candle:

    let cfg = candle_nn::ConvTranspose1dConfig{
        stride: 2,
        padding:0,
        dilation:1,
        groups:1,
        ..Default::default()
    };
    let conv1d = candle_nn::conv_transpose1d_no_bias(3, 16, 3, cfg, vb)?;
    let x = Tensor::ones((1,3,224), DType::F32, &Device::Cpu)?;
    let y = conv1d.forward(&x)?;
    println!("{y}"); // Tensor[[1, 16, 449], f32]

2D 转置卷积/反卷积

二维转置卷积层(也称为反卷积层),主要用于 上采样恢复图像分辨率 。它通过转置卷积操作将低分辨率特征图映射到高分辨率空间,常用于图像生成(如 GAN)、语义分割等任务。 和2D 卷积一样,输入和输出都是 (N,C,H,W)。

其中输出的高H 和宽 W 公式如下

Hout=(Hin1)×stride[0]2×padding[0]+dilation[0]×(kernel_size[0]1)+output_padding[0]+1H_{out​}=(H_{in}​−1)×stride[0]−2×padding[0]+dilation[0]×(kernel\_size[0]−1)+output\_padding[0]+1 Wout=(Win1)×stride[1]2×padding[1]+dilation[1]×(kernel_size[1]1)+output_padding[1]+1W_{out}​=(W_{in}​−1)×stride[1]−2×padding[1]+dilation[1]×(kernel\_size[1]−1)+output\_padding[1]+1

pytorch:

    m = nn.ConvTranspose2d(3, 16, 3, stride=2,bias=False)
    input = torch.ones(1, 3, 224,224,dtype=torch.float32)
    output = m(input)
    print(output)
    print(output.size()) #torch.Size([1, 16, 449, 449])

candle:

    let cfg = candle_nn::ConvTranspose2dConfig{
        stride: 2,
        padding:0,
        dilation:1,
        ..Default::default()
    };
    let conv1d = candle_nn::conv_transpose2d_no_bias(3, 16, 3, cfg, vb)?;
    let x = Tensor::ones((1,3,224,224), DType::F32, &Device::Cpu)?;
    let y = conv1d.forward(&x)?;
    println!("{y}"); // Tensor[[1, 16, 449, 449], f32]

2D 最大池化

最大池化在 pytorch 可以做为层存在,但是 candle 是直接对 tensor 进行操作的。 pytorch

    p = nn.MaxPool2d(3, stride=2)
    input = torch.ones(1, 3, 224,224,dtype=torch.float32)
    output = p(input)

candle

    let x = Tensor::ones((1,3,224,224), DType::F32, &Device::Cpu)?;
    let y = x.max_pool2d_with_stride(3,2)?;

2D 平均池化

2D 平均池化操作和上面 2D 最大池化一样,只是 max_pool2d 变成了avg_pool2d,max_pool2d_with_stride变成了avg_pool2d_with_stride

激活层

激活层主要是引入非线性变换,使模型能够学习复杂的模式。

pytorch 也有函数的调用方式,这里不展开。

pytorch

    p = nn.ReLU()
    p = nn.GELU()
    p = nn.SiLU()
    p = nn.Tanh()
    p = nn.ELU()
    p = nn.Sigmoid()
    p = nn.ReLU6()
    p = nn.LeakyReLU(0.01)

candle 两种方法都可用,本质上,candle_nn::Activation 实现了 Module,是对前一种方法的封装。

    let x = Tensor::ones((1,3,224,224), DType::F32, &Device::Cpu)?;
    let y = x.relu()?; //relu
    let y = x.gelu()?;//gelu
    let y = x.silu()?;//gelu
    let y = x.tanh()?;//tanh
    let y = x.elu(1f64)?;//elu
    let y = candle_nn::ops::sigmoid(&x)?;//sigmoid
    let y = x.clamp(0f32, 6f32)?; //relu6
    let y = x.relu()?.sqr()?;//Relu2
    let y = candle_nn::ops::leaky_relu(&x, 0.01)?;
	// 或者下面这种方式。
    let activation = candle_nn::Activation::Relu;
    let activation = candle_nn::Activation::Gelu;
    let activation = candle_nn::Activation::Silu;
    let activation = candle_nn::Activation::Elu(0.01);
    let activation = candle_nn::Activation::Sigmoid;
    let activation = candle_nn::Activation::Relu6;
    //...

    let y = activation.forward(&x)?;

更多层见下篇内容。