Dec 10, 2025
4 min read
Rust,
Candle,
PyTorch,

Rust Candle 的 PyTorch 等价操作记录

SDPA

SDPA 模块在 pytorch 的 api 是 torch.nn.functional.scaled_dot_product_attention。它被用来加速大模型的Attention 计算。

公式大概如下:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V

在 pytorch 里,它是一个高度优化的实现。 candle 核心库里没有这个实现。不过在 transformers 模块里有一些模型有实现,实现的版本也相当的简单,基本就是上面公式的复刻:

fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
    let dim = q.dim(D::Minus1)?;
    let scale_factor = 1.0 / (dim as f64).sqrt();
    let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
    candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
}

正常来说,可能还会有因果掩码causal_maskattn_mask 的实现。

linspace

linspace 模块在 pytorch 的 api 是 torch.linspace。它被用来生成一个等间隔数值序列的一维张量。 比如,生成从 0 到 1 的 5 个等间隔值:

x = torch.linspace(0, 1, 5)# 参数分别为 开始值, 结束值, 步数
print(x)  # tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])

candle 也没有自带,不过自己实现也很简单:

pub fn linspace(start: f64, stop: f64, steps: usize,device: &Device) -> Result<Tensor> {
    if steps == 0 {
        Tensor::from_vec(Vec::<f64>::new(), steps, device)
    } else if steps == 1 {
        Tensor::from_vec(vec![start], steps, device)
    } else {
        let delta = (stop - start) / (steps - 1) as f64;
        let vs = (0..steps)
            .map(|step| start + step as f64 * delta)
            .collect::<Vec<_>>();
        Tensor::from_vec(vs, steps, device)
    }
}

weight_norm (conv1d)

torch.nn.utils.weight_norm 是一种权重归一化技术,用于稳定训练和加速收敛。

python 大概用法如下:

from torch.nn.utils import weight_norm

# 应用权重归一化到卷积层
conv = nn.Conv1d(128, 256, kernel_size=3)
conv = weight_norm(conv)  # 对权重进行归一化

数学意义上,权重归一化会把权重参数 w 分解为两个部分:

w=gvvw = g \odot \frac{v}{\|v\|}

其中 gg 是一个可学习的参数,vv 是原始的权重参数。v||v||vv 的 L2 范数。

candle 无内置,但是这个在一些 LLM 里比较常用,因此 transformers 模块里一些模型有自己的实现。实现大概如下:

fn conv1d_weight_norm(
    in_c: usize,
    out_c: usize,
    kernel_size: usize,
    bias: bool,
    config: candle_nn::Conv1dConfig,
    vb: VarBuilder,
) -> Result<Conv1d> {
    let weight = if vb.contains_tensor("weight") {
        vb.get((out_c, in_c, kernel_size), "weight")?
    } else {
        let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
        let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
        let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
        weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
    };
    let bias = if bias {
        Some(vb.get(out_c, "bias")?)
    } else {
        None
    };
    Ok(Conv1d::new(weight, bias, config))
}

与 pytorch 不同的是,weight_gweight_v 的 shape 需要根据实际的情况而定。无法完全照抄上面代码。

Snake 激活函数

公式大概如下

snake(x)=x+(1/α)sin2(αx)snake(x) = x + (1/α) * sin²(αx)

这是一个比较新的激活函数,比较适合周期性的数据。 事实上,我是第一次看到这个激活函数。看起来,这个激活函数更多用于音频数据,因为 candle 里的 Descript Audio Codec (DAC) 模型就实现了这个激活函数。VoxCPMAudioVae 模块也用了这个激活函数。

#[derive(Debug, Clone)]
pub struct Snake1d {
    alpha: Tensor,
}

impl Snake1d {
    pub fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
        let alpha = vb.get((1, channels, 1), "alpha")?;
        Ok(Self { alpha })
    }
}

impl candle::Module for Snake1d {
    fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let xs_shape = xs.shape();
        let xs = xs.flatten_from(2)?;
        let sin = self.alpha.broadcast_mul(&xs)?.sin()?;
        let sin = (&sin * &sin)?;
        (xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape)
    }
}

乘积

不管是 np.prod 还是 torch.prod,都是计算张量中所有元素的乘积。

arr = np.array([1, 2, 3, 4])
result = np.prod(arr)
print(result)  # 输出: 24 (1×2×3×4 = 24)

candle 中不存在这个操作,需要转为数组,再使用x.iter().product(); 计算。

slice_assign

slice_assign 是一个用于将一个张量赋给另一个张量的操作。 在 pytorch 中,我们经常会看到类似下面这种用法:

    x = torch.arange(0, 4*5).reshape(4, 5)
    y = torch.arange(0, 2*3).reshape(3, 2)
    x[1:4, 3:5] = y

这个操作将张量 y 的元素赋给张量 x 的指定切片。

更可视化的解释:

// 原始 tensor:
// 行\列 0  1  2  3  4
// 0     [0, 1, 2, 3, 4]
// 1     [5, 6, 7, 8, 9]       ← 选择行 1-3
// 2     [10,11,12,13,14]      ← 选择行 1-3  
// 3     [15,16,17,18,19]      ← 选择行 1-3
//                ↑  ↑
//               列 3-4

// 选择的区域: tensor[1..4, 3..5]
// 形状: (3, 2)
// [[8, 9],
//  [13,14],
//  [18,19]]

// 用 src 替换:
// src = [[0, 1],
//        [2, 3],
//        [4, 5]]

// 结果:
// [[ 0,  1,  2,  3,  4],
//  [ 5,  6,  7,  0,  1],  ← 第1行,列3-4变为 [0, 1]
//  [10, 11, 12,  2,  3],  ← 第2行,列3-4变为 [2, 3]
//  [15, 16, 17,  4,  5]]  ← 第3行,列3-4变为 [4, 5]

Rust 中当然不支持这种语法,但是可以通过 x.slice_assign([start, end], y) 来实现。

let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;

这个操作唯一比较难理解的地方是 shape 参数,特别是高维张量。 简单来说,slice_assign 选择区域的 shape 必须要和被赋值的张量的 shape 一致。