Datawhale干货
作者:王浩,结行科技算法工程师
参加了“世界人工智能创新大赛”——手写体 OCR 识别竞赛(任务一),取得了Top1的成绩。队伍随机组的,有人找我我就加了进来,这是我第一次做OCR相关的项目,所以随意起了个名字。下面通过这篇文章来详细介绍我们的方案。
实践背景
赛题背景
银行日常业务中涉及到各类凭证的识别录入,例如身份证录入、支票录入、对账单录入等。以往的录入方式主要是以人工录入为主,效率较低,人力成本较高。近几年来,OCR相关技术以其自动执行、人为干预较少等特点正逐步替代传统的人工录入方式。但OCR技术在实际应用中也存在一些问题,在各类凭证字段的识别中,手写体由于其字体差异性大、字数不固定、语义关联性较低、凭证背景干扰等原因,导致OCR识别率准确率不高,需要大量人工校正,对日常的银行录入业务造成了一定的影响。
赛题地址:.cn/competitions/65
赛题任务
本次赛题将提供手写体图像切片数据集,数据集从真实业务场景中,经过切片脱敏得到,参赛队伍通过识别技术,获得对应的识别结果。即:
输入:手写体图像切片数据集
输出:对应的识别结果
本任务提供开放可下载的训练集及测试集,允许线下建模或线上提供 Notebook 环境及 Terminal 容器环境(脱网)建模,输出识别结果完成赛题。
赛题数据
A.数据规模和内容覆盖
B.数据示例
原始手写体图像共分为三类,分别涉及银行名称、年月日、金额三大类,分别示意如下:
相应图片切片中可能混杂有一定量的干扰信息,分别示例如下:
识别结果 JSON 在训练集中的格式如下:
json 文件内容规范:{"image1":"陆万捌千零贰拾伍元整","image2":"付经管院工资","image3":"",...}
实践方案
通过在网上查阅资料,得知OCR比赛最常用的模型是CRNN+CTC。所以我最开始也是采用这个方案。
上图是我找到的资料,有好多个版本。因为是第一次做OCR的项目,所以我优先选择有数据集的项目,这样可以快速的了解模型的输入输出。
所以我选择的第一个Attention_ocr.pytorch-master.zip,从名字上可以看出这个是加入注意力机制,感觉效果会好一些。
构建数据集
下图是Attention_ocr.pytorch-master.zip自带的数据集截图,从截图上可以看出,数据的格式:“图片路径+空格+标签”。我们也需要按照这样的格式构建数据集。
新建makedata.py文件,插入下面的代码。
importosimportjson#官方给的数据集image_path_amount="./data/train/amount/images"image_path_date="./data/train/date/images"#增强数据集image_path_test='./data/gan_test_15000/images/0'image_path_train='./data/gan_train_15500_0/images/0'amount_list=os.listdir(image_path_amount)amount_list=os.listdir(image_path_amount)new_amount_list=[]forfilenameinamount_list:new_amount_list.append(image_path_amount+"/"+filename)date_list=os.listdir(image_path_date)new_date_list=[]forfilenameindate_list:new_date_list.append(image_path_date+"/"+filename)new_test_list=[]forfilenameinamount_list:new_test_list.append(image_path_amount+"/"+filename)new_train_list=[]forfilenameinamount_list:new_train_list.append(image_path_amount+"/"+filename)
image_path_amount和image_path_date是官方给定的数据集路径。
image_path_test和image_path_train是增强的数据集(在后面会讲如何做增强)
创建建立list,保存图片的路径。
amount_json="./data/train/amount/gt.json"date_json="./data/train/date/gt.json"train_json="train_data.json"test_json="test_data.json"withopen(amount_json,"r",encoding='utf-8')asf:load_dict_amount=json.load(f)withopen(date_json,"r",encoding='utf-8')asf:load_dict_date=json.load(f)withopen(train_json,"r",encoding='utf-8')asf:load_dict_train=json.load(f)withopen(test_json,"r",encoding='utf-8')asf:load_dict_test=json.load(f)
四个json文件对应上面的四个list,json文件存储的是图片的名字和图片的标签,把json解析出来存到字典中。
#聚合listall_list=new_amount_list+new_date_list+new_test_list+new_train_listfromsklearn.model_selectionimporttrain_test_split#切分训练集合和验证集train_list,test_list=train_test_split(all_list,test_size=0.15,random_state=42)#聚合字典all_dic={}all_dic.update(load_dict_amount)all_dic.update(load_dict_date)all_dic.update(load_dict_train)all_dic.update(load_dict_test)withopen('train.txt','w')asf:forlineintrain_list:f.write(line+""+all_dic[line.split('/')[-1]]+"\n")withopen('val.txt','w')asf:forlineintest_list:f.write(line+""+all_dic[line.split('/')[-1]]+"\n")
将四个list聚合为一个list。
使用train_test_split切分训练集和验证集。
聚合字典。
然后分别遍历trainlist和testlist,将其写入train.txt和val.txt。
到这里数据集就制作完成了。得到train.txt和val.txt
查看train.txt
数据集和自带的数据集格式一样了,然后我们就可以开始训练了。
获取class
新建getclass.py文件夹,加入以下代码:
importjsonamount_json="./data/train/amount/gt.json"date_json="./data/train/date/gt.json"withopen(amount_json,"r",encoding='utf-8')asf:load_dict_amount=json.load(f)withopen(date_json,"r",encoding='utf-8')asf:load_dict_date=json.load(f)all_dic={}all_dic.update(load_dict_amount)all_dic.update(load_dict_date)list_key=[]forkeylineinall_dic.values():forkeyinkeyline:ifkeynotinlist_key:list_key.append(key)withopen('data/char_std_5990.txt','w')asf:forlineinlist_key:f.write(line+"\n")
执行完就可以得到存储class的txt文件。打开char_std_5990.txt,看到有21个类。
模型改进
crnn的卷积部分类似VGG,我对模型的改进主要有一下几个方面:
1、加入激活函数Swish。
2、加入BatchNorm。
3、加入SE注意力机制。
4、适当加深模型。
代码如下:
n=nn.Sequential(nn.Conv2d(nc,64,3,1,1),Swish(),nn.BatchNorm2d(64),nn.MaxPool2d(2,2),#64x16x50nn.Conv2d(64,128,3,1,1),Swish(),nn.BatchNorm2d(128),nn.MaxPool2d(2,2),#128x8x25nn.Conv2d(128,256,3,1,1),nn.BatchNorm2d(256),Swish(),#256x8x25nn.Conv2d(256,256,3,1,1),nn.BatchNorm2d(256),Swish(),#256x8x25SELayer(256,16),nn.MaxPool2d((2,2),(2,1),(0,1)),#256x4x25nn.Conv2d(256,512,3,1,1),nn.BatchNorm2d(512),Swish(),#512x4x25nn.Conv2d(512,512,1),nn.BatchNorm2d(512),Swish(),nn.Conv2d(512,512,3,1,1),nn.BatchNorm2d(512),Swish(),#512x4x25SELayer(512,16),nn.MaxPool2d((2,2),(2,1),(0,1)),#512x2x25nn.Conv2d(512,512,2,1,0),nn.BatchNorm2d(512),Swish())#512x1x25
SE和Swish
classSELayer(nn.Module):def__init__(self,channel,reduction=16):super(SELayer,self).__init__()self.avg_pool=nn.AdaptiveAvgPool2d(1)self.fc=nn.Sequential(nn.Linear(channel,channel//reduction,bias=True),nn.LeakyReLU(inplace=True),nn.Linear(channel//reduction,channel,bias=True),nn.Sigmoid())defforward(self,x):b,c,_,_=x.size()y=self.avg_pool(x).view(b,c)y=self.fc(y).view(b,c,1,1)returnx*y.expand_as(x)classSwish(nn.Module):defforward(self,x):returnx*torch.sigmoid(x)
模型训练
打开train.py ,在训练之前,我们还要调节一下参数。
parser=argparse.ArgumentParser()parser.add_argument('--trainlist',default='train.txt')parser.add_argument('--vallist',default='val.txt')parser.add_argument('--workers',type=int,help='numberofdataloadingworkers',default=0)parser.add_argument('--batchSize',type=int,default=4,help='inputbatchsize')parser.add_argument('--imgH',type=int,default=32,help='theheightoftheinputimagetonetwork')parser.add_argument('--imgW',type=int,default=512,help='thewidthoftheinputimagetonetwork')parser.add_argument('--nh',type=int,default=512,help='sizeofthelstmhiddenstate')parser.add_argument('--niter',type=int,default=300,help='numberofepochstotrainfor')parser.add_argument('--lr',type=float,default=0.00005,help='learningrateforCritic,default=0.00005')parser.add_argument('--beta1',type=float,default=0.5,help='beta1foradam.default=0.5')parser.add_argument('--cuda',action='store_true',help='enablescuda',default=True)parser.add_argument('--ngpu',type=int,default=1,help='numberofGPUstouse')parser.add_argument('--encoder',type=str,default='',help="pathtoencoder(tocontinuetraining)")parser.add_argument('--decoder',type=str,default='',help='pathtodecoder(tocontinuetraining)')parser.add_argument('--experiment',default='./expr/attentioncnn',help='Wheretostoresamplesandmodels')parser.add_argument('--displayInterval',type=int,default=100,help='Intervaltobedisplayed')parser.add_argument('--valInterval',type=int,default=1,help='Intervaltobedisplayed')parser.add_argument('--saveInterval',type=int,default=1,help='Intervaltobedisplayed')parser.add_argument('--adam',default=True,action='store_true',help='Whethertouseadam(defaultisrmsprop)')parser.add_argument('--adadelta',action='store_true',help='Whethertouseadadelta(defaultisrmsprop)')parser.add_argument('--keep_ratio',default=True,action='store_true',help='whethertokeepratioforimageresize')parser.add_argument('--random_sample',default=True,action='store_true',help='whethertosamplethedatasetwithrandomsampler')parser.add_argument('--teaching_forcing_prob',type=float,default=0.5,help='wheretouseteachforcing')parser.add_argument('--max_width',type=int,default=129,help='thewidthofthefeaturemapoutfromcnn')parser.add_argument("--output_file",default='deep_model.log',type=str,required=False)opt=parser.parse_args()
trainlist:训练集,默认是train.txt。
vallist:验证集路径,默认是val.txt。
batchSize:批大小,根据显存大小设置。
imgH:图片的高度,crnn模型默认为32,这里不需要修改。
imgW:图片宽度,我在这里设置为512。
keep_ratio:设置为True,设置为True后,程序会保持图片的比率,然后在一个batch内统一尺寸,这样训练的模型精度更高。
lr:学习率,设置为0.00005,这里要注意,不要太大,否则不收敛。
其他的参数就不一一介绍了,大家可以自行尝试。
运行结果:
训练完成后,可以在expr文件夹下面找到模型。
结果预测
在推理之前,我们还需要确认最长的字符串,新建getmax.py,添加如下代码:
importosimportjsonimage_path_amount="./data/train/amount/images"image_path_date="./data/train/date/images"amount_list=os.listdir(image_path_amount)new_amount_list=[]forfilenameinamount_list:new_amount_list.append(image_path_amount+"/"+filename)date_list=os.listdir(image_path_date)new_date_list=[]forfilenameindate_list:new_date_list.append(image_path_date+"/"+filename)amount_json="./data/train/amount/gt.json"date_json="./data/train/date/gt.json"withopen(amount_json,"r",encoding='utf-8')asf:load_dict_amount=json.load(f)withopen(date_json,"r",encoding='utf-8')asf:load_dict_date=json.load(f)all_list=new_amount_list+new_date_listfromsklearn.model_selectionimporttrain_test_splitall_dic={}all_dic.update(load_dict_amount)all_dic.update(load_dict_date)maxLen=0foriinall_dic.values():if(len(i)>maxLen):maxLen=len(i)print(maxLen)
运行结果:28
将test.py中的max_length设置为28。
修改模型的路径,包括encoder_path和decoder_path。
encoder_path='./expr/attentioncnn/encoder_22.pth'decoder_path='./expr/attentioncnn/decoder_22.pth'
修改测试集的路径:
forpathintqdm(glob.glob('./data/测试集/date/images/*.jpg')):text,prob=test(path)ifprob<0.8:count+=1result_dict[os.path.basename(path)]={'result':text,'confidence':prob}forpathintqdm(glob.glob('./data/测试集/amount/images/*.jpg')):text,prob=test(path)ifprob<0.8:count+=1result_dict[os.path.basename(path)]={'result':text,'confidence':prob}
写到最后
作者第一次参加OCR相关的赛事,在任务一中取得Top1的好成绩,背后的付出和努力通过方案分享也能看到。近期接触到很多在比赛中拿到不错成绩的小伙伴,不少是第一次尝试。所以,努力后还是可以得到自己满意的结果的。
整理不易,点赞三连↓