Mar 23, 2025
6 min read
Rust,
Candle,
Pytorch,

Rust Candle 框架与 Pytorch 张量等价操作之归约运算

本文对比了Pytorch和Rust Candle 框架在张量归约运算上的实现差异,涵盖求和、均值、最大值、最小值等常见操作。

书接上回。

归约运算是将一个数据集合(如数组、张量等)通过某种规则或函数逐步“压缩”为一个更小的结果。像求和、均值、最大值、最小值这些都算是归约运算。

表格预览 ✅ 表示有对应实现 🚫 表示无对应实现 ☢️ 表示有代替实现

操作PytorchCandle
最大值、最小值的索引argmax、argminargmax、argmin
切片的最大值、最小值amax、aminmax,min
张量的最小值和最大值aminmax未实现,可通过上面实现☢️
张量中所有元素的最大值、最小值max、min未实现,可间接实现☢️
(input - other) 的 p 范数dict未实现🚫
每行求和指数的对数logsumexplog_sum_exp
所有元素的平均值meanmean/mean_keepdim
中位值median未实现🚫
众数值以及索引mode未实现🚫
所有元素的乘积prod未实现🚫
标准差std未实现,见下面☢️
标准差和均值std_mean未实现☢️
求和sumsum、sum_all
唯一元素unique未实现🚫
方差varvar/var_keepdim
方差和均值var_mean未实现,可参考var 和 mean☢️
张量中的元素总数(补上)numelelem_count

开始之前

pytorch 中有很多操作支持 -1 这样的参数,比如

torch.mean(a,-1)

-1 一般表示最后一维,candle 中由于类型限制,无法使用-1,而是使用枚举:

let mean = a.mean(D::Minus1)?; //D::Minus1 表示 -1

最大值、最小值的索引

主要的功能是沿着某一个维度按最大值(最小值)。

pytorch

	a = torch.tensor([
        [1.0, -2.0,3.0],
        [0.1, -0.2, 0.01]
    ])
    # 表示矩阵中按第 0 维度、沿着行的方向(即垂直方向)进行操作,返回每一列中最大值的行索引。
    print(a.argmax(dim=0)) # tensor([0, 1, 0]) 
    # 表示矩阵中按第 1 维度、沿着列的方向(即水平方向)进行操作,返回每一行中最大值的列索引。
    print(a.argmax(dim=1)) # tensor([2, 0])
    
    print(a.argmin(dim=0)) #tensor([1, 0, 1])
    print(a.argmin(dim=1)) #tensor([1, 1])

candle

    let a_data = vec![1.0, -2.0,3.0,0.1, -0.2, 0.01];
    let x = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    
    //表示矩阵中按第 0 维度、沿着行的方向(即垂直方向)进行操作,返回每一列中最大值的行索引。
    let y = x.argmax(0)?;
    println!("{:?}",y);//Tensor[2; u32]
    //表示矩阵中按第 1 维度、沿着列的方向(即水平方向)进行操作,返回每一行中最大值的列索引。
    let y = x.argmax(1)?;
    println!("{:?}",y);

    let y = x.argmin(0)?;
    println!("{:?}",y); //Tensor[1, 0, 1; u32]
    let y = x.argmin(1)?;
    println!("{:?}",y); //Tensor[1, 1; u32]

切片的最大值、最小值

candle 里没有 amax、amin。不过 candle 的 max 和 min 基本上和 torch 里的 amax 和 min 比较一致。

pytorch

    a = torch.tensor([
        [1.0, -2.0,3.0],
        [0.1, -0.2, 0.01]
    ])

    print(a.amax(dim=1)) #tensor([3.0000, 0.1000])

candle

    let a_data = vec![1.0, -2.0,3.0,0.1, -0.2, 0.01];
    let x = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    
    let y = x.max(1)?;
    println!("y = {y}");//y = [3.0000, 0.1000]

candle 也可以基于 argmax 和 stack 实现 amax。

    let a_data = vec![1.0, -2.0,3.0,0.1, -0.2, 0.01];
    let x = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    
    let y = x.argmax(1)?;
    let indexs = y.to_vec1::<u32>()?;

    let mut max_list = Vec::with_capacity(x.dims()[0]);
    for (i,_) in indexs.iter().enumerate(){
        let max = x.i((i,indexs[i] as usize))?;
        max_list.push(max);
    }

    let amax = candle_core::Tensor::stack(&max_list, 0)?;
    println!("amax:{:?}",amax); //amax:Tensor[3, 0.1; f64]

由于 amin 和 amax 类似,这里不做展开。

张量的最小值和最大值(aminmax)

这个其实就是上面这个 amax、amin 的结合,不再展开。

张量中所有元素的最大值、最小值

pytorch 中的 max、min 和 candle 中的 max、min 是不一样的。pytorch 中的 max、min 会直接返回值和索引:

    print(a.max(dim=1)) 
    # torch.return_types.max(
    # values=tensor([3.0000, 0.1000]),
    # indices=tensor([2, 0]))

candle 需要同时使用 argmax 和 max:

    let a_data = vec![1.0, -2.0,3.0,0.1, -0.2, 0.01];
    let x = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    
    let y = x.max(1)?;
    let idx = x.argmax(1)?;
    
    println!("{:?}",y);
    println!("{:?}",idx);

同理,candle 也需要 min 和 argmax 获取等效上的 torch.min。

amax、amin 和 max 、min 的区别

pytorch 中的 amax/amin 不返回索引,max/min 会返回索引。 candle 中的 max/min 基本使用上大致上等于 pytorch 的 amax/amin。但是 candle 的功能肯定不如 pytorch 版本。

每行求和指数的对数

公式如下:

logsumexp(x)i=logiexp(xij)logsumexp(x)_i​=log\sum_i{exp(x_{ij}​)}

pytorch:

    a = torch.tensor([
        [1.0, -2.0,3.0],
        [0.1, -0.2, 0.01]
    ])
    print(torch.logsumexp(a, dim=1)) #tensor([3.1328, 1.0764])

candle

    let a_data = vec![1.0, -2.0,3.0,0.1, -0.2, 0.01];
    let x = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    
    let y = x.log_sum_exp(1)?;
    println!("{:?}",y);//Tensor[3.1328452337275756, 1.076350264534186; f64]

amax 和 amin 都有 keepdim参数,和 candle 中的 max_keepdim /min_keepdim 对应。

所有元素的平均值

虽然 pytorch 和 candle 的平均值都是 mean,但是,当 2 者都指定 dim 参数时,才等同操作。 当 pytorch 不指定 dim 维度时,默认返回的是整个 Tensor 的 平均值。

pytorch

    a = torch.tensor([
        [1.0, -2.0,3.0],
        [0.1, -0.2, 0.01]
    ])
    # a.mean() 和 torch.mean(a) 等同
    print(a.mean()) #tensor(0.3183),如果不指定 dim,会返回所有元素的平均值
    print(a.mean(dim=0)) #tensor([ 0.5500, -1.1000,  1.5050])
    print(a.mean(dim=1)) #tensor([ 0.6667, -0.0300])

candle

    let a_data = vec![1.0, -2.0,3.0,0.1, -0.2, 0.01];
    let x = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    
    let y = x.mean(0)?;
    println!("{:?}",y);//Tensor[0.55, -1.1, 1.505; f64]
    let y = x.mean(1)?;
    println!("{:?}",y);//Tensor[0.6666666666666666, -0.030000000000000002; f64]

因为 rust 方法必须指定参数,不过可以通过下面自行实现,或者使用 x.mean_all()

    let z = y.to_vec1::<f64>()?;
    // mean z
    let z = z.iter().sum::<f64>()/z.len() as f64; 
    println!("{:?}",z); //0.3183333333333333

	// 下面等同于上面
	let y = x.mean_all()?;
    println!("y:{:?}",y);

其他操作如果也有类似不指定维度的,直接返回所有的,都可以找找是否有 xxx_all() 方法。比如 sum 也有一个 sum_all()

pytorch 的中的 mean 还有一个 keepdim 参数,这个与 candle 中的 mean_keepdim()方法等同:

//python
a.mean(dim=0,keepdim=True)
//rust 
x.mean_keepdim(0)?;

标准差

标准差的公式如下:

1max(0, NδN)i=0N1(xix)2\frac{1}{max(0, N−δN)}\sum_{i=0}^{N−1}​(x_i​−\overline{x})^2​

candle 没有 std 函数。不过,可以通过上面公式算出来,比如,当 dim = 0,unbiased=False 时: pytorch:

	a = torch.tensor([
        [1.0, -2.0,3.0],
        [0.1, -0.2, 0.01]
    ])
    print(a.std(dim=0,unbiased=False)) # tensor([0.4500, 0.9000, 1.4950])

candle 等效代码如下 :

    let a_data = vec![1.0, -2.0,3.0,0.1, -0.2, 0.01];
    let x = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    
    // 计算均值
    let mean = x.mean(0)?;
    // 计算偏差
    let diff = x.broadcast_sub(&mean)?;
    // 计算平方偏差
    let diff_sq = diff.sqr()?;
    // 计算平方偏差的均值
    let mean_sq = diff_sq.mean(0)?;
    // 计算标准差
    let std = mean_sq.sqrt()?;
    println!("std:{:?}",std); //std:Tensor[0.45, 0.9, 1.495; f64]

pytorch 的 std 默认 unbiased = True ,a.std(dim=0,unbiased=True) 按列计算偏差,并使用无偏估计。

差异的地方大概算法如下:

variance = torch.sum(squared_diff) / (x.numel() - 1)

candle 实现如下:

    let a_data = vec![1.0, -2.0,3.0,0.1, -0.2, 0.01];
    let x = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    
    // 计算均值
    let mean = x.mean(0)?;
    
    println!("x.dims()[0]:{:?}",x.dims()[0]);
    // 计算偏差
    let diff = x.broadcast_sub(&mean)?;
    // 计算平方偏差
    let diff_sq = diff.sqr()?;
    // 计算平方偏差的均值
    let mean_sq = diff_sq.sum(0)? /  ((x.dims()[0] - 1) as f64);
    // 计算标准差
    let std = mean_sq?.sqrt()?;
    println!("std:{:?}",std); //std:Tensor[0.45, 0.9, 1.495; f64]
    

x.dims()[0] 表示按列的长度计算无偏估计。

标准差和均值

candle 未实现,参考上面实现即可。

求和

pytorch 和 candle 的求和基本一致,唯一不同的是,当 sum 无参数时,candle 使用 sum_all 代替: pytorch

    a = torch.tensor([
        [1.0, -2.0,3.0],
        [0.1, -0.2, 0.01]
    ])
    print(a.sum()) #tensor(1.9100)
    print(a.sum(dim=0)) #tensor([ 1.1000, -2.2000,  3.0100])
    print(a.sum(dim=1)) #tensor([ 2.0000, -0.0900])

candle:

    let a_data = vec![1.0, -2.0,3.0,0.1, -0.2, 0.01];
    let x = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    
    let y = x.sum_all()?;
    println!("{:?}",y);//Tensor[1.9100000000000001; f64]
    let y = x.sum(0)?;
    println!("{:?}",y);//Tensor[1.1, -2.2, 3.01; f64]
    let y = x.sum(1)?;
    println!("{:?}",y);//Tensor[2, -0.09000000000000001; f64]

方差

方差的公式如下:

σ2=1max(0, NδN)i=0N1(xix)2σ^2=\frac{1}{max(0, N−δN)​}\sum_{i=0}^{N−1}​(x_i​−\overline{x})^2

区别是,pytorch 可不指定 dim:

    a = torch.tensor([
        [1.0, -2.0,3.0],
        [0.1, -0.2, 0.01]
    ])
    print(a.var()) #tensor(2.6884)
    print(a.var(dim=0)) #tensor([0.4050, 1.6200, 4.4700])
    print(a.var(dim=1)) #tensor([6.3333, 0.0237])
    print(a.var(dim=0,keepdim=True)) #tensor([[0.4050, 1.6200, 4.4700]])

candle 没有 a.var() 的等价操作:

    let a_data = vec![1.0, -2.0,3.0,0.1, -0.2, 0.01];
    let x = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;

    let y = x.var(0)?;
    println!("{:?}",y);//Tensor[0.405, 1.62, 4.4700500000000005; f64]
    let y = x.var(1)?;
    println!("{:?}",y);//Tensor[6.333333333333334, 0.023700000000000002; f64]
    let y = x.var_keepdim(0)?;
    println!("{y}"); //[[0.4050, 1.6200, 4.4701]]

张量中的元素总数(补上)

pytorch

    a = torch.tensor([
        [1.0, -2.0,3.0],
        [0.1, -0.2, 0.01]
    ])
    print(a.numel())

candle

    let a_data = vec![1.0, -2.0,3.0,0.1, -0.2, 0.01];
    let x = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    println!("count:{:?}",x.elem_count()); //std:6