Continuing from the previous discussion.
Reduction operations involve compressing a data set (such as an array or tensor) into a smaller result through specific rules or functions. Common reduction operations include summation, mean, maximum, and minimum.
Table Preview: ✅ Indicates equivalent implementation exists 🚫 Indicates no equivalent implementation exists ☢️ Indicates alternative implementation exists
| Operation | Pytorch | Candle | |
|---|---|---|---|
| Indices of Maximum/Minimum Values | argmax, argmin | argmax, argmin | ✅ |
| Slice-wise Maximum/Minimum Values | amax, amin | max, min | ✅ |
| Tensor Minimum and Maximum Values | aminmax | Not implemented, can be achieved using above methods | ☢️ |
| Global Maximum/Minimum Value of All Elements in Tensor | max, min | Not implemented, can be indirectly achieved | ☢️ |
p-Norm of (input - other) | dict | Not implemented | 🚫 |
| Logarithm of Sum of Exponentials Across Rows | logsumexp | log_sum_exp | ✅ |
| Mean of All Elements | mean | mean/mean_keepdim | ✅ |
| Median Value | median | Not implemented | 🚫 |
| Mode Value and Indices | mode | Not implemented | 🚫 |
| Product of All Elements | prod | Not implemented | 🚫 |
| Standard Deviation | std | Not implemented, see below | ☢️ |
| Standard Deviation and Mean | std_mean | Not implemented | ☢️ |
| Summation | sum | sum, sum_all | ✅ |
| Unique Elements | unique | Not implemented | 🚫 |
| Variance | var | var/var_keepdim | ✅ |
| Variance and Mean | var_mean | Not implemented, refer to var and mean | ☢️ |
| Total Number of Elements in Tensor | numel | elem_count | ✅ |
Preliminaries
In Pytorch, many operations support parameters like -1, for example:
torch.mean(a, -1)
-1 generally represents the last dimension. In Candle, due to type restrictions, -1 cannot be used directly but instead uses an enumeration:
let mean = a.mean(D::Minus1)?; // D::Minus1 represents -1
Indices of Maximum/Minimum Values
The primary function is to find the indices along a certain dimension based on the maximum or minimum value.
Pytorch:
a = torch.tensor([
[1.0, -2.0, 3.0],
[0.1, -0.2, 0.01]
])
# Indicates operation along the 0th dimension (vertical direction), returning the row indices of the maximum values in each column.
print(a.argmax(dim=0)) # tensor([0, 1, 0])
# Indicates operation along the 1st dimension (horizontal direction), returning the column indices of the maximum values in each row.
print(a.argmax(dim=1)) # tensor([2, 0])
print(a.argmin(dim=0)) # tensor([1, 0, 1])
print(a.argmin(dim=1)) # tensor([1, 1])
Candle:
let a_data = vec![1.0, -2.0, 3.0, 0.1, -0.2, 0.01];
let x = Tensor::from_vec(a_data, (2, 3), &Device::Cpu)?;
// Indicates operation along the 0th dimension (vertical direction), returning the row indices of the maximum values in each column.
let y = x.argmax(0)?;
println!("{:?}", y); // Tensor[2; u32]
// Indicates operation along the 1st dimension (horizontal direction), returning the column indices of the maximum values in each row.
let y = x.argmax(1)?;
println!("{:?}", y);
let y = x.argmin(0)?;
println!("{:?}", y); // Tensor[1, 0, 1; u32]
let y = x.argmin(1)?;
println!("{:?}", y); // Tensor[1, 1; u32]
Slice-wise Maximum/Minimum Values
Candle does not have amax or amin. However, Candle’s max and min are largely consistent with Pytorch’s amax and min.
Pytorch:
a = torch.tensor([
[1.0, -2.0, 3.0],
[0.1, -0.2, 0.01]
])
print(a.amax(dim=1)) # tensor([3.0000, 0.1000])
Candle:
let a_data = vec![1.0, -2.0, 3.0, 0.1, -0.2, 0.01];
let x = Tensor::from_vec(a_data, (2, 3), &Device::Cpu)?;
let y = x.max(1)?;
println!("y = {y}"); // y = [3.0000, 0.1000]
Candle can also implement amax using argmax and stack:
let a_data = vec![1.0, -2.0, 3.0, 0.1, -0.2, 0.01];
let x = Tensor::from_vec(a_data, (2, 3), &Device::Cpu)?;
let y = x.argmax(1)?;
let indexs = y.to_vec1::<u32>()?;
let mut max_list = Vec::with_capacity(x.dims()[0]);
for (i, _) in indexs.iter().enumerate() {
let max = x.i((i, indexs[i] as usize))?;
max_list.push(max);
}
let amax = candle_core::Tensor::stack(&max_list, 0)?;
println!("amax:{:?}", amax); // amax: Tensor[3, 0.1; f64]
Since amin is similar to amax, it will not be expanded here.
Tensor Minimum and Maximum Values (aminmax)
This is essentially the combination of the above amax and amin, so it will not be expanded further.
Global Maximum/Minimum Value of All Elements in Tensor
The max and min in Pytorch differ from those in Candle. Pytorch’s max and min return both the value and the index:
print(a.max(dim=1))
# torch.return_types.max(
# values=tensor([3.0000, 0.1000]),
# indices=tensor([2, 0]))
Candle requires both argmax and max:
let a_data = vec![1.0, -2.0, 3.0, 0.1, -0.2, 0.01];
let x = Tensor::from_vec(a_data, (2, 3), &Device::Cpu)?;
let y = x.max(1)?;
let idx = x.argmax(1)?;
println!("{:?}", y);
println!("{:?}", idx);
Similarly, Candle requires min and argmax to achieve the equivalent of Pytorch’s min.
Differences Between amax, amin and max, min
In Pytorch, amax/amin do not return indices, whereas max/min do.
In Candle, max/min are roughly equivalent to Pytorch’s amax/amin. However, Candle’s functionality is definitely less comprehensive than Pytorch’s version.
Logarithm of Sum of Exponentials Across Rows
Formula:
Pytorch:
a = torch.tensor([
[1.0, -2.0, 3.0],
[0.1, -0.2, 0.01]
])
print(torch.logsumexp(a, dim=1)) # tensor([3.1328, 1.0764])
Candle:
let a_data = vec![1.0, -2.0, 3.0, 0.1, -0.2, 0.01];
let x = Tensor::from_vec(a_data, (2, 3), &Device::Cpu)?;
let y = x.log_sum_exp(1)?;
println!("{:?}", y); // Tensor[3.1328452337275756, 1.076350264534186; f64]
Both
amaxandaminhave akeepdimparameter, corresponding tomax_keepdim/min_keepdimin Candle.
Mean of All Elements
Although both Pytorch and Candle use mean for averaging, they are only equivalent when both specify the dim parameter.
When Pytorch does not specify the dim dimension, it defaults to returning the average of all elements in the Tensor.
Pytorch:
a = torch.tensor([
[1.0, -2.0, 3.0],
[0.1, -0.2, 0.01]
])
# `a.mean()` and `torch.mean(a)` are equivalent
print(a.mean()) # tensor(0.3183), returns the average of all elements if `dim` is not specified
print(a.mean(dim=0)) # tensor([ 0.5500, -1.1000, 1.5050])
print(a.mean(dim=1)) # tensor([ 0.6667, -0.0300])
Candle:
let a_data = vec![1.0, -2.0, 3.0, 0.1, -0.2, 0.01];
let x = Tensor::from_vec(a_data, (2, 3), &Device::Cpu)?;
let y = x.mean(0)?;
println!("{:?}", y); // Tensor[0.55, -1.1, 1.505; f64]
let y = x.mean(1)?;
println!("{:?}", y); // Tensor[0.6666666666666666, -0.030000000000000002; f64]
Since Rust methods must specify parameters, you can either implement this yourself or use x.mean_all().
let z = y.to_vec1::<f64>()?;
// Calculate mean of z
let z = z.iter().sum::<f64>() / z.len() as f64;
println!("{:?}", z); // 0.3183333333333333
// The following is equivalent to the above
let y = x.mean_all()?;
println!("y:{:?}", y);
For other operations that similarly do not specify dimensions and return all elements, check if there is an
xxx_all()method. For example,sumhas asum_all().
Pytorch’s mean also has a keepdim parameter, which corresponds to Candle’s mean_keepdim() method:
// Python
a.mean(dim=0, keepdim=True)
// Rust
x.mean_keepdim(0)?;
Standard Deviation
The formula for standard deviation is as follows:
Candle does not have a std function, but it can be calculated using the formula above. For example, when dim = 0 and unbiased = False:
Pytorch:
a = torch.tensor([
[1.0, -2.0, 3.0],
[0.1, -0.2, 0.01]
])
print(a.std(dim=0, unbiased=False)) # tensor([0.4500, 0.9000, 1.4950])
Equivalent Candle code:
let a_data = vec![1.0, -2.0, 3.0, 0.1, -0.2, 0.01];
let x = Tensor::from_vec(a_data, (2, 3), &Device::Cpu)?;
// Calculate mean
let mean = x.mean(0)?;
// Calculate difference
let diff = x.broadcast_sub(&mean)?;
// Calculate squared difference
let diff_sq = diff.sqr()?;
// Calculate mean of squared differences
let mean_sq = diff_sq.mean(0)?;
// Calculate standard deviation
let std = mean_sq.sqrt()?;
println!("std:{:?}", std); // std: Tensor[0.45, 0.9, 1.495; f64]
Pytorch’s std defaults to unbiased = True, meaning a.std(dim=0, unbiased=True) calculates the standard deviation per column using an unbiased estimate.
The difference lies in the calculation as follows:
variance = torch.sum(squared_diff) / (x.numel() - 1)
Candle implementation:
let a_data = vec![1.0, -2.0, 3.0, 0.1, -0.2, 0.01];
let x = Tensor::from_vec(a_data, (2, 3), &Device::Cpu)?;
// Calculate mean
let mean = x.mean(0)?;
println!("x.dims()[0]:{:?}", x.dims()[0]);
// Calculate difference
let diff = x.broadcast_sub(&mean)?;
// Calculate squared difference
let diff_sq = diff.sqr()?;
// Calculate mean of squared differences
let mean_sq = diff_sq.sum(0)? / ((x.dims()[0] - 1) as f64);
// Calculate standard deviation
let std = mean_sq?.sqrt()?;
println!("std:{:?}", std); // std: Tensor[0.45, 0.9, 1.495; f64]
x.dims()[0] indicates calculating the unbiased estimate based on the length of the columns.
Standard Deviation and Mean
Candle does not implement this, refer to the implementation above.
Summation
The summation in Pytorch and Candle are basically consistent. The only difference is that when sum has no parameters, Candle uses sum_all instead:
Pytorch:
a = torch.tensor([
[1.0, -2.0, 3.0],
[0.1, -0.2, 0.01]
])
print(a.sum()) # tensor(1.9100)
print(a.sum(dim=0)) # tensor([ 1.1000, -2.2000, 3.0100])
print(a.sum(dim=1)) # tensor([ 2.0000, -0.0900])
Candle:
let a_data = vec![1.0, -2.0, 3.0, 0.1, -0.2, 0.01];
let x = Tensor::from_vec(a_data, (2, 3), &Device::Cpu)?;
let y = x.sum_all()?;
println!("{:?}", y); // Tensor[1.9100000000000001; f64]
let y = x.sum(0)?;
println!("{:?}", y); // Tensor[1.1, -2.2, 3.01; f64]
let y = x.sum(1)?;
println!("{:?}", y); // Tensor[2, -0.09000000000000001; f64]
Variance
Variance formula:
The difference is that Pytorch can omit specifying dim:
a = torch.tensor([
[1.0, -2.0, 3.0],
[0.1, -0.2, 0.01]
])
print(a.var()) # tensor(2.6884)
print(a.var(dim=0)) # tensor([0.4050, 1.6200, 4.4700])
print(a.var(dim=1)) # tensor([6.3333, 0.0237])
print(a.var(dim=0, keepdim=True)) # tensor([[0.4050, 1.6200, 4.4700]])
Candle does not have an equivalent operation for a.var():
let a_data = vec![1.0, -2.0, 3.0, 0.1, -0.2, 0.01];
let x = Tensor::from_vec(a_data, (2, 3), &Device::Cpu)?;
let y = x.var(0)?;
println!("{:?}", y); // Tensor[0.405, 1.