100字范文,内容丰富有趣,生活中的好帮手!
100字范文 > 世界人工智能大赛 Top1 方案!手写体 OCR 识别

世界人工智能大赛 Top1 方案!手写体 OCR 识别

时间:2022-12-14 17:09:16

相关推荐

世界人工智能大赛 Top1 方案!手写体 OCR 识别

Datawhale干货

作者:王浩,结行科技算法工程师

参加了“世界人工智能创新大赛”——手写体 OCR 识别竞赛(任务一),取得了Top1的成绩。队伍随机组的,有人找我我就加了进来,这是我第一次做OCR相关的项目,所以随意起了个名字。下面通过这篇文章来详细介绍我们的方案。

实践背景

赛题背景

银行日常业务中涉及到各类凭证的识别录入,例如身份证录入、支票录入、对账单录入等。以往的录入方式主要是以人工录入为主,效率较低,人力成本较高。近几年来,OCR相关技术以其自动执行、人为干预较少等特点正逐步替代传统的人工录入方式。但OCR技术在实际应用中也存在一些问题,在各类凭证字段的识别中,手写体由于其字体差异性大、字数不固定、语义关联性较低、凭证背景干扰等原因,导致OCR识别率准确率不高,需要大量人工校正,对日常的银行录入业务造成了一定的影响。

赛题地址:.cn/competitions/65

赛题任务

本次赛题将提供手写体图像切片数据集,数据集从真实业务场景中,经过切片脱敏得到,参赛队伍通过识别技术,获得对应的识别结果。即:

输入:手写体图像切片数据集

输出:对应的识别结果

本任务提供开放可下载的训练集及测试集,允许线下建模或线上提供 Notebook 环境及 Terminal 容器环境(脱网)建模,输出识别结果完成赛题。

赛题数据

A.数据规模和内容覆盖

B.数据示例

原始手写体图像共分为三类,分别涉及银行名称、年月日、金额三大类,分别示意如下:

相应图片切片中可能混杂有一定量的干扰信息,分别示例如下:

识别结果 JSON 在训练集中的格式如下:

json 文件内容规范:{"image1":"陆万捌千零贰拾伍元整","image2":"付经管院工资","image3":"",...}

实践方案

通过在网上查阅资料,得知OCR比赛最常用的模型是CRNN+CTC。所以我最开始也是采用这个方案。

上图是我找到的资料,有好多个版本。因为是第一次做OCR的项目,所以我优先选择有数据集的项目,这样可以快速的了解模型的输入输出。

所以我选择的第一个Attention_ocr.pytorch-master.zip,从名字上可以看出这个是加入注意力机制,感觉效果会好一些。

构建数据集

下图是Attention_ocr.pytorch-master.zip自带的数据集截图,从截图上可以看出,数据的格式:“图片路径+空格+标签”。我们也需要按照这样的格式构建数据集。

新建makedata.py文件,插入下面的代码。

importosimportjson#官方给的数据集image_path_amount="./data/train/amount/images"image_path_date="./data/train/date/images"#增强数据集image_path_test='./data/gan_test_15000/images/0'image_path_train='./data/gan_train_15500_0/images/0'amount_list=os.listdir(image_path_amount)amount_list=os.listdir(image_path_amount)new_amount_list=[]forfilenameinamount_list:new_amount_list.append(image_path_amount+"/"+filename)date_list=os.listdir(image_path_date)new_date_list=[]forfilenameindate_list:new_date_list.append(image_path_date+"/"+filename)new_test_list=[]forfilenameinamount_list:new_test_list.append(image_path_amount+"/"+filename)new_train_list=[]forfilenameinamount_list:new_train_list.append(image_path_amount+"/"+filename)

image_path_amount和image_path_date是官方给定的数据集路径。

image_path_test和image_path_train是增强的数据集(在后面会讲如何做增强)

创建建立list,保存图片的路径。

amount_json="./data/train/amount/gt.json"date_json="./data/train/date/gt.json"train_json="train_data.json"test_json="test_data.json"withopen(amount_json,"r",encoding='utf-8')asf:load_dict_amount=json.load(f)withopen(date_json,"r",encoding='utf-8')asf:load_dict_date=json.load(f)withopen(train_json,"r",encoding='utf-8')asf:load_dict_train=json.load(f)withopen(test_json,"r",encoding='utf-8')asf:load_dict_test=json.load(f)

四个json文件对应上面的四个list,json文件存储的是图片的名字和图片的标签,把json解析出来存到字典中。

#聚合listall_list=new_amount_list+new_date_list+new_test_list+new_train_listfromsklearn.model_selectionimporttrain_test_split#切分训练集合和验证集train_list,test_list=train_test_split(all_list,test_size=0.15,random_state=42)#聚合字典all_dic={}all_dic.update(load_dict_amount)all_dic.update(load_dict_date)all_dic.update(load_dict_train)all_dic.update(load_dict_test)withopen('train.txt','w')asf:forlineintrain_list:f.write(line+""+all_dic[line.split('/')[-1]]+"\n")withopen('val.txt','w')asf:forlineintest_list:f.write(line+""+all_dic[line.split('/')[-1]]+"\n")

将四个list聚合为一个list。

使用train_test_split切分训练集和验证集。

聚合字典。

然后分别遍历trainlist和testlist,将其写入train.txt和val.txt。

到这里数据集就制作完成了。得到train.txt和val.txt

查看train.txt

数据集和自带的数据集格式一样了,然后我们就可以开始训练了。

获取class

新建getclass.py文件夹,加入以下代码:

importjsonamount_json="./data/train/amount/gt.json"date_json="./data/train/date/gt.json"withopen(amount_json,"r",encoding='utf-8')asf:load_dict_amount=json.load(f)withopen(date_json,"r",encoding='utf-8')asf:load_dict_date=json.load(f)all_dic={}all_dic.update(load_dict_amount)all_dic.update(load_dict_date)list_key=[]forkeylineinall_dic.values():forkeyinkeyline:ifkeynotinlist_key:list_key.append(key)withopen('data/char_std_5990.txt','w')asf:forlineinlist_key:f.write(line+"\n")

执行完就可以得到存储class的txt文件。打开char_std_5990.txt,看到有21个类。

模型改进

crnn的卷积部分类似VGG,我对模型的改进主要有一下几个方面:

1、加入激活函数Swish。

2、加入BatchNorm。

3、加入SE注意力机制。

4、适当加深模型。

代码如下:

n=nn.Sequential(nn.Conv2d(nc,64,3,1,1),Swish(),nn.BatchNorm2d(64),nn.MaxPool2d(2,2),#64x16x50nn.Conv2d(64,128,3,1,1),Swish(),nn.BatchNorm2d(128),nn.MaxPool2d(2,2),#128x8x25nn.Conv2d(128,256,3,1,1),nn.BatchNorm2d(256),Swish(),#256x8x25nn.Conv2d(256,256,3,1,1),nn.BatchNorm2d(256),Swish(),#256x8x25SELayer(256,16),nn.MaxPool2d((2,2),(2,1),(0,1)),#256x4x25nn.Conv2d(256,512,3,1,1),nn.BatchNorm2d(512),Swish(),#512x4x25nn.Conv2d(512,512,1),nn.BatchNorm2d(512),Swish(),nn.Conv2d(512,512,3,1,1),nn.BatchNorm2d(512),Swish(),#512x4x25SELayer(512,16),nn.MaxPool2d((2,2),(2,1),(0,1)),#512x2x25nn.Conv2d(512,512,2,1,0),nn.BatchNorm2d(512),Swish())#512x1x25

SE和Swish

classSELayer(nn.Module):def__init__(self,channel,reduction=16):super(SELayer,self).__init__()self.avg_pool=nn.AdaptiveAvgPool2d(1)self.fc=nn.Sequential(nn.Linear(channel,channel//reduction,bias=True),nn.LeakyReLU(inplace=True),nn.Linear(channel//reduction,channel,bias=True),nn.Sigmoid())defforward(self,x):b,c,_,_=x.size()y=self.avg_pool(x).view(b,c)y=self.fc(y).view(b,c,1,1)returnx*y.expand_as(x)classSwish(nn.Module):defforward(self,x):returnx*torch.sigmoid(x)

模型训练

打开train.py ,在训练之前,我们还要调节一下参数。

parser=argparse.ArgumentParser()parser.add_argument('--trainlist',default='train.txt')parser.add_argument('--vallist',default='val.txt')parser.add_argument('--workers',type=int,help='numberofdataloadingworkers',default=0)parser.add_argument('--batchSize',type=int,default=4,help='inputbatchsize')parser.add_argument('--imgH',type=int,default=32,help='theheightoftheinputimagetonetwork')parser.add_argument('--imgW',type=int,default=512,help='thewidthoftheinputimagetonetwork')parser.add_argument('--nh',type=int,default=512,help='sizeofthelstmhiddenstate')parser.add_argument('--niter',type=int,default=300,help='numberofepochstotrainfor')parser.add_argument('--lr',type=float,default=0.00005,help='learningrateforCritic,default=0.00005')parser.add_argument('--beta1',type=float,default=0.5,help='beta1foradam.default=0.5')parser.add_argument('--cuda',action='store_true',help='enablescuda',default=True)parser.add_argument('--ngpu',type=int,default=1,help='numberofGPUstouse')parser.add_argument('--encoder',type=str,default='',help="pathtoencoder(tocontinuetraining)")parser.add_argument('--decoder',type=str,default='',help='pathtodecoder(tocontinuetraining)')parser.add_argument('--experiment',default='./expr/attentioncnn',help='Wheretostoresamplesandmodels')parser.add_argument('--displayInterval',type=int,default=100,help='Intervaltobedisplayed')parser.add_argument('--valInterval',type=int,default=1,help='Intervaltobedisplayed')parser.add_argument('--saveInterval',type=int,default=1,help='Intervaltobedisplayed')parser.add_argument('--adam',default=True,action='store_true',help='Whethertouseadam(defaultisrmsprop)')parser.add_argument('--adadelta',action='store_true',help='Whethertouseadadelta(defaultisrmsprop)')parser.add_argument('--keep_ratio',default=True,action='store_true',help='whethertokeepratioforimageresize')parser.add_argument('--random_sample',default=True,action='store_true',help='whethertosamplethedatasetwithrandomsampler')parser.add_argument('--teaching_forcing_prob',type=float,default=0.5,help='wheretouseteachforcing')parser.add_argument('--max_width',type=int,default=129,help='thewidthofthefeaturemapoutfromcnn')parser.add_argument("--output_file",default='deep_model.log',type=str,required=False)opt=parser.parse_args()

trainlist:训练集,默认是train.txt。

vallist:验证集路径,默认是val.txt。

batchSize:批大小,根据显存大小设置。

imgH:图片的高度,crnn模型默认为32,这里不需要修改。

imgW:图片宽度,我在这里设置为512。

keep_ratio:设置为True,设置为True后,程序会保持图片的比率,然后在一个batch内统一尺寸,这样训练的模型精度更高。

lr:学习率,设置为0.00005,这里要注意,不要太大,否则不收敛。

其他的参数就不一一介绍了,大家可以自行尝试。

运行结果:

运行结果

训练完成后,可以在expr文件夹下面找到模型。

训练的模型

结果预测

在推理之前,我们还需要确认最长的字符串,新建getmax.py,添加如下代码:

importosimportjsonimage_path_amount="./data/train/amount/images"image_path_date="./data/train/date/images"amount_list=os.listdir(image_path_amount)new_amount_list=[]forfilenameinamount_list:new_amount_list.append(image_path_amount+"/"+filename)date_list=os.listdir(image_path_date)new_date_list=[]forfilenameindate_list:new_date_list.append(image_path_date+"/"+filename)amount_json="./data/train/amount/gt.json"date_json="./data/train/date/gt.json"withopen(amount_json,"r",encoding='utf-8')asf:load_dict_amount=json.load(f)withopen(date_json,"r",encoding='utf-8')asf:load_dict_date=json.load(f)all_list=new_amount_list+new_date_listfromsklearn.model_selectionimporttrain_test_splitall_dic={}all_dic.update(load_dict_amount)all_dic.update(load_dict_date)maxLen=0foriinall_dic.values():if(len(i)>maxLen):maxLen=len(i)print(maxLen)

运行结果:28

将test.py中的max_length设置为28。

修改模型的路径,包括encoder_path和decoder_path。

encoder_path='./expr/attentioncnn/encoder_22.pth'decoder_path='./expr/attentioncnn/decoder_22.pth'

修改测试集的路径:

forpathintqdm(glob.glob('./data/测试集/date/images/*.jpg')):text,prob=test(path)ifprob<0.8:count+=1result_dict[os.path.basename(path)]={'result':text,'confidence':prob}forpathintqdm(glob.glob('./data/测试集/amount/images/*.jpg')):text,prob=test(path)ifprob<0.8:count+=1result_dict[os.path.basename(path)]={'result':text,'confidence':prob}

写到最后

作者第一次参加OCR相关的赛事,在任务一中取得Top1的好成绩,背后的付出和努力通过方案分享也能看到。近期接触到很多在比赛中拿到不错成绩的小伙伴,不少是第一次尝试。所以,努力后还是可以得到自己满意的结果的。

整理不易,三连

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。