【自然语言处理(NLP)】基于Transformer架构的预训练语言模型:BERT 训练之数据集处理、训练代码实现

文章目录

  • 介绍
  • BERT 训练之数据集处理
    • BERT 原理及模型代码实现
    • 数据集处理
      • 导包
      • 加载数据
      • 生成下一句预测任务的数据
      • 从段落中获取nsp数据
      • 生成遮蔽语言模型任务的数据
      • 从token中获取mlm数据
      • 将文本转换为预训练数据集
      • 创建Dataset
      • 加载WikiText-2数据集
  • BERT 训练代码实现
    • 导包
    • 加载数据
    • 构建BERT模型
    • 模型损失
    • 训练
    • 获取BERT编码器

个人主页:道友老李
欢迎加入社区:道友老李的学习社区

介绍

**自然语言处理(Natural Language Processing,NLP)**是计算机科学领域与人工智能领域中的一个重要方向。它研究的是人类(自然)语言与计算机之间的交互。NLP的目标是让计算机能够理解、解析、生成人类语言,并且能够以有意义的方式回应和操作这些信息。

NLP的任务可以分为多个层次,包括但不限于:

  1. 词法分析:将文本分解成单词或标记(token),并识别它们的词性(如名词、动词等)。
  2. 句法分析:分析句子结构,理解句子中词语的关系,比如主语、谓语、宾语等。
  3. 语义分析:试图理解句子的实际含义,超越字面意义,捕捉隐含的信息。
  4. 语用分析:考虑上下文和对话背景,理解话语在特定情境下的使用目的。
  5. 情感分析:检测文本中表达的情感倾向,例如正面、负面或中立。
  6. 机器翻译:将一种自然语言转换为另一种自然语言。
  7. 问答系统:构建可以回答用户问题的系统。
  8. 文本摘要:从大量文本中提取关键信息,生成简短的摘要。
  9. 命名实体识别(NER):识别文本中提到的特定实体,如人名、地名、组织名等。
  10. 语音识别:将人类的语音转换为计算机可读的文字格式。

NLP技术的发展依赖于算法的进步、计算能力的提升以及大规模标注数据集的可用性。近年来,深度学习方法,特别是基于神经网络的语言模型,如BERT、GPT系列等,在许多NLP任务上取得了显著的成功。随着技术的进步,NLP正在被应用到越来越多的领域,包括客户服务、智能搜索、内容推荐、医疗健康等。

BERT 训练之数据集处理

BERT 原理及模型代码实现

【自然语言处理(NLP)】基于Transformer架构的预训练语言模型:BERT 原理及代码实现

数据集处理

导包

import os
import random
import torch
import dltools

加载数据

def _read_wiki(data_dir):file_name = os.path.join(data_dir, 'wiki.train.tokens')with open(file_name, 'r',encoding="utf-8") as f:lines = f.readlines()# 大写字母转换为小写字母paragraphs = [line.strip().lower().split(' . ') for line in lines if len(line.split(' . ')) >= 2]random.shuffle(paragraphs)return paragraphs_read_wiki('./wikitext-2')

在这里插入图片描述

生成下一句预测任务的数据

def _get_next_sentence(sentence, next_sentence, paragraphs):if random.random() < 0.5:is_next = Trueelse:# paragraphs是三重列表的嵌套next_sentence = random.choice(random.choice(paragraphs))is_next = Falsereturn sentence, next_sentence, is_next

从段落中获取nsp数据

def _get_nsp_data_from_paragraph(paragraph, paragraphs, vocab, max_len):nsp_data_from_paragraph = []for i in range(len(paragraph) - 1):tokens_a, tokens_b, is_next = _get_next_sentence(paragraph[i], paragraph[i + 1], paragraphs)# 考虑1个'<cls>'词元和2个'<sep>'词元if len(tokens_a) + len(tokens_b) + 3 > max_len:continuetokens, segments = dltools.get_tokens_and_segments(tokens_a, tokens_b)nsp_data_from_paragraph.append((tokens, segments, is_next))return nsp_data_from_paragraph

生成遮蔽语言模型任务的数据

  1. 为遮蔽语言模型的输入创建新的词元副本,其中输入可能包含替换的mask或随机词元
  2. 打乱后用于在遮蔽语言模型任务中获取15%的随机词元进行预测
  3. 80%的时间:将词替换为mask词元
  4. 10%的时间:保持词不变
  5. 10%的时间:用随机词替换该词
def _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds,vocab):# 为遮蔽语言模型的输入创建新的词元副本,其中输入可能包含替换的“<mask>”或随机词元mlm_input_tokens = [token for token in tokens]pred_positions_and_labels = []# 打乱后用于在遮蔽语言模型任务中获取15%的随机词元进行预测random.shuffle(candidate_pred_positions)for mlm_pred_position in candidate_pred_positions:if len(pred_positions_and_labels) >= num_mlm_preds:breakmasked_token = None# 80%的时间:将词替换为“<mask>”词元if random.random() < 0.8:masked_token = '<mask>'else:# 10%的时间:保持词不变if random.random() < 0.5:masked_token = tokens[mlm_pred_position]# 10%的时间:用随机词替换该词else:masked_token = random.choice(vocab.idx_to_token)mlm_input_tokens[mlm_pred_position] = masked_tokenpred_positions_and_labels.append((mlm_pred_position, tokens[mlm_pred_position]))return mlm_input_tokens, pred_positions_and_labels

从token中获取mlm数据

在遮蔽语言模型任务中不会预测特殊词元

def _get_mlm_data_from_tokens(tokens, vocab):candidate_pred_positions = []# tokens是一个字符串列表for i, token in enumerate(tokens):# 在遮蔽语言模型任务中不会预测特殊词元if token in ['<cls>', '<sep>']:continuecandidate_pred_positions.append(i)# 遮蔽语言模型任务中预测15%的随机词元num_mlm_preds = max(1, round(len(tokens) * 0.15))mlm_input_tokens, pred_positions_and_labels = _replace_mlm_tokens(tokens, candidate_pred_positions, num_mlm_preds, vocab)pred_positions_and_labels = sorted(pred_positions_and_labels,key=lambda x: x[0])pred_positions = [v[0] for v in pred_positions_and_labels]mlm_pred_labels = [v[1] for v in pred_positions_and_labels]return vocab[mlm_input_tokens], pred_positions, vocab[mlm_pred_labels]

将文本转换为预训练数据集

  1. valid_lens不包括’'的计数
  2. 填充词元的预测将通过乘以0权重在损失中过滤掉
def _pad_bert_inputs(examples, max_len, vocab):max_num_mlm_preds = round(max_len * 0.15)all_token_ids, all_segments, valid_lens,  = [], [], []all_pred_positions, all_mlm_weights, all_mlm_labels = [], [], []nsp_labels = []for (token_ids, pred_positions, mlm_pred_label_ids, segments,is_next) in examples:all_token_ids.append(torch.tensor(token_ids + [vocab['<pad>']] * (max_len - len(token_ids)), dtype=torch.long))all_segments.append(torch.tensor(segments + [0] * (max_len - len(segments)), dtype=torch.long))# valid_lens不包括'<pad>'的计数valid_lens.append(torch.tensor(len(token_ids), dtype=torch.float32))all_pred_positions.append(torch.tensor(pred_positions + [0] * (max_num_mlm_preds - len(pred_positions)), dtype=torch.long))# 填充词元的预测将通过乘以0权重在损失中过滤掉all_mlm_weights.append(torch.tensor([1.0] * len(mlm_pred_label_ids) + [0.0] * (max_num_mlm_preds - len(pred_positions)),dtype=torch.float32))all_mlm_labels.append(torch.tensor(mlm_pred_label_ids + [0] * (max_num_mlm_preds - len(mlm_pred_label_ids)), dtype=torch.long))nsp_labels.append(torch.tensor(is_next, dtype=torch.long))return (all_token_ids, all_segments, valid_lens, all_pred_positions,all_mlm_weights, all_mlm_labels, nsp_labels)

创建Dataset

  1. 输入paragraphs[i]是代表段落的句子字符串列表
  2. 而输出paragraphs[i]是代表段落的句子列表,其中每个句子都是词元列表
  3. 获取下一句子预测任务的数据
  4. 获取遮蔽语言模型任务的数据
  5. 填充输入
class _WikiTextDataset(torch.utils.data.Dataset):def __init__(self, paragraphs, max_len):# 输入paragraphs[i]是代表段落的句子字符串列表;# 而输出paragraphs[i]是代表段落的句子列表,其中每个句子都是词元列表paragraphs = [dltools.tokenize(paragraph, token='word') for paragraph in paragraphs]sentences = [sentence for paragraph in paragraphsfor sentence in paragraph]self.vocab = dltools.Vocab(sentences, min_freq=5, reserved_tokens=['<pad>', '<mask>', '<cls>', '<sep>'])# 获取下一句子预测任务的数据examples = []for paragraph in paragraphs:examples.extend(_get_nsp_data_from_paragraph(paragraph, paragraphs, self.vocab, max_len))# 获取遮蔽语言模型任务的数据examples = [(_get_mlm_data_from_tokens(tokens, self.vocab)+ (segments, is_next))for tokens, segments, is_next in examples]# 填充输入(self.all_token_ids, self.all_segments, self.valid_lens,self.all_pred_positions, self.all_mlm_weights,self.all_mlm_labels, self.nsp_labels) = _pad_bert_inputs(examples, max_len, self.vocab)def __getitem__(self, idx):return (self.all_token_ids[idx], self.all_segments[idx],self.valid_lens[idx], self.all_pred_positions[idx],self.all_mlm_weights[idx], self.all_mlm_labels[idx],self.nsp_labels[idx])def __len__(self):return len(self.all_token_ids)

加载WikiText-2数据集

def load_data_wiki(batch_size, max_len):"""加载WikiText-2数据集"""num_workers = dltools.get_dataloader_workers()data_dir = "./wikitext-2/"paragraphs = _read_wiki(data_dir)train_set = _WikiTextDataset(paragraphs, max_len)train_iter = torch.utils.data.DataLoader(train_set, batch_size,shuffle=True, num_workers=num_workers)return train_iter, train_set.vocabbatch_size, max_len = 512, 64
train_iter, vocab = load_data_wiki(batch_size, max_len)for (tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X,mlm_Y, nsp_y) in train_iter:print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,nsp_y.shape)break
torch.Size([512, 64]) torch.Size([512, 64]) torch.Size([512]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512, 10]) torch.Size([512])
len(vocab)
20256

BERT 训练代码实现

导包

import torch
from torch import nn
import dltools

加载数据

dltools中加载本地wiki文件,请自行修改路径 ./data/wikitext-2

batch_size, max_len = 1, 64
# dltools中加载本地wiki文件,请自行修改路径 ./data/wikitext-2
train_iter, vocab = dltools.load_data_wiki(batch_size, max_len)# tokens, segments, valid_lens, pred_positions, mlm_weights,mlm, nsp
for i in train_iter:break
i

在这里插入图片描述

构建BERT模型

net = dltools.BERTModel(len(vocab), num_hiddens=128, norm_shape=[128],ffn_num_input=128, ffn_num_hiddens=256, num_heads=2,num_layers=2, dropout=0.2, key_size=128, query_size=128,value_size=128, hid_in_features=128, mlm_in_features=128,nsp_in_features=128)
devices = dltools.try_all_gpus()

模型损失

  1. 前向传播
  2. 计算遮蔽语言模型损失
  3. 计算下一句子预测任务的损失
loss = nn.CrossEntropyLoss()def _get_batch_loss_bert(net, loss, vocab_size, tokens_X,segments_X, valid_lens_x,pred_positions_X, mlm_weights_X,mlm_Y, nsp_y):# 前向传播_, mlm_Y_hat, nsp_Y_hat = net(tokens_X, segments_X,valid_lens_x.reshape(-1),pred_positions_X)# 计算遮蔽语言模型损失mlm_l = loss(mlm_Y_hat.reshape(-1, vocab_size), mlm_Y.reshape(-1)) * mlm_weights_X.reshape(-1, 1)mlm_l = mlm_l.sum() / (mlm_weights_X.sum() + 1e-8)# 计算下一句子预测任务的损失nsp_l = loss(nsp_Y_hat, nsp_y)l = mlm_l + nsp_lreturn mlm_l, nsp_l, l

训练

遮蔽语言模型损失的和,下一句预测任务损失的和,句子对的数量,计数

def train_bert(train_iter, net, loss, vocab_size, devices, num_steps):net = nn.DataParallel(net, device_ids=devices).to(devices[0])trainer = torch.optim.Adam(net.parameters(), lr=0.01)step, timer = 0, dltools.Timer()animator = dltools.Animator(xlabel='step', ylabel='loss',xlim=[1, num_steps], legend=['mlm', 'nsp'])# 遮蔽语言模型损失的和,下一句预测任务损失的和,句子对的数量,计数metric = dltools.Accumulator(4)num_steps_reached = Falsewhile step < num_steps and not num_steps_reached:for tokens_X, segments_X, valid_lens_x, pred_positions_X,mlm_weights_X, mlm_Y, nsp_y in train_iter:tokens_X = tokens_X.to(devices[0])segments_X = segments_X.to(devices[0])valid_lens_x = valid_lens_x.to(devices[0])pred_positions_X = pred_positions_X.to(devices[0])mlm_weights_X = mlm_weights_X.to(devices[0])mlm_Y, nsp_y = mlm_Y.to(devices[0]), nsp_y.to(devices[0])trainer.zero_grad()timer.start()mlm_l, nsp_l, l = _get_batch_loss_bert(net, loss, vocab_size, tokens_X, segments_X, valid_lens_x,pred_positions_X, mlm_weights_X, mlm_Y, nsp_y)l.backward()trainer.step()metric.add(mlm_l, nsp_l, tokens_X.shape[0], 1)timer.stop()animator.add(step + 1,(metric[0] / metric[3], metric[1] / metric[3]))step += 1if step == num_steps:num_steps_reached = Truebreakprint(f'MLM loss {metric[0] / metric[3]:.3f}, 'f'NSP loss {metric[1] / metric[3]:.3f}')print(f'{metric[2] / timer.sum():.1f} sentence pairs/sec on 'f'{str(devices)}')train_bert(train_iter, net, loss, len(vocab), devices, 500)

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

获取BERT编码器

def get_bert_encoding(net, tokens_a, tokens_b=None):tokens, segments = dltools.get_tokens_and_segments(tokens_a, tokens_b)token_ids = torch.tensor(vocab[tokens], device=devices[0]).unsqueeze(0)segments = torch.tensor(segments, device=devices[0]).unsqueeze(0)valid_len = torch.tensor(len(tokens), device=devices[0]).unsqueeze(0)encoded_X, _, _ = net(token_ids, segments, valid_len)return encoded_Xtokens_a = ['a', 'crane', 'is', 'flying']
encoded_text = get_bert_encoding(net, tokens_a)
# 词元:'<cls>','a','crane','is','flying','<sep>'
encoded_text_cls = encoded_text[:, 0, :]
encoded_text_crane = encoded_text[:, 2, :]
encoded_text.shape, encoded_text_cls.shape, encoded_text_crane[0][:3]
(torch.Size([1, 6, 128]),torch.Size([1, 128]),tensor([-1.0005,  0.8355,  0.2930], grad_fn=<SliceBackward0>))
tokens_a, tokens_b = ['a', 'crane', 'driver', 'came'], ['he', 'just', 'left']
encoded_pair = get_bert_encoding(net, tokens_a, tokens_b)
# 词元:'<cls>','a','crane','driver','came','<sep>','he','just',
# 'left','<sep>'
encoded_pair_cls = encoded_pair[:, 0, :]
encoded_pair_crane = encoded_pair[:, 2, :]
encoded_pair.shape, encoded_pair_cls.shape, encoded_pair_crane[0][:3]
(torch.Size([1, 10, 128]),torch.Size([1, 128]),tensor([-1.0168,  0.8235,  0.2141], grad_fn=<SliceBackward0>))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/67796.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Numpy核心编程攻略:Python数据处理、分析详解与科学计算】2.5 高级索引应用:图像处理中的区域提取

2.5 高级索引应用&#xff1a;图像处理中的区域提取 目录/提纲 #mermaid-svg-BI09xc20YqcpUam7 {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-BI09xc20YqcpUam7 .error-icon{fill:#552222;}#mermaid-svg-BI09xc20…

通过Redisson构建延时队列并实现注解式消费

目录 一、序言二、延迟队列实现1、Redisson延时消息监听注解和消息体2、Redisson延时消息发布器3、Redisson延时消息监听处理器 三、测试用例四、结语 一、序言 两个月前接了一个4万的私活&#xff0c;做一个线上商城小程序&#xff0c;在交易过程中不可避免的一个问题就是用户…

Baklib构建高效协同的基于云的内容中台解决方案

内容概要 随着云计算技术的飞速发展&#xff0c;内容管理的方式也在不断演变。企业面临着如何在数字化转型过程中高效管理和协同处理内容的新挑战。为应对这些挑战&#xff0c;引入基于云的内容中台解决方案显得尤为重要。 Baklib作为创新型解决方案提供商&#xff0c;致力于…

deepseek+vscode自动化测试脚本生成

近几日Deepseek大火,我这里也尝试了一下,确实很强。而目前vscode的AI toolkit插件也已经集成了deepseek R1,这里就介绍下在vscode中利用deepseek帮助我们完成自动化测试脚本的实践分享 安装AI ToolKit并启用Deepseek 微软官方提供了一个针对AI辅助的插件,也就是 AI Toolk…

电介质超表面中指定涡旋的非线性生成

涡旋光束在众多领域具有重要应用&#xff0c;但传统光学器件产生涡旋光束的方式限制了其在集成系统中的应用。超表面的出现为涡旋光束的产生带来了新的可能性&#xff0c;尤其是在非线性领域&#xff0c;尽管近些年来已经有一些研究&#xff0c;但仍存在诸多问题&#xff0c;如…

基于Springboot+mybatis+mysql+html图书管理系统2

基于Springbootmybatismysqlhtml图书管理系统2 一、系统介绍二、功能展示1.用户登陆2.用户主页3.图书查询4.还书5.个人信息修改6.图书管理&#xff08;管理员&#xff09;7.学生管理&#xff08;管理员&#xff09;8.废除记录&#xff08;管理员&#xff09; 三、数据库四、其它…

本地部署DeepSeek方法

本地部署完成后的效果如下图&#xff0c;整体与chatgpt类似&#xff0c;只是模型在本地推理。 我们在本地部署主要使用两个工具&#xff1a; ollamaopen-webui ollama是在本地管理和运行大模型的工具&#xff0c;可以直接在terminal里和大模型对话。open-webui是提供一个类…

游戏引擎 Unity - Unity 启动(下载 Unity Editor、生成 Unity Personal Edition 许可证)

Unity Unity 首次发布于 2005 年&#xff0c;属于 Unity Technologies Unity 使用的开发技术有&#xff1a;C# Unity 的适用平台&#xff1a;PC、主机、移动设备、VR / AR、Web 等 Unity 的适用领域&#xff1a;开发中等画质中小型项目 Unity 适合初学者或需要快速上手的开…

【开源免费】基于Vue和SpringBoot的公寓报修管理系统(附论文)

本文项目编号 T 186 &#xff0c;文末自助获取源码 \color{red}{T186&#xff0c;文末自助获取源码} T186&#xff0c;文末自助获取源码 目录 一、系统介绍二、数据库设计三、配套教程3.1 启动教程3.2 讲解视频3.3 二次开发教程 四、功能截图五、文案资料5.1 选题背景5.2 国内…

《苍穹外卖》项目学习记录-Day11订单统计

根据起始时间和结束时间&#xff0c;先把begin放入集合中用while循环当begin不等于end的时候&#xff0c;让begin加一天&#xff0c;这样就把这个区间内的时间放到List集合。 查询每天的订单总数也就是查询的时间段是大于当天的开始时间&#xff08;0点0分0秒&#xff09;小于…

【python】python油田数据分析与可视化(源码+数据集)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;专__注&#x1f448;&#xff1a;专注主流机器人、人工智能等相关领域的开发、测试技术。 【python】python油田数据分析与可视化&#xff08…

FBX SDK的使用:基础知识

Windows环境配置 FBX SDK安装后&#xff0c;目录下有三个文件夹&#xff1a; include 头文件lib 编译的二进制库&#xff0c;根据你项目的配置去包含相应的库samples 官方使用案列 动态链接 libfbxsdk.dll, libfbxsdk.lib是动态库&#xff0c;需要在配置属性->C/C->预…

一文讲解HashMap线程安全相关问题(上)

HashMap不是线程安全的&#xff0c;主要有以下几个问题&#xff1a; ①、多线程下扩容会死循环。JDK1.7 中的 HashMap 使用的是头插法插入元素&#xff0c;在多线程的环境下&#xff0c;扩容的时候就有可能导致出现环形链表&#xff0c;造成死循环。 JDK 8 时已经修复了这个问…

python学习——常用的内置函数汇总

文章目录 类型转换函数数学函数常用的迭代器操作函数常用的其他内置函数 类型转换函数 数学函数 常用的迭代器操作函数 实操&#xff1a; from cv2.gapi import descr_oflst [55, 42, 37, 2, 66, 23, 18, 99]# (1) 排序操作 asc_lst sorted(lst) # 升序 desc_lst sorted(l…

MySQL数据库环境搭建

下载MySQL 官网&#xff1a;https://downloads.mysql.com/archives/installer/ 下载社区版就行了。 安装流程 看b站大佬的视频吧&#xff1a;https://www.bilibili.com/video/BV12q4y1477i/?spm_id_from333.337.search-card.all.click&vd_source37dfd298d2133f3e1f3e3c…

如何用微信小程序写春联

​ 生活没有模板,只需心灯一盏。 如果笑能让你释然,那就开怀一笑;如果哭能让你减压,那就让泪水流下来。如果沉默是金,那就不用解释;如果放下能更好地前行,就别再扛着。 一、引入 Vant UI 1、通过 npm 安装 npm i @vant/weapp -S --production​​ 2、修改 app.json …

[SAP ABAP] 静态断点的使用

在 ABAP 编程环境中&#xff0c;静态断点通过关键字BREAK-POINT实现&#xff0c;当程序执行到这一语句时&#xff0c;会触发调试器中断程序的运行&#xff0c;允许开发人员检查当前状态并逐步跟踪后续代码逻辑 通常情况下&#xff0c;在代码的关键位置插入静态断点可以帮助开发…

96,【4】 buuctf web [BJDCTF2020]EzPHP

进入靶场 查看源代码 GFXEIM3YFZYGQ4A 一看就是编码后的 1nD3x.php 访问 得到源代码 <?php // 高亮显示当前 PHP 文件的源代码&#xff0c;用于调试或展示代码结构 highlight_file(__FILE__); // 关闭所有 PHP 错误报告&#xff0c;防止错误信息泄露可能的安全漏洞 erro…

基于深度学习的输电线路缺陷检测算法研究(论文+源码)

输电线路关键部件的缺陷检测对于电网安全运行至关重要&#xff0c;传统方法存在效率低、准确性不高等问题。本研究探讨了利用深度学习技术进行输电线路关键组件的缺陷检测&#xff0c;目的是提升检测的效率与准确度。选用了YOLOv8模型作为基础&#xff0c;并通过加入CA注意力机…

3、从langchain到rag

文章目录 本文介绍向量和向量数据库向量向量数据库 索引开始动手实现rag加载文档数据并建立索引将向量存放到向量数据库中检索生成构成一条链 本文介绍 从本节开始&#xff0c;有了上一节的langchain基础学习&#xff0c;接下来使用langchain实现一个rag应用&#xff0c;并稍微…