我们首先需要处理数据集。我这里直接采用中科院自动化所提供的CASIA-HWDB 1.0 数据集。这是一个手写单字样本数据集,非常适合手写汉字输入场景。
处理数据集
数据集官网地址: https://nlpr.ia.ac.cn/databases/handwriting/Download.html
HWDB1.0 有 3740 个汉字样本,不过里面除了汉字,还有一些标点符号、特殊符号、数字以及英文字母等,一共加起来其实有 4037 个分类。
下载数据集:
wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0TrainPart1.zip
wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0TrainPart2.zip
wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0TrainPart3.zip
wget https://nlpr.ia.ac.cn/databases/Download/Offline/CharData/Gnt1.0Test.zip
其中前三个为训练集,最后一个为测试/验证集。
你可能会发现,解压后得到的并不是正常的图片,而是 gnt 格式的二进制文件。.gnt 格式是一种紧凑的二进制格式,里面存储大量手写汉字图像及其标签。
网上找到的解析脚本:
import os
from pathlib import Path
import struct
from PIL import Image
def write_txt(save_path: str, content: list, mode='w'):
"""
将list内容写入txt中
@param
content: list格式内容
save_path: 绝对路径str
@return:None
"""
with open(save_path, mode, encoding='utf-8') as f:
for value in content:
f.write(value + '\n')
path = 'data/test_gnt'
save_dir = 'data/test' # 目录下均为gnt文件
gnt_paths = list(Path(path).iterdir())
label_list = []
for gnt_path in gnt_paths:
count = 0
print(gnt_path)
with open(str(gnt_path), 'rb') as f:
while f.read(1) != "":
f.seek(-1, 1)
count += 1
try:
# 只所以添加try,是因为有时f.read会报错 struct.error: unpack requires a buffer of 4 bytes
# 原因尚未找到
length_bytes = struct.unpack('<I', f.read(4))[0]
tag_code = f.read(2)
width = struct.unpack('<H', f.read(2))[0]
height = struct.unpack('<H', f.read(2))[0]
im = Image.new('RGB', (width, height))
img_array = im.load()
for x in range(0, height):
for y in range(0, width):
pixel = struct.unpack('<B', f.read(1))[0]
img_array[y, x] = (pixel, pixel, pixel)
filename = str(count) + '.png'
tag_code = tag_code.decode('gbk').strip('\x00')
save_path = f'{save_dir}/images/{gnt_path.stem}'
if not Path(save_path).exists():
Path(save_path).mkdir(parents=True, exist_ok=True)
im.save(f'{save_path}/{filename}')
label_list.append(f'{gnt_path.stem}/{filename}\t{tag_code}')
except:
break
write_txt(f'{save_dir}/gt.txt', label_list)
这样我们就得到了图片,以及标注 gt.txt,大概标注内容如下:
//...
240-t/218.png 拔
240-t/219.png 跋
240-t/220.png 靶
240-t/221.png 把
240-t/222.png 耙
//...
我们还需要基于gt.txt把汉字标签提取出来,转换为索引,方便模型训练。
label.txt 格式如下:
! 0
" 1
# 2
$ 3
% 4
( 5
) 6
* 7
+ 8
, 9
- 10
. 11
/ 12
0 13
1 14
2 15
3 16
//更多...
接下来就可以构建数据集了。
先分析一下图片,通过统计发现,数据集分辨率区间如下:
图片宽度区间分布统计:
[0-50]: 158633 张图片
[50-100]: 1425052 张图片
[100-150]: 96269 张图片
[150-200]: 300 张图片
[200,]: 4 张图片
图片高度区间分布统计:
[0-50]: 55862 张图片
[50-100]: 1385139 张图片
[100-150]: 236884 张图片
[150-200]: 2320 张图片
[200,]: 53 张图片
加载标签:
def load_label_to_idx(self,label_path:Path):
char_to_idx = {}
with open(str(label_path), 'r', encoding='utf-8') as f:
for line in f:
line = line.strip().split('\t')
char = line[0]
char_to_idx[char] = int(line[1])
return char_to_idx
加载数据:
def load_data(self):
data = []
txt_file = Path(self.root_dir) / self.mode / 'gt.txt'
with open(txt_file, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip().split('\t')
img_path = line[0]
label = line[1]
data.append((img_path, label))
return data
样本处理:
def __getitem__(self, index):
# 根据索引index从self.data中获取图片路径和对应的字符标签
img_path, char = self.data[index]
# 构建完整的图片路径:根目录/模式(如train或test)/images/图片文件名
img_path = Path(self.root_dir) / self.mode / 'images' / img_path
# 打开图片文件
img = Image.open(img_path)
# 确保图片是RGB格式(3通道)
img = img.convert('RGB')
# 如果有定义transform(数据增强/预处理),则应用transform
if self.transform: img = self.transform(img)
# 从label_idx字典中获取字符对应的数字索引
label_indices = self.label_idx[char]
# 将标签转换为torch.long类型的张量
label_tensor = torch.tensor(int(label_indices), dtype=torch.long)
# 返回处理后的图片和对应的标签张量
return img, label_tensor
这里我们可以注意的是,其实数据集默认就是 RGB 的,但是如果可以优化的话,可以直接优化为一个通道就可以做到识别了,这样会让模型的参数量减少很多。
由于数据集的数据集比较特殊,基本是白底背景加上黑色字,因此使用 imagenet 的归一化参数不太适合了,我这里直接使用 transforms.Normalize(mean=[0.95], std=[0.2])。
模型设计
模型最好采用已成功验证过的模型方案进行修改,比如 mobilenet 或者 resnet。自己实现相对来说是比较难达到上面这些模型的精度(或者需要自己多花时间修改“丹方”)。我这里优先采用基于 mobilenetv2 的方案。事实上,使用 resnet 精度会更高。
class MobileNetV2_Chinese(MobileNetV2):
def __init__(self, num_classes=4037):
super().__init__()
# 调整第一个卷积层的步长
#self.features[0][0].stride = (1, 1)
# 重新定义分类器
last_channel = 1280
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(last_channel, num_classes),
)
# 自适应池化
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
# 等效于上面的操作
# x = nn.functional.adaptive_avg_pool2d(x, (1, 1))
# x = torch.flatten(x, 1)
x = self.classifier(x)
return x
MobileNetV2_Chinese 主要是把全连接层分类数要从 1000 改为4037。
训练
训练代码没什么可说的。
class HandwritingTrainer(pl.LightningModule):
def __init__(self,model = 'mobilenetv2',batch_size=32):
super().__init__()
if model == 'resnet18':
self.model = ResNet18_Chinese(4037)
elif model == 'mobilenetv2':
self.model = MobileNetV2_Chinese(4037)
else:
raise ValueError("model must be resnet18 or mobilenetv2")
self.criterion = torch.nn.CrossEntropyLoss()
self.batch_size = batch_size
self.root_dir = "data"
def forward(self, x):
return self.model(x)
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.model.parameters(),lr=1e-2,weight_decay=1e-4,momentum=0.9)
# optimizer = torch.optim.AdamW(
# self.model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=100)
return [optimizer], [scheduler]
def common_transforms_compose(self,mode='train'):
if mode == 'train':
return transforms.Compose([
lambda img: resize_to_sqr(img),
transforms.Resize((96,96)),
transforms.Lambda(lambda img: img.point(lambda x: 0 if x < 200 else 255)), # 阈值设为200
transforms.RandomGrayscale(p=0.1), # 随机灰度化模拟不同墨色
transforms.RandomRotation(15),
transforms.RandomPerspective(distortion_scale=0.2, p=0.1), # 模拟书写倾斜
transforms.RandomApply([transforms.GaussianBlur(3)], p=0.1), # 轻微模糊抗噪
transforms.ToTensor(),
transforms.Normalize(mean=[0.95], std=[0.2])
])
else:
return transforms.Compose([
lambda img: resize_to_sqr(img),
transforms.Resize((96,96)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.95], std=[0.2])
])
def train_dataloader(self):
train_dataset = HWDB1Dataset(
root_dir=self.root_dir, mode='train', transform=self.common_transforms_compose('train'))
return DataLoader(
dataset=train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=8,
persistent_workers=True,
pin_memory=True,
drop_last=True
)
def val_dataloader(self):
train_dataset = HWDB1Dataset(
root_dir=self.root_dir, mode='test', transform=self.common_transforms_compose('test'))
return DataLoader(
dataset=train_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=8,
persistent_workers=True,
pin_memory=True,
drop_last=True
)
def compute_grad_norm(self):
# 计算所有参数的梯度范数
total_norm = 0.0
for param in self.parameters():
if param.grad is not None:
param_norm = param.grad.data.norm(2) # L2 范数
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
return total_norm
def on_after_backward(self):
# 计算梯度范数
grad_norm = self.compute_grad_norm()
self.log("grad_norm", grad_norm, on_step=True, on_epoch=False, prog_bar=True, logger=True)
def common_setps(self, batch, batch_idx,mode = "train"):
x,y = batch
pred = self.model(x)
# print(f"Model output: {pred}") # 打印模型输出
loss = self.criterion(pred, y)
self.log(f'{mode}_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def training_step(self, batch, batch_idx):
return self.common_setps(batch, batch_idx,mode = "train")
def validation_step(self, batch, batch_idx):
return self.common_setps(batch, batch_idx,mode = "val")
from trainer import HandwritingTrainer
from argparse import ArgumentParser
if __name__ == "__main__":
parse =ArgumentParser()
parse.add_argument('-m','--model', type=str,default="mobilenetv2")
parse.add_argument('-b','--batch_size', type=int,default=32)
args = parse.parse_args()
checkpoint_callback = ModelCheckpoint(
monitor='val_loss',
dirpath='./checkpoints/',
filename='best-model',
save_top_k=1,
mode='min'
)
# 创建TensorBoardLogger实例并指定日志保存路径
logger = TensorBoardLogger(save_dir="logs", name=args.model,log_graph=True)
# 设置 ModelCheckpoint 回调
checkpoint_callback = ModelCheckpoint(
monitor='val_loss', # 监控指标,例如验证损失
dirpath=logger.log_dir, # 检查点保存路径
filename='checkpoint-{epoch:02d}-{val_loss:.3f}', # 保存文件的命名格式
save_top_k=3, # 只保存最好的k个模型
mode='min', # 'min'表示监控指标越小越好
save_weights_only=True, # 只保存权重而不是整个模型
)
early_stopping = EarlyStopping('val_loss', mode="min")
trainer_args = {
'accelerator': 'gpu',
'devices': [0],
'callbacks': [checkpoint_callback,early_stopping],
'max_epochs': 100,
}
trainer = pl.Trainer(logger=logger, **trainer_args, fast_dev_run=False)
# 训练数据
model = HandwritingTrainer(model=args.model,batch_size=args.batch_size)
print(model)
trainer.fit(model)
启动训练
uv run train.py -b 64
优化点
需要做更多的消融实验来验证,比如:
- 96x96 是否比 224x224 更优?
- 单通道是否比三通道更好?
- 修改其他下采样是否会更好?
但是有一些需要注意的,就是移除下采样会让模型参数膨胀,训练也会需要更长的时候。
通过一些消融实现发现:
- 96x96 分辨率下,不修改原版模型的层也能达到很好的效果。
- 224x224 分辨率下并没有表现出更好的识别率。
- 数据集的处理非常重要,比如,像 “一”、“|” 这类字的长宽比是非常夸张的,最好自定义 resize 让其保持不过分变形。