昇思25天学习打卡营第13天|BERT

一、简介:

BERT全称是来自变换器的双向编码器表征量(Bidirectional Encoder Representations from Transformers),它是Google于2018年末开发并发布的一种新型语言模型。与BERT模型相似的预训练语言模型例如问答、命名实体识别、自然语言推理、文本分类等在许多自然语言处理任务中发挥着重要作用。模型是基于Transformer中的Encoder并加上双向的结构,因此一定要熟练掌握Transformer的Encoder的结构。

BERT模型的主要创新点都在pre-train方法上,即用了Masked Language Model和Next Sentence Prediction两种方法分别捕捉词语和句子级别的representation。

在用Masked Language Model方法训练BERT的时候,随机把语料库中15%的单词做Mask操作。对于这15%的单词做Mask操作分为三种情况:80%的单词直接用[Mask]替换、10%的单词直接替换成另一个新的单词、10%的单词保持不变。

因为涉及到Question Answering (QA) 和 Natural Language Inference (NLI)之类的任务,增加了Next Sentence Prediction预训练任务,目的是让模型理解两个句子之间的联系。与Masked Language Model任务相比,Next Sentence Prediction更简单些,训练的输入是句子A和B,B有一半的几率是A的下一句,输入这两个句子,BERT模型预测B是不是A的下一句。

BERT预训练之后,会保存它的Embedding table和12层Transformer权重(BERT-BASE)或24层Transformer权重(BERT-LARGE)。使用预训练好的BERT模型可以对下游任务进行Fine-tuning,比如:文本分类、相似度判断、阅读理解等。

二、环境准备:

在导入包之前,首先咱们要确定已经在实验的环境中安装了mindspore和mindnlp:mindspore的安装可以参考昇思25天学习打卡营第1天|快速入门-CSDN博客,mindnlp的安装则直接运行下面的命令即可:

pip install mindnlp==0.3.1

安装完成之后,导入我们下面训练需要的包:

import os
import timeimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, contextfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy

三、数据集:

1、数据集下载:

这里使用的是来自于百度飞桨的一份已标注的、经过分词预处理的机器人聊天数据集。数据由两列组成,以制表符('\t')分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本,如下示例,文件为 utf8 编码。

label--text_a

0--谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?

1--我有事等会儿就回来和你聊

2--我见到你很高兴谢谢你帮我

wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
tar xvf emotion_detection.tar.gz

2、数据预处理:

新建 process_dataset 函数用于数据加载和数据预处理:

import numpy as npdef process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):is_ascend = mindspore.get_context('device_target') == 'Ascend'column_names = ["label", "text_a"]dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)# transformstype_cast_op = transforms.TypeCast(mindspore.int32)def tokenize_and_pad(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text)return tokenized['input_ids'], tokenized['attention_mask']# map datasetdataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset

昇腾NPU环境下暂不支持动态Shape,数据预处理部分采用静态Shape处理,也就是说数据预处理阶段需要采用固定的数据形状。

定义bert的分词器:

from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
tokenizer.pad_token_id

将数据集划分为训练集、开发集和测试集:

class SentimentDataset:"""Sentiment Dataset"""def __init__(self, path):self.path = pathself._labels, self._text_a = [], []self._load()def _load(self):with open(self.path, "r", encoding="utf-8") as f:dataset = f.read()lines = dataset.split("\n")for line in lines[1:-1]:label, text_a = line.split("\t")self._labels.append(int(label))self._text_a.append(text_a)def __getitem__(self, index):return self._labels[index], self._text_a[index]def __len__(self):return len(self._labels)dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)dataset_train.get_col_names()
print(next(dataset_train.create_tuple_iterator()))

四、模型构建:

通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。

from mindnlp.transformers import BertForSequenceClassification, BertModel
from mindnlp._legacy.amp import auto_mixed_precision# set bert config and define parameters for training
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
model = auto_mixed_precision(model, 'O1')optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_val, metrics=metric,epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])%%time
# start training
trainer.run(tgt_columns="labels")

五、模型验证:

将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

六、模型推理:

遍历推理数据集,将结果与标签进行统一展示。

dataset_infer = SentimentDataset("data/infer.tsv")def predict(text, label=None):label_map = {0: "消极", 1: "中性", 2: "积极"}text_tokenized = Tensor([tokenizer(text).input_ids])logits = model(text_tokenized)predict_label = logits[0].asnumpy().argmax()info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"if label is not None:info += f" , label: '{label_map[label]}'"print(info)from mindspore import Tensorfor label, text in dataset_infer:predict(text, label)

输入一句话测试一下(doge):

predict("家人们咱就是说一整个无语住了 绝绝子叠buff")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/38689.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2.3章节Python中的数值类型

1.整型数值 2.浮点型数值 3.复数   Python中的数值类型清晰且丰富,主要分为以下几种类型,每种类型都有其特定的用途和特性。 一、整型数值 1.定义:整数类型用于表示整数值,如1、-5、100等。 2.特点: Python 3中的…

Quartz表达式:定时任务调度的高级配置与应用

Quartz表达式:定时任务调度的高级配置与应用 大家好,我是免费搭建查券返利机器人省钱赚佣金就用微赚淘客系统3.0的小编,也是冬天不穿秋裤,天冷也要风度的程序猿! 1. Quartz表达式概述 Quartz是一个强大的开源作业调…

卡尔曼滤波公式推导笔记

视频见B站上DR_CAN的卡尔曼滤波器 【卡尔曼滤波器】3_卡尔曼增益超详细数学推导 ~全网最完整_哔哩哔哩_bilibili

动手学深度学习5.6 GPU-笔记练习(PyTorch)

以下内容为结合李沐老师的课程和教材补充的学习笔记,以及对课后练习的一些思考,自留回顾,也供同学之人交流参考。 本节课程地址:17 使用和购买 GPU【动手学深度学习v2】_哔哩哔哩_bilibili 本节教材地址:5.6. GPU —…

数据库定义语言(DDL)

数据库定义语言(DDL) 一、数据库操作 1、 查询所有的数据库 SHOW DATABASES;效果截图: 2、使用指定的数据库 use 2403 2403javaee;效果截图: 3、创建数据库 CREATE DATABASE 2404javaee;效果截图: 4、删除数据…

玩转springboot之springboot热部署

springboot热部署 热部署是在服务器运行时重新部署项目&#xff0c;直接加载整个应用&#xff0c;会释放内存&#xff0c;不过比较耗时 配置tomcat实现热部署 有三种方式 方式一 把项目web文件放在webapps目录下 方式二 在tomcat\conf\server.xml中的<host>标签内添加<…

面向阿克曼移动机器人(自行车模型)的LQR(最优二次型调节器)路径跟踪方法

线性二次调节器&#xff08;Linear Quadratic Regulator&#xff0c;LQR&#xff09;是针对线性系统的最优控制方法。LQR 方法标准的求解体系是在考虑到损耗尽可能小的情况下, 以尽量小的代价平衡其他状态分量。一般情况下&#xff0c;线性系统在LQR 控制方法中用状态空间方程描…

opencv c++ python获取摄像头默认分辨率及设置缩放倍数

c代码 #include <opencv2/opencv.hpp> #include <iostream>int main() {// 创建一个VideoCapture对象cv::VideoCapture cap(0); // 参数0表示打开默认摄像头// 检查摄像头是否成功打开if (!cap.isOpened()) {std::cerr << "Error: Could not open came…

Android super.img结构及解包和重新组包

Android super.img结构及解包和重新组包 从Android10版本开始&#xff0c;Android系统使用动态分区&#xff0c;system、vendor、 odm等都包含在super.img里面&#xff0c;编译后的最终镜像不再有这些单独的 image&#xff0c;取而代之的是一个总的 super.img. 1. 基础知识 …

【Unity】RPG2D龙城纷争(七)关卡编辑器之剧情编辑

更新日期:2024年7月1日。 项目源码:第五章发布(正式开始游戏逻辑的章节) 索引 简介一、剧情编辑1.对话数据集2.对话触发方式3.选择对话角色4.设置对话到关卡5.通关条件简介 严格来说,剧情编辑不在关卡编辑器界面中完成,只不过它仍然属于关卡编辑的范畴。 在我们的设想中…

鸿蒙:页面动画-属性动画、显示动画

1.属性动画是通过设置组件的animation属性来给组件添加动画&#xff0c;当组件的width、height、backgroundColor、scale等属性变更时可以实现过渡渐变效果。 2.显示动画是通过全局animateTo函数来修改组件的属性&#xff0c;实现属性变化时的渐变过渡效果。 核心属性

【你也能从零基础学会网站开发】关系型数据库中的表(Table)设计结构以及核心组成部分

&#x1f680; 个人主页 极客小俊 ✍&#x1f3fb; 作者简介&#xff1a;程序猿、设计师、技术分享 &#x1f40b; 希望大家多多支持, 我们一起学习和进步&#xff01; &#x1f3c5; 欢迎评论 ❤️点赞&#x1f4ac;评论 &#x1f4c2;收藏 &#x1f4c2;加关注 关系型数据库中…

【Git 学习笔记】Ch1.1 Git 简介 + Ch1.2 Git 对象

还是绪个言吧 今天整理 GitHub 仓库&#xff0c;无意间翻到了几年前自学 Git 的笔记。要论知识的稳定性&#xff0c;Git 应该能挤进前三——只要仓库还在&#xff0c;理论上当时的所有开发细节都可以追溯出来。正好过段时间会用到 Git&#xff0c;现在整理出来就当温故知新了。…

DiffBIR: Towards Blind Image Restoration with Generative Diffusion Prior

深圳先进研究院&上海ai lab&港中文https://github.com/XPixelGroup/DiffBIRhttps://arxiv.org/pdf/2308.15070 问题引入 使用一个统一的框架来处理image restoration任务&#xff0c;包含图片超分BSR&#xff0c;图片去噪BID和人脸restoration BFR&#xff0c;分为两…

【从零开始学架构 架构基础】五 架构设计的复杂度来源:低成本、安全、规模

架构设计的复杂度来源其实就是架构设计要解决的问题&#xff0c;主要有如下几个&#xff1a;高性能、高可用、可扩展、低成本、安全、规模。复杂度的关键&#xff0c;就是新旧技术之间不是完全的替代关系&#xff0c;有交叉&#xff0c;有各自的特点&#xff0c;所以才需要具体…

微信AI机器人智能助手:利用大模型定制训练知识库

随着人工智能技术的迅速发展&#xff0c;AI已经渗透到了我们生活得方方面面。AI文本撰写、AI绘画、AI生成视频、AI换脸等各类应用层出不穷。作为领先的创新人工智能和元宇宙厂商&#xff0c;道可云凭借自身在人工智能、元宇宙、虚拟数字人等领域的技术积累&#xff0c;将AI技术…

Android SurfaceFlinger ——获取显示屏信息(十八)

经过前面文章对开机启动动画的流程梳理,引出了实际上在开机启动动画中,并没有Activity,而是通过 OpenGL es 进行渲染,最后通过某种方式,把数据交给 Android 渲染系统。 让我们回忆一下开机动画前期准备的相关步骤,大致分为如下几个: 1)getInternalDisplayToken:获取显…

在C++中何时应该使用异常处理

在C中&#xff0c;异常处理是一种用于处理运行时错误的技术&#xff0c;它允许程序在发生错误时优雅地恢复或终止&#xff0c;而不是突然崩溃或产生不可预测的行为。以下是几种情况下应该使用异常处理的场景&#xff1a; 错误恢复&#xff1a;当程序遇到可以从中恢复的错误时&…

跨越界限,巴比达带你访问远程桌面【内网穿透技术分享】

在远程工作的时代&#xff0c;远程桌面访问成为了许多职场人士的日常。Windows系统默认的远程桌面服务监听在3389端口&#xff0c;但对于内网环境下的机器来说&#xff0c;直接从外部访问这个端口常常面临重重阻碍。不过&#xff0c;有了巴比达内网穿透&#xff0c;这一切都将不…

141个图表,完美展示数据分类别关系!

本文介绍使用Python工具seaborn详细实现分类关系图表&#xff0c;包含8类图141个代码模版。 分类关系图表用于展示数字变量和一个或多个分类变量之间的关系&#xff0c;可以进一步分为&#xff1a;箱形图&#xff08;box plot&#xff09;、增强箱形图&#xff08;enhanced bo…