Mar 21, 2025
5 min read
Rust,
Candle,
Pytorch,

Rust Candle 框架与 Pytorch 的张量的索引、切片、连接、变异等价操作

本文对比了Rust Candle与Pytorch在张量索引、切片、连接、变异等操作上的实现差异,涵盖常见张量操作及其等价方法。

书接上回, 深度学习中,张量还存在大量的索引、切片、连接、变异等操作。由于 Python 的特性,这些张量的操作相对来说是比较方便的。但是在 Candle 中可能会有和 Pytorch 中有差异的地方。

索引与切片

定义一个3x3的矩阵,下面有些操作不提示的话,则会基于这个矩阵:

	[
       [1,2,3],
       [4,5,6],
       [7,8,9], 
    ]

pytorch 中取第一行的数据、取第一列的数据分别为:

    print(x[0]) #tensor([1, 2, 3])
    print(x[:,:1])  #tensor([[1],
                    #        [4],
                    #        [7]])  

candle 中取第一行的数据,取第一列的数据分别为:

    let data = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9];
    let x = Tensor::from_vec(data, &[3, 3], &Device::Cpu)?;

    let y = x.i(0)?;
    println!("{y}"); //[1, 2, 3]

    let y =  x.i((.., ..1))?;
    println!("{y}",);   //[[1],
                        // [4],
                        // [7]]

select

select 等价于切片。

在 pytorch 中,tensor.select(0, index) 等价于 tensor[index],而 tensor.select(2, index) 等价于 tensor[:,:,index]

在 candle 中, 等价于 i() 方法。

改变张量的形状

pytorch 中使用 view或者 reshape 重构张量的形状:

    x = torch.tensor(data=[1,2,3,4,5,6,7,8,9])
    y = x.view(3,3)
    y = x.reshape(3,3)
    
    # 输出
    # tensor([[1, 2, 3],
	#         [4, 5, 6],
	#         [7, 8, 9]])
    print(y)

candle 中只有 reshape api 可以重构张量的形状:

    let data = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9];
    let x = Tensor::new(data, &Device::Cpu)?;
    let y = x.reshape((3,3))?;
    
    // 输出的结果为:
    //[[1, 2, 3],
    //[4, 5, 6],
    //[7, 8, 9]]
    println!("{y}");

连接张量

很多时候我们需要在给定维度中连接 tensors 中给定的张量序列。连接 操作有 2 种。

cat

假设两个张量 a 和 b 形状均为 (2, 3),分别按第 0 维和第 1 维拼接,pytorch 代码如下:

    a = torch.tensor([[1, 2, 3], [4, 5, 6]])
    b = torch.tensor([[7, 8, 9], [10, 11, 12]])
    
    # 沿行(dim=0)拼接 → 形状 (4, 3)
    y = torch.cat([a,b], dim=0)
    # 结果为:
    # tensor([[ 1,  2,  3],
    #         [ 4,  5,  6],
    #         [ 7,  8,  9],
    #         [10, 11, 12]])
    print(y)
    
    # 沿行(dim=0)拼接 → 形状 (2,6)
    y = torch.cat([a,b], dim=1)
    # 结果为:
    # tensor([[ 1,  2,  3,  7,  8,  9],
    #         [ 4,  5,  6, 10, 11, 12]])
    print(y)

candle 大致一样:

    let a_data = vec![1i64, 2, 3, 4, 5, 6];
    let b_data = vec![7i64, 8, 9,10,11,12];
    let a = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    let b = Tensor::from_vec(b_data, (2,3),&Device::Cpu)?;
    
    // 沿行(dim=0)拼接 → 形状 (4, 3)
    let y = candle_core::Tensor::cat(&[&a,&b], 0)?;
    // 结果如下:
    // [[ 1,  2,  3],
    // [ 4,  5,  6],
    // [ 7,  8,  9],
    // [10, 11, 12]]
    println!("{y}");
    
    // 沿行(dim=1)拼接 → 形状 (2,6)
    let y = candle_core::Tensor::cat(&[&a,&b], 1)?;
    
    // 结果如下:
    // [[ 1,  2,  3,  7,  8,  9],
    // [ 4,  5,  6, 10, 11, 12]]
    println!("{y}");

stack

stack 是另一个拼接方法,同样假设两个张量 a 和 b 形状均为 (2, 3),分别按第 0 维和第 1 维拼接,pytorch 代码如下:

    a = torch.tensor([[1, 2, 3], [4, 5, 6]])
    b = torch.tensor([[7, 8, 9], [10, 11, 12]])
    
    # 沿行(dim=0)拼接 → 形状 (2, 2, 3)
    y = torch.stack([a,b], dim=0)
    # 结果为:
    # [
    #   [[ 1,  2,  3],[ 4,  5,  6]],
    #   [[ 7,  8,  9],[ 10, 11, 12]]
    # ]
    print(y)
    
    # 沿行(dim=0)拼接 → 形状 (2, 2, 3)
    y = torch.stack([a,b], dim=1)
    # 结果为:
    # [
    #   [[ 1,  2,  3], [ 7,  8,  9]],
    #   [[ 4,  5,  6], [10, 11, 12]]
    # ]
    print(y)

candle 同理:

    let a_data = vec![1i64, 2, 3, 4, 5, 6];
    let b_data = vec![7i64, 8, 9,10,11,12];
    let a = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    let b = Tensor::from_vec(b_data, (2,3),&Device::Cpu)?;
    
    // 沿行(dim=0)拼接 → 形状 (2, 2, 3)
    let y = candle_core::Tensor::stack(&[&a,&b], 0)?;
    // 结果为:
    // [
    //   [[ 1,  2,  3],[ 4,  5,  6]],
    //   [[ 7,  8,  9],[ 10, 11, 12]]
    // ]
    println!("{y}");
    
    // 沿行(dim=0)拼接 → 形状 (2, 2, 3)
    let y = candle_core::Tensor::stack(&[&a,&b], 1)?;

    // 结果为:
    // [
    //   [[ 1,  2,  3], [ 7,  8,  9]],
    //   [[ 4,  5,  6], [10, 11, 12]]
    // ]
    println!("{y}");

cat 与 stack 的区别?

在PyTorch中,torch.cat 和 torch.stack 都用于张量拼接,但关键区别在于是否创建新维度。 可以看到,stack 直接增加了一个维度。

不过,stack 可以使用 cat 代替,因为我们可以手动增加维度。

pytorch:

    # 下面这两个操作等价
    y1 = torch.stack([a,b], dim=0)
    y2 = torch.cat([a.unsqueeze(0),b.unsqueeze(0)], dim=0)
    # 结果为:
    # [
    #   [[ 1,  2,  3],[ 4,  5,  6]],
    #   [[ 7,  8,  9],[ 10, 11, 12]]
    # ]
    assert y1.equal(y2)

candle:

// 下面这两个操作等价
    let y1 = candle_core::Tensor::stack(&[&a,&b], 1)?;
    let y2 = candle_core::Tensor::cat(&[&a.unsqueeze(0)?,&b.unsqueeze(0)?], 1)?;

    println!("{y1}");
    println!("{y2}");

从上面也可以看出, 两边的升维 API 都是 unsqueeze

有长维就有降维,降维操作是 squeeze:

pytorch

    a = torch.tensor([[1, 2, 3]])
    
    # (2,3) -> (3)
    b = a.squeeze()
    
    # tensor([1, 2, 3])
    print(b.shape)
   

candle

    let a = Tensor::new(vec![&[1i64,2,3]],  &Device::Cpu)?;
    // 降维,参数为维度,默认为0
    // Tensor[2, 3; i64]
    let y = a.squeeze(0)?;

    //Tensor[1, 2, 3; i64]
    println!("{:?}",y);

分割张量

把张量分类为指定的块,默认会返回指定大小的块,但是 如果最后一块不能被指定块大小整除,可能会小于指定的块,可以理解为 除法中的余数。

pytorch

	a = torch.tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]])
    # 从第 0 维分割
    b = a.chunk(2, dim=0)
    # (tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]]),)
    print(b)
    # 从第 1 维分割
    c = a.chunk(2, dim=1)
    # (tensor([[0, 1, 2, 3, 4, 5]]), tensor([[ 6,  7,  8,  9, 10]]))
    print(c)

candle

    let a = Tensor::new(vec![&[[0i64,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]]],  &Device::Cpu)?;
    // 从第 0 维分割
    let b = a.chunk(2, 0)?;
    // (tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]]),)
    println!("a: {a}");
    // 从第 1 维分割
    let c = a.chunk(2, 1)?;
    // [tensor([[0, 1, 2, 3, 4, 5]]), tensor([[ 6,  7,  8,  9, 10]])]
    println!("c: {:?}",c);

转置张量

通过给定张量的 2 个维度,对张量进行转置,使用 transpose。 pytorch

	data = [
       [1,2,3],
       [4,5,6],
    ]
    a = torch.tensor(data)
    # 对第 0 维和第 1 维进行转置
    b = a.transpose(0,1)
    
    # [[1, 4],
    #  [2, 5],
    #  [3, 6]]
    print(b)

candle

    let a_data = vec![1i64, 2, 3, 4, 5, 6];
    let a = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    
    // 对矩阵进行转置
    let y = a.transpose(0,1)?;

    // 打打印结果:
    //     [[1, 4],
    //      [2, 5],
    //      [3, 6]]
    println!("{y}");

以上操作相对来说是比较常见的一些张量操作,还有更多不太常见的操作不在这里展开。这些是模型推理的一些基础,我们了解到这些基础才能更好地了解模型上的一些操作意义。