Candle 是由 Hugging Face 开发的一个极简的机器学习框架,它专为 Rust 语言打造,旨在提供高性能和易用性的完美结合。
Candle 的核心目标之一是使无服务器推理成为可能。这意味着开发者可以更容易地将机器学习模型部署到云端,而无需担心底层基础设施的管理。
本文详细说明如何将 PyTorch 模型代码转换为 Rust Candle 代码。
进行模型代码转换需要掌握以下几点:
- 掌握 PyTorch 和 Candle 的对应 API
- 理解模型的结构
- 验证转换后的模型输出与原模型一致
Pytorch vs Candle
以下是一些常见的等效代码对照。这个对照表虽然不完整,但可以作为参考。要了解更多等效写法,需要查看源代码或官方示例。这也是一个难点,因为目前 Candle 官方尚未制定相关文档计划。
| Using PyTorch | Using Candle | |
|---|---|---|
| Creation | torch.Tensor([[1, 2], [3, 4]]) | Tensor::new(&[[1f32, 2.], [3., 4.]], &Device::Cpu)? |
| Creation | torch.zeros((2, 2)) | Tensor::zeros((2, 2), DType::F32, &Device::Cpu)? |
| Indexing | tensor[:, :4] | tensor.i((.., ..4))? |
| Operations | tensor.view((2, 2)) | tensor.reshape((2, 2))? |
| Operations | a.matmul(b) | a.matmul(&b)? |
| Arithmetic | a + b | &a + &b |
| Device | tensor.to(device="cuda") | tensor.to_device(&Device::new_cuda(0)?)? |
| Dtype | tensor.to(dtype=torch.float16) | tensor.to_dtype(&DType::F16)? |
| Saving | torch.save({"A": A}, "model.bin") | candle::safetensors::save(&HashMap::from([("A", A)]), "model.safetensors")? |
| Loading | weights = torch.load("model.bin") | candle::safetensors::load("model.safetensors", &device) |
更多的等效代码我会在后面的例子中给出来。
理解模型的结构 (Model Structure)
我们只有打印出模型的结构,才知道一个模型到底有哪些模块。在pytorch 中,使用print即可打印出模型结构:
model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
print(model)
前面的文章中,我们使用 mobilenetv2 作为尝试对象,这一次,我将会使用ResNet 作为例子。
ResNet即“残差网络”(Residual Network) 是何凯明在2015年于微软研究院提出的一种深度卷积神经网络架构。它的最大特点是引入了“残差块”(residual block)的概念,通过引入捷径连接(shortcut connections)或跳跃连接(skip connections),使得网络能够学习输入与输出之间的映射差异,即残差(residual)。这样,即使网络很深,也能有效训练。
就像 Transformer 是大模型的基石,ResNet 的设计理念也影响了几乎所有现代传统 CNN 网络的架构。我们可以在很多CNN模型中看到 ResNet的身影。
Pytorch 官方实现的Resnet 地址在这里:
https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
与MobileNet 不同的是,Resnet 是在Candle 里是有官方版本实现的。Candle 的官方版本可以看这里:
https://github.com/huggingface/candle/blob/main/candle-transformers/src/models/resnet.rs
有意思的是,Candle 官方的实现,并不是基于Pytorch 的代码参考,而是同一个作者的另一项工作:基于 Pytorch 的C++ 版本 libtorch 的 Rust 绑定:
https://github.com/LaurentMazare/tch-rs/blob/main/src/vision/resnet.rs
Candle 官方的 ResNet 实现广泛使用了闭包进行代码封装。我的实现版本早于官方版本,并同时参考了 PyTorch 官方版本和 tch 版本,因此两个版本的实现方式存在一些差异。
Resnet的模型结构
Resnet 有几个版本,这里会以Resnet 18 和Resnet 50为例子,下面我们看看Resnet18的结构:
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(1): BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer2): Sequential(
(0): BasicBlock(
(conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer3): Sequential(
(0): BasicBlock(
(conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(layer4): Sequential(
(0): BasicBlock(
(conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(downsample): Sequential(
(0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(1): BasicBlock(
(conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=512, out_features=1000, bias=True)
)
Resnet50 与 Resnet 15 的差别在于 BasicBlock 块的不同:
Bottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
可以看到,Resnet 50 采用了更为复杂的Bottleneck结构,每个残差块包含三个卷积层:首先是一个1x1的卷积层用于减少通道数,接着是一个3x3的卷积层进行特征提取,最后再用一个1x1的卷积层恢复原来的通道数。
只要我们在实现时正确区分这两种结构,就可以分别实现 Resnet18 和 Resnet50。
Candle 中的特殊部分
VarBuilder
VarBuilder 是 candle 库中的一个工具,用于帮助构建和管理神经网络模型的变量(如权重和偏置)。它在创建层(例如卷积层、全连接层等)时非常有用,因为它可以方便地初始化这些层的参数,并且支持从预训练模型加载权重。
以下是一个示例,其中 maxpool 是模型的一个变量。在 candle 初始化时,我们需要使用 VarBuilder 来绑定这个变量:
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
vb.pp("conv1")
如果绑定有问题,会得到下面这个错误:
Error: WithBacktrace { inner: UnexpectedShape { msg: "shape mismatch for layer2.0.conv2.weight", expected: [128, 64, 3, 3], got: [128, 128, 3, 3] },
我们还需要使用它来加载权重:
let model_file = "./testdata/resnet18.safetensors";
let device = candle_core::Device::Cpu;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
这个方法需要在 unsafe 块中调用,原因是 VarBuilder::from_mmaped_safetensors 使用了 memmap2::MmapOptions,而 memmap2::MmapOptions 中所有基于文件的内存映射构造函数都被标记为 unsafe,因为如果底层文件在映射后被修改(无论是内部还是外部进程),可能会导致未定义行为(Undefined Behavior, UB)。应用程序必须考虑这种风险,并采取适当的预防措施,例如使用文件权限、锁或者进程私有的(如已取消链接的)文件等。
由于篇幅限制,我会在《Rust AI 进阶(下): PyTorch 模型转换为 Rust candle 模型教程》中详细讲解每个模块的具体实现。