Apr 25, 2025
5 min read
Rust,
pytorch,
candle,
typescript,
wasm,

从零开始构建手写输入法:模型训练篇

我们首先需要处理数据集。我这里直接采用中科院自动化所提供的CASIA-HWDB 1.0 数据集。这是一个手写单字样本数据集,非常适合手写汉字输入场景。

处理数据集

数据集官网地址: https://nlpr.ia.ac.cn/databases/handwriting/Download.html

HWDB1.0 有 3740 个汉字样本,不过里面除了汉字,还有一些标点符号、特殊符号、数字以及英文字母等,一共加起来其实有 4037 个分类。

下载数据集:

wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0TrainPart1.zip
wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0TrainPart2.zip
wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0TrainPart3.zip

wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0Test.zip

其中前三个为训练集,最后一个为测试/验证集。

你可能会发现,解压后得到的并不是正常的图片,而是 gnt 格式的二进制文件。.gnt 格式是一种紧凑的二进制格式,里面存储大量手写汉字图像及其标签。

网上找到的解析脚本:

import os
from pathlib import Path
import struct
from PIL import Image
def write_txt(save_path: str, content: list, mode='w'):
    """
    将list内容写入txt中
    @param
    content: list格式内容
    save_path: 绝对路径str
    @return:None
    """
    with open(save_path, mode, encoding='utf-8') as f:
        for value in content:
            f.write(value + '\n')



path = 'data/test_gnt'
save_dir = 'data/test'  # 目录下均为gnt文件

gnt_paths = list(Path(path).iterdir())

label_list = []
for gnt_path in gnt_paths:
    count = 0
    print(gnt_path)
    with open(str(gnt_path), 'rb') as f:
        while f.read(1) != "":
            f.seek(-1, 1)
            count += 1
            try:
                # 只所以添加try,是因为有时f.read会报错 struct.error: unpack requires a buffer of 4 bytes
                # 原因尚未找到
                length_bytes = struct.unpack('<I', f.read(4))[0]

                tag_code = f.read(2)

                width = struct.unpack('<H', f.read(2))[0]

                height = struct.unpack('<H', f.read(2))[0]

                im = Image.new('RGB', (width, height))
                img_array = im.load()
                for x in range(0, height):
                    for y in range(0, width):
                        pixel = struct.unpack('<B', f.read(1))[0]
                        img_array[y, x] = (pixel, pixel, pixel)

                filename = str(count) + '.png'
                tag_code = tag_code.decode('gbk').strip('\x00')
                save_path = f'{save_dir}/images/{gnt_path.stem}'
                if not Path(save_path).exists():
                    Path(save_path).mkdir(parents=True, exist_ok=True)
                im.save(f'{save_path}/{filename}')

                label_list.append(f'{gnt_path.stem}/{filename}\t{tag_code}')
            except:
                break

write_txt(f'{save_dir}/gt.txt', label_list)

这样我们就得到了图片,以及标注 gt.txt,大概标注内容如下:

//...
240-t/218.png	拔
240-t/219.png	跋
240-t/220.png	靶
240-t/221.png	把
240-t/222.png	耙
//...

我们还需要基于gt.txt把汉字标签提取出来,转换为索引,方便模型训练。

label.txt 格式如下:

!	0
"	1
#	2
$	3
%	4
(	5
)	6
*	7
+	8
,	9
-	10
.	11
/	12
0	13
1	14
2	15
3	16
//更多...

接下来就可以构建数据集了。

先分析一下图片,通过统计发现,数据集分辨率区间如下:

图片宽度区间分布统计:
[0-50]: 158633 张图片
[50-100]: 1425052 张图片
[100-150]: 96269 张图片
[150-200]: 300 张图片
[200,]: 4 张图片
图片高度区间分布统计:
[0-50]: 55862 张图片
[50-100]: 1385139 张图片
[100-150]: 236884 张图片
[150-200]: 2320 张图片
[200,]: 53 张图片

加载标签:

    def load_label_to_idx(self,label_path:Path):
        char_to_idx = {}
        with open(str(label_path), 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip().split('\t')
                char = line[0]
                char_to_idx[char] = int(line[1])
        return char_to_idx

加载数据:

    def load_data(self):
        data = []
        txt_file = Path(self.root_dir) / self.mode / 'gt.txt'
        with open(txt_file, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip().split('\t')
                img_path = line[0]
                label = line[1]
                data.append((img_path, label))
        return data

样本处理:

def __getitem__(self, index):
    # 根据索引index从self.data中获取图片路径和对应的字符标签
    img_path, char = self.data[index]
    
    # 构建完整的图片路径:根目录/模式(如train或test)/images/图片文件名
    img_path = Path(self.root_dir) / self.mode / 'images' / img_path
    
    # 打开图片文件
    img = Image.open(img_path)
    # 确保图片是RGB格式(3通道)
    img = img.convert('RGB')

    # 如果有定义transform(数据增强/预处理),则应用transform
    if self.transform: img = self.transform(img)
    
    # 从label_idx字典中获取字符对应的数字索引
    label_indices = self.label_idx[char]
    # 将标签转换为torch.long类型的张量
    label_tensor = torch.tensor(int(label_indices), dtype=torch.long)
    
    # 返回处理后的图片和对应的标签张量
    return img, label_tensor

这里我们可以注意的是,其实数据集默认就是 RGB 的,但是如果可以优化的话,可以直接优化为一个通道就可以做到识别了,这样会让模型的参数量减少很多。

由于数据集的数据集比较特殊,基本是白底背景加上黑色字,因此使用 imagenet 的归一化参数不太适合了,我这里直接使用 transforms.Normalize(mean=[0.95], std=[0.2])

模型设计

模型最好采用已成功验证过的模型方案进行修改,比如 mobilenet 或者 resnet。自己实现相对来说是比较难达到上面这些模型的精度(或者需要自己多花时间修改“丹方”)。我这里优先采用基于 mobilenetv2 的方案。事实上,使用 resnet 精度会更高。

class MobileNetV2_Chinese(MobileNetV2):
    def __init__(self, num_classes=4037):
        super().__init__()
        # 调整第一个卷积层的步长
        #self.features[0][0].stride = (1, 1)
        # 重新定义分类器
        last_channel = 1280
        self.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(last_channel, num_classes),
        )
        # 自适应池化
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        # 等效于上面的操作
        # x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
        # x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

MobileNetV2_Chinese 主要是把全连接层分类数要从 1000 改为4037。

训练

训练代码没什么可说的。

class HandwritingTrainer(pl.LightningModule):

    def __init__(self,model = 'mobilenetv2',batch_size=32):
        super().__init__()
        if model == 'resnet18':
            self.model = ResNet18_Chinese(4037)
        elif model == 'mobilenetv2':
            self.model = MobileNetV2_Chinese(4037)
        else:
            raise ValueError("model must be resnet18 or mobilenetv2")
        self.criterion = torch.nn.CrossEntropyLoss()
        self.batch_size = batch_size
        self.root_dir = "data"

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.model.parameters(),lr=1e-2,weight_decay=1e-4,momentum=0.9)
        # optimizer = torch.optim.AdamW(
        #     self.model.parameters(), lr=1e-3, weight_decay=1e-4)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=100)
        return [optimizer], [scheduler]

        def common_transforms_compose(self,mode='train'):
        if mode == 'train':
            return transforms.Compose([
                lambda img: resize_to_sqr(img),
                transforms.Resize((96,96)),
                transforms.Lambda(lambda img: img.point(lambda x: 0 if x < 200 else 255)),  # 阈值设为200
                transforms.RandomGrayscale(p=0.1),  # 随机灰度化模拟不同墨色
                transforms.RandomRotation(15),
                transforms.RandomPerspective(distortion_scale=0.2, p=0.1),  # 模拟书写倾斜
                transforms.RandomApply([transforms.GaussianBlur(3)], p=0.1),  # 轻微模糊抗噪
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.95], std=[0.2])
            ])
        else:
            return transforms.Compose([
                lambda img: resize_to_sqr(img),
                transforms.Resize((96,96)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.95], std=[0.2])
            ])

    def train_dataloader(self):
        train_dataset = HWDB1Dataset(
            root_dir=self.root_dir, mode='train', transform=self.common_transforms_compose('train'))
        return DataLoader(
            dataset=train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=8,
            persistent_workers=True,
            pin_memory=True,
            drop_last=True

        )

    def val_dataloader(self):
        train_dataset = HWDB1Dataset(
            root_dir=self.root_dir, mode='test', transform=self.common_transforms_compose('test'))
        return DataLoader(
            dataset=train_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=8,
            persistent_workers=True,
            pin_memory=True,
            drop_last=True

        )
    
    def compute_grad_norm(self):
        # 计算所有参数的梯度范数
        total_norm = 0.0
        for param in self.parameters():
            if param.grad is not None:
                param_norm = param.grad.data.norm(2)  # L2 范数
                total_norm += param_norm.item() ** 2
        total_norm = total_norm ** 0.5
        return total_norm
    
    def on_after_backward(self):
        # 计算梯度范数
        grad_norm = self.compute_grad_norm()
        self.log("grad_norm", grad_norm, on_step=True, on_epoch=False, prog_bar=True, logger=True)

    
    def common_setps(self, batch, batch_idx,mode = "train"):
        x,y = batch
        pred = self.model(x)
        # print(f"Model output: {pred}")  # 打印模型输出

        loss = self.criterion(pred, y)

        self.log(f'{mode}_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss
    
    def training_step(self, batch, batch_idx):
        return self.common_setps(batch, batch_idx,mode = "train")
    
    def validation_step(self, batch, batch_idx):
        return self.common_setps(batch, batch_idx,mode = "val")

from trainer import HandwritingTrainer
from argparse import ArgumentParser

if __name__ == "__main__":

    parse =ArgumentParser()
    parse.add_argument('-m','--model', type=str,default="mobilenetv2")
    parse.add_argument('-b','--batch_size', type=int,default=32)
    args = parse.parse_args()

    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath='./checkpoints/',
        filename='best-model',
        save_top_k=1,
        mode='min'
    )

    # 创建TensorBoardLogger实例并指定日志保存路径
    logger = TensorBoardLogger(save_dir="logs", name=args.model,log_graph=True)

    # 设置 ModelCheckpoint 回调  
    checkpoint_callback = ModelCheckpoint(  
        monitor='val_loss',  # 监控指标,例如验证损失  
        dirpath=logger.log_dir,  # 检查点保存路径  
        filename='checkpoint-{epoch:02d}-{val_loss:.3f}',  # 保存文件的命名格式  
        save_top_k=3,  # 只保存最好的k个模型  
        mode='min',  # 'min'表示监控指标越小越好  
        save_weights_only=True,  # 只保存权重而不是整个模型  
    ) 
    early_stopping = EarlyStopping('val_loss', mode="min")
    trainer_args = {
        'accelerator': 'gpu',
        'devices': [0],
        'callbacks': [checkpoint_callback,early_stopping],
        'max_epochs': 100,
    }
    trainer = pl.Trainer(logger=logger, **trainer_args, fast_dev_run=False)
    # 训练数据
    model = HandwritingTrainer(model=args.model,batch_size=args.batch_size)
    print(model)
    trainer.fit(model)

启动训练

uv run train.py -b 64

优化点

需要做更多的消融实验来验证,比如:

  1. 96x96 是否比 224x224 更优?
  2. 单通道是否比三通道更好?
  3. 修改其他下采样是否会更好?

但是有一些需要注意的,就是移除下采样会让模型参数膨胀,训练也会需要更长的时候。

通过一些消融实现发现:

  1. 96x96 分辨率下,不修改原版模型的层也能达到很好的效果。
  2. 224x224 分辨率下并没有表现出更好的识别率。
  3. 数据集的处理非常重要,比如,像 “一”、“|” 这类字的长宽比是非常夸张的,最好自定义 resize 让其保持不过分变形。