写这么多篇关于 candle 与 pytorch 的文章,其实是我学习与整理的一些笔记。之前使用 pytorch 几乎是囫囵吞枣,从来没有关注过 pytorch 的一些 api 细节。而 candle 这边约等于没有文档。
masked_fill
masked_fill 是一种用于条件化张量填充 的操作,其作用是根据指定的布尔掩码(mask),将张量中满足条件的位置替换为给定的值。 candle 中没有关于masked_fill 的官方 接口的实现,不过有在 transformer 模块中找到一些自定义实现。
pytorch
x = torch.tensor([[1.0, 0.0], [0.3, -0.4]])
mask = x.to(torch.bool)
c = x.masked_fill(mask, torch.finfo(x.dtype).min)
print(c) #tensor([[-3.4028e+38, 0.0000e+00],
#[-3.4028e+38, -3.4028e+38]])
candle
// 自定义实现
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
}
//调用示例
let data = vec![1.0f32, 0.0,0.3, -0.4];
let x = Tensor::from_vec(data, (2,2), &Device::Cpu)?;
let mask = x.ne(0.0)?;
let y = masked_fill(&x, &mask, f32::MIN)?;
println!("mask:{y}");
// mask:[[-3.4028e38, 0.0000e0],
// [-3.4028e38, -3.4028e38]]
// Tensor[[2, 2], f32]
广播机制
PyTorch 的广播机制允许不同形状的张量进行逐元素操作(如加减乘除),只要它们的形状满足以下条件:
- 从尾部维度开始比较,两个张量的维度大小必须相等,或者其中一个维度的大小为 1。
- 如果两个张量的维度数不同,则会在较小张量的前面补 1,直到两者的维度数相同。
假设我们有两个张量:
A 的形状为 [1, 1, 64, 64]
B 的形状为 [64, 64]
这两个张量在 pytorch 中是可以直接相加的:
a = torch.ones(1,1,64,64)
b = torch.ones(64,64)
print(a+b)
但是,在 candle 中,由于 rust 的特性,不同尺寸的张量无法实现 a+b 这种操作。因此我们需要使用 broadcast_add 达到相同的目地。
let device = Device::Cpu;
let a = Tensor::ones((1,1,64,64), DType::F32, &device)?;
let b = Tensor::ones((64,64), DType::F32, &device)?;
// 相加
let c = a.broadcast_add(&b)?;
println!("c::{c}");
矩阵乘法
pytorch 中 a@b 等价于 torch.matmul(a, b)。
那和 a*b有什么区别?
以下面 2 个矩阵为例子:
a=[1324]
b=[5768]
a*b 其实叫`逐元素乘法 ,它要求 a 和 b 的尺寸必须相同,并会按相同的元素逐个相乘。既
= \begin{bmatrix} 5&12\\21&32\\ \end{bmatrix} $$
而矩阵相乘 a @ b 的计算过程是逐行与逐列的点积 。
$$ a @ b = \begin{bmatrix} 1*5+2*7&1*6+2*8\\3*5+4*7&3*6+4*8\\ \end{bmatrix}
= \begin{bmatrix} 19&22\\43&50\\ \end{bmatrix} $$
pytorch :
```python
a = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
b = torch.tensor([[5.0, 6.0], [7.0, 8.0]])
#tensor([[19., 22.],[43., 50.]])
print(a @ b)
#tensor([[ 5., 12.],[21., 32.]])
print(a * b)
```
candle:
```rust
let a_data = vec![1.0f32, 2.0,3.0,4.0];
let b_data = vec![5.0f32, 6.0,7.0,8.0];
let a = Tensor::from_vec(a_data, (2,2), &Device::Cpu)?;
let b = Tensor::from_vec(b_data, (2,2), &Device::Cpu)?;
let x = a.matmul(&b)?;
//[[19., 22.],[43., 50.]]
println!("x:{x}");
//[[ 5., 12.],[21., 32.]]
let y = (a * b)?;
println!("y:{y}");
```
## ModuleList
ModuleList 也是一个容器,只是提供列表容器,并没有实质上的功能。这个在 candle 上是没有实现的。
但是我们有时候会碰到这种结构:
```
(albert_layer_groups): ModuleList(
(0): AlbertLayerGroup(
(albert_layers): ModuleList(
(0): AlbertLayer(
```
里面的那个 `0` 也是 `key`,这意味着如果我们想单纯的使用 Vec<...> 来表示是不行的。原因是Vec<...> 无法为模型结构构建这个 key 名,因为我们需要在构建器里使用 `vb.pp("0")`,而 ModuleList 本身也需要一个 key,比如上面的 `albert_layers` 和 `albert_layer_groups`。
下面这个主要会报类似于 ` cannot find tensor albert.encoder.albert_layer_groups.0.0.full_layer_layer_norm.weight` 的错误,也就是 key 的路径对不上了。
```rust
//
#[derive(Debug, Clone)]
struct AlbertLayerGroup {
albert_layers: Vec<AlbertLayer>,
}
```
我自己的习惯做法是,把 `Vec<AlbertLayer>` 变成自定义的 struct,内部维护。这种做法在功能上没什么意义,主要是为了让权重的 key 能对应上。
```rust
#[derive(Debug, Clone)]
struct AlbertLayers {
layers: Vec<AlbertLayer>,
}
impl AlbertLayers {
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let mut layers = vec![];
for i in 0..config.inner_group_num {
layers.push(AlbertLayer::load(vb.pp(i), config)?);
}
Ok(Self { layers })
}
}
```
AlbertLayerGroup 变成这样:
```rust
#[derive(Debug, Clone)]
struct AlbertLayerGroup {
albert_layers: AlbertLayers,
}
impl AlbertLayerGroup {
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
let albert_layers = AlbertLayers::load(vb.pp("albert_layers"), config)?;
Ok(Self { albert_layers })
}
}
```
这种做法在代码上比较啰嗦,不过胜在清晰。不加这种结构的做法也是可以的,`vb.pp("")` 是可以用 `.` 语法的,我们一样可以用 `vb.pp("albert_layers.0")` 取到权重。