Bert中文文本分类

这是一个经典的文本分类问题,使用google的预训练模型BERT中文版bert-base-chinese来做中文文本分类。可以先在Huggingface上下载预训练模型备用。https://huggingface.co/google-bert/bert-base-chinese/tree/main

我使用的训练环境是

pip install torch==2.0.0;
pip install transformers==4.30.2;
pip install gensim==4.3.3;
pip install huggingface-hub==0.15.1;
pip install modelscope==1.20.1;

一、准备训练数据

1.1 准备中文文本分类任务的训练数据

这里Demo数据如下:

各银行信用卡挂失费迥异 北京银行收费最高    0
莫泰酒店流拍 大摩叫价或降至6亿美元 4
乌兹别克斯坦议会立法院主席获连任   6
德媒披露鲁能引援关键人物 是他力荐德甲亚洲强人    7
辉立证券给予广汽集团持有评级 2
图文-业余希望赛海南站第二轮 球场的菠萝蜜  7
陆毅鲍蕾:近乎完美的爱情(组图)(2)    9
7000亿美元救市方案将成期市毒药  0
保诚启动210亿美元配股交易以融资收购AIG部门   2

分类class类别文件:

finance
realty
stocks
education
science
society
politics
sports
game
entertainment

1.2 数据读取和截断,使满足BERT模型输入

读取训练数据,对文本进行处理,如截取过长的文本、补齐较短的文本,加上起始标示、对文本进行编码、添加掩码、转为tensor等操作。

import os
from config import parsers
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import torchfrom transformers import AutoTokenizer, AutoModelForMaskedLMdef read_data(file):# 读取文件all_data = open(file, "r", encoding="utf-8").read().split("\n")# 得到所有文本、所有标签、句子的最大长度texts, labels, max_length = [], [], []for data in all_data:if data:text, label = data.split("\t")max_length.append(len(text))texts.append(text)labels.append(label)# 根据不同的数据集返回不同的内容if os.path.split(file)[1] == "train.txt":max_len = max(max_length)return texts, labels, max_lenreturn texts, labels,class MyDataset(Dataset):def __init__(self, texts, labels, max_length):self.all_text = textsself.all_label = labelsself.max_len = max_lengthself.tokenizer = BertTokenizer.from_pretrained(parsers().bert_pred)
#         self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")def __getitem__(self, index):# 取出一条数据并截断长度text = self.all_text[index][:self.max_len]label = self.all_label[index]# 分词text_id = self.tokenizer.tokenize(text)# 加上起始标志text_id = ["[CLS]"] + text_id# 编码token_id = self.tokenizer.convert_tokens_to_ids(text_id)# 掩码  -》mask = [1] * len(token_id) + [0] * (self.max_len + 2 - len(token_id))# 编码后  -》长度一致token_ids = token_id + [0] * (self.max_len + 2 - len(token_id))# str -》 intlabel = int(label)# 转化成tensortoken_ids = torch.tensor(token_ids)mask = torch.tensor(mask)label = torch.tensor(label)return (token_ids, mask), labeldef __len__(self):# 得到文本的长度return len(self.all_text)

将文本处理后,就可以使用torch.utils.data中自带的DataLoader模块来加载训练数据了。

二、微调BERT模型

我们是微调BERT模型,需要获取BERT最后一个隐藏层的输出作为输入到下一个全连接层。

至于选择BERT模型的哪个输出作为linear层的输入,可以通过实验尝试,或者遵循常理。

pooler_output:这是通过将最后一层的隐藏状态的第一个token(通常是[CLS] token)通过一个线性层和激活函数得到的输出,常用于分类任务。
last_hidden_state:这是模型所有层的最后一个隐藏状态的输出,包含了整个序列的上下文信息,适用于序列级别的任务。

简单调用下BERT模型,打印出来最后一层看下:

import torch
import time
import torch.nn as nn
from transformers import BertTokenizer
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLMdef process_text(text, bert_pred):tokenizer = BertTokenizer.from_pretrained(bert_pred)token_id = tokenizer.convert_tokens_to_ids(["[CLS]"] + tokenizer.tokenize(text))mask = [1] * len(token_id) + [0] * (38 + 2 - len(token_id))token_ids = token_id + [0] * (38 + 2 - len(token_id))token_ids = torch.tensor(token_ids).unsqueeze(0)mask = torch.tensor(mask).unsqueeze(0)x = torch.stack([token_ids, mask])return xdevice = "cpu"
bert = BertModel.from_pretrained('./bert-base-chinese/')
texts = ["沈腾和马丽的电影《独行月球》挺好看"]
for text in texts:x = process_text(text, './bert-base-chinese/')input_ids, attention_mask = x[0].to(device), x[1].to(device)hidden_out = bert(input_ids, attention_mask=attention_mask,output_hidden_states=False) print(hidden_out)

 输出结果:

2.1 文本分类任务,选择使用pooler_output作为线性层的输入。

import torch.nn as nn
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
from config import parsers
import torchclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.args = parsers()self.device = "cuda:0" if torch.cuda.is_available() else "cpu"  self.bert = BertModel.from_pretrained(self.args.bert_pred) # bert 模型进行微调for param in self.bert.parameters():param.requires_grad = True# 一个全连接层self.linear = nn.Linear(self.args.num_filters, self.args.class_num)def forward(self, x):input_ids, attention_mask = x[0].to(self.device), x[1].to(self.device)hidden_out = self.bert(input_ids, attention_mask=attention_mask,output_hidden_states=False)  # 是否输出所有encoder层的结果# shape (batch_size, hidden_size)  pooler_output -->  hidden_out[0]pred = self.linear(hidden_out.pooler_output)# 返回预测结果return pred

2.2 优化器使用Adam、损失函数使用交叉熵损失函数

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = MyModel().to(device)
opt = AdamW(model.parameters(), lr=args.learn_rate)
loss_fn = nn.CrossEntropyLoss()

三、训练模型

3.1 参数配置

def parsers():parser = argparse.ArgumentParser(description="Bert model of argparse")parser.add_argument("tx_date",nargs='?') #可选输入参数,计算日期parser.add_argument("--train_file", type=str, default=os.path.join("./data_all", "train.txt"))parser.add_argument("--dev_file", type=str, default=os.path.join("./data_all", "dev.txt"))parser.add_argument("--test_file", type=str, default=os.path.join("./data_all", "test.txt"))parser.add_argument("--classification", type=str, default=os.path.join("./data_all", "class.txt"))parser.add_argument("--bert_pred", type=str, default="./bert-base-chinese")parser.add_argument("--class_num", type=int, default=12)parser.add_argument("--max_len", type=int, default=38)parser.add_argument("--batch_size", type=int, default=32)parser.add_argument("--epochs", type=int, default=10)parser.add_argument("--learn_rate", type=float, default=1e-5)parser.add_argument("--num_filters", type=int, default=768)parser.add_argument("--save_model_best", type=str, default=os.path.join("model", "all_best_model.pth"))parser.add_argument("--save_model_last", type=str, default=os.path.join("model", "all_last_model.pth"))args = parser.parse_args()return args

3.2 模型训练

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn as nn
from sklearn.metrics import accuracy_score
import timeif __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"print("device:", device)train_text, train_label, max_len = read_data(args.train_file)dev_text, dev_label = read_data(args.dev_file)args.max_len = max_lentrain_dataset = MyDataset(train_text, train_label, args.max_len)train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)dev_dataset = MyDataset(dev_text, dev_label, args.max_len)dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False)model = MyModel().to(device)opt = AdamW(model.parameters(), lr=args.learn_rate)loss_fn = nn.CrossEntropyLoss()acc_max = float("-inf")for epoch in range(args.epochs):loss_sum, count = 0, 0model.train()for batch_index, (batch_text, batch_label) in enumerate(train_dataloader):batch_label = batch_label.to(device)pred = model(batch_text)loss = loss_fn(pred, batch_label)opt.zero_grad()loss.backward()opt.step()loss_sum += losscount += 1# 打印内容if len(train_dataloader) - batch_index <= len(train_dataloader) % 1000 and count == len(train_dataloader) % 1000:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0if batch_index % 1000 == 999:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0model.eval()all_pred, all_true = [], []with torch.no_grad():for batch_text, batch_label in dev_dataloader:batch_label = batch_label.to(device)pred = model(batch_text)pred = torch.argmax(pred, dim=1).cpu().numpy().tolist()label = batch_label.cpu().numpy().tolist()all_pred.extend(pred)all_true.extend(label)acc = accuracy_score(all_pred, all_true)print(f"dev acc:{acc:.4f}")if acc > acc_max:print(acc, acc_max)acc_max = acctorch.save(model.state_dict(), args.save_model_best)print(f"以保存最佳模型")torch.save(model.state_dict(), args.save_model_last)end = time.time()print(f"运行时间:{(end-start)/60%60:.4f} min")

模型保存为:

-rw-rw-r--  1 gaoToby gaoToby 391M Dec 24 14:02 all_best_model.pth
-rw-rw-r--  1 gaoToby gaoToby 391M Dec 24 14:02 all_last_model.pth

四、模型推理预测

准备预测文本文件,加载模型,进行文本的类别预测。


def text_class_name(pred):result = torch.argmax(pred, dim=1)print(torch.argmax(pred, dim=1).cpu().numpy().tolist())result = result.cpu().numpy().tolist()classification = open(args.classification, "r", encoding="utf-8").read().split("\n")classification_dict = dict(zip(range(len(classification)), classification))print(f"文本:{text}\t预测的类别为:{classification_dict[result[0]]}")if __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"model = load_model(device, args.save_model_best)texts = ["沈腾和马丽的新电影《独行月球》好看", "最近金融环境不太好,投资需谨慎"]print("模型预测结果:")for text in texts:x = process_text(text, args.bert_pred)with torch.no_grad():pred = model(x)text_class_name(pred)end = time.time()print(f"耗时为:{end - start} s")

以上,基本流程完成。当然模型还需要调优来改进预测效果的。

代码是实际跑通的,我训练和预测均使用的是GPU。如果是使用GPU做模型训练,再使用CPU做推理预测的情况,推理预测加载模型的时候注意修改下:

 myModel.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

Done

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/890848.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【无标题】学生信息管理系统界面

网页是vue框架&#xff0c;后端直接python写的没使用框架

macos安装maven以及.bash_profile文件优化

文章目录 下载和安装maven本地仓库配置国内镜像仓库配置.bash_profile文件优化 下载和安装maven maven下载地址 存放在/Library/Java/env/maven目录 本地仓库配置 在maven-3.9.9目录下创建maven-repo目录作为本地文件仓库打开setting配置文件 在setting标签下&#xff0c;添…

用Excel表格在线发布期末考试成绩单

每到期末&#xff0c;发布学生的期末考试成绩单便是老师们的一项重要任务。以往&#xff0c;传统的纸质成绩单分发效率低还易出错&#xff0c;而借助 Excel 表格在线发布&#xff0c;则开启了全新高效模式。 老师们先是精心整理各科成绩&#xff0c;录入精准无误的分数到 Excel…

WPF 绘制过顶点的圆滑曲线(样条,贝塞尔)

项目中要用到样条曲线&#xff0c;必须过顶点&#xff0c;圆滑后还不能太走样&#xff0c;捣鼓一番&#xff0c;发现里面颇有玄机&#xff0c;于是把我多方抄来改造的方法发出来&#xff0c;方便新手&#xff1a; 如上图&#xff0c;看代码吧&#xff1a; -------------------…

python监控数据处理应用服务Socket心跳解决方案

1. 概述 从网页、手机App上抓取数据应用服务&#xff0c;涉及到多个系统集成协同工作&#xff0c;依赖工具较多。例如&#xff0c;使用Frida进行代码注入和动态分析&#xff0c;以实现对网络通信的监控和数据捕获。在这样的集成环境中&#xff0c;手机模拟器、手机中应用、消息…

商品线上个性定制,并实时预览3D定制效果,是如何实现的?

商品线上3D个性化定制的实现涉及多个环节和技术&#xff0c;以下是详细的解释&#xff1a; 一、实现流程 产品3D建模&#xff1a; 是实现3D可视化定制的前提&#xff0c;需要对产品进行三维建模。可通过三维扫描仪或建模师进行建模&#xff0c;将产品的外观、结构、材质等细…

Python PyMupdf 去除PDF文档中Watermark标识水印

通过PDF阅读或编辑工具&#xff0c;可在PDF中加入Watermark标识的PDF水印&#xff0c;如下图&#xff1a; 该类水印特点 这类型的水印&#xff0c;会在文件的字节流中出现/Watermark、EMC等标识&#xff0c;那么&#xff0c;我们可以通过改变文件字节内容&#xff0c;清理掉…

旧衣回收小程序开发,绿色生活,便捷回收

随着绿色生活、资源回收利用理念的影响&#xff0c;人们逐渐开始关注旧衣回收&#xff0c;选择将断舍离等闲置衣物进行回收&#xff0c;在资源回收的同时也能够减少资金浪费。目前&#xff0c;旧衣回收的方式也迎来了数字化发展&#xff0c;相比传统的回收方式更加便捷&#xf…

Bluetooth Spec【0】蓝牙核心架构

蓝牙核心系统由一个主机、一个主控制器和零个或多个辅助控制器组成蓝牙BR/ EDR核心系统的最小实现包括了由蓝牙规范定义的四个最低层和相关协议&#xff0c;以及一个公共服务层协议&#xff1b;服务发现协议&#xff08;SDP&#xff09;和总体配置文件要求在通用访问配置文件&a…

vulnhub靶场-matrix-breakout-2-morpheus攻略(截止至获取shell)

扫描出ip为192.168.121.161 访问该ip&#xff0c;发现只是一个静态页面什么也没有 使用dir dirsearch 御剑都只能扫描到/robots.txt /server-status 两个页面&#xff0c;前者提示我们什么也没有&#xff0c;后面两个没有权限访问 扫描端口&#xff0c;存在81端口 访问&#x…

Java - 日志体系_Apache Commons Logging(JCL)日志接口库

文章目录 官网1. 什么是JCL&#xff1f;2. JCL的主要特点3. JCL的核心组件4. JCL的实现机制5. SimpleLog 简介6. CodeExample 1 &#xff1a; 默认日志实现 (JCL 1.3.2版本)Example 2 &#xff1a; JCL (1.2版本&#xff09; Log4J 【安全风险高&#xff0c;请勿使用】 7. 使用…

C++-----------映射

探索 C 中的映射与查找表 在 C 编程中&#xff0c;映射&#xff08;Map&#xff09;和查找表&#xff08;Lookup Table&#xff09;是非常重要的数据结构&#xff0c;它们能够高效地存储和检索数据&#xff0c;帮助我们解决各种实际问题。今天&#xff0c;我们就来深入探讨一下…

免费 IP 归属地接口

免费GEOIP&#xff0c;查询IP信息&#xff0c;支持IPV4 IPV6 ,包含国家地理位置&#xff0c;维度&#xff0c;asm,邮编 等&#xff0c;例如 例如查询1.1.1.1 http://geoip.91hu.top/?ip1.1.1.1 返回json 对象

Linux应用软件编程-多任务处理(进程)

多任务&#xff1a;让系统具备同时处理多个事件的能力。让系统具备并发性能。方法&#xff1a;进程和线程。这里先讲进程。 进程&#xff08;process&#xff09;&#xff1a;正在执行的程序&#xff0c;执行过程中需要消耗内存和CPU。 进程的创建&#xff1a;操作系统在进程创…

认识计算机网络

单单看这一个词语&#xff0c;有熟悉又陌生&#xff0c;让我们来重新认识一下这位大角色——计算机网络。 一、是什么 以及 怎么来的 计算机网络是指将地理位置不同的具有独立功能的多台计算机及其外部设备&#xff0c;通过通信线路和通信设备连接起来&#xff0c;在网络操作…

3. Kafka入门—安装与基本命令

Kafka基础操作 一. 章节简介二. kafka简介三. Kafka安装1. 准备工作2. Zookeeper安装2.1 配置文件2.2 启动相关命令3. Kafka安装3.1 配置文件3.2 启动相关命令-------------------------------------------------------------------------------------------------------------…

【Redis】 数据淘汰策略

面试官询问缓存过多而内存有限时内存被占满的处理办法&#xff0c;引出 Redis 数据淘汰策略。 数据淘汰策略与数据过期策略不同&#xff0c; 过期策略针对设置过期时间的 key 删除&#xff0c; 淘汰策略是在内存不够时按规则删除内存数据。 八种数据淘汰策略介绍 no evision&…

meshy的文本到3d的使用

Meshy官方网站&#xff1a; 中文官网&#xff1a; Meshy官网中文站 ​编辑 Opens in a new window ​编辑www.meshycn.com Meshy AI 中文官网首页 英文官网&#xff1a; Meshy目前似乎还没有单独的英文官网&#xff0c;但您可以在中文官网上找到英文界面或相关英文资料。 链…

计算机网络压缩版

计算机网络到现在零零散散也算过了三遍&#xff0c;一些协议大概了解&#xff0c;但总是模模糊糊的印象&#xff0c;现在把自己的整体认识总结一下&#xff0c;&#xff08;本来想去起名叫《看这一篇就够了》&#xff0c;但是发现网上好的文章太多了&#xff0c;还是看这篇吧&a…

C++-----线性结构

C线性结构模板 概念&#xff1a;线性结构是一种数据元素之间存在一对一线性关系的数据结构&#xff0c;如数组、链表、栈、队列等。C中的模板可以让我们编写通用的代码&#xff0c;适用于不同的数据类型&#xff0c;而不必为每种数据类型都重复编写相同的代码结构。作用&#…