SDPA
The SDPA module in PyTorch has the API torch.nn.functional.scaled_dot_product_attention. It is used to accelerate Attention computation in large models.
The formula is roughly as follows:
In PyTorch, it is a highly optimized implementation. Candle’s core library does not have this implementation. However, some models in the transformers module have implementations. The implemented version is quite simple, essentially a replication of the above formula:
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let dim = q.dim(D::Minus1)?;
let scale_factor = 1.0 / (dim as f64).sqrt();
let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
}
Normally, there might also be implementations for causal masks causal_mask and attn_mask.
linspace
The linspace module in PyTorch has the API torch.linspace. It is used to generate a 1D tensor with evenly spaced values.
For example, generating 5 evenly spaced values from 0 to 1:
x = torch.linspace(0, 1, 5)# Arguments are start value, end value, number of steps
print(x) # tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000])
Candle does not have this built-in, but implementing it yourself is quite simple:
pub fn linspace(start: f64, stop: f64, steps: usize,device: &Device) -> Result<Tensor> {
if steps == 0 {
Tensor::from_vec(Vec::<f64>::new(), steps, device)
} else if steps == 1 {
Tensor::from_vec(vec![start], steps, device)
} else {
let delta = (stop - start) / (steps - 1) as f64;
let vs = (0..steps)
.map(|step| start + step as f64 * delta)
.collect::<Vec<_>>();
Tensor::from_vec(vs, steps, device)
}
}
weight_norm (conv1d)
torch.nn.utils.weight_norm is a weight normalization technique used to stabilize training and accelerate convergence.
The Python usage is roughly as follows:
from torch.nn.utils import weight_norm
# Apply weight normalization to convolutional layer
conv = nn.Conv1d(128, 256, kernel_size=3)
conv = weight_norm(conv) # Normalize the weights
Mathematically, weight normalization decomposes the weight parameter w into two parts:
where is a learnable parameter, is the original weight parameter. is the L2 norm of .
Candle has no built-in implementation, but this is commonly used in some LLMs, so some models in the transformers module have their own implementations. The implementation is roughly as follows:
fn conv1d_weight_norm(
in_c: usize,
out_c: usize,
kernel_size: usize,
bias: bool,
config: candle_nn::Conv1dConfig,
vb: VarBuilder,
) -> Result<Conv1d> {
let weight = if vb.contains_tensor("weight") {
vb.get((out_c, in_c, kernel_size), "weight")?
} else {
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?
};
let bias = if bias {
Some(vb.get(out_c, "bias")?)
} else {
None
};
Ok(Conv1d::new(weight, bias, config))
}
Unlike PyTorch, the shapes of
weight_gandweight_vneed to be determined according to the actual situation. You cannot completely copy the code above.
Snake Activation Function
The formula is roughly as follows
This is a relatively new activation function, more suitable for periodic data. In fact, this is the first time I’ve seen this activation function. It appears that this activation function is more commonly used for audio data, as the Descript Audio Codec (DAC) model in candle implements this activation function. The AudioVae module in VoxCPM also uses this activation function.
#[derive(Debug, Clone)]
pub struct Snake1d {
alpha: Tensor,
}
impl Snake1d {
pub fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
let alpha = vb.get((1, channels, 1), "alpha")?;
Ok(Self { alpha })
}
}
impl candle::Module for Snake1d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs_shape = xs.shape();
let xs = xs.flatten_from(2)?;
let sin = self.alpha.broadcast_mul(&xs)?.sin()?;
let sin = (&sin * &sin)?;
(xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape)
}
}
Product
Both np.prod and torch.prod calculate the product of all elements in a tensor.
arr = np.array([1, 2, 3, 4])
result = np.prod(arr)
print(result) # Output: 24 (1×2×3×4 = 24)
This operation does not exist in candle and needs to be converted to an array, then calculated using x.iter().product();.
slice_assign
slice_assign is an operation used to assign one tensor to another tensor. In PyTorch, we often see usage similar to the following:
x = torch.arange(0, 4*5).reshape(4, 5)
y = torch.arange(0, 2*3).reshape(3, 2)
x[1:4, 3:5] = y
This operation assigns the elements of tensor y to the specified slice of tensor x.
A more visual explanation:
// Original tensor:
// Row\Col 0 1 2 3 4
// 0 [0, 1, 2, 3, 4]
// 1 [5, 6, 7, 8, 9] ← Select rows 1-3
// 2 [10,11,12,13,14] ← Select rows 1-3
// 3 [15,16,17,18,19] ← Select rows 1-3
// ↑ ↑
// Col 3-4
// Selected region: tensor[1..4, 3..5]
// Shape: (3, 2)
// [[8, 9],
// [13,14],
// [18,19]]
// Replace with src:
// src = [[0, 1],
// [2, 3],
// [4, 5]]
// Result:
// [[ 0, 1, 2, 3, 4],
// [ 5, 6, 7, 0, 1], ← Row 1, cols 3-4 become [0, 1]
// [10, 11, 12, 2, 3], ← Row 2, cols 3-4 become [2, 3]
// [15, 16, 17, 4, 5]] ← Row 3, cols 3-4 become [4, 5]
Rust of course does not support this syntax, but it can be implemented using x.slice_assign([start, end], y).
let tensor = Tensor::arange(0u32, 4 * 5, &dev)?.reshape((4, 5))?;
let src = Tensor::arange(0u32, 2 * 3, &dev)?.reshape((3, 2))?;
let out = tensor.slice_assign(&[1..4, 3..5], &src)?;
The only somewhat difficult aspect of this operation to understand is the shape parameter, especially for higher-dimensional tensors. Simply put, the shape of the slice_assign selected region must match the shape of the tensor being assigned.