Bert中文文本分类

这是一个经典的文本分类问题,使用google的预训练模型BERT中文版bert-base-chinese来做中文文本分类。可以先在Huggingface上下载预训练模型备用。https://huggingface.co/google-bert/bert-base-chinese/tree/main

我使用的训练环境是

pip install torch==2.0.0;
pip install transformers==4.30.2;
pip install gensim==4.3.3;
pip install huggingface-hub==0.15.1;
pip install modelscope==1.20.1;

一、准备训练数据

1.1 准备中文文本分类任务的训练数据

这里Demo数据如下:

各银行信用卡挂失费迥异 北京银行收费最高    0
莫泰酒店流拍 大摩叫价或降至6亿美元 4
乌兹别克斯坦议会立法院主席获连任   6
德媒披露鲁能引援关键人物 是他力荐德甲亚洲强人    7
辉立证券给予广汽集团持有评级 2
图文-业余希望赛海南站第二轮 球场的菠萝蜜  7
陆毅鲍蕾:近乎完美的爱情(组图)(2)    9
7000亿美元救市方案将成期市毒药  0
保诚启动210亿美元配股交易以融资收购AIG部门   2

分类class类别文件:

finance
realty
stocks
education
science
society
politics
sports
game
entertainment

1.2 数据读取和截断,使满足BERT模型输入

读取训练数据,对文本进行处理,如截取过长的文本、补齐较短的文本,加上起始标示、对文本进行编码、添加掩码、转为tensor等操作。

import os
from config import parsers
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import torchfrom transformers import AutoTokenizer, AutoModelForMaskedLMdef read_data(file):# 读取文件all_data = open(file, "r", encoding="utf-8").read().split("\n")# 得到所有文本、所有标签、句子的最大长度texts, labels, max_length = [], [], []for data in all_data:if data:text, label = data.split("\t")max_length.append(len(text))texts.append(text)labels.append(label)# 根据不同的数据集返回不同的内容if os.path.split(file)[1] == "train.txt":max_len = max(max_length)return texts, labels, max_lenreturn texts, labels,class MyDataset(Dataset):def __init__(self, texts, labels, max_length):self.all_text = textsself.all_label = labelsself.max_len = max_lengthself.tokenizer = BertTokenizer.from_pretrained(parsers().bert_pred)
#         self.tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")def __getitem__(self, index):# 取出一条数据并截断长度text = self.all_text[index][:self.max_len]label = self.all_label[index]# 分词text_id = self.tokenizer.tokenize(text)# 加上起始标志text_id = ["[CLS]"] + text_id# 编码token_id = self.tokenizer.convert_tokens_to_ids(text_id)# 掩码  -》mask = [1] * len(token_id) + [0] * (self.max_len + 2 - len(token_id))# 编码后  -》长度一致token_ids = token_id + [0] * (self.max_len + 2 - len(token_id))# str -》 intlabel = int(label)# 转化成tensortoken_ids = torch.tensor(token_ids)mask = torch.tensor(mask)label = torch.tensor(label)return (token_ids, mask), labeldef __len__(self):# 得到文本的长度return len(self.all_text)

将文本处理后,就可以使用torch.utils.data中自带的DataLoader模块来加载训练数据了。

二、微调BERT模型

我们是微调BERT模型,需要获取BERT最后一个隐藏层的输出作为输入到下一个全连接层。

至于选择BERT模型的哪个输出作为linear层的输入,可以通过实验尝试,或者遵循常理。

pooler_output:这是通过将最后一层的隐藏状态的第一个token(通常是[CLS] token)通过一个线性层和激活函数得到的输出,常用于分类任务。
last_hidden_state:这是模型所有层的最后一个隐藏状态的输出,包含了整个序列的上下文信息,适用于序列级别的任务。

简单调用下BERT模型,打印出来最后一层看下:

import torch
import time
import torch.nn as nn
from transformers import BertTokenizer
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLMdef process_text(text, bert_pred):tokenizer = BertTokenizer.from_pretrained(bert_pred)token_id = tokenizer.convert_tokens_to_ids(["[CLS]"] + tokenizer.tokenize(text))mask = [1] * len(token_id) + [0] * (38 + 2 - len(token_id))token_ids = token_id + [0] * (38 + 2 - len(token_id))token_ids = torch.tensor(token_ids).unsqueeze(0)mask = torch.tensor(mask).unsqueeze(0)x = torch.stack([token_ids, mask])return xdevice = "cpu"
bert = BertModel.from_pretrained('./bert-base-chinese/')
texts = ["沈腾和马丽的电影《独行月球》挺好看"]
for text in texts:x = process_text(text, './bert-base-chinese/')input_ids, attention_mask = x[0].to(device), x[1].to(device)hidden_out = bert(input_ids, attention_mask=attention_mask,output_hidden_states=False) print(hidden_out)

 输出结果:

2.1 文本分类任务,选择使用pooler_output作为线性层的输入。

import torch.nn as nn
from transformers import BertModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
from config import parsers
import torchclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.args = parsers()self.device = "cuda:0" if torch.cuda.is_available() else "cpu"  self.bert = BertModel.from_pretrained(self.args.bert_pred) # bert 模型进行微调for param in self.bert.parameters():param.requires_grad = True# 一个全连接层self.linear = nn.Linear(self.args.num_filters, self.args.class_num)def forward(self, x):input_ids, attention_mask = x[0].to(self.device), x[1].to(self.device)hidden_out = self.bert(input_ids, attention_mask=attention_mask,output_hidden_states=False)  # 是否输出所有encoder层的结果# shape (batch_size, hidden_size)  pooler_output -->  hidden_out[0]pred = self.linear(hidden_out.pooler_output)# 返回预测结果return pred

2.2 优化器使用Adam、损失函数使用交叉熵损失函数

device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = MyModel().to(device)
opt = AdamW(model.parameters(), lr=args.learn_rate)
loss_fn = nn.CrossEntropyLoss()

三、训练模型

3.1 参数配置

def parsers():parser = argparse.ArgumentParser(description="Bert model of argparse")parser.add_argument("tx_date",nargs='?') #可选输入参数,计算日期parser.add_argument("--train_file", type=str, default=os.path.join("./data_all", "train.txt"))parser.add_argument("--dev_file", type=str, default=os.path.join("./data_all", "dev.txt"))parser.add_argument("--test_file", type=str, default=os.path.join("./data_all", "test.txt"))parser.add_argument("--classification", type=str, default=os.path.join("./data_all", "class.txt"))parser.add_argument("--bert_pred", type=str, default="./bert-base-chinese")parser.add_argument("--class_num", type=int, default=12)parser.add_argument("--max_len", type=int, default=38)parser.add_argument("--batch_size", type=int, default=32)parser.add_argument("--epochs", type=int, default=10)parser.add_argument("--learn_rate", type=float, default=1e-5)parser.add_argument("--num_filters", type=int, default=768)parser.add_argument("--save_model_best", type=str, default=os.path.join("model", "all_best_model.pth"))parser.add_argument("--save_model_last", type=str, default=os.path.join("model", "all_last_model.pth"))args = parser.parse_args()return args

3.2 模型训练

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn as nn
from sklearn.metrics import accuracy_score
import timeif __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"print("device:", device)train_text, train_label, max_len = read_data(args.train_file)dev_text, dev_label = read_data(args.dev_file)args.max_len = max_lentrain_dataset = MyDataset(train_text, train_label, args.max_len)train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)dev_dataset = MyDataset(dev_text, dev_label, args.max_len)dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False)model = MyModel().to(device)opt = AdamW(model.parameters(), lr=args.learn_rate)loss_fn = nn.CrossEntropyLoss()acc_max = float("-inf")for epoch in range(args.epochs):loss_sum, count = 0, 0model.train()for batch_index, (batch_text, batch_label) in enumerate(train_dataloader):batch_label = batch_label.to(device)pred = model(batch_text)loss = loss_fn(pred, batch_label)opt.zero_grad()loss.backward()opt.step()loss_sum += losscount += 1# 打印内容if len(train_dataloader) - batch_index <= len(train_dataloader) % 1000 and count == len(train_dataloader) % 1000:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0if batch_index % 1000 == 999:msg = "[{0}/{1:5d}]\tTrain_Loss:{2:.4f}"print(msg.format(epoch + 1, batch_index + 1, loss_sum / count))loss_sum, count = 0.0, 0model.eval()all_pred, all_true = [], []with torch.no_grad():for batch_text, batch_label in dev_dataloader:batch_label = batch_label.to(device)pred = model(batch_text)pred = torch.argmax(pred, dim=1).cpu().numpy().tolist()label = batch_label.cpu().numpy().tolist()all_pred.extend(pred)all_true.extend(label)acc = accuracy_score(all_pred, all_true)print(f"dev acc:{acc:.4f}")if acc > acc_max:print(acc, acc_max)acc_max = acctorch.save(model.state_dict(), args.save_model_best)print(f"以保存最佳模型")torch.save(model.state_dict(), args.save_model_last)end = time.time()print(f"运行时间:{(end-start)/60%60:.4f} min")

模型保存为:

-rw-rw-r--  1 gaoToby gaoToby 391M Dec 24 14:02 all_best_model.pth
-rw-rw-r--  1 gaoToby gaoToby 391M Dec 24 14:02 all_last_model.pth

四、模型推理预测

准备预测文本文件,加载模型,进行文本的类别预测。


def text_class_name(pred):result = torch.argmax(pred, dim=1)print(torch.argmax(pred, dim=1).cpu().numpy().tolist())result = result.cpu().numpy().tolist()classification = open(args.classification, "r", encoding="utf-8").read().split("\n")classification_dict = dict(zip(range(len(classification)), classification))print(f"文本:{text}\t预测的类别为:{classification_dict[result[0]]}")if __name__ == "__main__":start = time.time()args = parsers()device = "cuda:0" if torch.cuda.is_available() else "cpu"model = load_model(device, args.save_model_best)texts = ["沈腾和马丽的新电影《独行月球》好看", "最近金融环境不太好,投资需谨慎"]print("模型预测结果:")for text in texts:x = process_text(text, args.bert_pred)with torch.no_grad():pred = model(x)text_class_name(pred)end = time.time()print(f"耗时为:{end - start} s")

以上,基本流程完成。当然模型还需要调优来改进预测效果的。

代码是实际跑通的,我训练和预测均使用的是GPU。如果是使用GPU做模型训练,再使用CPU做推理预测的情况,推理预测加载模型的时候注意修改下:

 myModel.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

Done

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/890848.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NLP工具

1、用python提取PDF中各类文本内容的方法 2、文本分块(Chunking)方法&#xff0c;直接影响LLM应用效果 3、

【无标题】学生信息管理系统界面

网页是vue框架&#xff0c;后端直接python写的没使用框架

macos安装maven以及.bash_profile文件优化

文章目录 下载和安装maven本地仓库配置国内镜像仓库配置.bash_profile文件优化 下载和安装maven maven下载地址 存放在/Library/Java/env/maven目录 本地仓库配置 在maven-3.9.9目录下创建maven-repo目录作为本地文件仓库打开setting配置文件 在setting标签下&#xff0c;添…

提示词工程教程(七):小样本和上下文学习

概述 本教程使用 OpenAI 的 GPT 模型和 LangChain 库探索小样本学习和上下文学习的前沿技术。这些方法使 AI 模型能够使用最少的示例执行复杂的任务&#xff0c;从而彻底改变了我们处理机器学习问题的方式。 主题 传统机器学习通常需要大量数据集进行训练&#xff0c;这非常耗…

用Excel表格在线发布期末考试成绩单

每到期末&#xff0c;发布学生的期末考试成绩单便是老师们的一项重要任务。以往&#xff0c;传统的纸质成绩单分发效率低还易出错&#xff0c;而借助 Excel 表格在线发布&#xff0c;则开启了全新高效模式。 老师们先是精心整理各科成绩&#xff0c;录入精准无误的分数到 Excel…

WPF 绘制过顶点的圆滑曲线(样条,贝塞尔)

项目中要用到样条曲线&#xff0c;必须过顶点&#xff0c;圆滑后还不能太走样&#xff0c;捣鼓一番&#xff0c;发现里面颇有玄机&#xff0c;于是把我多方抄来改造的方法发出来&#xff0c;方便新手&#xff1a; 如上图&#xff0c;看代码吧&#xff1a; -------------------…

python监控数据处理应用服务Socket心跳解决方案

1. 概述 从网页、手机App上抓取数据应用服务&#xff0c;涉及到多个系统集成协同工作&#xff0c;依赖工具较多。例如&#xff0c;使用Frida进行代码注入和动态分析&#xff0c;以实现对网络通信的监控和数据捕获。在这样的集成环境中&#xff0c;手机模拟器、手机中应用、消息…

商品线上个性定制,并实时预览3D定制效果,是如何实现的?

商品线上3D个性化定制的实现涉及多个环节和技术&#xff0c;以下是详细的解释&#xff1a; 一、实现流程 产品3D建模&#xff1a; 是实现3D可视化定制的前提&#xff0c;需要对产品进行三维建模。可通过三维扫描仪或建模师进行建模&#xff0c;将产品的外观、结构、材质等细…

开源 SOAP over UDP

简介 看到有人想要实现两个 EXE 之间的互动。这可以采用 RPC 的方式嘛。 Delphi 现成的 RPC 框架&#xff0c;比如 WebService&#xff0c;比如 DataSnap&#xff1b; 当然&#xff0c;github 上面还有第三方开源的 XMLRPC 等等。 为啥要搞一个 UDP Delphi 的 WebService …

【Laravel】接口的访问频率限制器

Laravel 接口的访问频率&#xff0c;你可以在 Laravel 中使用速率限制器&#xff08;Rate Limiter&#xff09;。以下是一个详细的步骤&#xff0c;展示如何为这个特定的 API 路由设置速率限制&#xff1a; 1. 配置 RouteServiceProvider 首先&#xff0c;确保在 App\Provide…

Vue.use()和Vue.component()

当很多页面用到同一个组件&#xff0c;又不想每次都在局部注册时&#xff0c;可以在main.js 中全局注册 Vue.component()一次只能注册一个组件 import CcInput from /components/cc-input.vue Vue.component(CcInput);Vue.use()一次可以注册多个组件 对于自定义的组件&#…

地理数据库Telepg面试内容整理-请描述空间索引的基本概念,如何使用它提高查询性能

空间索引的基本概念 空间索引是专门用于加速空间数据(如地理位置、几何对象等)查询的一种数据结构。空间数据本质上是多维的,包含了坐标、形状、区域等信息,这使得传统的单维索引(如 B+ 树)并不适用。空间索引通过将空间数据映射到特定的索引结构中,使得在进行空间查询时…

Rust : tokio中select!

关于tokio的select宏&#xff0c;有不少的用途。包括超时和竞态选择等。 关于select宏需要关注&#xff0c;相关的异步条件&#xff0c;会同时执行&#xff0c;只是当有一个最早完成时&#xff0c;会执行“抛弃”和“对应”策略。 说明&#xff1a;对本文以下素材的来源表示感…

Python PyMupdf 去除PDF文档中Watermark标识水印

通过PDF阅读或编辑工具&#xff0c;可在PDF中加入Watermark标识的PDF水印&#xff0c;如下图&#xff1a; 该类水印特点 这类型的水印&#xff0c;会在文件的字节流中出现/Watermark、EMC等标识&#xff0c;那么&#xff0c;我们可以通过改变文件字节内容&#xff0c;清理掉…

python EEGPT报错:Cannot cast ufunc ‘clip‘ output from dtype(‘float64‘)

今天在运行EEGPT的时候遇见了下面的问题&#xff0c;首先是nme报错&#xff0c;然后引起了numpy的报错&#xff1a; numpy.core._exceptions._UFuncOutputCastingError: Cannot cast ufunc clip output from dtype(float64)在网上找了好久的教程&#xff0c;但是没有找到。猜测…

旧衣回收小程序开发,绿色生活,便捷回收

随着绿色生活、资源回收利用理念的影响&#xff0c;人们逐渐开始关注旧衣回收&#xff0c;选择将断舍离等闲置衣物进行回收&#xff0c;在资源回收的同时也能够减少资金浪费。目前&#xff0c;旧衣回收的方式也迎来了数字化发展&#xff0c;相比传统的回收方式更加便捷&#xf…

[论文笔记] 从生成到评估:LLM-as-a-judge 的机遇与挑战

https://arxiv.org/pdf/2411.16594 1. LLM-as-a-judge 的引入 传统的评估方法(如 BLEU 和 ROUGE)在处理生成内容的有用性、无害性等细腻属性时表现不足。随着大语言模型(LLM)的发展,提出了 “LLM-as-a-judge”(LLM 作为评估者)的新范式,用于对任务进行评分、排序或选择…

Bluetooth Spec【0】蓝牙核心架构

蓝牙核心系统由一个主机、一个主控制器和零个或多个辅助控制器组成蓝牙BR/ EDR核心系统的最小实现包括了由蓝牙规范定义的四个最低层和相关协议&#xff0c;以及一个公共服务层协议&#xff1b;服务发现协议&#xff08;SDP&#xff09;和总体配置文件要求在通用访问配置文件&a…

【C 基础】C语言代码编译过程

从一个源文件(.c)到可执行程序到底经历了哪几步&#xff0c;我想大多数的人都知道&#xff0c;到时到底每一步都做了什么&#xff0c;我估计也没多少人能够说得清清楚楚&#xff0c;明明白白。 其实总的流程是这样的。 【第一步】编辑hello.c 1 #include <stdio.h> 2 …

数据处理之数据规约

数据处理之数据规约 1. 数据规约概述 数据规约是数据处理中的重要方法&#xff0c;旨在让数据处理更简便、高效&#xff0c;以满足业务需求。当从数据仓库获取的数据量庞大时&#xff0c;直接在海量数据上进行分析和挖掘成本颇高。数据规约可得到数据集的归约表示&#xff0c…