Bert 在 OCNLI 训练微调

目录

  • 0 资料
  • 1 预训练权重
  • 2 wandb
  • 3 Bert-OCNLI
    • 3.1 目录结构
    • 3.2 导入的库
    • 3.3 数据集
      • 自然语言推断
      • 数据集路径
      • 读取数据集
      • 数据集样例展示
      • 数据集类别统计
      • 数据集类
      • 加载数据
    • 3.4 Bert
    • 3.4 训练
  • 4 训练微调结果
    • 3k
    • 10k
    • 50k

0 资料

【数据集微调】

阿里天池比赛 微调BERT的数据集(“任务1:OCNLI–中文原版自然语言推理”)

数据集地址:https://tianchi.aliyun.com/competition/entrance/531841/information

由于这个比赛已经结束,原地址提交不了榜单看测试结果,请参照下面的信息,下载数据集、提交榜单测试。

  • “任务1:OCNLI–中文原版自然语言推理”数据集的GitHub地址:https://github.com/CLUEbenchmark/OCNLI

  • 榜单提交地址:https://www.cluebenchmarks.com/index.html

  • 榜单提交步骤:

    • 打开“榜单提交地址”,点击“立即测评”——填写相关信息(github地址填https://github.com/CLUEbenchmark/CLUE,其他信息任意填)。
    • 上传一个.zip压缩文件,在压缩文件里存放我们模型预测结果的文件。
    • 点击提交。
  • 【注意】预测结果文件的格式:https://storage.googleapis.com/cluebenchmark/tasks/clue_submit_examples.zip

15.4. 自然语言推断与数据集:https://zh-v2.d2l.ai/chapter_natural-language-processing-applications/natural-language-inference-and-dataset.html

15.7. 自然语言推断:微调BERT:https://zh-v2.d2l.ai/chapter_natural-language-processing-applications/natural-language-inference-bert.html#id3

保姆级教程,用PyTorch和BERT进行文本分类:https://zhuanlan.zhihu.com/p/524487313

1 预训练权重

在国内,一般是手动下载预训练权重,而非网络自动下载。

我们将用到 chinese-macbert-base 这个预训练文件,下载网址如下:

https://huggingface.co/hfl/chinese-macbert-base/tree/main

除了叉掉的,其余都要下载。
在这里插入图片描述

2 wandb

pip install wandb

WandB 是一个用于实验跟踪、版本控制和结果可视化的工具,主要用于机器学习项目。
wandb使用教程(一):基础用法:https://zhuanlan.zhihu.com/p/493093033

3 Bert-OCNLI

3.1 目录结构

在这里插入图片描述

3.2 导入的库

import os
import torch
from torch import nn
import pandas as pd
from transformers import BertModel, BertTokenizer
from torch.optim import Adam
from tqdm import tqdm

3.3 数据集

自然语言推断

自然语言推断(natural language inference)主要研究 假设(hypothesis)是否可以从前提(premise)中推断出来, 其中两者都是文本序列。 换言之,自然语言推断决定了一对文本序列之间的逻辑关系。这类关系通常分为三种类型:

蕴涵(entailment):假设可以从前提中推断出来。矛盾(contradiction):假设的否定可以从前提中推断出来。中性(neutral):所有其他情况。

自然语言推断也被称为识别文本蕴涵任务。 例如,下面的一个文本对将被贴上“蕴涵”的标签,因为假设中的“表白”可以从前提中的“拥抱”中推断出来。

前提:两个女人拥抱在一起。假设:两个女人在示爱。

下面是一个“矛盾”的例子,因为“运行编码示例”表示“不睡觉”,而不是“睡觉”。

前提:一名男子正在运行Dive Into Deep Learning的编码示例。假设:该男子正在睡觉。

第三个例子显示了一种“中性”关系,因为“正在为我们表演”这一事实无法推断出“出名”或“不出名”。

前提:音乐家们正在为我们表演。假设:音乐家很有名。

自然语言推断一直是理解自然语言的中心话题。它有着广泛的应用,从信息检索到开放领域的问答。为了研究这个问题,我们将首先研究一个流行的自然语言推断基准数据集。

数据集路径

# 数据集路径
data_dir = 'OCNLI/data/ocnli'

读取数据集

# 读ocnli,两个参数,data_dir是数据集的路径,is_train为bool类型,True代表训练,False代表验证
def read_ocnli(data_dir, is_train):# 将ocnli解析为前提、假设、标签# labels_map是标签映射,0、1、2代表三类,3代表无法分类(或者应该去除的数据)。labels_map = {'entailment':0, 'neutral':1, 'contradiction':2, '-': 3}file_name = os.path.join(data_dir, 'train.3k.json' if is_train else 'dev.json')rows = pd.read_json(file_name, lines=True)premises = [sentence1 for sentence1 in rows['sentence1'] ]  # 前提hypotheses = [sentence2 for sentence2 in rows['sentence2'] ] # 假设# if label != '-' 是为了去除无法分类的标签labels = [labels_map[label] for label in rows['label'] if label != '-'] # 标签return premises, hypotheses, labels

数据集样例展示

# 样例展示
train_data = read_ocnli(data_dir, is_train=True)
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):print("前提:", x0)print("假设:", x1)print("标签:", y)

结果:

前提: 现在,我代表国务院,向大会报告政府工作,请予审议,并请全国政协委员提出意见
假设: 全国政协委员无权提出建议
标签: 2
前提: 不过以后呢,两年增加一次工资.
假设: 多年之后工资很高
标签: 1
前提: 一万块,嗯那头盔要八千.
假设: 说话的人很有钱
标签: 1

数据集类别统计

# 类别数据统计
val_data = read_ocnli(data_dir, is_train=False)label_set = [0, 1, 2]for data in [train_data, val_data]:print([[row for row in data[2]].count(i) for i in label_set])

结果:

[974, 1054, 966]
[947, 1103, 900]

数据集类

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
class OCNLI_Dataset(torch.utils.data.Dataset):def __init__(self, dataset):sentence1 = [sentence1 for sentence1 in dataset[0]]sentence2 = [sentence2 for sentence2 in dataset[1]]# 用 _ 将前提和假设拼接在一起,但这应该不是好的做法sentence1_2 = ['{}_{}'.format(a, b) for a, b in zip(sentence1, sentence2)]self.texts = [tokenizer(sentence, padding='max_length', # bert最大可以设置到512,对OCNLI的统计计算中,# 发现所有数据没有超过128,max_length越大,计算量越大max_length = 128, truncation=True,return_tensors="pt") for sentence in sentence1_2 ] self.labels = torch.tensor(dataset[2])def __len__(self):return len(self.labels)def __getitem__(self, idx):return self.texts[idx], self.labels[idx]

加载数据

train_set = OCNLI_Dataset(read_ocnli(data_dir, True))
test_set = OCNLI_Dataset(read_ocnli(data_dir, False))
print(len(train_set))
# for train_input, train_label in train_set:
#     print(train_input)
#     print(train_label)
#     input()

结果:

3000

3.4 Bert

class BertClassifier(nn.Module):def __init__(self, dropout=0.5):super(BertClassifier, self).__init__()self.bert = BertModel.from_pretrained('bert-base-chinese')self.dropout = nn.Dropout(dropout)self.linear = nn.Linear(768, 3) # 这里的3代表输出的类别self.relu = nn.ReLU()def forward(self, input_id, mask):_, pooled_output = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)dropout_output = self.dropout(pooled_output)linear_output = self.linear(dropout_output)final_layer = self.relu(linear_output)return final_layer

3.4 训练

def train(model, train_data, val_data, learning_rate, epochs):# 通过Dataset类获取训练和验证集train, val = OCNLI_Dataset(train_data), OCNLI_Dataset(val_data)# DataLoader根据batch_size获取数据,训练时选择打乱样本train_dataloader = torch.utils.data.DataLoader(train, batch_size=32, shuffle=True)val_dataloader = torch.utils.data.DataLoader(val, batch_size=32)# 判断是否使用GPUuse_cuda = torch.cuda.is_available()device = torch.device("cuda" if use_cuda else "cpu")# 定义损失函数和优化器criterion = nn.CrossEntropyLoss()optimizer = Adam(model.parameters(), lr=learning_rate)if use_cuda:model = model.cuda()criterion = criterion.cuda()# 开始进入训练循环for epoch_num in range(epochs):# 定义两个变量,用于存储训练集的准确率和损失total_acc_train = 0total_loss_train = 0# 进度条函数tqdmfor train_input, train_label in tqdm(train_dataloader):train_label = train_label.to(device)mask = train_input['attention_mask'].to(device)input_id = train_input['input_ids'].squeeze(1).to(device)# 通过模型得到输出output = model(input_id, mask)# 计算损失batch_loss = criterion(output, train_label)# input()total_loss_train += batch_loss.item()# print("total_loss_train:",total_loss_train)# 计算精度acc = (output.argmax(dim=1) == train_label).sum().item()total_acc_train += acc# 模型更新model.zero_grad()batch_loss.backward()optimizer.step()# ------ 验证模型 -----------# 定义两个变量,用于存储验证集的准确率和损失total_acc_val = 0total_loss_val = 0# 不需要计算梯度with torch.no_grad():# 循环获取数据集,并用训练好的模型进行验证for val_input, val_label in val_dataloader:# 如果有GPU,则使用GPU,接下来的操作同训练val_label = val_label.to(device)mask = val_input['attention_mask'].to(device)input_id = val_input['input_ids'].squeeze(1).to(device)output = model(input_id, mask)batch_loss = criterion(output, val_label)total_loss_val += batch_loss.item()acc = (output.argmax(dim=1) == val_label).sum().item()total_acc_val += accprint(f'''Epochs: {epoch_num + 1} | Train Loss: {total_loss_train / len(train): .3f} | Train Accuracy: {total_acc_train / len(train): .3f} | Val Loss: {total_loss_val / len(train): .3f} | Val Accuracy: {total_acc_val / len(train): .3f}''')     print("total_loss_train:",total_loss_train)print("total_acc_train:",total_acc_train)print("total_loss_val:",total_loss_val)print("total_acc_val:",total_acc_val)print("len(train_data):",len(train))          
EPOCHS = 50
model = BertClassifier()
LR = 1e-6
train(model, read_ocnli(data_dir, True), read_ocnli(data_dir, False), LR, EPOCHS)

在这里插入图片描述

4 训练微调结果

3k

10k

50k

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/9456.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

想学PR的有福了,一小时学会PR剪视频

想学PR的有福了,一小时学会PR剪视频 Pr是什么软件?教程介绍及教程展示教程领取结语下期更新预报 Pr是什么软件? Pr是指Adobe Premiere Pro,它是由Adobe公司开发的一款专业级的视频编辑软件。这款软件广泛应用于电影、电视和网页视…

SQL统计语句记录

1.达梦数据库 统计指定单位的12个月份的业务数据 SELECT a.DEPT_ID, b.dept_name, a.USER_NAME, count(a.dept_id) as count, sum(case when to_char(a.CREATE_TIME,yyyy-mm) 2023-01 THEN 1 else 0 end) as one,sum(case when to_char(a.CREATE_TIME,yyyy-mm) 2023-02 T…

【JavaEE 初阶(四)】多线程进阶

❣博主主页: 33的博客❣ ▶️文章专栏分类:JavaEE◀️ 🚚我的代码仓库: 33的代码仓库🚚 🫵🫵🫵关注我带你了解更多线程知识 目录 1.前言2.常见的锁策略2.1悲观锁vs乐观锁2.2轻量级锁vs重量级锁2.3自旋锁vs挂起锁2.4读写…

【数据结构(邓俊辉)学习笔记】栈与队列01——栈应用(栈混洗、前缀后缀表达式、括号匹配)

文章目录 0. 概述1. 操作与接口2. 操作实例3. 实现4. 栈与递归5. 应用5.1 逆序输出5.1.1 进制转换5.1.1.1 思路5.1.1.2 算法实现 5.2 递归嵌套5.2.1 栈混洗5.2.1.1 混洗5.2.1.2 计数5.2.1.3 甄别 5.2.2 括号匹配5.2.2.1 构思5.2.2.2 实现5.2.2.3 实例 5.3 延迟缓冲5.3.1 中缀表…

(✌)粤嵌—2024/5/9—寻找两个正序数组的中位数

代码实现&#xff1a; int binary_search(int *arr, int n, int key) {int head 0, tail n - 1, mid;while (head < tail) {mid (head tail) / 2;if (arr[mid] key) {return mid;}if (arr[mid] > key) {tail mid - 1;} else {head mid 1;}}return head; }void in…

JetBrains的Java集成开发环境IntelliJ 2024.1版本在Windows/Linux系统的下载与安装配置

目录 前言一、IntelliJ在Windows安装二、IntelliJ在Linux安装三、Windows下使用配置四、Linux下使用配置总结 前言 ​ “ IntelliJ IDEA Ultimate是一款功能强大的Java集成开发环境&#xff08;IDE&#xff09;。它提供了丰富的功能和工具&#xff0c;可以帮助开发人员更高效地…

1067 试密码(测试点2测试点5)

solution 测试点2,5 : The test may have space,so you should use getline() function but not cin() function #include<iostream> #include<string> using namespace std; int main(){string ans, test;int n, cnt 0;cin >> ans >> n;getchar();…

基于 C# 开源的 EF Core 查询计划可视化神器

介绍 EFCore.Visualizer 是 Entity Framework Core 查询计划调试器&#xff0c;一个开源的 EF Core 查询计划可视化工具, 您可以直接在 Visual Studio 中查看查询的查询计划&#xff0c;开箱即用&#xff0c;非常方便。目前&#xff0c;可视化工具支持 SQL Server 和 PostgreS…

java后端15问!

前言 最近一位粉丝去面试一个中厂&#xff0c;Java后端。他说&#xff0c;好几道题答不上来&#xff0c;于是我帮忙整理了一波答案 G1收集器JVM内存划分对象进入老年代标志你在项目中用到的是哪种收集器&#xff0c;怎么调优的new对象的内存分布局部变量的内存分布Synchroniz…

笨方法学习python(七)

输入 一般软件做的事情主要就是下面几条&#xff1a; 接受人的输入。改变输入。打印出改变了的输入。 前面几节都是print输出&#xff0c;这节了解一下输入input&#xff1b;在python2中使用的是raw_input&#xff0c;python3就只是input。 print ("How old are you?&…

springboot如何查看版本号之间的相互依赖

第一种&#xff1a; 查看本地项目maven的依赖&#xff1a; ctrl鼠标左键&#xff1a;按下去可以进入maven的下一层&#xff1a; ctrl鼠标左键&#xff1a;按下去可以进入maven的再下一层&#xff1a; 就可以查看springboot的一些依赖版本号了&#xff1b; 第二种&#xff1a; 还…

RuoYi-Vue-Plus (Echarts 图表)

一、echarts 图表介绍和使用 官网地址:目前echarts以及贡献给Apache Apache EChartshttps://echarts.apache.org/zh/index.htmlecharts配置项手册 Documentation - Apache EChartshttps://echarts.apache.org/z

【快捷部署】022_ZooKeeper(3.5.8)

&#x1f4e3;【快捷部署系列】022期信息 编号选型版本操作系统部署形式部署模式复检时间022ZooKeeper3.5.8Ubuntu 20.04tar包单机2024-05-07 一、快捷部署 #!/bin/bash ################################################################################# # 作者&#xff…

宏的优缺点?C++有哪些技术替代宏?(const)权限的平移、缩小

宏的优缺点&#xff1f; 优点&#xff1a; 1.增强代码的复用性。【减少冗余代码】 2.提高性能&#xff0c;提升代码运行效率。 缺点&#xff1a; 1.不方便调试宏。&#xff08;因为预编译阶段进行了替换&#xff09; 2.导致代码可读性差&#xff0c;可维护性差&#xff0…

OpenSSL实现AES的ECB和CBC加解密,可一次性加解密任意长度的明文字符串或字节流(QT C++环境)

本篇博文讲述如何在Qt C的环境中使用OpenSSL实现AES-ECB/CBC-Pkcs7加/解密&#xff0c;可以一次性加解密一个任意长度的明文字符串或者字节流&#xff0c;但不适合分段读取加解密的&#xff08;例如&#xff0c;一个4GB的大型文件需要加解密&#xff0c;要分段读取&#xff0c;…

基于无监督学习算法的滑坡易发性评价的实施(k聚类、谱聚类、Hier聚类)

基于无监督学习算法的滑坡易发性评价的实施 1. k均值聚类2. 谱聚类3. Hier聚类4. 基于上述聚类方法的易发性实施本研究中的数据集和代码可从以下链接下载: 数据集实施代码1. k均值聚类 K-Means 聚类是一种矢量量化方法,最初来自信号处理,旨在将 N 个观测值划分为 K 个聚类,…

我悟了!24年软考架构就这100道母题,历史重复率90%

距离软考考试的时间越来越近了&#xff0c;趁着这两周赶紧准备起来 今天给大家整理了——系统架构设计师100道经典母题&#xff0c;有PDF&#xff0c;可打印&#xff0c;每天刷几道。 一、计算机系统基础&#xff08;12&#xff09; 1. 计算机采用分级存储体系的主要目的是为了…

深度学习笔记001

目录 一、批量规范化 二、残差网络ResNet 三、稠密连接网络&#xff08;DenseNet&#xff09; 四、循环神经网络 五、信息论 六、梯度截断 本篇blog仅仅是本人在学习《动手学深度学习 Pytorch版》一书中做的一些笔记&#xff0c;感兴趣的读者可以去官网http://zh.gluon.a…

中小学校活动向媒体投稿报道宣传有哪些好方法

作为一所中小学校的教师,我肩负着向外界展示学校风采、宣传校园文化活动的重要使命。起初,每当学校举办特色活动或取得教学成果时,我都会满怀热情地撰写新闻稿,希望通过媒体的平台让更多人了解我们的故事。然而,理想丰满,现实骨感,我很快发现,通过电子邮件向媒体投稿的过程充满…

技术速递|Python in Visual Studio Code 2024年4月发布

排版&#xff1a;Alan Wang 我们很高兴地宣布 Visual Studio Code 的 Python 和 Jupyter 扩展 2024 年 4 月发布&#xff01; 此版本包括以下公告&#xff1a; 改进了 Flask 和 Django 的调试配置流程Jupyter Run Dependent Cells with Pylance 的模块和导入分析Hatch 环境发…