之前从零开始完整地实现了 Yolov10 模型,并成功在 CPU 上运行。但是使用 cuda 加速时,会出现了以下错误:
DriverError(CUDA_ERROR_INVALID_VALUE, "invalid argument")
通过错误堆栈信息排查,最终定位到了 topk 函数上。
具体来说, yolov10 在 v10postprocess 模块中,使用了两次 topk 函数。
第一个 topk 主要用于选择置信度最高的检测结果:
max_scores, index = torch.topk(max_scores, max_det, dim=-1)
- 对每个预测框的最大类别得分(max_scores)进行排序.
- 选择得分最高的前 max_det 个预测框.
- 减少需要处理的数量,提高效率.
第二个 topk 主要在已经筛选出的 max_det 个预测框中,再次按得分排序。
scores, index = torch.topk(scores.flatten(1), max_det, dim=-1)
在 candle 框架上, topk 函数大体实现如下:
pub trait TopKLastDimOp {
/// Note: this implements torch.topk with sorted=True.
fn topk(&self, topk: usize) -> Result<TopKOutput>;
/// Note: this implements torch.topk with sorted=False.
fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput>;
}
impl TopKLastDimOp for Tensor {
fn topk(&self, topk: usize) -> Result<TopKOutput> {
// Sorted descending
let sorted_indices = self.arg_sort_last_dim(false)?;
let topk_indices = sorted_indices.narrow(D::Minus1, 0, topk)?.contiguous()?;
Ok(TopKOutput {
values: self.gather(&topk_indices, D::Minus1)?,
indices: topk_indices,
})
}
fn topk_unsorted(&self, topk: usize) -> Result<TopKOutput> {
// Sorted descending
let sorted_indices_all = self.arg_sort_last_dim(false)?;
let topk_indices_sorted = sorted_indices_all
.narrow(D::Minus1, 0, topk)?
.contiguous()?;
let topk_values_sorted = self.gather(&topk_indices_sorted, D::Minus1)?;
// Reorder the indices ascending
let reorder_indices = topk_indices_sorted.arg_sort_last_dim(true)?;
let topk_indices_unsorted = topk_indices_sorted.gather(&reorder_indices, D::Minus1)?;
let topk_values_unsorted = topk_values_sorted.gather(&reorder_indices, D::Minus1)?;
Ok(TopKOutput {
values: topk_values_unsorted,
indices: topk_indices_unsorted,
})
}
}
topk 本身没什么问题,有问题的是 topk 实现里的 arg_sort_last_dim 函数。在第一个topk里,要处理的max_det 形状是 [1,8400]。这是一个非常大的张量。但是 arg_sort_last_dim 在 cuda 上,不支持这到大的张量尺寸。
通过下面的测试发现, arg_sort_last_dim 函数甚至无法处理大于 1024 的张量。
fn main() {
let a = Tensor::zeros(
1025,
DType::F32,
&Device::cuda_if_available(0).unwrap(),
)
.unwrap();
dbg!(&a.arg_sort_last_dim(true));
}
事实上,arg_sort_last_dim 这个问题很早就有人提过了,但是几乎没有人解决。
所以暂时无解。