Since the Candle framework loads safetensors format weights, safetensors only store weight information without computational graph details. Therefore, it is necessary to implement the model structure externally. Since most models are currently trained using Pytorch, understanding the equivalent operations between Pytorch and Candle is essential. By understanding their differences, we can more conveniently perform model migration.
TLDR
This article mainly discusses the basic tensor operations between Candle and Pytorch, suitable for beginners to reference.
Basic Operations
Tensors
Directly initializing tensors from an array, Pytorch automatically infers the type:
x = torch.tensor([1,1])
print(x) #tensor([1, 1])
Candle requires specifying the data type:
let data: [u32; 2] = [1u32, 1];
let x = candle_core::Tensor::new(&data, &Device::Cpu).unwrap();
println!("{x}");
// [1, 1]
// Tensor[[2], u32]
Generating tensors based on the shape of other tensors, as follows in Pytorch:
x = torch.tensor([0.1,0.2])
zero_tensor = torch.zeros_like(x) # Fill with 0
ones_tensor = torch.ones_like(x) # Fill with 1
random_tensor = torch.rand_like(x) # Fill with random numbers
print(zero_tensor) # tensor([0., 0.])
print(ones_tensor) # tensor([1., 1.])
print(random_tensor) #tensor([0.1949, 0.4253])
In Candle:
let data: [f32; 2] = [0.1,0.2];
let x = candle_core::Tensor::new(&data, &Device::Cpu)?;
let zero_tensor = x.zeros_like()?;
let ones_tensor = x.ones_like()?;
let random_tensor = x.rand_like(0.0, 1.0)?;
println!("zero_tensor: {zero_tensor}"); //zero_tensor: [0., 0.] Tensor[[2], f32]
println!("ones_tensor: {ones_tensor}"); //ones_tensor: [1., 1.] Tensor[[2], f32]
println!("random_tensor: {random_tensor}"); //random_tensor: [0.9306, 0.5341] Tensor[[2], f32]
Checking tensor dimensions, Pytorch provides two methods:
x = torch.tensor([0.1,0.2])
print(x.shape) # torch.Size([2])
print(x.size()) # torch.Size([2])
In Candle, directly print or call the shape method. For very large tensors, it’s better to use shape:
println!("{:?}",x.shape()); //[2]
println!("{x}"); //Tensor[[2], f32]
Tensor Arithmetic
Tensor arithmetic operations such as addition, subtraction, multiplication, and division, Pytorch generally uses operators rather than calling methods.
x = torch.tensor([0.1,0.2])
y = torch.tensor([0.3,0.4])
a1 = x + y
a2 = x - y
a3 = x * y
a4 = x / y
print(a1,a2,a3,a4)
#tensor([0.4000, 0.6000]) tensor([-0.2000, -0.2000]) tensor([0.0300, 0.0800]) tensor([0.3333, 0.5000])
In Candle, since operators are also overloaded, two methods can be used:
let data: [f32; 2] = [0.1,0.2];
let x = candle_core::Tensor::new(&data, &Device::Cpu)?;
let data:[f32; 2] = [0.3,0.4];
let y = candle_core::Tensor::new(&data, &Device::Cpu)?;
let a1 = x.add(&y)?;
let a2 = x.sub(&y)?;
let a3 = x.mul(&y)?;
let a4 = x.div(&y)?;
println!("{a1}");//[0.4000, 0.6000]
println!("{a2}");//[-0.2000, -0.2000]
println!("{a3}");//[0.0300, 0.0800]
println!("{a4}");//[0.3333, 0.5000]
let a1 = (&x + &y)?;
let a2 = (&x - &y)?;
let a3 = (&x * &y)?;
let a4 = (&x / &y)?;
println!("{a1}");//[0.4000, 0.6000]
println!("{a2}");//[-0.2000, -0.2000]
println!("{a3}");//[0.0300, 0.0800]
println!("{a4}");//[0.3333, 0.5000]
Accelerators
Pytorch supports a wide variety of accelerators, such as CPU, CUDA, MPS, ROCm, etc., and even custom backends.
torch.device("cpu")
torch.device("mps")
torch.device("cuda")
torch.device("cuda:0")
torch.device("cuda:1")
x = torch.tensor([0.1,0.2])
x.cuda()
x.cpu()
x.to("cpu")
x.to("cuda")
x.xpu()
From the enum definition in Candle, it only supports CPU, CUDA, and MPS, and must be specified at initialization. The default crate only enables CPU. If you want to use NVIDIA GPUs or MacBook M-series chips, you need to enable the corresponding features.
Pytorch’s
to=> Candle’sto_device
pub enum Device {
Cpu,
Cuda(crate::CudaDevice),
Metal(crate::MetalDevice),
}
// Call cpu
let data: [f32; 2] = [0.1,0.2];
let x = candle_core::Tensor::new(&data, &Device::Cpu)?;
// Call cuda
let cuda = Device::cuda(0)?;
let x = candle_core::Tensor::new(&data, &cuda)?;
// Call mps
let mps = Device::mps(0)?;
let x = candle_core::Tensor::new(&data, &mps)?;
let device = Device::new_cuda(0)?;
x.to_device(&device);
To check if the respective accelerators are available,
The above operations are almost similar APIs, with no special differences.