Apr 02, 2025
4 min read
Rust,
ONNX,
Pytorch,

Rust 实现 RMBG 推理

本文介绍了如何使用 Rust 实现 RMBG(移除图片背景)推理。通过加载 ONNX 模型,对输入图片进行预处理(resize、归一化),完成模型推理后生成掩码图,并基于掩码实现背景去除或透明度混合处理。最后优化了掩码反归一化算法,提升背景纯净度,适用于海报、游戏等领域。

玩过 comfyui 的人应该知道 RMBG 移除图片背景模型的惊艳效果。长期以来,它都是开源 matting 模型中效果最好的模型之一。RMBG 模型可以做到发丝级别的背景分割,这在海报、游戏和广告内容方向上有很大的使用。

RMBG 模型(权重开源,训练不开源)是开源的。地址如下:

https://huggingface.co/briaai/RMBG-2.0

Rust ? RMBG?

RMBG 有官方版本的 onnx 权重,我们可以使用 onnx 在 Rust 中实现 RMBG 的模型推理。推理开始前,我们需要理清 RMBG 的一些基本信息,方便我们如何去进行代码的编写。

RMBG 本质上就是一个分割模型。它的网络结构应该是基于 BiRefNet(https://github.com/ZhengPeng7/BiRefNet/blob/main/models/birefnet.py)。

分割模型最出名的架构是 以 unet 为代表的编码器-解码器结构。它的输入和输出的尺寸相对应,并且模型的输出是一般是一个二值掩码。

大概是这样(不是绝对的,要看模型和训练方法): 输入:[1,3,224,224] 输出:[1,1,224,224]

RMBG 的输入是:[1,3,1024,1024],输出是:[1,1,1024,1024]

知道了输入和输出,我们就可以知道前处理和后处理应该怎么做,大概流程如下:

  1. 加载模型。
  2. 加载图片,并把图片 resize 到[1024,1024]
  3. 图片归一化/标准化到: [0,255] =>[-1,1]
  4. 模型推理。
  5. 反归一化 [-1,1]=> [0,255]
  6. resize 尺寸[1024,1024]到原图上尺寸。
  7. 这时已经得到灰度的掩码图,可以按实现任务做处理。

步骤

以这张猫图片为例子:

![[cat1.jpeg]]

需要的库:

[dependencies]
ndarray = "0.16.1"
anyhow = "1.0.97"
ort = {version = "=2.0.0-rc.9", git="https://github.com/pykeio/ort", features = ["cuda"]}
image = "0.25.6"

加载模型

fn load_onnx<P: AsRef<Path>>(model_file: P) -> anyhow::Result<Session> {
    let model = Session::builder()?
        .with_intra_threads(4)?
        .with_execution_providers(vec![
            CPUExecutionProvider::default().build(),
            CUDAExecutionProvider::default().with_device_id(0).build()
        ])?
        .commit_from_file(model_file)?;

    Ok(model)
}

加载图片和 resize 需要用到 image-rs

// 加载图片
let original_img = image::open(image_path)?;
    let (img_width, img_height) = (original_img.width(), original_img.height());

resize 有多种算法,不同的算法有一些差异,这里使用 Linear Filter。

let img = original_img.resize_exact(1024, 1024, image::imageops::FilterType::Triangle);

归一化需要转到 ndarray来处理,这里的处理是先转到[0,1],再转到 [-1,1]

	let mut input = Array::zeros((1, 3, 1024, 1024));
    for pixel in img.pixels() {
		let x = pixel.0 as _;
		let y = pixel.1 as _;
		let [r, g, b, _] = pixel.2.0;

		input[[0, 0, y, x]] = ((r as f32) / 255. - 0.5) / 0.5;
		input[[0, 1, y, x]] = ((g as f32) / 255. - 0.5) / 0.5;
		input[[0, 2, y, x]] = ((b as f32) / 255. - 0.5) / 0.5;
	}

归一化完后前处理已经完成,接下来就是推理:

    let outputs = model.run(inputs![
        "pixel_values"=>TensorRef::from_array_view(&input)?
    ])?;
    let output = outputs["alphas"].try_extract_array::<f32>()?;

pixel_values 可以通过打印 model.inputs 获得。这样我们得到了模型输出的掩码向量。之后就是反向归一化到[0,1],再恢复到图片所需要的[0,255]

	// 反向操作, # Normalize to [0, 1]
    let output = output.map(|x| (x + 1.) * 0.5);

    // im_array 是一个灰度图片遮罩
    let im_array = output.map(|x| (x * 255.) as u8);

到这一步,我们就得到了 RMBG 分割图,转成灰度图片保存看看结果:

	let dim = im_array.dim();//dim = [1, 1, 1024, 1024]
    let height = dim[2];
    let width = dim[3];

	// get raw vec
    let (buf,_) = im_array.into_raw_vec_and_offset();
// 把原始向量转成灰度图片
    let gray = image::GrayImage::from_vec(width as u32, height as u32, buf).unwrap();
    
    // 把灰度图恢复成原图的尺寸
    // 实际就是 input 的反向操作,rmbg 的输入和输出尺寸相同,都是 1024x1024。本质上还是一个分割模型
    let gray = image::imageops::resize(&gray, img_width, img_height, image::imageops::FilterType::Nearest);
    let _ = gray.save("gray.png");

结果显示: ![[gray.png]]

后面我们就可以基于实际的任务对图片进行处理。比如我们想要基于掩码图扣出上面这只猫,我们可以基于掩码图进行简单的二值运算:

// 根据掩膜图像对原始图像进行二值化处理
// 此函数的目的是根据掩膜图像的灰度值,将原始图像的像素保留或设置为透明
// 参数:
//   original_img: &DynamicImage - 引用原始图像
//   mask: &GrayImage - 引用掩膜图像,应为灰度图像
// 返回值:
//   anyhow::Result<DynamicImage> - 返回处理后的图像,如果发生错误则返回Result中的错误
fn apply_mask(original_img: &DynamicImage, mask: &image::GrayImage) -> anyhow::Result<DynamicImage> {
    // 获取原始图像的尺寸
    let (width, height) = original_img.dimensions();
    // 创建一个与原始图像尺寸相同的可变图像,使用RGBA格式
    let mut cropped_img = DynamicImage::ImageRgba8(ImageBuffer::new(width, height));

    // 遍历原始图像的每个像素
    for (x, y, pixel) in original_img.pixels() {
        // 获取掩膜图像中对应位置的像素
        let mask_pixel = mask.get_pixel(x, y);
        // 获取掩膜像素的灰度值,假设掩膜为灰度图像
        let mask_value = mask_pixel[0];

        // 根据掩膜像素的灰度值决定是否保留原始像素
        if mask_value > 128 { // 如果灰度值大于128,保留原始像素
            cropped_img.put_pixel(x, y, pixel);
        } else { // 否则,将像素设置为透明
            cropped_img.put_pixel(x, y, Rgba([0, 0, 0, 0]));
        }
    }

    // 返回处理后的图像
    Ok(cropped_img)
}

效果如下: ![[cropped_image.png]]

基于二值判断边缘处理太粗糙了,我们再优化一下,这次基于混合的权重实现:

// 带透明度的权重混合
/**
 * 根据提供的掩膜(mask),将原始图像(original_img)与掩膜进行混合。
 * 掩膜的灰度值被用作混合的权重,以决定原始像素的透明度。
 * 
 * 参数:
 * - original_img: &DynamicImage - 原始图像的引用。
 * - mask: &image::GrayImage - 掩膜图像的引用,其灰度值用作混合权重。
 * 
 * 返回:
 * - anyhow::Result<DynamicImage> - 返回一个新的图像,其中原始图像的像素根据掩膜的权重进行了混合。
 *   如果发生错误,将返回一个错误结果。
 */
fn apply_mask2(original_img: &DynamicImage, mask: &image::GrayImage) -> anyhow::Result<DynamicImage> {
    // 获取原始图像的尺寸
    let (width, height) = original_img.dimensions();
    // 创建一个新的RGBA图像,用于后续的像素操作
    let mut cropped_img = DynamicImage::ImageRgba8(ImageBuffer::new(width, height));

    // 遍历原始图像的每个像素
    for (x, y, pixel) in original_img.pixels() {
        // 获取掩膜中对应位置的像素
        let mask_pixel = mask.get_pixel(x, y);
        // 将掩膜的灰度值归一化到[0, 1],作为混合权重
        let mask_value = mask_pixel[0] as f32 / 255.0;

        // 计算加权后的像素值
        let alpha = mask_value; // 使用掩膜值作为alpha透明度
        // 根据alpha值和原始像素值计算新的像素值
        let new_pixel = Rgba([
            ((pixel[0] as f32 * alpha) + (0.0 * (1.0 - alpha))) as u8,
            ((pixel[1] as f32 * alpha) + (0.0 * (1.0 - alpha))) as u8,
            ((pixel[2] as f32 * alpha) + (0.0 * (1.0 - alpha))) as u8,
            ((pixel[3] as f32 * alpha) + (0.0 * (1.0 - alpha))) as u8,
        ]);

        // 将计算得到的新像素值放置到新图像中
        cropped_img.put_pixel(x, y, new_pixel);
    }

    // 返回处理后的新图像
    Ok(cropped_img)
}

结果如下:

![[cropped_image2.png]]

改进

如果你使用其他复杂背景的话,会发现上面的代码的背景不够纯净,原因是掩码的反向归一化算法直接线性映射 (x+1)*0.5,可能保留无效中间值。

let output = output.map(|x| (x + 1.) * 0.5);

我们再优化一下,基于 (output - min)/(max - min),由于 rust 的 f32 未实现 std::cmp::Ord,无法直接取最小值和最大值,下面代码直接使用自定义实现。

let max = output.iter().filter(|&&x| !x.is_nan()).fold(f32::MIN, |a, &b| a.max(b));
    let min = output.iter().filter(|&&x| !x.is_nan()).fold(f32::MAX, |a, &b| a.min(b));
    let normalized = output.map(|x| (x - min) / (max - min));