李沐70_bert微调——自学笔记

微调BERT

1.BERT滴哦每一个词元返回抽取了上下文信息的特征向量

2.不同的任务使用不同的特性

句子分类

将cls对应的向量输入到全连接层分类

命名实体识别

1.识别应该词元是不是命名实体,例如人名、机构、位置

2.将非特殊词元放进全连接层分类

问题回答

1.给定应该问题和描述文字,找出一个片段作为回答

2.对片段中的每个词元预测它是不是回答的开头或结束

总结

1.即使下游任务各有不同,使用BERT微调时君只需要增加输出层

2.但根据任务的不同,输入的表示,和使用Bert特征也会不一样

!pip install d2l==0.17.6  ### 很重要,不要下载错了,对于colab

Natural Language Inference and Dataset

斯坦福自然语言推断(SNLI)数据集

import os
import re
import torch
from torch import nn
from d2l import torch as d2ld2l.DATA_HUB['SNLI'] = ('https://nlp.stanford.edu/projects/snli/snli_1.0.zip','9fcde07509c7e87ec61c640c1b2753d9041758e4')data_dir = d2l.download_extract('SNLI')

读取数据集

定义函数read_snli以仅提取数据集的一部分,然后返回前提、假设及其标签的列表。*

def read_snli(data_dir, is_train):"""将SNLI数据集解析为前提、假设和标签"""def extract_text(s):# 删除我们不会使用的信息s = re.sub('\\(', '', s)s = re.sub('\\)', '', s)# 用一个空格替换两个或多个连续的空格s = re.sub('\\s{2,}', ' ', s)return s.strip()label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}file_name = os.path.join(data_dir, 'snli_1.0_train.txt'if is_train else 'snli_1.0_test.txt')with open(file_name, 'r') as f:rows = [row.split('\t') for row in f.readlines()[1:]]premises = [extract_text(row[1]) for row in rows if row[0] in label_set]hypotheses = [extract_text(row[2]) for row in rows if row[0] \in label_set]labels = [label_set[row[0]] for row in rows if row[0] in label_set]return premises, hypotheses, labels

打印前3对前提和假设,以及它们的标签(“0”“1”和“2”分别对应于“蕴涵”“矛盾”和“中性”)。

train_data = read_snli(data_dir, is_train=True)
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):print('前提:', x0)print('假设:', x1)print('标签:', y)
前提: A person on a horse jumps over a broken down airplane .
假设: A person is training his horse for a competition .
标签: 2
前提: A person on a horse jumps over a broken down airplane .
假设: A person is at a diner , ordering an omelette .
标签: 1
前提: A person on a horse jumps over a broken down airplane .
假设: A person is outdoors , on a horse .
标签: 0

训练集约有550000对,测试集约有10000对。下面显示了训练集和测试集中的三个标签“蕴涵”“矛盾”和“中性”是平衡的。

test_data = read_snli(data_dir, is_train=False)
for data in [train_data, test_data]:print([[row for row in data[2]].count(i) for i in range(3)])
[183416, 183187, 182764]
[3368, 3237, 3219]

定义一个用于加载SNLI数据集的类。类构造函数中的变量num_steps指定文本序列的长度,使得每个小批量序列将具有相同的形状。

class SNLIDataset(torch.utils.data.Dataset):"""用于加载SNLI数据集的自定义数据集"""def __init__(self, dataset, num_steps, vocab=None):self.num_steps = num_stepsall_premise_tokens = d2l.tokenize(dataset[0])all_hypothesis_tokens = d2l.tokenize(dataset[1])if vocab is None:self.vocab = d2l.Vocab(all_premise_tokens + \all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])else:self.vocab = vocabself.premises = self._pad(all_premise_tokens)self.hypotheses = self._pad(all_hypothesis_tokens)self.labels = torch.tensor(dataset[2])print('read ' + str(len(self.premises)) + ' examples')def _pad(self, lines):return torch.tensor([d2l.truncate_pad(self.vocab[line], self.num_steps, self.vocab['<pad>'])for line in lines])def __getitem__(self, idx):return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]def __len__(self):return len(self.premises)

我们可以调用read_snli函数和SNLIDataset类来下载SNLI数据集,并返回训练集和测试集的DataLoader实例,以及训练集的词表。

def load_data_snli(batch_size, num_steps=50):"""下载SNLI数据集并返回数据迭代器和词表"""num_workers = d2l.get_dataloader_workers()data_dir = d2l.download_extract('SNLI')train_data = read_snli(data_dir, True)test_data = read_snli(data_dir, False)train_set = SNLIDataset(train_data, num_steps)test_set = SNLIDataset(test_data, num_steps, train_set.vocab)train_iter = torch.utils.data.DataLoader(train_set, batch_size,shuffle=True,num_workers=num_workers)test_iter = torch.utils.data.DataLoader(test_set, batch_size,shuffle=False,num_workers=num_workers)return train_iter, test_iter, train_set.vocab

在这里,我们将批量大小设置为128时,将序列长度设置为50,并调用load_data_snli函数来获取数据迭代器和词表。然后我们打印词表大小。

train_iter, test_iter, vocab = load_data_snli(128, 50)
len(vocab)
read 549367 examples
read 9824 examples/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.warnings.warn(_create_warning_msg(18678

现在我们打印第一个小批量的形状。与情感分析相反,我们有分别代表前提和假设的两个输入X[0]和X[1]。

for X, Y in train_iter:print(X[0].shape)print(X[1].shape)print(Y.shape)break
/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.self.pid = os.fork()torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128])

Fine-Tuning BERT

import json
import multiprocessing
import os
import torch
from torch import nn
from d2l import torch as d2l

加载预训练的BERT

提供了两个版本的预训练的BERT:“bert.base”与原始的BERT基础模型一样大,需要大量的计算资源才能进行微调,而“bert.small”是一个小版本,以便于演示。

d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip','225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip','c72329e68a732bef0452e4b96a1c341c8910f81f')

两个预训练好的BERT模型都包含一个定义词表的“vocab.json”文件和一个预训练参数的“pretrained.params”文件。我们实现了以下load_pretrained_model函数来加载预先训练好的BERT参数。

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,num_heads, num_layers, dropout, max_len, devices):data_dir = d2l.download_extract(pretrained_model)# 定义空词表以加载预定义词表vocab = d2l.Vocab()vocab.idx_to_token = json.load(open(os.path.join(data_dir,'vocab.json')))vocab.token_to_idx = {token: idx for idx, token in enumerate(vocab.idx_to_token)}bert = d2l.BERTModel(len(vocab), num_hiddens, norm_shape=[256],ffn_num_input=256, ffn_num_hiddens=ffn_num_hiddens,num_heads=4, num_layers=2, dropout=0.2,max_len=max_len, key_size=256, query_size=256,value_size=256, hid_in_features=256,mlm_in_features=256, nsp_in_features=256)# 加载预训练BERT参数bert.load_state_dict(torch.load(os.path.join(data_dir,'pretrained.params')))return bert, vocab

加载和微调经过预训练BERT的小版本(“bert.small”)

devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model('bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,num_layers=2, dropout=0.1, max_len=512, devices=devices)
Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip...

微调BERT数据集

定义了一个定制的数据集类SNLIBERTDataset。在每个样本中,前提和假设形成一对文本序列,并被打包成一个BERT输入序列,如 图15.6.2所示。回想 14.8.4节,片段索引用于区分BERT输入序列中的前提和假设。利用预定义的BERT输入序列的最大长度(max_len),持续移除输入文本对中较长文本的最后一个标记,直到满足max_len。为了加速生成用于微调BERT的SNLI数据集,我们使用4个工作进程并行生成训练或测试样本。

class SNLIBERTDataset(torch.utils.data.Dataset):def __init__(self, dataset, max_len, vocab=None):all_premise_hypothesis_tokens = [[p_tokens, h_tokens] for p_tokens, h_tokens in zip(*[d2l.tokenize([s.lower() for s in sentences])for sentences in dataset[:2]])]self.labels = torch.tensor(dataset[2])self.vocab = vocabself.max_len = max_len(self.all_token_ids, self.all_segments,self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)print('read ' + str(len(self.all_token_ids)) + ' examples')def _preprocess(self, all_premise_hypothesis_tokens):pool = multiprocessing.Pool(4)  # 使用4个进程out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)all_token_ids = [token_ids for token_ids, segments, valid_len in out]all_segments = [segments for token_ids, segments, valid_len in out]valid_lens = [valid_len for token_ids, segments, valid_len in out]return (torch.tensor(all_token_ids, dtype=torch.long),torch.tensor(all_segments, dtype=torch.long),torch.tensor(valid_lens))def _mp_worker(self, premise_hypothesis_tokens):p_tokens, h_tokens = premise_hypothesis_tokensself._truncate_pair_of_tokens(p_tokens, h_tokens)tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \* (self.max_len - len(tokens))segments = segments + [0] * (self.max_len - len(segments))valid_len = len(tokens)return token_ids, segments, valid_lendef _truncate_pair_of_tokens(self, p_tokens, h_tokens):# 为BERT输入中的'<CLS>'、'<SEP>'和'<SEP>'词元保留位置while len(p_tokens) + len(h_tokens) > self.max_len - 3:if len(p_tokens) > len(h_tokens):p_tokens.pop()else:h_tokens.pop()def __getitem__(self, idx):return (self.all_token_ids[idx], self.all_segments[idx],self.valid_lens[idx]), self.labels[idx]def __len__(self):return len(self.all_token_ids)

下载完SNLI数据集后,我们通过实例化SNLIBERTDataset类来生成训练和测试样本。这些样本将在自然语言推断的训练和测试期间进行小批量读取。

# 如果出现显存不足错误,请减少“batch_size”。在原始的BERT模型中,max_len=512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,num_workers=num_workers)
read 549367 examples
read 9824 examples

微调BERT

用于自然语言推断的微调BERT只需要一个额外的多层感知机,该多层感知机由两个全连接层组成(请参见下面BERTClassifier类中的self.hidden和self.output)。这个多层感知机将特殊的“”词元的BERT表示进行了转换,该词元同时编码前提和假设的信息为自然语言推断的三个输出:蕴涵、矛盾和中性。

class BERTClassifier(nn.Module):def __init__(self, bert):super(BERTClassifier, self).__init__()self.encoder = bert.encoderself.hidden = bert.hiddenself.output = nn.Linear(256, 3)def forward(self, inputs):tokens_X, segments_X, valid_lens_x = inputsencoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)return self.output(self.hidden(encoded_X[:, 0, :]))

预训练的BERT模型bert被送到用于下游应用的BERTClassifier实例net中。

net = BERTClassifier(bert)

为了允许具有陈旧梯度的参数,标志ignore_stale_grad=True在step函数d2l.train_batch_ch13中被设置。我们通过该函数使用SNLI的训练集(train_iter)和测试集(test_iter)对net模型进行训练和评估。

lr, num_epochs = 1e-4, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)
loss 0.520, train acc 0.791, test acc 0.780
2536.5 examples/sec on [device(type='cuda', index=0)]

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/4676.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QT c++ 代码布局原则 简单例子

本文描述QT c widget代码布局遵循的原则&#xff1a;实中套虚&#xff0c;虚中套实。 本文最后列出了代码下载链接。 在QT6.2.4 msvc2019编译通过。 所谓实是实体组件&#xff1a;比如界面框、文本标签、组合框、文本框、按钮、表格、图片框等。 所谓虚是Layout组件&#x…

Redis哈希槽和一致性哈希

前言 单点的Redis有一定的局限&#xff1a; 单点发生故障&#xff0c;数据丢失&#xff0c;影响整体服务应用自身资源有限&#xff0c;无法承载更多资源分配并发访问&#xff0c;给服务器主机带来压力&#xff0c;性能瓶颈 我们想提升系统的容量、性能和可靠性&#xff0c;就…

sentinel-1.8.7与nacos-2.3.0实现动态规则配置、双向同步

&#x1f60a; 作者&#xff1a; 一恍过去 &#x1f496; 主页&#xff1a; https://blog.csdn.net/zhuocailing3390 &#x1f38a; 社区&#xff1a; Java技术栈交流 &#x1f389; 主题&#xff1a; sentinel-1.8.7与nacos-2.3.0实现动态规则配置、双向同步 ⏱️ 创作时…

unity的特性AttriBute详解

unity的特性AttriBute曾经令我大为头疼。因为不动使用的法则&#xff0c;但是教程都是直接就写&#xff0c;卡住就不能继续学下去。令我每一次看到&#xff0c;直接不敢看了。 今天使用文心一言搜索一番&#xff0c;发现&#xff0c;恐惧都是自己想象的&#xff0c;实际上这个…

Kotlin泛型之 循环引用泛型(A的泛型是B的子类,B的泛型是A的子类)

IDE(编辑器)报错 循环引用泛型是我起的名字&#xff0c;不知道官方的名字是什么。这个问题是我在定义Android 的MVP时提出来的。具体是什么样的呢&#xff1f;我们看一下我的基础的MVP定义&#xff1a; interface IPresenter<V> { fun getView(): V }interface IVie…

Nodejs 第六十八章(远程桌面)

远程桌面 远程桌面&#xff08;Remote Desktop&#xff09;是一种技术&#xff0c;允许用户通过网络远程连接到另一台计算机&#xff0c;并在本地计算机上控制远程计算机的操作。通过远程桌面&#xff0c;用户可以在不同地点的计算机之间共享屏幕、键盘和鼠标&#xff0c;就像…

宝塔面板安装教程(linux)

宝塔官网地址 宝塔官网linux安装地址 针对Ubuntu系统的安装命令&#xff1a; wget -O install.sh https://download.bt.cn/install/install-ubuntu_6.0.sh && sudo bash install.sh ed8484bec 安装过程中&#xff0c;中途会出现一个 Y&N ? 的选项&#xf…

OpenCV如何模板匹配

返回:OpenCV系列文章目录&#xff08;持续更新中......&#xff09; 上一篇&#xff1a;OpenCV如何实现背投 下一篇 &#xff1a;OpenCV在图像中寻找轮廓 目标 在本教程中&#xff0c;您将学习如何&#xff1a; 使用 OpenCV 函数 matchTemplate()搜索图像贴片和输入图像之间…

如何下载AndroidStudio旧版本

文章目录 1. Android官方网站2. 往下滑找到历史版本归档3. 同意软件下载条款协议4. 下载旧版本Androidstudio1. Android官方网站 点击 Android官网AS下载页面 https://developer.android.google.cn/studio 进入AndroidStuido最新版下载页面,如下图: 2. 往下滑找到历史版本归…

一本书了解AI的下一个风口:AI Agent

在数字化浪潮中&#xff0c;人工智能&#xff08;AI&#xff09;已成为推动现代社会前进的强劲引擎。 从智能手机的智能助手到自动驾驶汽车的精准导航&#xff0c;AI技术的应用已经渗透到生活的方方面面。 随着技术的飞速发展&#xff0c;我们正站在一个新的转折点上&#xff…

构建本地大语言模型知识库问答系统

MaxKB 2024 年 4 月 12 日&#xff0c;1Panel 开源项目组正式对外介绍了其官方出品的开源子项目 ——MaxKB&#xff08;github.com/1Panel-dev/MaxKB&#xff09;。MaxKB 是一款基于 LLM&#xff08;Large Language Model&#xff09;大语言模型的知识库问答系统。MaxKB 的产品…

[论文笔记]GAUSSIAN ERROR LINEAR UNITS (GELUS)

引言 今天来看一下GELU的原始论文。 作者提出了GELU(Gaussian Error Linear Unit,高斯误差线性单元)非线性激活函数&#xff1a; GELU x Φ ( x ) \text{GELU} x\Phi(x) GELUxΦ(x)&#xff0c;其中 Φ ( x ) \Phi(x) Φ(x)​是标准高斯累积分布函数。与ReLU激活函数通过输入…

网盘—上传文件

本文主要讲解网盘里面关于文件操作部分的上传文件&#xff0c;具体步骤如下 目录 1、实施步骤&#xff1a; 2、代码实现 2.1、添加上传文件协议 2.2、添加上传文件槽函数 2.3、添加槽函数定义 2.4、关联上传槽函数 2.5、服务器端 2.6、在服务器端添加上传文件请求的ca…

算法学习(5)-图的遍历

目录 什么是深度和广度优先 图的深度优先遍历-城市地图 图的广度优先遍历-最少转机 什么是深度和广度优先 使用深度优先搜索来遍历这个图的过程具体是&#xff1a; 首先从一个未走到过的顶点作为起始顶点&#xff0c; 比如以1号顶点作为起点。沿1号顶点的边去尝试访问其它未…

提升编码技能:学习如何使用 C# 和 Fizzler 获取特价机票

引言 五一假期作为中国的传统节日&#xff0c;也是旅游热门的时段之一&#xff0c;特价机票往往成为人们关注的焦点。在这个数字化时代&#xff0c;利用爬虫技术获取特价机票信息已成为一种常见的策略。通过结合C#和Fizzler库&#xff0c;我们可以更加高效地实现这一目标&…

2024年---蓝桥杯网络安全赛道部分WP

一、题目名称&#xff1a;packet 1、下载附件是一个流量包 2、用wireshark分析&#xff0c;看到了一个cat flag的字样 3、追踪http数据流&#xff0c;在下面一行看到了base64编码。 4、解码之后得到flag 二、题目名称&#xff1a;cc 1、下载附件&#xff0c;打开是一个html …

Docker构建LNMP部署WordPress

前言 使用容器化技术如 Docker 可以极大地简化应用程序的部署和管理过程&#xff0c;本文将介绍如何利用 Docker 构建 LNMP 环境&#xff0c;并通过部署 WordPress 来展示这一过程。 目录 一、环境准备 1. 项目需求 2. 安装包下载 3. 服务器环境 4. 规划工作目录 5. 创…

CAPS Wizard for Mac:打字输入辅助应用

CAPS Wizard for Mac是一款专为Mac用户设计的打字输入辅助应用&#xff0c;以其简洁、高效的功能&#xff0c;为用户带来了全新的打字体验。 CAPS Wizard for Mac v5.3激活版下载 该软件能够智能预测用户的输入内容&#xff0c;实现快速切换和自动大写锁定&#xff0c;从而大大…

OmniReader Pro for Mac:强大且全面的阅读工具

OmniReader Pro for Mac是一款专为Mac用户设计的强大且全面的阅读工具&#xff0c;它集阅读、编辑、管理等多种功能于一身&#xff0c;为用户提供了卓越的阅读体验。 OmniReader Pro for Mac v2.9.5激活版下载 该软件支持多种文件格式的阅读&#xff0c;包括PDF、Word、Excel、…

pycharm配置wsl开发环境(conda)

背景 在研究qanything项目的过程中&#xff0c;为了进行二次开发&#xff0c;需要在本地搭建开发环境。然后根据文档说明发现该项目并不能直接运行在windows开发环境&#xff0c;但可以运行在wsl环境中。于是我需要先创建wsl环境并配置pycharm。 wsl环境创建 WSL是“Windows Su…