Mar 20, 2025
3 min read
Rust,
Candle,
Pytorch,

Rust AI 推理框架 Candle 与 Pytorch 的张量的基本等价操作

本文介绍了Rust Candle与Pytorch在张量基本操作上的等价实现,包括张量初始化、形状操作、加减乘除运算及加速器支持,适合初学者参考。

由于 Candle 框架加载的是 safetensors 格式权重,safetensors 只保存了权重信息,并没有计算图等信息,因此需要外部实现模型的结构,由于现在大部分模型都是通过 Pytorch 训练的,因此了解 Pytorch 与 Candle 之间的等价操作是非常有必要的。通过了解他们之间的差异,我们才比较方便地进行移植模型的工作。

TLDR

这篇文章内容基本讲述 candle 与 pytorch 之间的张量基本操作,适合初学者查阅。

基本操作

Tensors  张量

直接从数组中初始化张量,pytorch 会自动推断类型:

x = torch.tensor([1,1])
print(x) #tensor([1, 1])

candle 需要指定数据类型:

let data: [u32; 2] = [1u32, 1];
let x = candle_core::Tensor::new(&data, &Device::Cpu).unwrap();
println!("{x}");
// [1, 1]
// Tensor[[2], u32]

基于其他张量形状生成张量,pytorch 如下:

	x = torch.tensor([0.1,0.2])
    zero_tensor = torch.zeros_like(x) # 填充 0
    ones_tensor = torch.ones_like(x) # 填充 1
    random_tensor = torch.rand_like(x) # 填充随机数
    
    print(zero_tensor) # tensor([0., 0.])
    print(ones_tensor) # tensor([1., 1.])
    print(random_tensor) #tensor([0.1949, 0.4253])

candle 如下:

	let data: [f32; 2] = [0.1,0.2];
    let x = candle_core::Tensor::new(&data, &Device::Cpu)?;
    let zero_tensor = x.zeros_like()?;
    let ones_tensor = x.ones_like()?;
    let random_tensor = x.rand_like(0.0, 1.0)?;
    
    println!("zero_tensor: {zero_tensor}"); //zero_tensor: [0., 0.]  Tensor[[2], f32]
    println!("ones_tensor: {ones_tensor}"); //ones_tensor: [1., 1.]  Tensor[[2], f32]
    println!("random_tensor: {random_tensor}"); //random_tensor: [0.9306, 0.5341]  Tensor[[2], f32]

检查张量维度, pytorch 可用两种方法:

	x = torch.tensor([0.1,0.2])
    print(x.shape) # torch.Size([2])
    print(x.size()) # torch.Size([2])

candle 直接打印或者调用 shape 方法, 过大的张量最好还是用 shape:

println!("{:?}",x.shape()); //[2]
println!("{x}"); //Tensor[[2], f32]

张量运算

张量运算,加减乘除,pytorch 一般使用运算符,比较少调用方法。

	x = torch.tensor([0.1,0.2])
    y = torch.tensor([0.3,0.4])
    a1 = x + y
    a2 = x - y
    a3 = x * y
    a4 = x / y
    
    print(a1,a2,a3,a4)
    #tensor([0.4000, 0.6000]) tensor([-0.2000, -0.2000]) tensor([0.0300, 0.0800]) tensor([0.3333, 0.5000])

candle 这里由于也重载了运算符,因此也是可以使用两种方法调用:

    let data: [f32; 2] = [0.1,0.2];
    let x = candle_core::Tensor::new(&data, &Device::Cpu)?;

    let data:[f32; 2] = [0.3,0.4];
    let y = candle_core::Tensor::new(&data, &Device::Cpu)?;

    let a1 = x.add(&y)?;
    let a2 = x.sub(&y)?;
    let a3 = x.mul(&y)?;
    let a4 = x.div(&y)?;
    
    println!("{a1}");//[0.4000, 0.6000]
    println!("{a2}");//[-0.2000, -0.2000]
    println!("{a3}");//[0.0300, 0.0800]
    println!("{a4}");//[0.3333, 0.5000]

    let a1 = (&x + &y)?;
    let a2 = (&x - &y)?;
    let a3 = (&x * &y)?;
    let a4 = (&x / &y)?;

    println!("{a1}");//[0.4000, 0.6000]
    println!("{a2}");//[-0.2000, -0.2000]
    println!("{a3}");//[0.0300, 0.0800]
    println!("{a4}");//[0.3333, 0.5000]
 

加速器

pytorch 支持非常多的加速器,如CPU,CUDA,MPS,ROCm 等,基于还支持自定义后端。

    torch.device("cpu")
    torch.device("mps")
    torch.device("cuda")
    torch.device("cuda:0")
    torch.device("cuda:1")
    x = torch.tensor([0.1,0.2])
    x.cuda()
    x.cpu()
    x.to("cpu")
    x.to("cuda")
    x.xpu()

candle 从枚举定义就可以看出, 只支持 CPU,CUDA,MPS,初始化时就需要指定。 默认 crate 只开启了 cpu。如果你想在 nvidia 显卡或者在 macbook m 系列芯片上使用,需要开启相应的 features。

pytorch 的 to => candle to_device

pub enum Device {
    Cpu,
    Cuda(crate::CudaDevice),
    Metal(crate::MetalDevice),
}
	// 调用 cpu
    let data: [f32; 2] = [0.1,0.2];
    let x = candle_core::Tensor::new(&data, &Device::Cpu)?;

	// 调用 cuda
    let cuda = Device::cuda(0)?;
    let x = candle_core::Tensor::new(&data, &cuda)?;

	// 调用 mps
    let mps = Device::mps(0)?;
    let x = candle_core::Tensor::new(&data, &mps)?;

	let device = Device::new_cuda(0)?;
	x.to_device(&device);

判断相应的加速器是否可用,

可以看到以上的操作几乎是相似的 API, 并没有特别之处。