100字范文,内容丰富有趣,生活中的好帮手!
100字范文 > 使用TorchText库进行文本分类

使用TorchText库进行文本分类

时间:2024-07-03 02:18:41

相关推荐

使用TorchText库进行文本分类

使用Torchtext库进行文本分类(官方的例子)

配置:

torch 1.8.1+cpu

torchtext 0.9.1

官方文档的链接:/tutorials/beginner/text_sentiment_ngrams_tutorial.html

文章中关于AG_NEWS数据集一会儿能加载出来,一会儿加载不出来(什么时候是否需要splite),我也没搞懂,感觉不是重点,有大佬懂的话可以在评论区讲一下哈(感谢)!

(1)读入新闻类型的数据,这里使用的AG_NEWS数据集

由于直接使用官网下载会报错,先下载数据集,再使用下面的方法加载数据集

下载数据集的链接:/s/1Rz_XoaTZWSRiHGOwkACosQ 提取码:j0no

数据集简介:4类新闻,分别是 World、 Sports、Business、Sci/Tec(依次用0-3表示)

csv文件有三列,分别是label、title、new

import torchfrom torchtext.datasets import AG_NEWSpath = 'C:/Users/USER/Desktop/AG_NEWS.data'#改成你数据集存放的路径train_data, test_dataset = AG_NEWS(root=path, split=('train', 'test'))

(2)构建词汇表

from torchtext.data.utils import get_tokenizerfrom collections import Counterfrom torchtext.vocab import Vocabtokenizer = get_tokenizer('basic_english')train_iter = AG_NEWS(split='train')counter = Counter()for (label, line) in train_iter:counter.update(tokenizer(line))#统计每次词出现的频率vocab = Vocab(counter, min_freq=1)#构建词汇表,min_freq表示指定最小的频次是1#print(vocab.freqs) #查看词汇及其对应的频率#词汇表将token转换为整数for token in ['here', 'is', 'an', 'example']:#测试用例 print(vocab[token])

Vocab 待补充

/text/stable/vocab.html

(3)生成数据批处理和迭代器

text_pipeline、label_pipeline建立文本、标签处理管道,用于处理来自数据集迭代器的原始数据字符串(转换为数值型)

text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]label_pipeline = lambda x: int(x) - 1#print(text_pipeline('here is the an example'))#文本管道,将字符串转换为索引,#输出:[475, 21, 2, 30, 5286]#print(label_pipeline('10'))#标签管道,将label转换为整数,输出:9def collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_label, _text) in batch:label_list.append(label_pipeline(_label))processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)text_list = torch.cat(text_list)return label_list.to(device), text_list.to(device), offsets.to(device)

collate_fn函数将对从生成的一批样本起作用DataLoader,输入到的collate_fn是一批数据,批处理大小为DataLoader

并且collate_fn根据先前声明的数据处理管道对原始数据进行处理。

注意:确保将collate_fn其声明为顶级def。这样可以确保该功能在每个工作程序中均可用。

在本示例中,原始数据批处理输入中的文本条目被打包到一个列表中,

并作为输入的单个张量级联nn.EmbeddingBag

offset是定界符的张量,表示文本张量中各个序列的起始索引;

Label是一个张量,用于保存单个文本条目的标签;

先得到三个列表:

label_list:batch中每个文本的标签,如[0,1, 2, 3]

text_list:batch的每个文本转换成词汇表的索引,如[12,34,578, ...]

offsets:batch中每个文本的长度,如[10, 20, 30, 40]

对offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)的测试

直观理解,去掉最后一个,然后累加起来

cat()的测试

函数目的: 在给定维度上对输入的张量序列seq 进行连接操作

t1=[1,2,3]

t2=[4,5,6]

转换为tensor,然后进行拼接

(4)定义神经网络模型

该模型由nn.EmbeddingBag层和一个用于分类目的的线性层组成。

nn.EmbeddingBag默认模式为“均值”时,将计算结果嵌入“句袋”的平均值

此处的文本条目长度不同,nn.EmbeddingBag也能进行处理,因为文本长度以偏移量保存了下来。

下图表示模型结构,经过Word look-up table得到的词汇对应的索引,然后经过embeddingBag层、线性层,得到输出结果。

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange)self.fc.weight.data.uniform_(-initrange, initrange)self.fc.bias.data.zero_()def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)

EmbeddingBag 待补充

(5)训练、评估和测试代码

import timedef train(dataloader):model.train()total_acc, total_count = 0, 0log_interval = 500start_time = time.time()for idx, (label, text, offsets) in enumerate(dataloader):optimizer.zero_grad()predited_label = model(text, offsets)loss = criterion(predited_label, label)loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)optimizer.step()total_acc += (predited_label.argmax(1) == label).sum().item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:3d} | {:5d}/{:5d} batches ''| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),total_acc/total_count))total_acc, total_count = 0, 0start_time = time.time()def evaluate(dataloader):model.eval()total_acc, total_count = 0, 0with torch.no_grad():for idx, (label, text, offsets) in enumerate(dataloader):predited_label = model(text, offsets)loss = criterion(predited_label, label)total_acc += (predited_label.argmax(1) == label).sum().item()total_count += label.size(0)return total_acc/total_count

torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)

这个函数的作用,将p范数进行归一化操作,归一到0.1,默认为2范数,从而使得parameters的参数的2范数和为0.1

(6)main函数

if __name__ == '__main__':train_iter = AG_NEWS(split='train')num_class = len(set([label for (label, text) in train_iter]))vocab_size = len(vocab)emsize = 64model = TextClassificationModel(vocab_size, emsize, num_class).to(device)#95812,64,4EPOCHS = 10LR = 5BATCH_SIZE = 64criterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.SGD(model.parameters(), lr=LR)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)total_accu = Nonetrain_iter, test_iter = AG_NEWS()train_dataset = list(train_iter)test_dataset = list(test_iter)num_train = int(len(train_dataset) * 0.95)split_train_, split_valid_ = \random_split(train_dataset, [num_train, len(train_dataset) - num_train])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)#训练模块 for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)accu_val = evaluate(valid_dataloader)if total_accu is not None and total_accu > accu_val:scheduler.step()else:total_accu = accu_valprint('| end of epoch {:3d} | time: {:5.2f}s | ''valid accuracy {:8.3f} '.format(epoch,time.time() - epoch_start_time,accu_val))#测试数据集评估模型print('Checking the results of test dataset.')accu_test = evaluate(test_dataloader)print('test accuracy {:8.3f}'.format(accu_test))#预测模块ag_news_label = {1: "World",2: "Sports",3: "Business",4: "Sci/Tec"}ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \enduring the season’s worst weather conditions on Sunday at The \Open on his way to a closing 75 at Royal Portrush, which \considering the wind and the rain was a respectable showing. \Thursday’s first round at the WGC-FedEx St. Jude Invitational \was another story. With temperatures in the mid-80s and hardly any \wind, the Spaniard was 13 strokes better in a flawless round. \Thanks to his best putting performance on the PGA Tour, Rahm \finished with an 8-under 62 for a three-stroke lead, which \was even more impressive considering he’d never played the \front nine at TPC Southwind."model = model.to("cpu")print("This is a %s news" %ag_news_label[predict(ex_text_str, text_pipeline)])

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1.0, gamma=0.1) 调整学习率

step_size表示调整间隔,一般指的是epoch

lr = gamma*lr

这里就是 5 0.5 0.05 0.005这样子进行调整

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。