由于 Candle 框架加载的是 safetensors 格式权重,safetensors 只保存了权重信息,并没有计算图等信息,因此需要外部实现模型的结构,由于现在大部分模型都是通过 Pytorch 训练的,因此了解 Pytorch 与 Candle 之间的等价操作是非常有必要的。通过了解他们之间的差异,我们才比较方便地进行移植模型的工作。
TLDR
这篇文章内容基本讲述 candle 与 pytorch 之间的张量基本操作,适合初学者查阅。
基本操作
Tensors 张量
直接从数组中初始化张量,pytorch 会自动推断类型:
x = torch.tensor([1,1])
print(x) #tensor([1, 1])
candle 需要指定数据类型:
let data: [u32; 2] = [1u32, 1];
let x = candle_core::Tensor::new(&data, &Device::Cpu).unwrap();
println!("{x}");
// [1, 1]
// Tensor[[2], u32]
基于其他张量形状生成张量,pytorch 如下:
x = torch.tensor([0.1,0.2])
zero_tensor = torch.zeros_like(x) # 填充 0
ones_tensor = torch.ones_like(x) # 填充 1
random_tensor = torch.rand_like(x) # 填充随机数
print(zero_tensor) # tensor([0., 0.])
print(ones_tensor) # tensor([1., 1.])
print(random_tensor) #tensor([0.1949, 0.4253])
candle 如下:
let data: [f32; 2] = [0.1,0.2];
let x = candle_core::Tensor::new(&data, &Device::Cpu)?;
let zero_tensor = x.zeros_like()?;
let ones_tensor = x.ones_like()?;
let random_tensor = x.rand_like(0.0, 1.0)?;
println!("zero_tensor: {zero_tensor}"); //zero_tensor: [0., 0.] Tensor[[2], f32]
println!("ones_tensor: {ones_tensor}"); //ones_tensor: [1., 1.] Tensor[[2], f32]
println!("random_tensor: {random_tensor}"); //random_tensor: [0.9306, 0.5341] Tensor[[2], f32]
检查张量维度, pytorch 可用两种方法:
x = torch.tensor([0.1,0.2])
print(x.shape) # torch.Size([2])
print(x.size()) # torch.Size([2])
candle 直接打印或者调用 shape 方法, 过大的张量最好还是用 shape:
println!("{:?}",x.shape()); //[2]
println!("{x}"); //Tensor[[2], f32]
张量运算
张量运算,加减乘除,pytorch 一般使用运算符,比较少调用方法。
x = torch.tensor([0.1,0.2])
y = torch.tensor([0.3,0.4])
a1 = x + y
a2 = x - y
a3 = x * y
a4 = x / y
print(a1,a2,a3,a4)
#tensor([0.4000, 0.6000]) tensor([-0.2000, -0.2000]) tensor([0.0300, 0.0800]) tensor([0.3333, 0.5000])
candle 这里由于也重载了运算符,因此也是可以使用两种方法调用:
let data: [f32; 2] = [0.1,0.2];
let x = candle_core::Tensor::new(&data, &Device::Cpu)?;
let data:[f32; 2] = [0.3,0.4];
let y = candle_core::Tensor::new(&data, &Device::Cpu)?;
let a1 = x.add(&y)?;
let a2 = x.sub(&y)?;
let a3 = x.mul(&y)?;
let a4 = x.div(&y)?;
println!("{a1}");//[0.4000, 0.6000]
println!("{a2}");//[-0.2000, -0.2000]
println!("{a3}");//[0.0300, 0.0800]
println!("{a4}");//[0.3333, 0.5000]
let a1 = (&x + &y)?;
let a2 = (&x - &y)?;
let a3 = (&x * &y)?;
let a4 = (&x / &y)?;
println!("{a1}");//[0.4000, 0.6000]
println!("{a2}");//[-0.2000, -0.2000]
println!("{a3}");//[0.0300, 0.0800]
println!("{a4}");//[0.3333, 0.5000]
加速器
pytorch 支持非常多的加速器,如CPU,CUDA,MPS,ROCm 等,基于还支持自定义后端。
torch.device("cpu")
torch.device("mps")
torch.device("cuda")
torch.device("cuda:0")
torch.device("cuda:1")
x = torch.tensor([0.1,0.2])
x.cuda()
x.cpu()
x.to("cpu")
x.to("cuda")
x.xpu()
candle 从枚举定义就可以看出, 只支持 CPU,CUDA,MPS,初始化时就需要指定。 默认 crate 只开启了 cpu。如果你想在 nvidia 显卡或者在 macbook m 系列芯片上使用,需要开启相应的 features。
pytorch 的
to=> candleto_device
pub enum Device {
Cpu,
Cuda(crate::CudaDevice),
Metal(crate::MetalDevice),
}
// 调用 cpu
let data: [f32; 2] = [0.1,0.2];
let x = candle_core::Tensor::new(&data, &Device::Cpu)?;
// 调用 cuda
let cuda = Device::cuda(0)?;
let x = candle_core::Tensor::new(&data, &cuda)?;
// 调用 mps
let mps = Device::mps(0)?;
let x = candle_core::Tensor::new(&data, &mps)?;
let device = Device::new_cuda(0)?;
x.to_device(&device);
判断相应的加速器是否可用,
可以看到以上的操作几乎是相似的 API, 并没有特别之处。