First, we need to process the dataset. I will directly use the CASIA-HWDB 1.0 dataset provided by the Chinese Academy of Sciences. This is a dataset of handwritten single-character samples, which is very suitable for handwriting Chinese character input scenarios.
Processing the Dataset
Dataset official website: https://nlpr.ia.ac.cn/databases/handwriting/Download.html
HWDB1.0 contains 3,740 Chinese character samples, but it also includes punctuation marks, special symbols, numbers, and English letters, totaling 4,037 categories.
Download the dataset:
wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0TrainPart1.zip
wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0TrainPart2.zip
wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0TrainPart3.zip
wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0Test.zip
The first three are the training sets, and the last one is the test/validation set.
You may notice that after decompression, the files are not regular images but .gnt format binary files. The .gnt format is a compact binary format that stores a large number of handwritten Chinese character images along with their labels.
Here is a parsing script found online:
import os
from pathlib import Path
import struct
from PIL import Image
def write_txt(save_path: str, content: list, mode='w'):
"""
Write list content to a txt file
@param
content: list content
save_path: absolute path str
@return: None
"""
with open(save_path, mode, encoding='utf-8') as f:
for value in content:
f.write(value + '\n')
path = 'data/test_gnt'
save_dir = 'data/test' # All files in this directory are gnt files
gnt_paths = list(Path(path).iterdir())
label_list = []
for gnt_path in gnt_paths:
count = 0
print(gnt_path)
with open(str(gnt_path), 'rb') as f:
while f.read(1) != "":
f.seek(-1, 1)
count += 1
try:
# The try block is added because sometimes f.read throws an error: struct.error: unpack requires a buffer of 4 bytes
# The reason is still unknown
length_bytes = struct.unpack('<I', f.read(4))[0]
tag_code = f.read(2)
width = struct.unpack('<H', f.read(2))[0]
height = struct.unpack('<H', f.read(2))[0]
im = Image.new('RGB', (width, height))
img_array = im.load()
for x in range(0, height):
for y in range(0, width):
pixel = struct.unpack('<B', f.read(1))[0]
img_array[y, x] = (pixel, pixel, pixel)
filename = str(count) + '.png'
tag_code = tag_code.decode('gbk').strip('\x00')
save_path = f'{save_dir}/images/{gnt_path.stem}'
if not Path(save_path).exists():
Path(save_path).mkdir(parents=True, exist_ok=True)
im.save(f'{save_path}/{filename}')
label_list.append(f'{gnt_path.stem}/{filename}\t{tag_code}')
except:
break
write_txt(f'{save_dir}/gt.txt', label_list)
This way, we obtain the images and the corresponding annotations in gt.txt, which looks like this:
//...
240-t/218.png 拔
240-t/219.png 跋
240-t/220.png 靶
240-t/221.png 把
240-t/222.png 耙
//...
We also need to extract the Chinese character labels from gt.txt and convert them into indices for easier model training.
The label.txt format is as follows:
! 0
" 1
# 2
$ 3
% 4
( 5
) 6
* 7
+ 8
, 9
- 10
. 11
/ 12
0 13
1 14
2 15
3 16
//More...
Next, we can build the dataset.
First, let’s analyze the images. Through statistics, we find that the dataset’s resolution ranges are as follows:
Image width distribution:
[0-50]: 158,633 images
[50-100]: 1,425,052 images
[100-150]: 96,269 images
[150-200]: 300 images
[200+]: 4 images
Image height distribution:
[0-50]: 55,862 images
[50-100]: 1,385,139 images
[100-150]: 236,884 images
[150-200]: 2,320 images
[200+]: 53 images
Loading labels:
def load_label_to_idx(self,label_path:Path):
char_to_idx = {}
with open(str(label_path), 'r', encoding='utf-8') as f:
for line in f:
line = line.strip().split('\t')
char = line[0]
char_to_idx[char] = int(line[1])
return char_to_idx
Loading data:
def load_data(self):
data = []
txt_file = Path(self.root_dir) / self.mode / 'gt.txt'
with open(txt_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip().split('\t')
img_path = line[0]
label = line[1]
data.append((img_path, label))
return data
Sample processing:
def __getitem__(self, index):
# Get the image path and corresponding character label from self.data based on the index
img_path, char = self.data[index]
# Build the full image path: root directory/mode (e.g., train or test)/images/image file name
img_path = Path(self.root_dir) / self.mode / 'images' / img_path
# Open the image file
img = Image.open(img_path)
# Ensure the image is in RGB format (3 channels)
img = img.convert('RGB')
# Apply transform (data augmentation/preprocessing) if defined
if self.transform: img = self.transform(img)
# Get the numerical index of the character from the label_idx dictionary
label_indices = self.label_idx[char]
# Convert the label to a torch.long tensor
label_tensor = torch.tensor(int(label_indices), dtype=torch.long)
# Return the processed image and corresponding label tensor
return img, label_tensor
It’s worth noting that the dataset is already in RGB format by default, but if optimized, it can be reduced to a single channel for recognition, which would significantly reduce the model’s parameters.
Since the dataset is unique, with a white background and black characters, using ImageNet’s normalization parameters is not suitable. Here, I directly use transforms.Normalize(mean=[0.95], std=[0.2]).
Model Design
It’s best to modify a proven model architecture, such as MobileNet or ResNet. Implementing one from scratch is relatively difficult to achieve the same level of accuracy (or it would require more time to fine-tune the “recipe”). Here, I prioritize using a MobileNetV2-based approach. In fact, using ResNet would yield higher accuracy.
class MobileNetV2_Chinese(MobileNetV2):
def __init__(self, num_classes=4037):
super().__init__()
# Adjust the stride of the first convolutional layer
#self.features[0][0].stride = (1, 1)
# Redefine the classifier
last_channel = 1280
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(last_channel, num_classes),
)
# Adaptive pooling
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
# Equivalent to the above operation
# x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
# x = torch.flatten(x, 1)
x = self.classifier(x)
return x
MobileNetV2_Chinese mainly changes the number of classes in the fully connected layer from 1,000 to 4,037.
Training
The training code is straightforward.
class HandwritingTrainer(pl.LightningModule):
def __init__(self,model = 'mobilenetv2',batch_size=32):
super().__init__()
if model == 'resnet18':
self.model = ResNet18_Chinese(4037)
elif model == 'mobilenetv2':
self.model = MobileNetV2_Chinese(4037)
else:
raise ValueError("model must be resnet18 or mobilenetv2")
self.criterion = torch.nn.CrossEntropyLoss()
self.batch_size = batch_size
self.root_dir = "data"
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.model.parameters(),lr=1e-2,weight_decay=1e-4,momentum=0.9)
# optimizer = torch.optim.AdamW(
# self.model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=100)
return [optimizer], [scheduler]
def common_transforms_compose(self,mode='train'):
if mode == 'train':
return transforms.Compose([
lambda img: resize_to_sqr(img),
transforms.Resize((96,96)),
transforms.Lambda(lambda img: img.point(lambda x: 0 if x < 200 else 255)), # Set threshold to 200
transforms.RandomGrayscale(p=0.1), # Random grayscale to simulate different ink colors
transforms.RandomRotation(15),
transforms.RandomPerspective(distortion_scale=0.2, p=0.1), # Simulate writing tilt
transforms.RandomApply([transforms.GaussianBlur(3)], p=0.1), # Slight blur for noise resistance
transforms.ToTensor(),
transforms.Normalize(mean=[0.95], std=[0.2])
])
else:
return transforms.Compose([
lambda img: resize_to_sqr(img),
transforms.Resize((96,96)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.95], std=[0.2])
])
def train_dataloader(self):
train_dataset = HWDB1Dataset(
root_dir=self.root_dir, mode='train', transform=self.common_transforms_compose('train'))
return DataLoader(
dataset=train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=8,
persistent_workers=True,
pin_memory=True,
drop_last=True
)
def val_dataloader(self):
train_dataset = HWDB1Dataset(
root_dir=self.root_dir, mode='test', transform=self.common_transforms_compose('test'))
return DataLoader(
dataset=train_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=8,
persistent_workers=True,
pin_memory=True,
drop_last=True
)
def compute_grad_norm(self):
# Calculate the gradient norm of all parameters
total_norm = 0.0
for param in self.parameters():
if param.grad is not None:
param_norm = param.grad.data.norm(2) # L2 norm
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
return total_norm
def on_after_backward(self):
# Calculate gradient norm
grad_norm = self.compute_grad_norm()
self.log("grad_norm", grad_norm, on_step=True, on_epoch=False, prog_bar=True, logger=True)
def common_setps(self, batch, batch_idx,mode = "train"):
x,y = batch
pred = self.model(x)
# print(f"Model output: {pred}") # Print model output
loss = self.criterion(pred, y)
self.log(f'{mode}_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def training_step(self, batch, batch_idx):
return self.common_setps(batch, batch_idx,mode = "train")
def validation_step(self, batch, batch_idx):
return self.common_setps(batch, batch_idx,mode = "val")
from trainer import HandwritingTrainer
from argparse import ArgumentParser
if __name__ == "__main__":
parse =ArgumentParser()
parse.add_argument('-m','--model', type=str,default="mobilenetv2")
parse.add_argument('-b','--batch_size', type=int,default=32)
args = parse.parse_args()
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
dirpath='./checkpoints/',
filename='best-model',
save_top_k=1,
mode='min'
)
# Create a TensorBoardLogger instance and specify the log save path
logger = TensorBoardLogger(save_dir="logs", name=args.model,log_graph=True)
# Set up the ModelCheckpoint callback
checkpoint_callback = ModelCheckpoint(
monitor='val_loss', # Monitoring metric, e.g., validation loss
dirpath=logger.log_dir, # Checkpoint save path
filename='checkpoint-{epoch:02d}-{val_loss:.3f}', # File naming format
save_top_k=3, # Only save the best k models
mode='min', # 'min' means the lower the metric, the better
save_weights_only=True, # Only save weights, not the entire model
)
early_stopping = EarlyStopping('val_loss', mode="min")
trainer_args = {
'accelerator': 'gpu',
'devices': [0],
'callbacks': [checkpoint_callback,early_stopping],
'max_epochs': 100,
}
trainer = pl.Trainer(logger=logger, **trainer_args, fast_dev_run=False)
# Training data
model = HandwritingTrainer(model=args.model,batch_size=args.batch_size)
print(model)
trainer.fit(model)
Start training:
uv run train.py -b 64
Optimization Points
More ablation experiments are needed to verify, such as:
- Is 96x96 better than 224x224?
- Is a single channel better than three channels?
- Would modifying other downsampling layers improve performance?
However, it’s important to note that removing downsampling layers would cause the model’s parameters to expand, and training would take longer.
Through some ablation experiments, we found:
- At 96x96 resolution, the original model layers can achieve good results without modification.
- 224x224 resolution did not show better recognition rates.
- Dataset processing is crucial. For example, characters like “一” and ”|” have extreme aspect ratios, so custom resizing is necessary to prevent excessive distortion.