Apr 01, 2025
3 min read
Rust,
Candle,
Pytorch,

rust candle 的一些杂项

本文总结了 Rust Candle 框架中的一些杂项功能,包括自定义实现的 `masked_fill`、广播机制(`broadcast_add`)、矩阵乘法(`matmul` 与逐元素乘法的区别)以及模块容器 `ModuleList` 的替代实现。通过对比 PyTorch,分析了 Candle 在张量操作和模型加载上的差异及解决方案。

写这么多篇关于 candle 与 pytorch 的文章,其实是我学习与整理的一些笔记。之前使用 pytorch 几乎是囫囵吞枣,从来没有关注过 pytorch 的一些 api 细节。而 candle 这边约等于没有文档。

masked_fill

masked_fill 是一种用于条件化张量填充 的操作,其作用是根据指定的布尔掩码(mask),将张量中满足条件的位置替换为给定的值。 candle 中没有关于masked_fill官方 接口的实现,不过有在 transformer 模块中找到一些自定义实现。

pytorch

    x = torch.tensor([[1.0, 0.0], [0.3, -0.4]])
    mask = x.to(torch.bool)
    c = x.masked_fill(mask, torch.finfo(x.dtype).min)
    print(c) #tensor([[-3.4028e+38,  0.0000e+00],
                    #[-3.4028e+38, -3.4028e+38]])

candle

// 自定义实现
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
    let shape = mask.shape();
    let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
    let m = mask.where_cond(&on_true, on_false)?;
    Ok(m)
}

	//调用示例
    let data = vec![1.0f32, 0.0,0.3, -0.4];
    let x = Tensor::from_vec(data, (2,2), &Device::Cpu)?;
    let mask = x.ne(0.0)?;

    let y = masked_fill(&x, &mask, f32::MIN)?;

    println!("mask:{y}");
//     mask:[[-3.4028e38,   0.0000e0],
//      [-3.4028e38, -3.4028e38]]
//      Tensor[[2, 2], f32]

广播机制

PyTorch 的广播机制允许不同形状的张量进行逐元素操作(如加减乘除),只要它们的形状满足以下条件:

  1. 从尾部维度开始比较,两个张量的维度大小必须相等,或者其中一个维度的大小为 1。
  2. 如果两个张量的维度数不同,则会在较小张量的前面补 1,直到两者的维度数相同。

假设我们有两个张量:

  • A 的形状为 [1, 1, 64, 64]
  • B 的形状为 [64, 64]

这两个张量在 pytorch 中是可以直接相加的:

    a = torch.ones(1,1,64,64)
    b = torch.ones(64,64)
    print(a+b)

但是,在 candle 中,由于 rust 的特性,不同尺寸的张量无法实现 a+b 这种操作。因此我们需要使用 broadcast_add 达到相同的目地。

    let device = Device::Cpu;
    let a = Tensor::ones((1,1,64,64), DType::F32, &device)?;
    let b = Tensor::ones((64,64), DType::F32, &device)?;
    // 相加
    let c  = a.broadcast_add(&b)?;
    println!("c::{c}");

矩阵乘法

pytorch 中 a@b 等价于 torch.matmul(a, b)

那和 a*b有什么区别?

以下面 2 个矩阵为例子: a=[1234]a = \begin{bmatrix} 1&2\\3&4\\ \end{bmatrix} b=[5678]b = \begin{bmatrix} 5&6\\7&8\\ \end{bmatrix} a*b 其实叫`逐元素乘法 ,它要求 a 和 b 的尺寸必须相同,并会按相同的元素逐个相乘。既

= \begin{bmatrix} 5&12\\21&32\\ \end{bmatrix} $$ 而矩阵相乘 a @ b 的计算过程是逐行与逐列的点积 。 $$ a @ b = \begin{bmatrix} 1*5+2*7&1*6+2*8\\3*5+4*7&3*6+4*8\\ \end{bmatrix} = \begin{bmatrix} 19&22\\43&50\\ \end{bmatrix} $$ pytorch : ```python a = torch.tensor([[1.0, 2.0], [3.0, 4.0]]) b = torch.tensor([[5.0, 6.0], [7.0, 8.0]]) #tensor([[19., 22.],[43., 50.]]) print(a @ b) #tensor([[ 5., 12.],[21., 32.]]) print(a * b) ``` candle: ```rust let a_data = vec![1.0f32, 2.0,3.0,4.0]; let b_data = vec![5.0f32, 6.0,7.0,8.0]; let a = Tensor::from_vec(a_data, (2,2), &Device::Cpu)?; let b = Tensor::from_vec(b_data, (2,2), &Device::Cpu)?; let x = a.matmul(&b)?; //[[19., 22.],[43., 50.]] println!("x:{x}"); //[[ 5., 12.],[21., 32.]] let y = (a * b)?; println!("y:{y}"); ``` ## ModuleList ModuleList 也是一个容器,只是提供列表容器,并没有实质上的功能。这个在 candle 上是没有实现的。 但是我们有时候会碰到这种结构: ``` (albert_layer_groups): ModuleList( (0): AlbertLayerGroup( (albert_layers): ModuleList( (0): AlbertLayer( ``` 里面的那个 `0` 也是 `key`,这意味着如果我们想单纯的使用 Vec<...> 来表示是不行的。原因是Vec<...> 无法为模型结构构建这个 key 名,因为我们需要在构建器里使用 `vb.pp("0")`,而 ModuleList 本身也需要一个 key,比如上面的 `albert_layers` 和 `albert_layer_groups`。 下面这个主要会报类似于 ` cannot find tensor albert.encoder.albert_layer_groups.0.0.full_layer_layer_norm.weight` 的错误,也就是 key 的路径对不上了。 ```rust // #[derive(Debug, Clone)] struct AlbertLayerGroup { albert_layers: Vec<AlbertLayer>, } ``` 我自己的习惯做法是,把 `Vec<AlbertLayer>` 变成自定义的 struct,内部维护。这种做法在功能上没什么意义,主要是为了让权重的 key 能对应上。 ```rust #[derive(Debug, Clone)] struct AlbertLayers { layers: Vec<AlbertLayer>, } impl AlbertLayers { pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> { let mut layers = vec![]; for i in 0..config.inner_group_num { layers.push(AlbertLayer::load(vb.pp(i), config)?); } Ok(Self { layers }) } } ``` AlbertLayerGroup 变成这样: ```rust #[derive(Debug, Clone)] struct AlbertLayerGroup { albert_layers: AlbertLayers, } impl AlbertLayerGroup { fn load(vb: VarBuilder, config: &Config) -> Result<Self> { let albert_layers = AlbertLayers::load(vb.pp("albert_layers"), config)?; Ok(Self { albert_layers }) } } ``` 这种做法在代码上比较啰嗦,不过胜在清晰。不加这种结构的做法也是可以的,`vb.pp("")` 是可以用 `.` 语法的,我们一样可以用 `vb.pp("albert_layers.0")` 取到权重。