昇思25天学习打卡营第13天|BERT

一、简介:

BERT全称是来自变换器的双向编码器表征量(Bidirectional Encoder Representations from Transformers),它是Google于2018年末开发并发布的一种新型语言模型。与BERT模型相似的预训练语言模型例如问答、命名实体识别、自然语言推理、文本分类等在许多自然语言处理任务中发挥着重要作用。模型是基于Transformer中的Encoder并加上双向的结构,因此一定要熟练掌握Transformer的Encoder的结构。

BERT模型的主要创新点都在pre-train方法上,即用了Masked Language Model和Next Sentence Prediction两种方法分别捕捉词语和句子级别的representation。

在用Masked Language Model方法训练BERT的时候,随机把语料库中15%的单词做Mask操作。对于这15%的单词做Mask操作分为三种情况:80%的单词直接用[Mask]替换、10%的单词直接替换成另一个新的单词、10%的单词保持不变。

因为涉及到Question Answering (QA) 和 Natural Language Inference (NLI)之类的任务,增加了Next Sentence Prediction预训练任务,目的是让模型理解两个句子之间的联系。与Masked Language Model任务相比,Next Sentence Prediction更简单些,训练的输入是句子A和B,B有一半的几率是A的下一句,输入这两个句子,BERT模型预测B是不是A的下一句。

BERT预训练之后,会保存它的Embedding table和12层Transformer权重(BERT-BASE)或24层Transformer权重(BERT-LARGE)。使用预训练好的BERT模型可以对下游任务进行Fine-tuning,比如:文本分类、相似度判断、阅读理解等。

二、环境准备:

在导入包之前,首先咱们要确定已经在实验的环境中安装了mindspore和mindnlp:mindspore的安装可以参考昇思25天学习打卡营第1天|快速入门-CSDN博客,mindnlp的安装则直接运行下面的命令即可:

pip install mindnlp==0.3.1

安装完成之后,导入我们下面训练需要的包:

import os
import timeimport mindspore
from mindspore.dataset import text, GeneratorDataset, transforms
from mindspore import nn, contextfrom mindnlp._legacy.engine import Trainer, Evaluator
from mindnlp._legacy.engine.callbacks import CheckpointCallback, BestModelCallback
from mindnlp._legacy.metrics import Accuracy

三、数据集:

1、数据集下载:

这里使用的是来自于百度飞桨的一份已标注的、经过分词预处理的机器人聊天数据集。数据由两列组成,以制表符('\t')分隔,第一列是情绪分类的类别(0表示消极;1表示中性;2表示积极),第二列是以空格分词的中文文本,如下示例,文件为 utf8 编码。

label--text_a

0--谁骂人了?我从来不骂人,我骂的都不是人,你是人吗 ?

1--我有事等会儿就回来和你聊

2--我见到你很高兴谢谢你帮我

wget https://baidu-nlp.bj.bcebos.com/emotion_detection-dataset-1.0.0.tar.gz -O emotion_detection.tar.gz
tar xvf emotion_detection.tar.gz

2、数据预处理:

新建 process_dataset 函数用于数据加载和数据预处理:

import numpy as npdef process_dataset(source, tokenizer, max_seq_len=64, batch_size=32, shuffle=True):is_ascend = mindspore.get_context('device_target') == 'Ascend'column_names = ["label", "text_a"]dataset = GeneratorDataset(source, column_names=column_names, shuffle=shuffle)# transformstype_cast_op = transforms.TypeCast(mindspore.int32)def tokenize_and_pad(text):if is_ascend:tokenized = tokenizer(text, padding='max_length', truncation=True, max_length=max_seq_len)else:tokenized = tokenizer(text)return tokenized['input_ids'], tokenized['attention_mask']# map datasetdataset = dataset.map(operations=tokenize_and_pad, input_columns="text_a", output_columns=['input_ids', 'attention_mask'])dataset = dataset.map(operations=[type_cast_op], input_columns="label", output_columns='labels')# batch datasetif is_ascend:dataset = dataset.batch(batch_size)else:dataset = dataset.padded_batch(batch_size, pad_info={'input_ids': (None, tokenizer.pad_token_id),'attention_mask': (None, 0)})return dataset

昇腾NPU环境下暂不支持动态Shape,数据预处理部分采用静态Shape处理,也就是说数据预处理阶段需要采用固定的数据形状。

定义bert的分词器:

from mindnlp.transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
tokenizer.pad_token_id

将数据集划分为训练集、开发集和测试集:

class SentimentDataset:"""Sentiment Dataset"""def __init__(self, path):self.path = pathself._labels, self._text_a = [], []self._load()def _load(self):with open(self.path, "r", encoding="utf-8") as f:dataset = f.read()lines = dataset.split("\n")for line in lines[1:-1]:label, text_a = line.split("\t")self._labels.append(int(label))self._text_a.append(text_a)def __getitem__(self, index):return self._labels[index], self._text_a[index]def __len__(self):return len(self._labels)dataset_train = process_dataset(SentimentDataset("data/train.tsv"), tokenizer)
dataset_val = process_dataset(SentimentDataset("data/dev.tsv"), tokenizer)
dataset_test = process_dataset(SentimentDataset("data/test.tsv"), tokenizer, shuffle=False)dataset_train.get_col_names()
print(next(dataset_train.create_tuple_iterator()))

四、模型构建:

通过 BertForSequenceClassification 构建用于情感分类的 BERT 模型,加载预训练权重,设置情感三分类的超参数自动构建模型。后面对模型采用自动混合精度操作,提高训练的速度,然后实例化优化器,紧接着实例化评价指标,设置模型训练的权重保存策略,最后就是构建训练器,模型开始训练。

from mindnlp.transformers import BertForSequenceClassification, BertModel
from mindnlp._legacy.amp import auto_mixed_precision# set bert config and define parameters for training
model = BertForSequenceClassification.from_pretrained('bert-base-chinese', num_labels=3)
model = auto_mixed_precision(model, 'O1')optimizer = nn.Adam(model.trainable_params(), learning_rate=2e-5)metric = Accuracy()
# define callbacks to save checkpoints
ckpoint_cb = CheckpointCallback(save_path='checkpoint', ckpt_name='bert_emotect', epochs=1, keep_checkpoint_max=2)
best_model_cb = BestModelCallback(save_path='checkpoint', ckpt_name='bert_emotect_best', auto_load=True)trainer = Trainer(network=model, train_dataset=dataset_train,eval_dataset=dataset_val, metrics=metric,epochs=5, optimizer=optimizer, callbacks=[ckpoint_cb, best_model_cb])%%time
# start training
trainer.run(tgt_columns="labels")

五、模型验证:

将验证数据集加再进训练好的模型,对数据集进行验证,查看模型在验证数据上面的效果,此处的评价指标为准确率。

evaluator = Evaluator(network=model, eval_dataset=dataset_test, metrics=metric)
evaluator.run(tgt_columns="labels")

六、模型推理:

遍历推理数据集,将结果与标签进行统一展示。

dataset_infer = SentimentDataset("data/infer.tsv")def predict(text, label=None):label_map = {0: "消极", 1: "中性", 2: "积极"}text_tokenized = Tensor([tokenizer(text).input_ids])logits = model(text_tokenized)predict_label = logits[0].asnumpy().argmax()info = f"inputs: '{text}', predict: '{label_map[predict_label]}'"if label is not None:info += f" , label: '{label_map[label]}'"print(info)from mindspore import Tensorfor label, text in dataset_infer:predict(text, label)

输入一句话测试一下(doge):

predict("家人们咱就是说一整个无语住了 绝绝子叠buff")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/38689.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2.3章节Python中的数值类型

1.整型数值 2.浮点型数值 3.复数   Python中的数值类型清晰且丰富,主要分为以下几种类型,每种类型都有其特定的用途和特性。 一、整型数值 1.定义:整数类型用于表示整数值,如1、-5、100等。 2.特点: Python 3中的…

卡尔曼滤波公式推导笔记

视频见B站上DR_CAN的卡尔曼滤波器 【卡尔曼滤波器】3_卡尔曼增益超详细数学推导 ~全网最完整_哔哩哔哩_bilibili

动手学深度学习5.6 GPU-笔记练习(PyTorch)

以下内容为结合李沐老师的课程和教材补充的学习笔记,以及对课后练习的一些思考,自留回顾,也供同学之人交流参考。 本节课程地址:17 使用和购买 GPU【动手学深度学习v2】_哔哩哔哩_bilibili 本节教材地址:5.6. GPU —…

数据库定义语言(DDL)

数据库定义语言(DDL) 一、数据库操作 1、 查询所有的数据库 SHOW DATABASES;效果截图: 2、使用指定的数据库 use 2403 2403javaee;效果截图: 3、创建数据库 CREATE DATABASE 2404javaee;效果截图: 4、删除数据…

面向阿克曼移动机器人(自行车模型)的LQR(最优二次型调节器)路径跟踪方法

线性二次调节器(Linear Quadratic Regulator,LQR)是针对线性系统的最优控制方法。LQR 方法标准的求解体系是在考虑到损耗尽可能小的情况下, 以尽量小的代价平衡其他状态分量。一般情况下,线性系统在LQR 控制方法中用状态空间方程描…

Android super.img结构及解包和重新组包

Android super.img结构及解包和重新组包 从Android10版本开始,Android系统使用动态分区,system、vendor、 odm等都包含在super.img里面,编译后的最终镜像不再有这些单独的 image,取而代之的是一个总的 super.img. 1. 基础知识 …

鸿蒙:页面动画-属性动画、显示动画

1.属性动画是通过设置组件的animation属性来给组件添加动画,当组件的width、height、backgroundColor、scale等属性变更时可以实现过渡渐变效果。 2.显示动画是通过全局animateTo函数来修改组件的属性,实现属性变化时的渐变过渡效果。 核心属性

【你也能从零基础学会网站开发】关系型数据库中的表(Table)设计结构以及核心组成部分

🚀 个人主页 极客小俊 ✍🏻 作者简介:程序猿、设计师、技术分享 🐋 希望大家多多支持, 我们一起学习和进步! 🏅 欢迎评论 ❤️点赞💬评论 📂收藏 📂加关注 关系型数据库中…

【Git 学习笔记】Ch1.1 Git 简介 + Ch1.2 Git 对象

还是绪个言吧 今天整理 GitHub 仓库,无意间翻到了几年前自学 Git 的笔记。要论知识的稳定性,Git 应该能挤进前三——只要仓库还在,理论上当时的所有开发细节都可以追溯出来。正好过段时间会用到 Git,现在整理出来就当温故知新了。…

DiffBIR: Towards Blind Image Restoration with Generative Diffusion Prior

深圳先进研究院&上海ai lab&港中文https://github.com/XPixelGroup/DiffBIRhttps://arxiv.org/pdf/2308.15070 问题引入 使用一个统一的框架来处理image restoration任务,包含图片超分BSR,图片去噪BID和人脸restoration BFR,分为两…

【从零开始学架构 架构基础】五 架构设计的复杂度来源:低成本、安全、规模

架构设计的复杂度来源其实就是架构设计要解决的问题,主要有如下几个:高性能、高可用、可扩展、低成本、安全、规模。复杂度的关键,就是新旧技术之间不是完全的替代关系,有交叉,有各自的特点,所以才需要具体…

微信AI机器人智能助手:利用大模型定制训练知识库

随着人工智能技术的迅速发展,AI已经渗透到了我们生活得方方面面。AI文本撰写、AI绘画、AI生成视频、AI换脸等各类应用层出不穷。作为领先的创新人工智能和元宇宙厂商,道可云凭借自身在人工智能、元宇宙、虚拟数字人等领域的技术积累,将AI技术…

跨越界限,巴比达带你访问远程桌面【内网穿透技术分享】

在远程工作的时代,远程桌面访问成为了许多职场人士的日常。Windows系统默认的远程桌面服务监听在3389端口,但对于内网环境下的机器来说,直接从外部访问这个端口常常面临重重阻碍。不过,有了巴比达内网穿透,这一切都将不…

141个图表,完美展示数据分类别关系!

本文介绍使用Python工具seaborn详细实现分类关系图表,包含8类图141个代码模版。 分类关系图表用于展示数字变量和一个或多个分类变量之间的关系,可以进一步分为:箱形图(box plot)、增强箱形图(enhanced bo…

STM32第十四课:低功耗模式和RTC实时时钟

文章目录 需求一、低功耗模式1.睡眠模式2.停止模式3.待机模式 二、RTC实现实时时钟1.寄存器配置流程2.标准库开发3.主函数调用 三、需求实现代码 需求 1.实现睡眠模式、停止模式和待机模式。 2.实现RTC实时时间显示。 一、低功耗模式 电源对电子设备的重要性不言而喻&#xff…

UE5(c++)开发日志(3):将前面写的输出日志的方法进行封装

Public下新增一个c类: 选择无属性,因为不需要添加任何东西进去, 也不需要借助里面任何东西。 创建一个命名空间Debug,可以在命名空间内写一点静态方法 : namespace Debug{} static void Print(const FString& message, con…

Jenkins教程-12-发送html邮件测试报告

上一小节我们学习了发送钉钉测试报告通知的方法,本小节我们讲解一下发送html邮件测试报告的方法。 1、自动化用例执行完后,使用pytest_terminal_summary钩子函数收集测试结果,存入本地status.txt文件中,供Jenkins调用 #conftest…

全球AI新闻速递6.28

全球AI新闻速递 1.首款 Transformer 专用 AI 芯片 Sohu 登场。 2.钉钉:宣布对所有AI大模型厂商开放,首批7家接入。 3.华为联合清华大学发布《AI 终端白皮书》。 4.国家卫生健康委:推动AI技术在制定个性化营养、运动干预方案中的应用。 …

1Python的Pandas:基本简介

1. Pandas的简介 Pandas 是一个开源的 Python 数据分析库,由 Wes McKinney 在 2008 年开始开发,目的是为了解决数据分析任务中的各种需求。Pandas 是基于 NumPy 库构建的,它使得数据处理和分析工作变得更加快速和简单。Pandas 提供了易于使用…

项目实战--Spring Boot实现三次登录容错功能

一、功能描述 项目设计要求输入三次错误密码后,要求隔段时间才能继续进行登录操作,这里简单记录一下实现思路 二、设计方案 有几个问题需要考虑一下: 1.是只有输错密码才锁定,还是账户名和密码任何一个输错就锁定?2…