Mar 21, 2025
7 min read
Rust,
Candle,
Pytorch,

Equivalent Operations of Indexing, Slicing, Concatenation, and Mutation for Tensors in Rust Candle Framework and Pytorch

This article compares the implementation differences between Rust Candle and Pytorch in tensor indexing, slicing, concatenation, mutation, and other operations, covering common tensor operations and their equivalent methods.

Continuing from the last topic, in deep learning, tensors also involve a large number of indexing, slicing, concatenation, mutation, and other operations. Due to the characteristics of Python, these tensor operations are relatively convenient. However, there may be differences between Candle and Pytorch.

Indexing and Slicing

Define a 3x3 matrix; some operations below will be based on this matrix:

	[
       [1,2,3],
       [4,5,6],
       [7,8,9], 
    ]

In Pytorch, retrieving the first row and the first column data is as follows:

    print(x[0]) #tensor([1, 2, 3])
    print(x[:,:1])  #tensor([[1],
                    #        [4],
                    #        [7]])  

In Candle, retrieving the first row and the first column data is as follows:

    let data = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9];
    let x = Tensor::from_vec(data, &[3, 3], &Device::Cpu)?;

    let y = x.i(0)?;
    println!("{y}"); //[1, 2, 3]

    let y =  x.i((.., ..1))?;
    println!("{y}",);   //[[1],
                        // [4],
                        // [7]]

Select

Select is equivalent to slicing.

In Pytorch, tensor.select(0, index) is equivalent to tensor[index], and tensor.select(2, index) is equivalent to tensor[:,:,index].

In Candle, it is equivalent to the i() method.

Reshaping Tensors

In Pytorch, use view or reshape to reshape tensors:

    x = torch.tensor(data=[1,2,3,4,5,6,7,8,9])
    y = x.view(3,3)
    y = x.reshape(3,3)
    
    # Output
    # tensor([[1, 2, 3],
	#         [4, 5, 6],
	#         [7, 8, 9]])
    print(y)

In Candle, only the reshape API can be used to reshape tensors:

    let data = vec![1i64, 2, 3, 4, 5, 6, 7, 8, 9];
    let x = Tensor::new(data, &Device::Cpu)?;
    let y = x.reshape((3,3))?;
    
    // Output result:
    //[[1, 2, 3],
    //[4, 5, 6],
    //[7, 8, 9]]
    println!("{y}");

Concatenating Tensors

We often need to concatenate given tensors along a specified dimension. There are two types of concatenation operations.

Cat

Assume two tensors a and b both have shapes (2, 3), concatenated along the 0th and 1st dimensions respectively, Pytorch code is as follows:

    a = torch.tensor([[1, 2, 3], [4, 5, 6]])
    b = torch.tensor([[7, 8, 9], [10, 11, 12]])
    
    # Concatenate along rows (dim=0) → Shape (4, 3)
    y = torch.cat([a,b], dim=0)
    # Result:
    # tensor([[ 1,  2,  3],
    #         [ 4,  5,  6],
    #         [ 7,  8,  9],
    #         [10, 11, 12]])
    print(y)
    
    # Concatenate along rows (dim=0) → Shape (2,6)
    y = torch.cat([a,b], dim=1)
    # Result:
    # tensor([[ 1,  2,  3,  7,  8,  9],
    #         [ 4,  5,  6, 10, 11, 12]])
    print(y)

Candle works similarly:

    let a_data = vec![1i64, 2, 3, 4, 5, 6];
    let b_data = vec![7i64, 8, 9,10,11,12];
    let a = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    let b = Tensor::from_vec(b_data, (2,3),&Device::Cpu)?;
    
    // Concatenate along rows (dim=0) → Shape (4, 3)
    let y = candle_core::Tensor::cat(&[&a,&b], 0)?;
    // Result:
    // [[ 1,  2,  3],
    // [ 4,  5,  6],
    // [ 7,  8,  9],
    // [10, 11, 12]]
    println!("{y}");
    
    // Concatenate along rows (dim=1) → Shape (2,6)
    let y = candle_core::Tensor::cat(&[&a,&b], 1)?;
    
    // Result:
    // [[ 1,  2,  3,  7,  8,  9],
    // [ 4,  5,  6, 10, 11, 12]]
    println!("{y}");

Stack

Stack is another concatenation method. Similarly, assume two tensors a and b both have shapes (2, 3), concatenated along the 0th and 1st dimensions respectively, Pytorch code is as follows:

    a = torch.tensor([[1, 2, 3], [4, 5, 6]])
    b = torch.tensor([[7, 8, 9], [10, 11, 12]])
    
    # Concatenate along rows (dim=0) → Shape (2, 2, 3)
    y = torch.stack([a,b], dim=0)
    # Result:
    # [
    #   [[ 1,  2,  3],[ 4,  5,  6]],
    #   [[ 7,  8,  9],[ 10, 11, 12]]
    # ]
    print(y)
    
    # Concatenate along rows (dim=0) → Shape (2, 2, 3)
    y = torch.stack([a,b], dim=1)
    # Result:
    # [
    #   [[ 1,  2,  3], [ 7,  8,  9]],
    #   [[ 4,  5,  6], [10, 11, 12]]
    # ]
    print(y)

Candle works similarly:

    let a_data = vec![1i64, 2, 3, 4, 5, 6];
    let b_data = vec![7i64, 8, 9,10,11,12];
    let a = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    let b = Tensor::from_vec(b_data, (2,3),&Device::Cpu)?;
    
    // Concatenate along rows (dim=0) → Shape (2, 2, 3)
    let y = candle_core::Tensor::stack(&[&a,&b], 0)?;
    // Result:
    // [
    //   [[ 1,  2,  3],[ 4,  5,  6]],
    //   [[ 7,  8,  9],[ 10, 11, 12]]
    // ]
    println!("{y}");
    
    // Concatenate along rows (dim=0) → Shape (2, 2, 3)
    let y = candle_core::Tensor::stack(&[&a,&b], 1)?;

    // Result:
    // [
    //   [[ 1,  2,  3], [ 7,  8,  9]],
    //   [[ 4,  5,  6], [10, 11, 12]]
    // ]
    println!("{y}");

Difference Between Cat and Stack?

In PyTorch, torch.cat and torch.stack both concatenate tensors, but the key difference lies in whether a new dimension is created. As we can see, stack directly adds a new dimension.

However, stack can be replaced by cat because we can manually add dimensions.

Pytorch:

    # The following two operations are equivalent
    y1 = torch.stack([a,b], dim=0)
    y2 = torch.cat([a.unsqueeze(0),b.unsqueeze(0)], dim=0)
    # Result:
    # [
    #   [[ 1,  2,  3],[ 4,  5,  6]],
    #   [[ 7,  8,  9],[ 10, 11, 12]]
    # ]
    assert y1.equal(y2)

Candle:

// The following two operations are equivalent
    let y1 = candle_core::Tensor::stack(&[&a,&b], 1)?;
    let y2 = candle_core::Tensor::cat(&[&a.unsqueeze(0)?,&b.unsqueeze(0)?], 1)?;

    println!("{y1}");
    println!("{y2}");

From the above, the dimension-expanding API on both sides is unsqueeze.

If there is an increase in dimensions, there is also a decrease in dimensions, and the operation for reducing dimensions is squeeze:

Pytorch

    a = torch.tensor([[1, 2, 3]])
    
    # (2,3) -> (3)
    b = a.squeeze()
    
    # tensor([1, 2, 3])
    print(b.shape)
   

Candle

    let a = Tensor::new(vec![&[1i64,2,3]],  &Device::Cpu)?;
    // Reduce dimensions, parameter is the dimension, default is 0
    // Tensor[2, 3; i64]
    let y = a.squeeze(0)?;

    //Tensor[1, 2, 3; i64]
    println!("{:?}",y);

Splitting Tensors

Split the tensor into specified chunks. By default, it returns chunks of the specified size, but if the last chunk cannot be evenly divided by the specified chunk size, it may be smaller than the specified chunk size, which can be understood as the remainder in division.

Pytorch

	a = torch.tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]])
    # Split along the 0th dimension
    b = a.chunk(2, dim=0)
    # (tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]]),)
    print(b)
    # Split along the 1st dimension
    c = a.chunk(2, dim=1)
    # (tensor([[0, 1, 2, 3, 4, 5]]), tensor([[ 6,  7,  8,  9, 10]]))
    print(c)

Candle

    let a = Tensor::new(vec![&[[0i64,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]]],  &Device::Cpu)?;
    // Split along the 0th dimension
    let b = a.chunk(2, 0)?;
    // (tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10]]),)
    println!("a: {a}");
    // Split along the 1st dimension
    let c = a.chunk(2, 1)?;
    // [tensor([[0, 1, 2, 3, 4, 5]]), tensor([[ 6,  7,  8,  9, 10]])]
    println!("c: {:?}",c);

Transposing Tensors

By providing two dimensions of the tensor, transpose the tensor using transpose. Pytorch

	data = [
       [1,2,3],
       [4,5,6],
    ]
    a = torch.tensor(data)
    # Transpose along the 0th and 1st dimensions
    b = a.transpose(0,1)
    
    # [[1, 4],
    #  [2, 5],
    #  [3, 6]]
    print(b)

Candle

    let a_data = vec![1i64, 2, 3, 4, 5, 6];
    let a = Tensor::from_vec(a_data, (2,3),&Device::Cpu)?;
    
    // Transpose the matrix
    let y = a.transpose(0,1)?;

    // Print result:
    //     [[1, 4],
    //      [2, 5],
    //      [3, 6]]
    println!("{y}");

The above operations are some of the more common tensor operations. There are more uncommon operations that are not covered here. These are the foundations for model inference, and understanding these basics allows us to better understand the significance of some operations in models.