BERT数据处理,模型,预训练

代码来自李沐老师《动手学pytorch》
在数据处理时,首先执行以下代码
def load_data_wiki(batch_size, max_len):"""加载WikiText-2数据集"""num_workers = d2l.get_dataloader_workers()data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')以上两句代码,不再说明paragraphs = _read_wiki(data_dir)train_set = _WikiTextDataset(paragraphs, max_len)train_iter = torch.utils.data.DataLoader(train_set, batch_size,shuffle=True)return train_iter, train_set.vocab

d2l.DATA_HUB['wikitext-2'] = ('https://s3.amazonaws.com/research.metamind.io/wikitext/''wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')#@save
def _read_wiki(data_dir):file_name = os.path.join(data_dir, 'wiki.train.tokens')with open(file_name, 'r',encoding='utf-8') as f:lines = f.readlines()# 大写字母转换为小写字母 ,每行文本中包含两个句子,才进行处理,否则舍去文本paragraphs = [line.strip().lower().split(' . ')for line in lines if len(line.split(' . ')) >= 2]random.shuffle(paragraphs)return paragraphs

首先读取文本,每个文本必须包含两个以上句子(为了第二个预训练任务:判断两个句子,是否连续)。paragraphs 其中一部分结果如下所示

文本中包含了三个句子,每个’‘里面,代表一个句子
['common starlings are trapped for food in some mediterranean countries'
, 'the meat is tough and of low quality , so it is <unk> or made into <unk>'
, 'one recipe said it should be <unk> " until tender , however long that may be "'
, 'even when correctly prepared , it may still be seen as an acquired taste .']
class _WikiTextDataset(torch.utils.data.Dataset):def __init__(self, paragraphs, max_len):'''每一个paragraph就是上面的包含多个句子的列表,将其进行分词处理。下面是一个分词的例子[['common', 'starlings', 'are', 'trapped', 'for', 'food', 'in', 'some', 'mediterranean', 'countries'], ['the', 'meat', 'is', 'tough', 'and', 'of', 'low', 'quality', ',', 'so', 'it', 'is', '<unk>', 'or', 'made', 'into', '<unk>'], ['one', 'recipe', 'said', 'it', 'should', 'be', '<unk>', '"', 'until', 'tender', ',', 'however', 'long', 'that', 'may', 'be', '"'], ['even', 'when', 'correctly', 'prepared', ',', 'it', 'may', 'still', 'be', 'seen', 'as', 'an', 'acquired', 'taste', '.']]'''paragraphs = [d2l.tokenize(paragraph, token='word') for paragraph in paragraphs]#将词提取处理,保存sentences = [sentence for paragraph in paragraphsfor sentence in paragraph]#形成一个词典,min_freq为词最少出现的次数,少于5次,则不保存进词典中self.vocab = d2l.Vocab(sentences, min_freq=5, reserved_tokens=['<pad>', '<mask>', '<cls>', '<sep>'])# 获取下一句子预测任务的数据examples = []for paragraph in paragraphs:examples.extend(_get_nsp_data_from_paragraph(paragraph, paragraphs, self.vocab, max_len))'''
def _get_nsp_data_from_paragraph(paragraph,paragraphs,vocab,max_len):nsp_data_from_paragraph=[]for i in range(len(paragraph)-1):_get_next_sentence函数传入的是相邻的句子a,b。函数中b会有一定概率替换为其他的句子tokens_a, tokens_b, is_next = _get_next_sentence(paragraph[i], paragraph[i + 1], paragraphs)句子长度大于bert限制的长度,则舍去。if len(tokens_a)+len(tokens_b)+3>max_len:continue#加上<cls>和<sep>,segments用于区token在哪个句子中tokens, segments = d2l.get_tokens_and_segments(tokens_a, tokens_b)nsp_data_from_paragraph.append((tokens, segments, is_next))return nsp_data_from_paragraphtoken和segments的例子: True表示两个句子相邻,False表示b被随机替换,a,b不相邻。(['<cls>', 'mushrooms', 'grow', '<unk>', 'or', 'in', '"', '<unk>', 'groups', '"', 'in', 'late', 'summer', 'and', 'throughout', 'autumn', ',', 'though', 'it', 'is', 'not', 'commonly', 'encountered', 'species', '<sep>', 'it','can', 'be', 'found', 'in', 'europe', ',', 'asia', 'and', 'north', 'america', '.', '<sep>'], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,1, 1, 1, 1, 1], True),'''# 获取遮蔽语言模型任务的数据'''在这里我们会将句子中单词,替换为在词典中的索引。13意思为,句子的第13个词,进行了处理,可能不变,可能替换为其他词,可能替换为mask。在这里这个词没有替换。0与1区分两个句子,False代表两个句子不相邻。examples中的结果;([3, 2510, 31, 337, 9, 0, 6, 6891, 8, 11621, 6, 21, 11, 60, 3405, 14, 1542, 9546, 4, 2524,21, 185, 4421, 649, 38, 277, 2872, 13233, 4], [13], [60], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], False)'''examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)+ (segments, is_next))for tokens, segments, is_next in examples]#_pad_bert_inputs对数据进行填充,all_mlm_weights中1为需要预测,0为填充#    all_mlm_weights= tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0.](self.all_token_ids, self.all_segments, self.valid_lens,self.all_pred_positions, self.all_mlm_weights,self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(examples, max_len, self.vocab)def __getitem__(self, idx):return (self.all_token_ids[idx], self.all_segments[idx],self.valid_lens[idx], self.all_pred_positions[idx],self.all_mlm_weights[idx], self.all_mlm_labels[idx],self.nsp_labels[idx])def __len__(self):return len(self.all_token_ids)

上述已经将数据处理完,最后看一下处理后的例子:

将原来的句子列表填充1,一直到到大小为64
tensor([[    3,     5,     0, 18306,    23,    11,  2659,   156,  5779,   382,1296,   110,   158,    22,     5,  1771,   496,     0,  3398,     2,5,  3496,   110,  5038,   179,     4,    16,    11, 19837,     6,58,    13,     5,   685,     7,    66,   156,     0,  3063,    77,3842,    19,     4,     1,     1,     1,     1,     1,     1,     1,1,     1,     1,     1,     1,     1,     1,     1,     1,     1,1,     1,     1,     1]])
segments用于区分两个句子,0为第一个句子中的词,1为第二个句子中的词,后面的0为填充
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
valid_lens表示句子列表的有效长度
tensor([43.])
pred_positions需要预测的位置,0为填充
tensor([[19,  0,  0,  0,  0,  0,  0,  0,  0,  0]])
mlm_weights需要预测多少个词,0为填充
tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
预测位置的真实标签,0为填充
tensor([[22,  0,  0,  0,  0,  0,  0,  0,  0,  0]])
两句话是否相邻
tensor([0])

随后就是把处理好的数据,送入bert中。在 BERTEncoder 中,执行如下代码:

 def forward(self, tokens, segments, valid_lens):# Shape of `X` remains unchanged in the following code snippet:# (batch size, max sequence length, `num_hiddens`)#  将token和segment分别进行embedding,X = self.token_embedding(tokens) + self.segment_embedding(segments)#加入位置编码X = X + self.pos_embedding.data[:, :X.shape[1], :]for blk in self.blks:X = blk(X, valid_lens)return X

将编码完后的数据,进行多头注意力和残差化

    def forward(self, X, valid_lens):Y = self.addnorm1(X, self.attention(X, X, X, valid_lens))return self.addnorm2(Y, self.ffn(Y))

将结果返回到如下代码中:其中encoded_X .shape=torch.Size([1, 64, 128]),1代表批次大小为1,我们设置的每个批次只有行文本,每行文本由64个词组成,bert提取128维的向量来表示每个词。随后进行两个任务,一个是预测被掩盖的单词,另一个为判断两个句子是否为相邻。

    def forward(self, tokens, segments, valid_lens=None, pred_positions=None):encoded_X = self.encoder(tokens, segments, valid_lens)if pred_positions is not None:mlm_Y_hat = self.mlm(encoded_X, pred_positions)else:mlm_Y_hat = None# The hidden layer of the MLP classifier for next sentence prediction.# 0 is the index of the '<cls>' tokennsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))return encoded_X, mlm_Y_hat, nsp_Y_hat

第一个任务为预测被mask的单词:

'''
例如:batch为1,X为1*64*128,其中num_pred_positions =10,batch_idx 会重复为[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],pred_positions为[ 3,  6, 10, 12, 15, 20,  0,  0,  0,  0],X[batch_idx, pred_positions]会将需要预测的向量取出。然后reshape为1*10*128的矩阵。最后连接一个mlp,经过规范化后接nn.Linear(num_hiddens, vocab_size)),会生成再vocab上的预测'''def forward(self, X, pred_positions):num_pred_positions = pred_positions.shape[1]pred_positions = pred_positions.reshape(-1)batch_size = X.shape[0]batch_idx = torch.arange(0, batch_size)# Suppose that `batch_size` = 2, `num_pred_positions` = 3, then# `batch_idx` is `torch.tensor([0, 0, 0, 1, 1, 1])`batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)masked_X = X[batch_idx, pred_positions]masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))mlm_Y_hat = self.mlp(masked_X)return mlm_Y_hat

结束后,会返回到上层的代码中:

def forward(self, tokens, segments, valid_lens=None, pred_positions=None):encoded_X = self.encoder(tokens, segments, valid_lens)if pred_positions is not None:mlm_Y_hat = self.mlm(encoded_X, pred_positions)else:mlm_Y_hat = None# The hidden layer of the MLP classifier for next sentence prediction.# 0 is the index of the '<cls>' token判断句子是否连续,将<cls>的向量,放入mlp中,接一个nn.Linear(num_inputs, 2),最后变成一个二分类问题。nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))return encoded_X, mlm_Y_hat, nsp_Y_hat

后面就是计算损失:

将mlm_Y_hat进行reshap,与mlm_Y求loss,最后需要乘mlm_weights_X,将填充的无用数据进行去除。mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)取平均lossmlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)nsp_l = loss(nsp_Y_hat, nsp_y)l = mlm_l + nsp_l

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/40643.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

django——配置 settings.py 及相关参数说明

3. 配置 settings.py 及相关参数说明 3.1 配置setting.py文件 设置setting.py文件 加入安装的库 apps.erp_test, rest_framework, django_filters, drf_spectacular,加入新增的APP users启动项目 # 运行项目先执行数据库相关操作&#xff0c;再启动 django 项目 python manag…

【JavaSE】面向对象之继承

继承 继承概念继承的语法父类成员的访问子类和父类没有同名成员变量子类和父类有同名成员变量成员方法名字不同成员方法名字相同 super关键字子类构造方法super和this继承方式 继承概念 继承(inheritance)机制&#xff1a;是面向对象程序设计使代码可以复用的最重要的手段&…

docker 安装nacos

1、下载nacos docker pull nacos/nacos-server2、启动nacos docker run --restart always --env MODEstandalone --name nacos -d -p 8848:8848 -p 9848:9848 -p 9849:9849 nacos/nacos-server3、验证nacos http://localhost:8848/nacos 默认用户名和密码&#xff1a;nacos

lvs集群与nat模式

一&#xff0c;什么是集群&#xff1a; 集群&#xff0c;群集&#xff0c;Cluster&#xff0c;由多台主机构成&#xff0c;但是对外只表现为一个整体&#xff0c;只提供一个访问入口&#xff08;域名与ip地址&#xff09;&#xff0c;相当于一台大型计算机。 二&#xff0c;集…

Java书签 #使用MyBatis接入多数据源

楔子&#xff1a;当然&#xff0c;世上有很多优秀的女性&#xff0c;我也会被她们吸引。这对男人来说是理所当然的。但目光被吸引和内心被吸引是截然不同的。- 东野圭吾《黎明之街》 今日书签 在一些应用场景中&#xff0c;可能需要连接多个不同的数据库&#xff0c;例如连接不…

Centos 防火墙命令

查看防火墙状态 systemctl status firewalld.service 或者 firewall-cmd --state 开启防火墙 单次开启防火墙 systemctl start firewalld.service 开机自启动防火墙 systemctl enable firewalld.service 重启防火墙 systemctl restart firewalld.service 防火墙设置开…

8.15 IO的多路复用

select的TCP客户端 poll的TCP客户端

Chart GPT免费可用地址共享资源

GPT4.0&#xff1a; https://gpt4e.ninvfeng.xyz github:https://github.com/ninvfeng/chatgpt4 WeUseAi&#xff1a;https://chatb.weuseai.pro AI.LS&#xff1a;https://n7.gpt03.xyz ChatX (iOS/macOS应用)&#xff1a;https://itunes.apple.com/app/id6446304087 ch…

C/C++ : C/C++的详解,C语言与C++的常用算法以及算法的各自用法和应用(初级,中级),C++ CSP考题(J居多,S偏少)的详解,NOI的真题题解

目录 1.C语言 2.C 3.C与C语言的共同/不同点 4.导读 5.相关文章 5.1&#xff1a;Dev-C是Windows 环境下的一个轻量级 C/C 集成开发环境&#xff08;IDE&#xff09; 5.2&#xff1a;C是从C语言发展而来的&#xff0c;而C语言的历史可以追溯到1969年 6.C/C最新年度总…

​LeetCode解法汇总88. 合并两个有序数组

目录链接&#xff1a; 力扣编程题-解法汇总_分享记录-CSDN博客 GitHub同步刷题项目&#xff1a; https://github.com/September26/java-algorithms 原题链接&#xff1a;力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 描述&#xff1a; 给你两个按…

解决方案:如何在 Amazon EMR Serverless 上执行纯 SQL 文件?

长久已来&#xff0c;SQL以其简单易用、开发效率高等优势一直是ETL的首选编程语言&#xff0c;在构建数据仓库和数据湖的过程中发挥着不可替代的作用。Hive和Spark SQL也正是立足于这一点&#xff0c;才在今天的大数据生态中牢牢占据着主力位置。在常规的Spark环境中&#xff0…

国企的大数据岗位方向的分析

现如今大数据已无所不在&#xff0c;并且正被越来越广泛的被应用到历史、政治、科学、经济、商业甚至渗透到我们生活的方方面面中&#xff0c;获取的渠道也越来越便利。 今天我们就来聊一聊“大屏应用”&#xff0c;说到大屏就一定要聊到数据可视化&#xff0c;现如今&#xf…

【Git】(三)回退版本

1、git reset命令 1.1 回退至上一个版本 git reset --hard HEAD^ 1.2 将本地的状态回退到和远程的一样 git reset --hard origin/master 注意&#xff1a;谨慎使用 –-hard 参数&#xff0c;它会删除回退点之前的所有信息。HEAD 说明&#xff1a;HEAD 表示当前版本HEAD^ 上…

服务链路追踪

一、服务链路追踪导论 1.背景 对于一个大型的几十个、几百个微服务构成的微服务架构系统&#xff0c;通常会遇到下面一些问题&#xff0c;比如&#xff1a; 如何串联整个调用链路&#xff0c;快速定位问题&#xff1f;如何理清各个微服务之间的依赖关系&#xff1f;如何进行…

pycorrector一键式文本纠错工具,整合了BERT、MacBERT、ELECTRA、ERNIE等多种模型,让您立即享受纠错的便利和效果

pycorrector&#xff1a;一键式文本纠错工具&#xff0c;整合了Kenlm、ConvSeq2Seq、BERT、MacBERT、ELECTRA、ERNIE、Transformer、T5等多种模型&#xff0c;让您立即享受纠错的便利和效果 pycorrector: 中文文本纠错工具。支持中文音似、形似、语法错误纠正&#xff0c;pytho…

Python OpenGL环境配置

1.Python的安装请参照 Anconda安装_安装anconda_lwb-nju的博客-CSDN博客anconda安装教程_安装ancondahttps://blog.csdn.net/lwbCUMT/article/details/125322193?spm1001.2014.3001.5501 Anconda换源虚拟环境创建及使用&#xff08;界面操作&#xff09;_anconda huanyuan_l…

彻底卸载Android Studio

永恒的爱是永远恪守最初的诺言。 在安装Android Studio会有很多问题导致无法正常运行&#xff0c;多次下载AS多次错误后了解到&#xff0c;删除以下四个文件才能彻底卸载Android Studio。 第一个文件&#xff1a;.gradle 路径&#xff1a;C:\Users\yao&#xff08;这里yao是本…

解密人工智能:线性回归 | 逻辑回归 | SVM

文章目录 1、机器学习算法简介1.1 机器学习算法包含的两个步骤1.2 机器学习算法的分类 2、线性回归算法2.1 线性回归的假设是什么&#xff1f;2.2 如何确定线性回归模型的拟合优度&#xff1f;2.3 如何处理线性回归中的异常值&#xff1f; 3、逻辑回归算法3.1 什么是逻辑函数?…

火山引擎联合Forrester发布《中国云原生安全市场现状及趋势白皮书》,赋能企业构建云原生安全体系

国际权威研究咨询公司Forrester 预测&#xff0c;2023年全球超过40%的企业将会采用云原生优先战略。然而&#xff0c;云原生在改变企业上云及构建新一代基础设施的同时&#xff0c;也带来了一系列的新问题&#xff0c;针对涵盖云原生应用、容器、镜像、编排系统平台以及基础设施…

用栈解决有效的括号匹配问题

//用数组实现栈 typedef char DataType; typedef struct stack {DataType* a;//动态数组int top;//栈顶int capacity; //容量 }ST;void STInit(ST*pst);//初始化void STDestroy(ST* pst);//销毁所有空间void STPush(ST* pst, DataType x);//插入数据到栈中void STPop(ST* pst);…