【NLP9-Transformer经典案例】

Transformer经典案例

1、语言模型

以一个符合语言规律的序列为输入,模型将利用序列间关系等特征,输出在一个在所有词汇上的概率分布,这样的模型称为语言模型。

2、语言模型能解决的问题

根据语言模型定义,可以在它的基础上完成机器翻译,文本生成等任务,因为我们通过最后输出的概率分布来预测下一个词汇是什么

语言模型可以判断输入的序列是否为一句完整的话,因为我们可以根据输出的概率分布查看最大概率是否落在句子结束符上,来判断完整性

语言模型本身的训练目标是预测下一个词,因为它的特征提取部分会抽象很多语言序列之间的关系,这些关系可能同样对其它语言类任务有效果。因此可以作为预训练模型进行迁移学习

3、模型实现步骤

1、导包

2、导入wikiText-2数据集并作基本处理

3、构建用于模型输入的批次化数据

4、构建训练和评估函数

5、进行训练和评估(包括验证以及测试)

4、数据准备

wikiText-2数据集体量中等,训练集共有600篇文章,共208万左右词汇,在33278个不重复词汇,OVV(有多少正常英文词汇不在该数据集中的占比)为2.6%。验证集和测试集均为60篇。

torchtext重要功能

对文本数据进行处理,比如文本语料加载,文本迭代器构建等。

包含很多经典文本语料的预加载方法。其中包括的语料有:用于情感分析的SST和IMDB,用于问题分类TREC,用于及其翻译的WMT14,IWSLT,以及用于语言模型任务wikiText-2

# 1、导包
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
#英文文本数据集工具包
import torchtext#导入英文分词工具
#from torchtext.legacy.data import Fieldfrom torchtext.data.utils import get_tokenizer
#from torchtext.legacy.data import *
#导入已经构建完成的Transformer包from pyitcast.transformer import TransformerModelimport torch
import torchtext
from torchtext.legacy.data import Field,TabularDataset,Iterator,BucketIteratorTEXT=torchtext.data.Field(tokenize=get_tokenizer("basic_english"),init_token ='<sos>',eos_token ='<eos>',lower = True)
print(TEXT)
# 2、导入wikiText-2数据集并作基本处理
train_txt,val_txt,test_txt = torchtext.datasets.WikiText2.splits(TEXT)
print(test_txt.examples[0].text[:10])#将训练集文本数据构建一个vocab对象,可以使用vocab对象的stoi方法统计文本共包含的不重复的词汇总数
TEXT.build_vocab(train_txt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 3、构建用于模型输入的批次化数据
def batchify(data,batch_size):#data,之前得到的文本数据# bsz,批次的样本量#第一步使用TEXT的numericalize方法将单词映射成对应的联系数字data = TEXT.numericalize([data.examples[0].text])#取得需要经过多少次的batch_size后能够遍历完所有的数据nbatch = data.size(0) // batch_size#利用narrow方法对数据进行切割#第一参数代表横轴切割还是纵轴切割,0 代表横轴,1代表纵轴#第二个参数,第三个参数分别代表切割的起始位置和终止位置data=data.narrow(0,0,nbatch * batch_size)#对data的形状进行转变data = data.view(batch_size,-1).t().contiguous()return data.to(device)# x=torch.tensor([[1,2,3],[4,5,6],[7,8,9]])
# print(x.narrow(0,0,2))
# print(x.narrow(1,1,2))#设置训练数据、验证数据、测试数据的批次大小
batch_size =20
eval_batch_size =10
train_data = batchify(train_txt,batch_size)
val_data = batchify(val_txt,eval_batch_size)
test_data = batchify(test_txt,eval_batch_size)#设定句子的最大长度
bptt =35
def get_batch(source,i):seq_len = min(bptt,len(source) - 1 - i)data = source[i:i+seq_len]target = source[i+1:i+1+seq_len].view(-1)return data,target# source = test_data
# i =1
# x,y = get_batch(source,i)
# print(x)
# print(y)# 4、构建训练和评估函数
#通过TEXT.vocab.stoi方法获取不重复的词汇总数
ntokens = len(TEXT.vocab.stoi)
#设置词嵌入维度的值等于200
emsize =200
#设置前馈全连接层的节点数等于200
nhid =200
#设置编码层层数等于2
nlayers=2
#设置多头注意力中的头数等于2
nhead =2
#设置置零比率
dropout =0.2
#将参数传入TransformerModel 中实例化模型
model = TransformerModel(ntokens,emsize,nhead,nhid,nlayers,dropout).to(device)#设定损失函数,采用交叉熵损失函数
criterion= nn.CrossEntropyLoss()#设置学习率
lr =5.0#设置优化器
optimizer = torch.optim.SGD(model.parameters(),lr=lr)criterion = nn.CrossEntropyLoss()
lr =5.0
optimizer = torch.optim.SGD(model.parameters(),lr=lr)
#定义学习率调整器,使用torch自带的lr_scheduler,将优化器传入
scheduler = torch.optim.lr_scheduler.StepLR(optimizer,1.0,gamma=0.95)# 5、进行训练和评估(包括验证以及测试)#训练函数
import time
def train():#首先开启训练模式model.train()#定义初始损失值total_loss =0start_time = time.time()#遍历训练数据进行模型的训练for batch,i in enumerate(range(0,train_data.size(0) -1 ,bptt)):#通过前面的get_batch函数获取源数据和目标数据data,targets = get_batch(train_data,i)#设置梯度归零optimizer.zero_grad()#通过模型预测输出output = model(data)#计算损失值loss = criterion(output.view(-1,ntokens),targets)#进行反向传播loss.backward()#进行梯度规范化,防止出现梯度爆炸或者梯度消失torch.nn.utils.clip_grad_norm(model.parameters(),0.5)#进行参数更新optimizer.step()#将损失值进行累加total_loss +=loss.item()# 获取当前开始时间log_interval =200#打印日志信息if batch % log_interval ==0 and batch>0:#计算平均损失cur_loss = total_loss/log_interval#计算训练到目前的耗时elapsed = time.time() - start_time#打印日志信息print('| epoch {:3d} | {:5d}/{:5d} batches | ''lr {:02.2f} | ms/batch {:5.2f} | ''loss {:5.2f} | ppl {:8.2f}'.format(epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],elapsed * 1000 / log_interval,cur_loss, math.exp(cur_loss)))#每个打印批次借结束后,将总损失值清零total_loss =0start_time= time.time()#评估函数
def evaluate(eval_model,data_source):#eval_model,代表每轮训练后产生的模型# data_source,代表验证集数据或者测试集数据#首先开启评估模式eval_model.eval()#初始化总损失值total_loss =0#模型开启评估模式,不进行反向传播求梯度with torch.no_grad():#遍历验证数据for i in range(0,data_source.size(0)-1,bptt):#首先通过get_batch函数获取源数据和目标数据data,targets = get_batch(data_source,i)#将源数据放入评估模型中,进行预测output = eval_model(data)#对输出张量进行变形output_flat = output.view(-1,ntokens)#累加损失值total_loss +=criterion(output_flat,targets).item()#返回评估的总损失值return total_lossbest_val_loss = float("inf")epochs =3
best_model =Nonefor epoch in range(1,epochs +1):epoch_start_time = time.time()train()val_loss = evaluate(model,val_data)print('-'*50)print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | ''valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),val_loss, math.exp(val_loss)))print('-'*50)#通过比较当前轮次的损失值,获取最佳模型if val_loss<best_val_loss:best_val_loss=val_lossbest_model=model#每个轮次后调整优化器的学习率scheduler.step()#添加测试的流程代码
test_loss = evaluate(best_model,test_data)
print('-'*90)
print('|End of training |test loss {:5,2f}'.format(test_loss))
print('-'*50)

没有运行成功

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/755020.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PHP反序列化---字符串逃逸(增加/减少)

一、PHP反序列化逃逸--增加&#xff1a; 首先分析源码&#xff1a; <?php highlight_file(__FILE__); error_reporting(0); class A{public $v1 ls;public $v2 123;public function __construct($arga,$argc){$this->v1 $arga;$this->v2 $argc;} } $a $_GET[v…

探索区块链世界:从加密货币到去中心化应用

相信提到区块链&#xff0c;很多人会想到比特币这样的加密货币&#xff0c;但实际上&#xff0c;区块链技术远不止于此&#xff0c;它正在深刻地改变我们的生活和商业。 首先&#xff0c;让我们来简单了解一下什么是区块链。区块链是一种分布式数据库技术&#xff0c;它通过将…

三维GIS仿真技术发展的几点思考

随着近几年三维GIS的快速发展&#xff0c;三维GIS可视化在数字孪生、元宇宙等热点方向起着重要技术支撑&#xff0c;结合这几年三维GIS技术进展,其出现的进展与问题&#xff0c;有以下几点思考&#xff0c;供读者参考&#xff1a; 1.关于国内GIS三维仿真技术处于什么水平&#…

蓝桥杯-python-递归

递归&#xff1a;通过自我调用解决问题的函数 注意&#xff1a; #1.递归出口 #2.当前问题如何变成子问题 例子&#xff1a;利用递归写一个阶乘函数&#xff0c;F(n),求n的阶乘 def f(n):if n < 1:return 1ans n * f(n-1)return ans print(f(5)) 例子&#xff1a;汉诺塔…

vsto excel 插件注册表属性值含义

在 VSTO (Visual Studio Tools for Office) 中&#xff0c;LoadBehavior 是用于指定 Office 插件加载行为的一个属性。具体含义如下&#xff1a; - LoadBehavior 0&#xff1a;此值表示插件已被禁用&#xff0c;将不会加载。 - LoadBehavior 1&#xff1a;此值表示插件将在 O…

【CSP试题回顾】202303-1-田地丈量

CSP-202303-1-田地丈量 解题思路 问题描述&#xff1a;给定一个长为 a、宽为 b 的矩形区域&#xff08;坐标范围从 (0,0) 到 (a,b)&#xff09;&#xff0c;以及 n 个其他矩形&#xff0c;每个矩形由其对角线的两个顶点 (x1, y1) 和 (x2, y2) 定义。需要计算这些矩形与给定矩形…

深度刨析Android ANR触发原理

一、概述 ANR(Application Not responding)&#xff0c;是指应用程序未响应&#xff0c;Android系统对于一些事件需要在一定的时间范围内完成&#xff0c;如果超过预定时间能未能得到有效响应或者响应时间过长&#xff0c;都会造成ANR。一般地&#xff0c;这时往往会弹出一个提…

015 Linux_生产消费模型

​&#x1f308;个人主页&#xff1a;Fan_558 &#x1f525; 系列专栏&#xff1a;Linux &#x1f339;关注我&#x1f4aa;&#x1f3fb;带你学更多操作系统知识 文章目录 前言一、生产消费模型&#xff08;1&#xff09;概念引入&#xff08;2&#xff09;生产消费模型的优点…

键牌 6寸水口钳工业级电子斜嘴水口剪偏口钳子电工专用小斜口钳

品牌&#xff1a;键牌 型号&#xff1a;6寸水口钳灰红 材质&#xff1a;不锈钢 颜色分类&#xff1a;6寸水口钳灰红 多用途电工钳&#xff0c;高硬度&#xff0c;韧性好&#xff0c;材质优。 匠心之作&#xff0c;精工典范&#xff0c;不锈钢材质&#xff0c;加厚刀刃&am…

JAVA学习-NIO.Buffer(缓冲器)

Java NIO中的缓冲器&#xff08;Buffer&#xff09;是用来存储数据的对象。它是一个固定大小的数组&#xff0c;可以容纳特定类型的数据。 一、Java NIO中提供了7种类型的缓冲器&#xff0c;分别是&#xff1a; 1. ByteBuffer&#xff1a; 字节缓冲器&#xff0c;用来存储字…

解决 ArrayList 的并发问题

目录 1. 场景复现1.1 数据不一致问题示例代码1.2 ConcurrentModificationException 问题示例代码 2. 解决并发的三种方法2.1 使用 Collections.synchronizedList2.2 使用 CopyOnWriteArrayList&#xff08;推荐使用&#xff09;2.3 使用显式的同步控制 3. 总结 ArrayList是java…

【JavaWeb】Spring非阻塞通信 - Spring Reactive之WebFlux的使用

【JavaWeb】Spring非阻塞通信 - Spring Reactive之WebFlux的使用 文章目录 【JavaWeb】Spring非阻塞通信 - Spring Reactive之WebFlux的使用参考资料一、初识WebFlux1、什么是函数式编程1&#xff09;面向对象编程思维 VS 函数式编程思维&#xff08;封装、继承和多态描述事物间…

Magical Combat VFX 2

我们为Unity推出的最新资产包:魔法战斗VFX包!这个包非常适合为你的游戏添加激烈而致命的魔法。有30多种独特的效果,包括血液、酸和毒咒,你可以在战斗场景中大显身手。而且移动支持和优化是首要任务,你可以在旅途中使用这些效果,而不用担心性能问题。使用功能齐全、移动就…

【 React 】React render方法的原理?在什么时候会被触发?

1. 原理 首先&#xff0c;render函数在react中有两种形式&#xff1a; 在类组件中&#xff0c;指的是render方法&#xff1a; class Foo extends React.Component {render() {return <h1> Foo </h1>; } }在函数组件中&#xff0c;指的是函数组件本身&#xff1a…

windows11安装SQL server数据库报错等待数据库引擎恢复句柄失败(二)

windows11安装SQL server数据库报错等待数据库引擎恢复句柄失败&#xff08;二&#xff09;&#xff0c;昨天在给网友远程的时候发现了一个新的问题。 计算机系统同样是Windows11&#xff0c;通过命令查出来的扇区相关结果也都是4096&#xff0c;但是最后的安装还是提示SQL ser…

JVM内存模型深度解读

JVM&#xff08;Java Virtual Machine&#xff0c;Java虚拟机&#xff09;对于Java开发者和运行 Java 应用程序而言至关重要。其重要性主要体现在跨平台性、内存管理和垃圾回收、性能优化、安全性和稳定性、故障排查与性能调优等方面。今天就下学习一下 JVM 的内存模型。 一、…

嵌入式学习40-数据结构

数据结构 1.定义 一组用来保存一种或者多种特定关系的 数据的集合&#xff08;组织和存储数据&#xff09; 程序的设计&#xff1a; …

31-Java前端控制器模式(Front Controller Pattern)

Java前端控制器模式 实现范例 前端控制器模式&#xff08;Front Controller Pattern&#xff09;是用来提供一个集中的请求处理机制&#xff0c;所有的请求都将由一个单一的处理程序处理该处理程序可以做认证/授权/记录日志&#xff0c;或者跟踪请求&#xff0c;然后把请求传给…

使用RabbitMQ,关键点总结

文章目录 1.MQ的基本概念2.常见的MQ产品3.MQ 的优势和劣势3.1 优势3.2 劣势 4.RabbitMQ简介4.1RabbitMQ 中的相关概念 1.MQ的基本概念 MQ全称 Message Queue&#xff08;消息队列&#xff09;&#xff0c;是在消息的传输过程中保存消息的容器。多用于分布式系统之间进行通信。…

nginx相关内容的安装

nginx的安装 安装依赖 yum install gcc gcc-c automake autoconf libtool make gd gd-devel libxslt-devel -y 安装lua与lua依赖 lua安装步骤如下: mkdir /www mkdir /www/server #选择你自己的目录即可,不需要跟我一致 cd /www/server tar -zxvf lua-5.4.3.tar.gz cd lua-5.4…