People who have played with comfyui should know the astonishing effect of the RMBG image background removal model. For a long time, it has been one of the best-performing open-source matting models. The RMBG model can achieve hair-level background segmentation, which is very useful in poster design, gaming, and advertising.
The RMBG model (weights are open-source, training is not) is open-source. The address is as follows:
https://huggingface.co/briaai/RMBG-2.0
Rust? RMBG?
RMBG has official ONNX weights, allowing us to perform RMBG model inference in Rust using ONNX. Before starting the inference, we need to clarify some basic information about RMBG to help us write the code.
RMBG is essentially a segmentation model. Its network structure should be based on BiRefNet (https://github.com/ZhengPeng7/BiRefNet/blob/main/models/birefnet.py).
The most famous architecture for segmentation models is the encoder-decoder structure represented by U-Net. Its input and output dimensions correspond, and the model’s output is generally a binary mask.
It looks something like this (not absolute, depends on the model and training method):
Input: [1,3,224,224]
Output: [1,1,224,224]
RMBG’s input is: [1,3,1024,1024], output is: [1,1,1024,1024].
Knowing the input and output allows us to understand how pre-processing and post-processing should be done. The general process is as follows:
- Load the model.
- Load the image and resize it to
[1024,1024]. - Normalize/standardize the image from
[0,255]to[-1,1]. - Perform model inference.
- Denormalize from
[-1,1]to[0,255]. - Resize the dimensions from
[1024,1024]back to the original image size. - At this point, we already have a grayscale mask image, and further processing can be done according to the task.
Steps
Let’s take this cat image as an example:
![[cat1.jpeg]]
Required libraries:
[dependencies]
ndarray = "0.16.1"
anyhow = "1.0.97"
ort = {version = "=2.0.0-rc.9", git="https://github.com/pykeio/ort", features = ["cuda"]}
image = "0.25.6"
Load the model:
fn load_onnx<P: AsRef<Path>>(model_file: P) -> anyhow::Result<Session> {
let model = Session::builder()?
.with_intra_threads(4)?
.with_execution_providers(vec![
CPUExecutionProvider::default().build(),
CUDAExecutionProvider::default().with_device_id(0).build()
])?
.commit_from_file(model_file)?;
Ok(model)
}
Loading and resizing the image requires image-rs.
// Load the image
let original_img = image::open(image_path)?;
let (img_width, img_height) = (original_img.width(), original_img.height());
There are multiple algorithms for resizing, each with some differences. Here we use Linear Filter.
let img = original_img.resize_exact(1024, 1024, image::imageops::FilterType::Triangle);
Normalization needs to be handled using ndarray. Here, the process first converts to [0,1], then to [-1,1].
let mut input = Array::zeros((1, 3, 1024, 1024));
for pixel in img.pixels() {
let x = pixel.0 as _;
let y = pixel.1 as _;
let [r, g, b, _] = pixel.2.0;
input[[0, 0, y, x]] = ((r as f32) / 255. - 0.5) / 0.5;
input[[0, 1, y, x]] = ((g as f32) / 255. - 0.5) / 0.5;
input[[0, 2, y, x]] = ((b as f32) / 255. - 0.5) / 0.5;
}
After normalization, the pre-processing is complete, and next comes the inference:
let outputs = model.run(inputs![
"pixel_values"=>TensorRef::from_array_view(&input)?
])?;
let output = outputs["alphas"].try_extract_array::<f32>()?;
pixel_values can be obtained by printing model.inputs. This way, we get the model’s output mask vector. Afterward, we reverse normalize to [0,1] and restore it to [0,255] required by the image.
// Reverse operation, # Normalize to [0, 1]
let output = output.map(|x| (x + 1.) * 0.5);
// im_array is a grayscale image mask
let im_array = output.map(|x| (x * 255.) as u8);
At this step, we have obtained the RMBG segmentation map. Let’s convert it into a grayscale image and save it to check the result:
let dim = im_array.dim();//dim = [1, 1, 1024, 1024]
let height = dim[2];
let width = dim[3];
// get raw vec
let (buf,_) = im_array.into_raw_vec_and_offset();
// Convert the raw vector into a grayscale image
let gray = image::GrayImage::from_vec(width as u32, height as u32, buf).unwrap();
// Resize the grayscale image back to the original image size
// Essentially reversing the input operation; RMBG's input and output sizes are both 1024x1024. It's fundamentally still a segmentation model.
let gray = image::imageops::resize(&gray, img_width, img_height, image::imageops::FilterType::Nearest);
let _ = gray.save("gray.png");
Results shown: ![[gray.png]]
After this, we can process the image based on the actual task. For instance, if we want to extract the cat above based on the mask image, we can perform simple binary operations based on the mask image:
// Binarize the original image based on the mask image
// The purpose of this function is to retain or set the pixels of the original image to transparent based on the grayscale value of the mask image.
// Parameters:
// original_img: &DynamicImage - Reference to the original image
// mask: &GrayImage - Reference to the mask image, which should be a grayscale image
// Return value:
// anyhow::Result<DynamicImage> - Returns the processed image; returns an error if any occurs
fn apply_mask(original_img: &DynamicImage, mask: &image::GrayImage) -> anyhow::Result<DynamicImage> {
// Get the dimensions of the original image
let (width, height) = original_img.dimensions();
// Create a mutable image with the same dimensions as the original image, using RGBA format
let mut cropped_img = DynamicImage::ImageRgba8(ImageBuffer::new(width, height));
// Iterate through each pixel of the original image
for (x, y, pixel) in original_img.pixels() {
// Get the pixel at the corresponding position in the mask image
let mask_pixel = mask.get_pixel(x, y);
// Get the grayscale value of the mask pixel, assuming the mask is a grayscale image
let mask_value = mask_pixel[0];
// Decide whether to retain the original pixel based on the grayscale value of the mask pixel
if mask_value > 128 { // If the grayscale value is greater than 128, retain the original pixel
cropped_img.put_pixel(x, y, pixel);
} else { // Otherwise, set the pixel to transparent
cropped_img.put_pixel(x, y, Rgba([0, 0, 0, 0]));
}
}
// Return the processed image
Ok(cropped_img)
}
Effect shown below: ![[cropped_image.png]]
Based on the binary judgment, edge handling is too rough. Let’s optimize it again, this time based on weighted blending:
// Weighted blending with transparency
/**
* Blend the original image (original_img) with the provided mask.
* The grayscale values of the mask are used as blending weights to determine the transparency of the original pixels.
*
* Parameters:
* - original_img: &DynamicImage - Reference to the original image.
* - mask: &image::GrayImage - Reference to the mask image, whose grayscale values are used as blending weights.
*
* Returns:
* - anyhow::Result<DynamicImage> - Returns a new image where the original image's pixels have been blended according to the mask's weights.
* If an error occurs, it will return an error result.
*/
fn apply_mask2(original_img: &DynamicImage, mask: &image::GrayImage) -> anyhow::Result<DynamicImage> {
// Get the dimensions of the original image
let (width, height) = original_img.dimensions();
// Create a new RGBA image for subsequent pixel operations
let mut cropped_img = DynamicImage::ImageRgba8(ImageBuffer::new(width, height));
// Iterate through each pixel of the original image
for (x, y, pixel) in original_img.pixels() {
// Get the pixel at the corresponding position in the mask
let mask_pixel = mask.get_pixel(x, y);
// Normalize the grayscale value of the mask to [0, 1], serving as the blending weight
let mask_value = mask_pixel[0] as f32 / 255.0;
// Calculate the weighted pixel value
let alpha = mask_value; // Use the mask value as alpha transparency
// Calculate the new pixel value based on the alpha value and the original pixel value
let new_pixel = Rgba([
((pixel[0] as f32 * alpha) + (0.0 * (1.0 - alpha))) as u8,
((pixel[1] as f32 * alpha) + (0.0 * (1.0 - alpha))) as u8,
((pixel[2] as f32 * alpha) + (0.0 * (1.0 - alpha))) as u8,
((pixel[3] as f32 * alpha) + (0.0 * (1.0 - alpha))) as u8,
]);
// Place the calculated new pixel value into the new image
cropped_img.put_pixel(x, y, new_pixel);
}
// Return the newly processed image
Ok(cropped_img)
}
Results shown below:
![[cropped_image2.png]]
Improvements
If you use other complex backgrounds, you’ll find that the background in the code above isn’t pure enough. The reason is that the reverse normalization algorithm of the mask directly maps (x+1)*0.5, potentially retaining invalid intermediate values.
let output = output.map(|x| (x + 1.) * 0.5);
We optimize it again, based on (output - min)/(max - min). Since Rust’s f32 does not implement std::cmp::Ord, we cannot directly obtain minimum and maximum values, so the following code uses a custom implementation.
let max = output.iter().filter(|&&x| !x.is_nan()).fold(f32::MIN, |a, &b| a.max(b));
let min = output.iter().filter(|&&x| !x.is_nan()).fold(f32::MAX, |a, &b| a.min(b));
let normalized = output.map(|x| (x - min) / (max - min));