Jan 06, 2025
5 min read
Rust,

Advanced Rust: How to Override Operators Elegantly in Rust

A simple implementation of operator overloading in Rust to achieve tensor addition operations.

A simple implementation of operator overloading in Rust to achieve tensor addition operations.

Previously, when learning Swift, I was amazed to see that Swift not only supports operator overloading but also allows custom operators, which was quite impressive. This flexibility is indeed eye-catching, but it also brings certain risks: excessive use of custom operators can make code difficult to maintain and increase learning costs.

In contrast, operator overloading provides a more balanced choice. In certain application scenarios, overloading operators can significantly simplify code, improve readability, and expressiveness. For example, in deep learning frameworks, tensor operations are a typical scenario. Without operator overloading, the code would become lengthy and difficult to read, affecting development efficiency and code maintainability.

By using operator overloading reasonably, we can make complex mathematical operations and data processing more intuitive while keeping the code concise and readable. This not only helps to improve development speed but also reduces the difficulty of subsequent maintenance.

For example, tensor addition in pytorch:

import torch

# Create two tensors
tensor_a = torch.tensor([1.0, 2.0, 3.0])
tensor_b = torch.tensor([4.0, 5.0, 6.0])

# Use + operator for addition
tensor_sum = tensor_a + tensor_b
print("Using + operator:", tensor_sum)

# Use .add() method for addition
tensor_sum_add = torch.add(tensor_a, tensor_b)
print("Using torch.add():", tensor_sum_add)

Can Rust achieve similar functionality? The answer is naturally yes. In Rust, operators are syntactic sugar for method calls. Therefore, many operators can be overloaded through traits. Most of the operators that can be overloaded are in the core::ops namespace. For more details, see this address: https://doc.rust-lang.org/core/ops/

Practice

Taking the above tensor as an example, let’s see how to implement the addition of two tensors.

First, define a rudimentary version of a tensor:

#[derive(Debug, PartialEq)]
struct Tensor {
    data: Vec<f64>,
}

impl Tensor {
    // Constructor: Create a new Tensor
    pub fn new(data: Vec<f64>) -> Self {
        Tensor { data }
    }

    // Get the length of the tensor
    pub fn len(&self) -> usize {
        self.data.len()
    }
}

Implement std::ops::Add:

use std::ops::Add;

// Implement tensor addition operation
// This function takes another tensor as a parameter and returns the result of adding the two tensors
// Note: This function assumes that the caller has already ensured that the two tensors have the same length
impl Add for Tensor {
    type Output = Self;

    fn add(self, other: Self) -> Self {
        // Check if the lengths of the two tensors are the same, if not, panic
        // This is an important precondition check to ensure the legality of subsequent operations
        if self.len() != other.len() {
            panic!("Tensors must have the same length to be added");
        }

        // Use iterators and zip method to add elements of the two tensors one by one
        // This demonstrates how to efficiently traverse two collections and perform element-level operations
        let result_data: Vec<f64> = self.data
            .iter()
            .zip(other.data.iter())
            .map(|(a, b)| a + b)
            .collect();

        // Create and return a new tensor containing the addition result
        // This is the final output of the entire addition operation
        Tensor::new(result_data)
    }
}

Example of usage:

fn main() {
    let tensor_a = Tensor::new(vec![1.0, 2.0, 3.0]);
    let tensor_b = Tensor::new(vec![4.0, 5.0, 6.0]);

    // Use + operator for addition
    let tensor_sum = tensor_a + tensor_b;
    println!("tensor_a + tensor_b = {:?}", tensor_sum);

    // Verify the result
    assert_eq!(tensor_sum, Tensor::new(vec![5.0, 7.0, 9.0]));
    println!("Test passed!");
}

Tensor and Scalar Operations

Operations between tensors and scalars can also be implemented:

// Implement Tensor + f32
impl Add<f32> for Tensor {
    type Output = Self;

    fn add(self, scalar: f32) -> Self {
        // Convert f32 to f64
        let scalar_f64 = scalar as f64;

        // Use iterators to add the scalar to each element
        let result_data: Vec<f64> = self.data
            .iter()
            .map(|&x| x + scalar_f64)
            .collect();

        Tensor::new(result_data)
    }
}

Example of usage:

fn main() {
    let tensor_a = Tensor::new(vec![1.0, 2.0, 3.0]);

    // Use + operator for addition
    let tensor_sum = tensor_a + 4.0;
    println!("tensor_a + 4.0f32 = {:?}", tensor_sum);

    // Verify the result
    assert_eq!(tensor_sum, Tensor::new(vec![5.0, 6.0, 7.0]));
    println!("Test passed!");
}

Reverse Operations

A question arises, the following is not possible:

let tensor_sum = 4.0 + tensor_a;

By default, Rust does not automatically handle reverse operations (i.e., f32 + Tensor), so we need to explicitly implement the Add<Tensor> trait for f32. For this, we can use std::ops::Add<Rhs = Tensor> to specify that the right-hand side operand type is Tensor.

// Implement f32 + Tensor
impl Add<Tensor> for f32 {
    type Output = Tensor;

    fn add(self, tensor: Tensor) -> Tensor {
        // Convert f32 to f64
        let scalar_f64 = self as f64;

        // Use iterators to add the scalar to each element
        let result_data: Vec<f64> = tensor.data
            .iter()
            .map(|&x| x + scalar_f64)
            .collect();

        Tensor::new(result_data)
    }
}

Conclusion

Directly overloading operators makes tensor operations more concise and more in line with mathematical semantics.