Apr 25, 2025
8 min read
Rust,
pytorch,
candle,
typescript,
wasm,

Building a Handwriting Input Method from Scratch: Model Training

First, we need to process the dataset. I will directly use the CASIA-HWDB 1.0 dataset provided by the Chinese Academy of Sciences. This is a dataset of handwritten single-character samples, which is very suitable for handwriting Chinese character input scenarios.

Processing the Dataset

Dataset official website: https://nlpr.ia.ac.cn/databases/handwriting/Download.html

HWDB1.0 contains 3,740 Chinese character samples, but it also includes punctuation marks, special symbols, numbers, and English letters, totaling 4,037 categories.

Download the dataset:

wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0TrainPart1.zip
wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0TrainPart2.zip
wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0TrainPart3.zip

wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0Test.zip

The first three are the training sets, and the last one is the test/validation set.

You may notice that after decompression, the files are not regular images but .gnt format binary files. The .gnt format is a compact binary format that stores a large number of handwritten Chinese character images along with their labels.

Here is a parsing script found online:

import os
from pathlib import Path
import struct
from PIL import Image
def write_txt(save_path: str, content: list, mode='w'):
    """
    Write list content to a txt file
    @param
    content: list content
    save_path: absolute path str
    @return: None
    """
    with open(save_path, mode, encoding='utf-8') as f:
        for value in content:
            f.write(value + '\n')



path = 'data/test_gnt'
save_dir = 'data/test'  # All files in this directory are gnt files

gnt_paths = list(Path(path).iterdir())

label_list = []
for gnt_path in gnt_paths:
    count = 0
    print(gnt_path)
    with open(str(gnt_path), 'rb') as f:
        while f.read(1) != "":
            f.seek(-1, 1)
            count += 1
            try:
                # The try block is added because sometimes f.read throws an error: struct.error: unpack requires a buffer of 4 bytes
                # The reason is still unknown
                length_bytes = struct.unpack('<I', f.read(4))[0]

                tag_code = f.read(2)

                width = struct.unpack('<H', f.read(2))[0]

                height = struct.unpack('<H', f.read(2))[0]

                im = Image.new('RGB', (width, height))
                img_array = im.load()
                for x in range(0, height):
                    for y in range(0, width):
                        pixel = struct.unpack('<B', f.read(1))[0]
                        img_array[y, x] = (pixel, pixel, pixel)

                filename = str(count) + '.png'
                tag_code = tag_code.decode('gbk').strip('\x00')
                save_path = f'{save_dir}/images/{gnt_path.stem}'
                if not Path(save_path).exists():
                    Path(save_path).mkdir(parents=True, exist_ok=True)
                im.save(f'{save_path}/{filename}')

                label_list.append(f'{gnt_path.stem}/{filename}\t{tag_code}')
            except:
                break

write_txt(f'{save_dir}/gt.txt', label_list)

This way, we obtain the images and the corresponding annotations in gt.txt, which looks like this:

//...
240-t/218.png	拔
240-t/219.png	跋
240-t/220.png	靶
240-t/221.png	把
240-t/222.png	耙
//...

We also need to extract the Chinese character labels from gt.txt and convert them into indices for easier model training.

The label.txt format is as follows:

!	0
"	1
#	2
$	3
%	4
(	5
)	6
*	7
+	8
,	9
-	10
.	11
/	12
0	13
1	14
2	15
3	16
//More...

Next, we can build the dataset.

First, let’s analyze the images. Through statistics, we find that the dataset’s resolution ranges are as follows:

Image width distribution:
[0-50]: 158,633 images
[50-100]: 1,425,052 images
[100-150]: 96,269 images
[150-200]: 300 images
[200+]: 4 images
Image height distribution:
[0-50]: 55,862 images
[50-100]: 1,385,139 images
[100-150]: 236,884 images
[150-200]: 2,320 images
[200+]: 53 images

Loading labels:

    def load_label_to_idx(self,label_path:Path):
        char_to_idx = {}
        with open(str(label_path), 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip().split('\t')
                char = line[0]
                char_to_idx[char] = int(line[1])
        return char_to_idx

Loading data:

    def load_data(self):
        data = []
        txt_file = Path(self.root_dir) / self.mode / 'gt.txt'
        with open(txt_file, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip().split('\t')
                img_path = line[0]
                label = line[1]
                data.append((img_path, label))
        return data

Sample processing:

def __getitem__(self, index):
    # Get the image path and corresponding character label from self.data based on the index
    img_path, char = self.data[index]
    
    # Build the full image path: root directory/mode (e.g., train or test)/images/image file name
    img_path = Path(self.root_dir) / self.mode / 'images' / img_path
    
    # Open the image file
    img = Image.open(img_path)
    # Ensure the image is in RGB format (3 channels)
    img = img.convert('RGB')

    # Apply transform (data augmentation/preprocessing) if defined
    if self.transform: img = self.transform(img)
    
    # Get the numerical index of the character from the label_idx dictionary
    label_indices = self.label_idx[char]
    # Convert the label to a torch.long tensor
    label_tensor = torch.tensor(int(label_indices), dtype=torch.long)
    
    # Return the processed image and corresponding label tensor
    return img, label_tensor

It’s worth noting that the dataset is already in RGB format by default, but if optimized, it can be reduced to a single channel for recognition, which would significantly reduce the model’s parameters.

Since the dataset is unique, with a white background and black characters, using ImageNet’s normalization parameters is not suitable. Here, I directly use transforms.Normalize(mean=[0.95], std=[0.2]).

Model Design

It’s best to modify a proven model architecture, such as MobileNet or ResNet. Implementing one from scratch is relatively difficult to achieve the same level of accuracy (or it would require more time to fine-tune the “recipe”). Here, I prioritize using a MobileNetV2-based approach. In fact, using ResNet would yield higher accuracy.

class MobileNetV2_Chinese(MobileNetV2):
    def __init__(self, num_classes=4037):
        super().__init__()
        # Adjust the stride of the first convolutional layer
        #self.features[0][0].stride = (1, 1)
        # Redefine the classifier
        last_channel = 1280
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(last_channel, num_classes),
        )
        # Adaptive pooling
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        # Equivalent to the above operation
        # x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        # x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

MobileNetV2_Chinese mainly changes the number of classes in the fully connected layer from 1,000 to 4,037.

Training

The training code is straightforward.

class HandwritingTrainer(pl.LightningModule):

    def __init__(self,model = 'mobilenetv2',batch_size=32):
        super().__init__()
        if model == 'resnet18':
            self.model = ResNet18_Chinese(4037)
        elif model == 'mobilenetv2':
            self.model = MobileNetV2_Chinese(4037)
        else:
            raise ValueError("model must be resnet18 or mobilenetv2")
        self.criterion = torch.nn.CrossEntropyLoss()
        self.batch_size = batch_size
        self.root_dir = "data"

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(),lr=1e-2,weight_decay=1e-4,momentum=0.9)
        # optimizer = torch.optim.AdamW(
        #     self.model.parameters(), lr=1e-3, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=100)
        return [optimizer], [scheduler]

        def common_transforms_compose(self,mode='train'):
        if mode == 'train':
            return transforms.Compose([
                lambda img: resize_to_sqr(img),
                transforms.Resize((96,96)),
                transforms.Lambda(lambda img: img.point(lambda x: 0 if x < 200 else 255)),  # Set threshold to 200
                transforms.RandomGrayscale(p=0.1),  # Random grayscale to simulate different ink colors
                transforms.RandomRotation(15),
                transforms.RandomPerspective(distortion_scale=0.2, p=0.1),  # Simulate writing tilt
                transforms.RandomApply([transforms.GaussianBlur(3)], p=0.1),  # Slight blur for noise resistance
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.95], std=[0.2])
            ])
        else:
            return transforms.Compose([
                lambda img: resize_to_sqr(img),
                transforms.Resize((96,96)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.95], std=[0.2])
            ])

    def train_dataloader(self):
        train_dataset = HWDB1Dataset(
            root_dir=self.root_dir, mode='train', transform=self.common_transforms_compose('train'))
        return DataLoader(
            dataset=train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=8,
            persistent_workers=True,
            pin_memory=True,
            drop_last=True

        )

    def val_dataloader(self):
        train_dataset = HWDB1Dataset(
            root_dir=self.root_dir, mode='test', transform=self.common_transforms_compose('test'))
        return DataLoader(
            dataset=train_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=8,
            persistent_workers=True,
            pin_memory=True,
            drop_last=True

        )
    
    def compute_grad_norm(self):
        # Calculate the gradient norm of all parameters
        total_norm = 0.0
        for param in self.parameters():
            if param.grad is not None:
                param_norm = param.grad.data.norm(2)  # L2 norm
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        return total_norm
    
    def on_after_backward(self):
        # Calculate gradient norm
        grad_norm = self.compute_grad_norm()
        self.log("grad_norm", grad_norm, on_step=True, on_epoch=False, prog_bar=True, logger=True)

    
    def common_setps(self, batch, batch_idx,mode = "train"):
        x,y = batch
        pred = self.model(x)
        # print(f"Model output: {pred}")  # Print model output

        loss = self.criterion(pred, y)

        self.log(f'{mode}_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss
    
    def training_step(self, batch, batch_idx):
        return self.common_setps(batch, batch_idx,mode = "train")
    
    def validation_step(self, batch, batch_idx):
        return self.common_setps(batch, batch_idx,mode = "val")

from trainer import HandwritingTrainer
from argparse import ArgumentParser

if __name__ == "__main__":

    parse =ArgumentParser()
    parse.add_argument('-m','--model', type=str,default="mobilenetv2")
    parse.add_argument('-b','--batch_size', type=int,default=32)
    args = parse.parse_args()

    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath='./checkpoints/',
        filename='best-model',
        save_top_k=1,
        mode='min'
    )

    # Create a TensorBoardLogger instance and specify the log save path
    logger = TensorBoardLogger(save_dir="logs", name=args.model,log_graph=True)

    # Set up the ModelCheckpoint callback  
    checkpoint_callback = ModelCheckpoint(  
        monitor='val_loss',  # Monitoring metric, e.g., validation loss  
        dirpath=logger.log_dir,  # Checkpoint save path  
        filename='checkpoint-{epoch:02d}-{val_loss:.3f}',  # File naming format  
        save_top_k=3,  # Only save the best k models  
        mode='min',  # 'min' means the lower the metric, the better  
        save_weights_only=True,  # Only save weights, not the entire model  
    ) 
    early_stopping = EarlyStopping('val_loss', mode="min")
    trainer_args = {
        'accelerator': 'gpu',
        'devices': [0],
        'callbacks': [checkpoint_callback,early_stopping],
        'max_epochs': 100,
    }
    trainer = pl.Trainer(logger=logger, **trainer_args, fast_dev_run=False)
    # Training data
    model = HandwritingTrainer(model=args.model,batch_size=args.batch_size)
    print(model)
    trainer.fit(model)

Start training:

uv run train.py -b 64

Optimization Points

More ablation experiments are needed to verify, such as:

  1. Is 96x96 better than 224x224?
  2. Is a single channel better than three channels?
  3. Would modifying other downsampling layers improve performance?

However, it’s important to note that removing downsampling layers would cause the model’s parameters to expand, and training would take longer.

Through some ablation experiments, we found:

  1. At 96x96 resolution, the original model layers can achieve good results without modification.
  2. 224x224 resolution did not show better recognition rates.
  3. Dataset processing is crucial. For example, characters like “一” and ”|” have extreme aspect ratios, so custom resizing is necessary to prevent excessive distortion.