Oct 13, 2025
13 min read
candle,
Rust,

YOLOv10 in Candle

Porting YOLOv10 to the Candle framework

Porting YOLOv10 to Candle isn’t necessarily complicated, but the main difficulty comes from the poor readability of the ultralytics project itself. In order to accommodate different variants such as YOLOv5-v12, YOLOE, and YOLO-World, the ultralytics codebase is filled with numerous if statements, making the code very chaotic.

Personally, I don’t quite agree with this code organization approach - it’s just too messy. Would it be better to sacrifice some abstraction for better readability? After all, there’s not much connection between different versions of the model.

However, after some effort, I successfully ported YOLOv10 to the Rust AI framework Candle.

Analysis

ultralytics defines the model structure through configuration files, then loads the configuration file with yaml_model_load(cfg) to parse the model. Here we’ll use YOLOv10s as an example because it has a smaller size and acceptable accuracy.

The structure definition for YOLOv10s is as follows:

backbone:
  # [from, repeats, module, args]
  - [-1, 1, Conv, [64, 3, 2]] # 0-P1/2
  - [-1, 1, Conv, [128, 3, 2]] # 1-P2/4
  - [-1, 3, C2f, [128, True]]
  - [-1, 1, Conv, [256, 3, 2]] # 3-P3/8
  - [-1, 6, C2f, [256, True]]
  - [-1, 1, SCDown, [512, 3, 2]] # 5-P4/16
  - [-1, 6, C2f, [512, True]]
  - [-1, 1, SCDown, [1024, 3, 2]] # 7-P5/32
  - [-1, 3, C2fCIB, [1024, True, True]]
  - [-1, 1, SPPF, [1024, 5]] # 9
  - [-1, 1, PSA, [1024]] # 10

# YOLOv10.0n head
head:
  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 6], 1, Concat, [1]] # cat backbone P4
  - [-1, 3, C2f, [512]] # 13

  - [-1, 1, nn.Upsample, [None, 2, "nearest"]]
  - [[-1, 4], 1, Concat, [1]] # cat backbone P3
  - [-1, 3, C2f, [256]] # 16 (P3/8-small)

  - [-1, 1, Conv, [256, 3, 2]]
  - [[-1, 13], 1, Concat, [1]] # cat head P4
  - [-1, 3, C2f, [512]] # 19 (P4/16-medium)

  - [-1, 1, SCDown, [512, 3, 2]]
  - [[-1, 10], 1, Concat, [1]] # cat head P5
  - [-1, 3, C2fCIB, [1024, True, True]] # 22 (P5/32-large)

  - [[16, 19, 22], 1, v10Detect, [nc]] # Detect(P3, P4, P5)

This is a 23-layer model structure definition.

YOLOv10 is an NMS-Free model. Specifically, it combines one-to-many (one2many) and one-to-one (one2one) strategies during training. For training, it uses one-to-many, and for inference, it uses one-to-one. Therefore, we don’t need to worry about the (one2many) part.

Implementation Approach

First, YOLOv8 has a pure Rust implementation, with the code available here https://github.com/huggingface/candle. Unfortunately, the YOLOv8 in the official Candle project wasn’t implemented based on the ultralytics project, probably due to licensing issues. As a result, we find that the model node names in Candle’s YOLOv8 completely mismatch those in ultralytics, and we cannot directly export ultralytics models for inference with Candle (though modifying the nodes would satisfy the requirements).

However, some models can be directly reused. Therefore, our porting work includes approximately:

  1. Implement exporting safetensors weights from ultralytics
  2. Implement modules not present in YOLOv8, such as SCDown, C2fCIB, PSA, v10Detect, etc.
  3. Adapt the weight node names from ultralytics
  4. Implement post-processing for the v10 version

Exporting Weights

Candle uses the safetensors weight format, so we first need to convert the weights exported by ultralytics to safetensors.

Core code as follows:

    model = YOLO(model_path)
    print(model.model)

    tensors = model.model.state_dict() # type: ignore

    for k, v in tensors.items():
        print(str(k), v.shape)

    # Save as safetensors format
    save_model(model.model, output_path) # type: ignore

We need to save the outputs of 2 prints: one is the shape of the weights, and the other is the keys of the weights. This is mainly used for adapting the node names of the weights later. This is very important.

Implementing Modules

SCDown and C2fCIB are relatively normal convolutional modules, so we won’t elaborate.

PSA stands for Position-Sensitive Attention, a neural network module that implements a position-sensitive attention mechanism to enhance feature extraction and processing capabilities. Before diving deeper, I thought YOLOv10 didn’t have attention modules. During porting, I realized my understanding was wrong. However, yolov10 only has this one module, unlike yolov12 where attention mechanisms run throughout the entire model.

The core of PSA is the Attention module. Multi-head attention mechanisms enhance expressive power. The Rust implementation is as follows:

#[derive(Clone, Debug)]
pub struct Attention{
    qkv: ConvBlock,
    proj: ConvBlock,
    pe: ConvBlock,
    num_heads: usize,
    key_dim: usize,
    scale: f64,
    head_dim: usize,
}

impl Attention {
    /// num_heads=8, attn_ratio=0.5
    pub fn load(vb:VarBuilder,dim:usize,num_heads:usize,attn_ratio:f64)->Result<Self> {
        let head_dim = dim / num_heads;
        let key_dim = (head_dim as f64 * attn_ratio) as usize;
        let scale = (key_dim as f64).powf(-0.5);
        let nh_kd = key_dim * num_heads;
        let h = dim + nh_kd * 2;

        let qkv = ConvBlock::load(vb.pp("qkv"), dim, h, 1, 1, None, None, false)?;
        let proj = ConvBlock::load(vb.pp("proj"), dim, dim, 1, 1, None, None, false)?;
        let pe = ConvBlock::load(vb.pp("pe"), dim, dim, 3, 1, None, Some(dim), false)?;

        Ok(
            Self {
                qkv,
                proj,
                pe,
                num_heads,
                key_dim,
                scale,
                head_dim,
            }
        )
    }
}

Other modules don’t have anything particularly special.

Implementation Issues Encountered

Candle and PyTorch are ultimately different, and many operation APIs differ. I’ve previously documented some of these differences. During this porting process, I encountered some issues again.

nn.ModuleList

This module is a container that doesn’t exist in Candle. I typically prefer to use Vec<Box<dyn Module>> to replace it.

For example, Python version:

self.m = nn.ModuleList(CIB(self.c, self.c, shortcut, e=1.0, lk=lk) for _ in range(n))

Rust version:

let mut cib = Vec::with_capacity(n);
for idx in 0..n {
    // CIB(self.c, self.c, shortcut, e=1.0, lk=lk)
    let b = CIB::load(vb.pp(format!("m.{idx}")), c, c, shortcut, 1f64, lk)?;
    cib.push(b)
}

torch.view

Candle doesn’t have the view function, so we need to use the reshape function instead.

split

Similarly, Candle doesn’t have the split function. The most commonly used way in YOLOv10 is to split by specifying a list of dimensions, for example:

q,k,v = x.split([self.key_dim, self.key_dim, self.head_dim], dim=2)

In Rust, the similar tensor splitting function I could think of is narrow:

let q = rs.narrow(2, 0, self.key_dim)?;  // Starting from position 0 of dim=2, take key_dim elements
let k = rs.narrow(2, self.key_dim, self.key_dim)?;  // Starting from position key_dim of dim=2, take key_dim elements
let v = rs.narrow(2, self.key_dim * 2, self.head_dim)?;  // Starting from position key_dim*2 of dim=2, take head_dim elements

Matrix multiplication operations

Python can implement matrix multiplication like k @ v, but in Rust, we need to use the matmul function.

let kv = k.matmul(v)?;

nn.MaxPool2d with padding

nn.MaxPool2d with padding, such as nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2), can be implemented in Rust as:

x
  .pad_with_zeros(2, self.k / 2, self.k / 2)?
  .pad_with_zeros(3, self.k / 2, self.k / 2)?
  .max_pool2d_with_stride(self.k, 1)?;

v10postprocess

The most difficult part to port in YOLOv10 is actually the v10postprocess module.

First, Candle doesn’t have topk. Fortunately, DeepSeek’s Candle implementation has a topk function that’s close to PyTorch’s topk function and can meet the requirements.


pub struct TopKOutput {
    pub values: Tensor,
    pub indices: Tensor,
}
pub trait TopKLastDimOp {
    fn topk(&self, topk: usize) -> Result<TopKOutput>;
}

impl TopKLastDimOp for Tensor {
    fn topk(&self, topk: usize) -> Result<TopKOutput> {
        // Sorted descending
        let sorted_indices = self.arg_sort_last_dim(false)?;
        // Get the size of the last dimension
        let last_dim_size = sorted_indices.dim(D::Minus1)?;
        // Ensure it doesn't exceed the actual size of the last dimension, consistent with PyTorch's torch.topk behavior
        let actual_topk = topk.min(last_dim_size);
        let topk_indices = sorted_indices
            .narrow(D::Minus1, 0, actual_topk)?
            .contiguous()?;
        Ok(TopKOutput {
            values: self.gather(&topk_indices, D::Minus1)?,
            indices: topk_indices,
        })
    }
}

Secondly, I couldn’t find a modulo operation (remainder) in Candle. It’s the following operation:

result = index % nc  # Calculate remainder of index ÷ nc

And this is a modulo operation that supports broadcasting mechanisms. My implementation is as follows:

pub trait TensorRemOps {
    fn broadcast_rem(&self, other: &Tensor) -> Result<Tensor>;
}
impl TensorRemOps for Tensor {
    fn broadcast_rem(&self, other: &Tensor) -> Result<Tensor> {
        // Get the broadcasted shape
        let broadcast_shape = broadcast_shape(self.shape(), other.shape())?;

        // Expand both tensors with broadcasting
        let self_expanded = self.expand(&broadcast_shape)?;
        let other_expanded = other.expand(&broadcast_shape)?;

        // Convert to 2D arrays for element-wise modulo operation
        let self_data = self_expanded.to_vec2::<u32>()?;
        let other_data = other_expanded.to_vec2::<u32>()?;

        // Perform element-wise modulo operation
        let result: Vec<Vec<u32>> = self_data
            .into_iter()
            .zip(other_data.into_iter())
            .map(|(row1, row2)| {
                row1.into_iter().zip(row2.into_iter())
                    .map(|(a, b)| a % b)
                    .collect()
            })
            .collect();

        // Flatten the result and reshape to the original shape
        let flat_result: Vec<u32> = result.into_iter().flatten().collect();
        Tensor::from_vec(flat_result, &broadcast_shape, self.device())
    }
}

At this point, structural obstacles are basically resolved.

Next is to stack the modules according to the original requirements like building blocks.

By observing the YOLOv10 model structure with print(model.model), we find it’s a very flat structure. From layers 0 to 23, organizing the code in Rust’s way isn’t very friendly. Therefore, I divided YOLOv10’s major modules into three parts: backbone, neck, and head.

  • backbone: Layers 0-10
  • neck: Layers 11-22
  • head: Layer 23
pub struct YoloV10 {
    backbone: Backbone,
    neck: YoloNeck,
    head: V10DetectionHead,
}

The reason we can organize the code differently from the original version is that v.pp() supports dot syntax for value retrieval, such as model.model.0.0.0.weight.

This organization method also facilitates weight value alignment with the PyTorch version.

Aligning Nodes

Node alignment mainly relies on three aspects:

  1. yolov10s.yaml file for viewing overall structure
  2. print(model.model) for aligning node hierarchy
  3. for k, v in tensors.items(): \print(str(k), v.shape) for aligning node details

Writing a test case:

  let vb = unsafe {
      VarBuilder::from_mmaped_safetensors(
          &["yolov10s.safetensors"],
          DType::F32,
          &device,
      )
  }?;
  let xs = vec![1f32; 640 * 640 * 3];
  let image_t = candle_core::Tensor::from_vec(xs, (1, 3, 640, 640), &device)?;
  let output = yolo.forward(&image_t)?;

If we run the above code without errors, congratulations! The model nodes are aligned!

Detection Head

As mentioned earlier, YOLOv10 has two strategies: one-to-many (one2many) and one-to-one (one2one). During the training phase, it returns:

if self.training:  # Training path
    return {"one2many": x, "one2one": one2one}

But we’re only doing inference, and during inference, we’re actually processing the one2one branch. So the one2many content is not needed.

Post-processing

When we originally implemented YOLOv10 with ONNX, post-processing was very simple. That’s because when exporting to ONNX, the v10postprocess module was also included in the ONNX computation graph. Actually, counting the v10postprocess module, I believe YOLOv10’s post-processing is more complex than YOLOv8. As shown in the implementation details above, we directly encountered 2 operations that don’t exist in Candle.

In Candle, we obviously have to implement the v10postprocess module ourselves.

fn v10postprocess(preds: &Tensor, max_det: usize, nc: usize) -> Result<Tensor> {
    let preds_shape = preds.dims();
    assert!(4 + nc == preds_shape[preds_shape.len() - 1]);

    // Split boxes and scores
    let boxes = preds.i((.., .., ..4))?;
    let scores = preds.i((.., .., 4..))?;

    let amax_scores = scores.max(D::Minus1)?;

    // max_scores, index = torch.topk(max_scores, max_det, dim=-1)
    let TopKOutput {
        values: _max_scores,
        indices: topk_indices,
    } = amax_scores.topk(max_det)?;

    let index = topk_indices.unsqueeze(D::Minus1)?; // Equivalent to index.unsqueeze(-1)
    let boxes = boxes.contiguous()?.gather(&index.repeat((1, 1, 4))?, 1)?;
    let scores = scores.contiguous()?.gather(&index.repeat((1, 1, nc))?, 1)?;

    // scores, index = torch.topk(scores.flatten(1), max_det, dim=-1)
    let scores_flat = scores.flatten(1, 2)?;
    let TopKOutput {
        values: scores,
        indices: index,
    } = scores_flat.topk(max_det)?;

    // println!("index: {:?}",index.shape());//boxes: [1, 300]
    // println!("scores: {:?}",scores.shape());//scores: [1, 300]

    let nc_tensor = Tensor::from_slice(&[nc as u32], 1, scores.device())?;
    let index_div = index.broadcast_div(&nc_tensor)?;

    // Use gather instead of advanced indexing boxes[i, index // nc]
    let boxes_indices = index_div.unsqueeze(2)?.repeat((1, 1, 4))?; // [batch_size, max_det, 4]
    let boxes_gathered = boxes.gather(&boxes_indices, 1)?; // [batch_size, max_det, 4]
    // scores[..., None] - Add new axis
    let scores_expanded = scores.unsqueeze(2)?; // [batch_size, max_det, 1]

    // (index % nc)[..., None].float() - Modulo and add new axis
    let index_mod = index.broadcast_rem(&nc_tensor)?; // index % nc
    let index_mod_expanded = index_mod.unsqueeze(2)?.to_dtype(DType::F32)?; // [batch_size, max_det, 1]

    // Concatenate all tensors on the last dimension
    // python: torch.cat([boxes[i, index // nc], scores[..., None], (index % nc)[..., None].float()], dim=-1)
    // The last dimension should be 4 + 1 + 1 = 6
    let result = Tensor::cat(&[&boxes_gathered, &scores_expanded, &index_mod_expanded], 2)?;

    Ok(result)
}

It’s important to note that Candle’s gather requires the input to be a contiguously stored tensor. Therefore, in v10postprocess, we need to use the contiguous() method to ensure the tensor is stored continuously in memory.

Aligning Results

Since ultralytics code is quite complex, it’s actually not very feasible to compare results using the same image for validation. The reason is that it’s difficult to achieve the same preprocessing as ultralytics. To reduce errors caused by preprocessing, it’s best to ensure the tensor values input to the model are consistent.

Fortunately, ultralytics supports pure tensor input, so we won’t enter the preprocessing steps.

one = torch.ones(1, 3, 640, 640)
results = model(one)

Similarly, in Candle, we can implement it as:

let xs = vec![1f32; 640 * 640 * 3];
let image_t = candle_core::Tensor::from_vec(xs, (1, 3, 640, 640), &device)?;

How to compare the outputs of both?

This is mainly an issue with ultralytics. There are two approaches. The first is to use hook functions:

# Define hook function
def hook(module, input, output):
    print(f"Layer: {module.__class__.__name__}")
    print(f"Input shape: {input[0].shape}")
    print(f"Output shape: {output.shape}")
    print(f"Output sample: {output}")
    return output

layer = model.model.model[0]  # type: ignore # model.model is the actual network structure
handle = layer.register_forward_hook(hook)

In model.model.model[0], 0 represents the first layer of the model. YOLOv10s has 24 layers [0-23].

The second approach is to find the main function that executes layers in ultralytics. This can be found in ultralytics/nn/tasks.py in _predict_once.

# Line 180
 x = m(x)  # run
 if m.i == 0:
    print("x::",x)

m.i is also the layer number of the model [0-23].

Error

When aligning results, it’s not necessary to require completely identical values. In fact, this is completely impossible. Errors mainly come from hardware differences (CPU, GPU), floating-point precision differences, framework differences, operator implementation differences, etc. As the model’s derivative deepens, errors will become larger and larger. For example, in my current Rust implementation, after post-processing, the error has reached the order of magnitude of 0.1.

Below is the effect of the Candle implementation, which is basically consistent with ultralytics. res

If you are interested in this implementation, you can view the source code here: https://github.com/kingzcheung/yolov10