Jan 06, 2025
3 min read
Rust,

进阶 Rust:如何在 Rust 中优雅地重载操作符

简单实现 Rust 中的重载操作符,以实现张量加法运算

以前学 swift 的时候看到 swift 不但支持重载操作符,甚至还可以自定义操作符,感觉到大为震撼。这种灵活性确实令人印象深刻,但也带来了一定的风险:过度使用自定义操作符可能会导致代码难以维护,并增加学习成本。

相比之下,操作符重载提供了一个更为平衡的选择。在某些应用场景中,重载操作符能够显著简化代码,提升可读性和表达力。例如,在深度学习框架中,张量运算是一个典型的场景。如果不能重载操作符,代码将变得冗长且难以阅读,影响开发效率和代码的可维护性。

通过合理地使用操作符重载,我们可以使复杂的数学运算和数据处理变得更加直观,同时保持代码的简洁和易读。这不仅有助于提高开发速度,还能降低后续维护的难度。

比如 pytorch 张量相加:

import torch

# 创建两个张量
tensor_a = torch.tensor([1.0, 2.0, 3.0])
tensor_b = torch.tensor([4.0, 5.0, 6.0])

# 使用 + 操作符进行相加
tensor_sum = tensor_a + tensor_b
print("Using + operator:", tensor_sum)

# 使用 .add() 方法进行相加
tensor_sum_add = torch.add(tensor_a, tensor_b)
print("Using torch.add():", tensor_sum_add)

Rust 能不能实现类似的操作?答案自然是可以的。在rust 中因为运算符是方法调用的语法糖。因此许多 Operator 可以通过 trait 重载。能够重载的运算符,基本都在core::ops 命名空间下。更多的可以看看这个地址: https://doc.rust-lang.org/core/ops/

实践

以上面的张量为例子,我们看看怎么实现两个张量相加。

先定义一个简陋版本的张量:

#[derive(Debug, PartialEq)]
struct Tensor {
    data: Vec<f64>,
}

impl Tensor {
    // 构造函数:创建一个新的 Tensor
    pub fn new(data: Vec<f64>) -> Self {
        Tensor { data }
    }

    // 获取张量的长度
    pub fn len(&self) -> usize {
        self.data.len()
    }
}

实现 std::ops::Add

// 实现张量的加法操作
// 该函数接收另一个张量作为参数,并返回两个张量相加的结果
// 注意:这个函数的实现假设调用者已经确保两个张量长度相同
fn add(self, other: Self) -> Self {
    // 检查两个张量的长度是否相同,如果不同则抛出异常
    // 这是一个重要的前置条件检查,确保后续操作的合法性
    if self.len() != other.len() {
        panic!("Tensors must have the same length to be added");
    }

    // 使用迭代器和 zip 方法将两个张量的元素逐个相加
    // 这里展示了如何高效地遍历两个集合并进行元素级别的操作
    let result_data: Vec<f64> = self.data
        .iter()
        .zip(other.data.iter())
        .map(|(a, b)| a + b)
        .collect();

    // 创建并返回一个新的张量,包含相加后的结果
    // 这是整个加法操作的最终产出
    Tensor::new(result_data)
}

调用示例如下:

fn main() {
    let tensor_a = Tensor::new(vec![1.0, 2.0, 3.0]);
    let tensor_b = Tensor::new(vec![4.0, 5.0, 6.0]);

    // 使用 + 操作符进行相加
    let tensor_sum = tensor_a + tensor_b;
    println!("tensor_a + tensor_b = {:?}", tensor_sum);

    // 验证结果
    assert_eq!(tensor_sum, Tensor::new(vec![5.0, 7.0, 9.0]));
    println!("Test passed!");
}

张量和标量运算

张量和标量之间自然也可以实现:

// 实现 Tensor + f32
impl Add<f32> for Tensor {
    type Output = Self;

    fn add(self, scalar: f32) -> Self {
        // 将 f32 转换为 f64
        let scalar_f64 = scalar as f64;

        // 使用迭代器将每个元素加上标量
        let result_data: Vec<f64> = self.data
            .iter()
            .map(|&x| x + scalar_f64)
            .collect();

        Tensor::new(result_data)
    }
}

调用示例:

fn main() {
    let tensor_a = Tensor::new(vec![1.0, 2.0, 3.0]);

    // 使用 + 操作符进行相加
    let tensor_sum = tensor_a + 4.0;
    println!("tensor_a + 4.0f32 = {:?}", tensor_sum);

    // 验证结果
    assert_eq!(tensor_sum, Tensor::new(vec![5.0, 6.0, 7.0]));
    println!("Test passed!");
}

反向操作

一个问题,下面这样是不行的:

let tensor_sum = 4.0 + tensor_a ;

默认情况下,Rust 不会自动处理反向操作(即 f32 + Tensor),因此我们需要显式地为 f32 实现 Add<Tensor> 特质。为此,我们可以使用 std::ops::Add<Rhs = Tensor> 来指定右边的操作数类型为 Tensor

// 实现 f32 + Tensor
impl Add<Tensor> for f32 {
    type Output = Tensor;

    fn add(self, tensor: Tensor) -> Tensor {
        // 将 f32 转换为 f64
        let scalar_f64 = self as f64;

        // 使用迭代器将每个元素加上标量
        let result_data: Vec<f64> = tensor.data
            .iter()
            .map(|&x| x + scalar_f64)
            .collect();

        Tensor::new(result_data)
    }
}

总结

直接重载运算符会让张量运算调用更加简洁,更加符合数学语义。