Due to the length, this tutorial is divided into two parts.
Sequential
Sequential in PyTorch is a container module that simply encapsulates multiple layers in the order they are added, forming a linear stack. Using Sequential simplifies model definition, especially when the network structure is a simple linear sequence where each layer’s output directly serves as the input for the next layer without branches or more complex connections.
Candle does not have this module, so we need to implement it ourselves. Below is my implementation:
use candle_core::{ Module, Tensor, Result };
#[derive(Debug, Clone)]
// Define a generic struct Sequential where T must implement the Module trait
pub struct Sequential<T: Module> {
// A vector containing multiple layers of type T
layers: Vec<T>,
}
// Define a constructor function seq to create a Sequential instance
pub fn seq<T: Module>(cnt: usize) -> Sequential<T> {
// Initialize an empty vector or a vector with specified capacity based on cnt
let v = if cnt == 0 { vec![] } else { Vec::with_capacity(cnt) };
// Return a Sequential instance with layers initialized to v
Sequential { layers: v }
}
// Implement methods for the Sequential struct
impl<T: Module> Sequential<T> {
// Return the number of layers in the sequence
pub fn len(&self) -> usize {
self.layers.len()
}
// Check if the sequence is empty
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
// Add a layer to the sequence
pub fn push(&mut self, layer: T) {
self.layers.push(layer);
}
// Add a layer to the sequence (same functionality as push method)
pub fn add(&mut self, layer: T) {
self.layers.push(layer);
}
}
// Implement the Module trait for the Sequential struct
impl<T: Module> Module for Sequential<T> {
// Forward propagation method, accepting a tensor and returning a result tensor
fn forward(&self, xs: &candle_core::Tensor) -> Result<Tensor> {
// Clone the input tensor
let mut xs = xs.clone();
// Iterate through all layers and perform forward propagation on the input tensor
for layer in self.layers.iter() {
xs = xs.apply(layer)?;
}
// Return the final result tensor
Ok(xs)
}
}
This is a generic module that can be used by many models.
1x1 Convolution (2D Convolution)
A 1x1 convolution corresponds to the following structure, which has three input channels, 64 output channels, and a 7x7 kernel size.
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
In Candle, there is a corresponding API: candle_nn::conv2d_no_bias.
/// Create a 2D convolutional layer (Conv2d).
///
/// This function creates a 2D convolutional layer without a bias term using the specified input channels, output channels, kernel size, padding, stride, and variable builder.
///
/// # Arguments
/// - `in_planes` (usize): Number of input channels. For RGB images, it is usually 3; for grayscale images, it is usually 1.
/// - `out_planes` (usize): Number of output feature maps. Determines how many feature maps the convolutional layer generates.
/// - `ksize` (usize): Kernel size. For example, 7 represents a 7x7 kernel.
/// - `padding` (usize): Amount of zero-padding added to the boundaries of the input image. This helps maintain the output feature map size similar to the input.
/// - `stride` (usize): Stride of the convolution kernel over the input image. A larger stride reduces the size of the output feature map.
/// - `vb` (VarBuilder): Variable builder used to initialize the weights in the convolutional layer. `VarBuilder` typically provides functionality to load weights from pre-trained models or random initialization.
///
/// # Returns
/// - `Result<Conv2d>`: If the convolutional layer is successfully created, returns a `Conv2d` instance; otherwise, returns an error message.
///
/// # Example
/// ```rust
/// let vb = VarBuilder::new(); // Assume this is how to initialize the variable builder
/// let conv_layer = conv2d(3, 64, 7, 3, 2, vb).unwrap();
/// ```
fn conv2d(
in_planes: usize,
out_planes: usize,
ksize: usize,
padding: usize,
stride: usize,
vb: VarBuilder
) -> Result<candle_nn::Conv2d> {
// Create a Conv2dConfig instance and set stride and padding. Other configuration items use default values.
let conv2d_cfg = candle_nn::Conv2dConfig {
stride,
padding,
..Default::default()
};
// Use the `conv2d_no_bias` function provided by the candle_nn library to create a 2D convolutional layer without a bias term.
// This function receives the number of input channels, output channels, kernel size, configuration, and variable builder as parameters.
candle_nn::conv2d_no_bias(in_planes, out_planes, ksize, conv2d_cfg, vb)
}
Batch Normalization 2D (BatchNorm2d)
BatchNorm2d is typically placed after a convolutional layer and before an activation function to improve gradient propagation during training and accelerate network convergence.
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
In Candle, there is an equivalent operation API:
nn:batch_norm(out_planes, 1e-5, vb.pp("bn1"))?;// out_planes=64
ReLU Activation Function
The ReLU formula is as follows: f(x) = max(0, x)
(relu): ReLU(inplace=True)
In Candle, relu does not need to be defined as a structure; it can be called directly within the forward method.
A simple example:
impl Module for Block {
fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
let ys = xs.relu()?;
Ok(ys)
}
}
MaxPool2d Max Pooling
Max pooling is mainly used to reduce computational load, increase receptive field, and enhance features.
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
Similarly, in Candle, max pooling does not require defining a structure; it can be called directly within the forward method.
let xs = xs.max_pool2d_with_stride(3, 2)?; // kernel_size=3 stride=2
As mentioned earlier, the main difference between ResNet18 and ResNet50 lies in the depth of the model network, primarily in the design of residual blocks (residual block) and changes in the number of channels.
Residual Block Design
ResNet18 uses the BasicBlock structure, where each residual block contains two 3x3 convolutional layers followed by BN layers and ReLU activation functions.
BasicBlock(
(conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
We define the BasicBlock structure using Candle:
/// Basic residual block structure
/// Contains two convolutional layers and two batch normalization layers
/// Includes an optional downsampling layer if input and output channel numbers differ or stride is not 1
#[derive(Debug, Clone)]
pub struct BasicBlock {
/// First convolutional layer
conv1: nn::Conv2d,
/// Batch normalization layer after the first convolutional layer
bn1: nn::BatchNorm,
/// Second convolutional layer
conv2: nn::Conv2d,
/// Batch normalization layer after the second convolutional layer
bn2: nn::BatchNorm,
/// Optional downsampling layer to adjust the size and number of channels of feature maps
downsample: Option<Downsample>,
}
Downsample
downsample is optional because some layers include a downsampling layer. This depends on the input and output channel numbers and stride.
/// Downsample function for input feature maps
/// Uses a 1x1 convolution to adjust the number of channels and perform downsampling when input and output channel numbers differ or stride is not 1
///
/// # Arguments
/// - `in_planes`: Number of input feature map channels
/// - `out_planes`: Number of output feature map channels
/// - `stride`: Convolution stride
/// - `vb`: Variable builder used to initialize convolution and batch normalization layers
///
/// # Returns
/// - If stride is not 1 or input and output channel numbers differ, returns Some(Downsample) instance
/// - Otherwise, returns None, indicating no downsampling is needed
fn downsample(in_planes: usize, out_planes: usize, stride: usize, vb: VarBuilder) -> Result<Option<Downsample>> {
// Check if downsampling is required
if stride != 1 || in_planes != out_planes {
// Use a 1x1 convolution to adjust the number of channels and perform downsampling
let conv = conv2d(in_planes, out_planes, 1, 0, stride, vb.pp(0))?;
// Use batch normalization to stabilize training
let bn = batch_norm(out_planes, 1e-5, vb.pp(1))?;
// Return the downsampling module
Ok(
Some(Downsample{ conv2d: conv, bn2: bn, in_planes, out_planes, stride})
)
} else {
// No downsampling needed, return None
Ok(None)
}
}
Factory function new implementation:
/// Create a new ResidualBlock instance.
///
/// # Arguments
///
/// - `vb`: Variable builder used to construct variables.
/// - `in_planes`: Number of input channels.
/// - `out_planes`: Number of output channels.
/// - `stride`: Convolution stride.
///
/// # Returns
///
/// - `Result<Self>`: Returns a constructed ResidualBlock instance.
///
/// # Description
///
/// This function constructs a residual block including two convolutional layers, two batch normalization layers, and a downsampling layer.
/// Each convolutional layer is followed by a batch normalization layer to accelerate the training process.
/// The downsampling layer adjusts the input dimensions to match the output for residual addition operations.
pub fn new(vb: VarBuilder, in_planes: usize, out_planes: usize, stride: usize) -> Result<Self> {
// Build the first convolutional layer using specified input channels, output channels, kernel size, stride, etc.
let conv1 = conv2d(in_planes, out_planes, 3, 1, stride, vb.pp("conv1"))?;
// Build the first batch normalization layer to normalize the output of the first convolutional layer
let bn1 = batch_norm(out_planes, 1e-5, vb.pp("bn1"))?;
// Build the second convolutional layer with the same input and output channels as the first convolutional layer
let conv2 = conv2d(out_planes, out_planes, 3, 1, 1, vb.pp("conv2"))?;
// Build the second batch normalization layer to normalize the output of the second convolutional layer
let bn2 = batch_norm(out_planes, 1e-5, vb.pp("bn2"))?;
// Build the downsampling layer to adjust the input dimensions to match the output for residual addition operations
let downsample = downsample(in_planes, out_planes, stride, vb.pp("downsample"))?;
// Return the constructed ResidualBlock instance
Ok(Self { conv1, bn1, conv2, bn2, downsample })
}
There is nothing much to say; just write the parameters according to the model structure.
Each feature must implement the Module trait. Here, we need to refer to the original implementation in PyTorch:
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
Conversion to Candle Version
impl Module for BasicBlock {
// Define the forward propagation function
// Parameter xs: input tensor
// Return value: tensor processed by network layers
fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
// Apply convolution, batch normalization, ReLU activation function, convolution, and batch normalization operations to the input tensor in sequence
let ys = xs
.apply(&self.conv1)?
.apply_t(&self.bn1, false)?
.relu()?
.apply(&self.conv2)?
.apply_t(&self.bn2, false)?;
// Depending on whether downsample exists, decide whether to downsample the input tensor, then add it to the previous result and apply ReLU activation function
// This explains why there is a conditional branch: to handle the case where the dimensions of the input tensor change
if let Some(downsample) = &self.downsample {
(xs.apply(downsample) + ys)?.relu()
} else {
(ys + xs)?.relu()
}
}
}
AdaptiveAvgPool2d Adaptive Average Pooling
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
The most troublesome part is this, as there is no corresponding API in candle (at least in the version the author implemented, which was candle 0.3, there was none; now, when the author writes this article, the candle version has reached 0.8).
So we need to know what AdaptiveAvgPool2d is like in PyTorch. The core idea of adaptive average pooling is: regardless of the size of the input feature map, it can dynamically adjust the size and stride of the pooling window to ensure that the size of the output feature map meets the preset target size. Specifically, AdaptiveAvgPool2d will calculate the size and stride of each pooling window dynamically based on the actual size of the input feature map and the user-specified output size, then average the elements within each window to obtain the output value corresponding to that window.
Essentially, it calculates the average through the size and stride of the pooling window. Since the pooling window size is 1 here, we need to find the relevant interface for calculating the mean in candle, such as mean:
/// Calculates the mean of all elements in the input tensor. The mean calculation covers all input dimensions,
/// and whether these dimensions are compressed or retained depends on the `mean_keepdim` parameter.
pub fn mean<D: Dims>(&self, mean_dims: D) -> Result<Self> {
// Convert mean_dims to index form and check its validity
let mean_dims = mean_dims.to_indexes(self.shape(), "mean")?;
// Calculate the size of the dimensions to be reduced
let reduced_dim: usize = mean_dims.iter().map(|i| self.dims()[*i]).product();
// Calculate the scaling factor
let scale = 1f64 / (reduced_dim as f64);
// Calculate the sum and apply the scaling factor to obtain the mean
self.sum_impl(mean_dims, false)? * scale
}
Let’s look at the forward propagation in Python, where the following calls are made:
x = self.avgpool(x)
x = torch.flatten(x, 1)
First, the tensor is averaged, then flattened, which is exactly equivalent to the following operations:
let xs = xs.mean(D::Minus1)?;
let xs = xs.mean(D::Minus1)?;
Linear Fully Connected Layer
The fully connected layer mainly integrates features and ultimately outputs a fixed-size vector. This vector is usually used for final decision-making in classification or regression tasks.
(fc): Linear(in_features=512, out_features=1000, bias=True)
This has a corresponding API in candle.
nn::linear(in_dim, out_dim, vb.pp("fc"));
layer1-layer4
(layer1): Sequential(
...
)
(layer2): Sequential(
...
)
(layer3): Sequential(
...
)
(layer4): Sequential(
...
)
These layers are basically repetitions of BasicBlock, so we can encapsulate a function
/// Create a basic layer consisting of multiple basic blocks
///
/// # Parameters
///
/// - `vb`: Variable builder for creating variables
/// - `in_planes`: Number of input channels
/// - `out_planes`: Number of output channels
/// - `stride`: Stride size, used for the first block
/// - `cnt`: Number of layers, i.e., the number of basic blocks in this layer
///
/// # Returns
///
/// Returns a sequence containing multiple basic blocks
fn basic_layer(
vb: VarBuilder,
in_planes: usize,
out_planes: usize,
stride: usize,
cnt: usize
) -> Result<Sequential<BasicBlock>> {
// Initialize a sequence to store multiple basic blocks
let mut layers = seq(cnt);
// Iterate over the index of each basic block, create a basic block, and add it to the sequence
for block_index in 0..cnt {
// Determine the input channel number: the first block uses the input parameter in_planes, other blocks use out_planes
let l_in = if block_index == 0 { in_planes } else { out_planes };
// Determine the stride size: the first block uses the input parameter stride, other blocks use 1
let stride = if block_index == 0 { stride } else { 1 };
// Create a basic block, here using the pp method of the variable builder to create a unique variable name prefix for each block
// Use the new method of BasicBlock to create a basic block and handle possible errors
let layer = BasicBlock::new(vb.pp(block_index.to_string()), l_in, out_planes, stride)?;
// Add the created basic block to the sequence
layers.push(layer);
}
// Return the sequence containing multiple basic blocks, indicating that the basic layer was successfully created
Ok(layers)
}
Overall Structure
The ResNet structure common to resnet18/resnet34:
/// Define the ResNet network structure
///
/// # Fields
///
/// - `conv1`: The first convolutional layer, used for initial feature extraction
/// - `bn1`: The first batch normalization layer, used for normalizing the output of the convolutional layer
/// - `layer1`: The first basic layer, consisting of multiple basic blocks
/// - `layer2`: The second basic layer, consisting of multiple basic blocks
/// - `layer3`: The third basic layer, consisting of multiple basic blocks
/// - `layer4`: The fourth basic layer, consisting of multiple basic blocks
/// - `linear`: An optional fully connected layer, used for final classification
#[derive(Debug, Clone)]
pub struct ResNet {
conv1: Conv2d,
bn1: nn::BatchNorm,
layer1: Sequential<BasicBlock>,
layer2: Sequential<BasicBlock>,
layer3: Sequential<BasicBlock>,
layer4: Sequential<BasicBlock>,
linear: Option<Linear>,
}
Create Instance
Defining c1-c4 is mainly because the depth of different ResNet versions is different, and the main difference between resnet18 and resnet34 is also here.
/// Create a new ResNet model instance.
///
/// # Parameters
/// - `vb`: Variable builder for constructing variables.
/// - `nclasses`: Optional number of classes, used to configure the output layer of the model.
/// - `c1`, `c2`, `c3`, `c4`: Number of convolutional layers in each residual block.
///
/// # Returns
/// Returns a constructed ResNet model instance.
pub fn new(
vb: VarBuilder,
nclasses: Option<usize>,
c1: usize,
c2: usize,
c3: usize,
c4: usize
) -> Result<Self> {
// The first convolutional layer, used to convert the input image from 3 channels to 64 channels, using a 7x7 convolutional kernel, stride of 2.
let conv1 = conv2d(3, 64, 7, 3, 2, vb.pp("conv1"))?;
// The first batch normalization layer, used to normalize the output of the convolutional layer.
let bn1 = batch_norm(64, 1e-5, vb.pp("bn1"))?;
// The first residual block, containing `c1` convolutional layers, input and output channels are both 64, stride is 1.
let layer1 = basic_layer(vb.pp("layer1"), 64, 64, 1, c1)?;
// The second residual block, containing `c2` convolutional layers, input channel is 64, output channel is 128, stride is 2.
let layer2 = basic_layer(vb.pp("layer2"), 64, 128, 2, c2)?;
// The third residual block, containing `c3` convolutional layers, input channel is 128, output channel is 256, stride is 2.
let layer3 = basic_layer(vb.pp("layer3"), 128, 256, 2, c3)?;
// The fourth residual block, containing `c4` convolutional layers, input channel is 256, output channel is 512, stride is 2.
let layer4 = basic_layer(vb.pp("layer4"), 256, 512, 2, c4)?;
// Decide whether to add a fully connected layer based on the number of classes.
let linear = if let Some(n) = nclasses {
// If the number of classes is provided, add a fully connected layer from 512 channels to the number of classes.
Some(nn::linear(512, n, vb.pp("fc"))?)
} else {
// If the number of classes is not provided, do not add a fully connected layer.
None
};
// Construct and return the ResNet model instance.
Ok(Self {
conv1,
bn1,
layer1,
layer2,
layer3,
layer4,
linear,
})
}
Forward Propagation
In Python, the forward propagation of ResNet is as follows:
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
Convert to Candle forward propagation:
// Define the forward propagation function of the network
// Parameters:
// * `xs`: Reference to the input tensor
// Return value:
// * Returns a Result containing the output tensor after forward propagation
fn forward(&self, xs: &candle_core::Tensor) -> Result<candle_core::Tensor> {
// Apply the first convolutional layer and the first batch normalization layer, followed by ReLU activation
let xs = xs.apply(&self.conv1)?;
let xs = xs.apply_t(&self.bn1, false)?;
let xs = xs.relu()?;
// Pad the input tensor in two dimensions and then apply max pooling
let xs = xs.pad_with_same(D::Minus1, 1, 1)?;
let xs = xs.pad_with_same(D::Minus2, 1, 1)?;
let xs = xs.max_pool2d_with_stride(3, 2)?;
// Apply four network layers, which may include convolutional layers, batch normalization layers, etc.
let xs = xs.apply(&self.layer1)?;
let xs = xs.apply(&self.layer2)?;
let xs = xs.apply(&self.layer3)?;
let xs = xs.apply(&self.layer4)?;
// Perform global average pooling, equivalent to reducing the spatial dimensions of the tensor to 1x1
let xs = xs.mean(D::Minus1)?;
let xs = xs.mean(D::Minus1)?;
// Depending on whether the last fully connected layer exists, choose to apply the fully connected layer or return directly
match &self.linear {
Some(fc) => xs.apply(fc),
None => Ok(xs),
}
}
Judging whether the fully connected layer exists is not a necessary operation. The reason for this design is that in actual business, sometimes we only need the feature extraction layers of the model and do not need the fully connected layer, so an option is added to selectively remove the fully connected layer.
Up to this point, the model structure of resnet18/34 has been implemented. We can provide a convenient function:
fn resnet(
vb: VarBuilder,
nclasses: Option<usize>,
c1: usize,
c2: usize,
c3: usize,
c4: usize
) -> Result<ResNet> {
ResNet::new(vb, nclasses, c1, c2, c3, c4)
}
pub fn resnet18(vb: VarBuilder, num_classes: usize) -> Result<ResNet> {
resnet(vb, Some(num_classes), 2, 2, 2, 2)
}
/// Remove the final layer
pub fn resnet18_no_final_layer(vb: VarBuilder) -> Result<ResNet> {
resnet(vb, None, 2, 2, 2, 2)
}
pub fn resnet34(vb: VarBuilder, num_classes: usize) -> Result<ResNet> {
resnet(vb, Some(num_classes), 3, 4, 6, 3)
}
/// Remove the final layer
pub fn resnet34_no_final_layer(vb: VarBuilder) -> Result<ResNet> {
resnet(vb, None, 3, 4, 6, 3)
}
ResNet 50/101/152
The design of these versions of the residual blocks is different from the previous ResNet18/34, not BasicBlock but BottleneckBlock.
This is BasicBlock
BasicBlock(
(conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
This is BottleneckBlock
Bottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(downsample): Sequential(
(0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
Therefore, we also need to redefine the BottleneckBlock structure:
// Residual block version for ResNet 50, 101, and 152.
#[derive(Debug, Clone)]
pub struct BottleneckBlock {
conv1: Conv2d, // First convolutional layer
bn1: nn::BatchNorm, // First batch normalization layer
conv2: Conv2d, // Second convolutional layer
bn2: nn::BatchNorm, // Second batch normalization layer
conv3: Conv2d, // Third convolutional layer
bn3: nn::BatchNorm, // Third batch normalization layer
downsample: Option<Downsample>, // Optional downsampling layer
}
The other basic logic is basically the same and will not be expanded.
Testing
The testing phase is crucial, as it helps us discover omissions in the code conversion process and verify the equivalence of different operations.
Unfortunately, candle does not have a convenient method to print the model structure implemented above. We can only rely on println!("{xs}") to view.
To verify the completeness of the model structure, we need to compare the code line by line.
As for verifying the correctness of the forward propagation logic of the model, I adopt the following method: use the same model weights, and print the xs variable at positions where problems may occur (usually in the forward function) separately, then compare whether the outputs of the two are consistent.
Using the word “consistent” is not accurate, because different model inference frameworks and underlying implementation languages may lead to differences in floating-point precision. Therefore, strictly speaking, there will be some.
Final Verification
We will use this image as an example:
Write a Test Case
// Function to test the ResNet-18 model
// This function loads a pre-trained ResNet-18 model and uses it to classify an image
fn test_resnet18() -> candle_core::Result<()> {
// Define the path to the model file
let model_file = "./testdata/resnet18.safetensors";
// Define the CPU as the computing device
let device = candle_core::Device::Cpu;
// Load the weight parameters from the model file
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
// Load and preprocess the input image
let image = load_image224("./testdata/mouse.png")?;
// Build the ResNet-18 model
let model = resnet18(vb, 1000)?;
// Add a dimension to the image data to match the model's input requirements
let image = image.unsqueeze(0)?;
// Perform a forward pass on the image data using the model to get the logits (unnormalized predictions)
let logits = model.forward(&image)?;
// Apply softmax to the logits to convert them into a probability distribution
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?.i(0)?.to_vec1::<f32>()?;
// Pair the probability distribution with their corresponding class indices and sort by probability in descending order
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
// Print the top five classes with the highest probabilities along with their probabilities
for &(category_idx, pr) in prs.iter().take(5) {
println!("{:24}: {:.2}%", CLASSES[category_idx], 100.0 * pr);
}
// Return Ok if the function executes successfully
Ok(())
}
Running the test case yields the following results, indicating that the code correctly identifies its classification as mouse, confirming that our code is successful:
running 1 test
test test_resnet18 ... ok
successes:
---- test_resnet18 stdout ----
mouse, computer mouse : 90.03%
punching bag, punch bag, punching ball, punchball: 4.49%
joystick : 1.80%
radio, wireless : 0.44%
vacuum, vacuum cleaner : 0.19%
Summary
When converting PyTorch model code to Candle code, the biggest challenge lies in finding equivalent APIs. In such cases, we need to delve into the source code of both PyTorch and Candle to find equivalent implementations.
Testing also presents difficulties due to the lack of comprehensive testing tools, leaving us to rely on practical experience for verification. This can be unfriendly to beginners.