bert分类模型使用

使用 bert-bert-chinese 预训练模型去做分类任务,这里找了新闻分类数据,数据有 20w,来自https://github.com/649453932/Bert-Chinese-Text-Classification-Pytorch/tree/master/THUCNews

数据 20w ,18w 训练数据,1w 验证数据, 1w 测试数据,10个类别我跑起来后,预测要7天7夜,于是吧数据都缩小了一些,每个类别抽一些,1800 训练数据,150 验证数据, 150 测试数据,都跑了 1.5 小时, cpu ,电脑 gpu 只有 2g 显存,带不起来

bert- base-chinses 模型下载:bert预训练模型下载-CSDN博客

训练

现在是大模型时代了,这篇文章的代码是利用大模型帮我写的的,通过大模型修正代码,并解释代码一直到可用,代码都写了注释了,整个分类流程就这样,算是一个通用模板了吧

train.py 

# 导入所需的库
import torch
import os
import time
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
from torch.optim import AdamW# 定义数据集类,符合高内聚原则
class NewsTitleDataset(Dataset):def __init__(self, file_path, tokenizer, max_len=128):self.data = []with open(file_path, 'r', encoding='utf-8') as f:for line in f.readlines():title, label = line.strip().split('\t')inputs = tokenizer(title, padding='max_length', truncation=True, max_length=max_len)self.data.append({'input_ids': inputs['input_ids'], 'attention_mask': inputs['attention_mask'], 'label': int(label)})def __len__(self):return len(self.data)def __getitem__(self, idx):'''在使用DataLoader加载数据进行训练或验证时被调用'''return self.data[idx]# 训练函数(部分代码,实际训练时应包含更多细节如损失计算、模型更新等)
def train_model(model, train_loader, val_loader, optimizer, epochs=3, model_save_path='../output/bert_news_classifier'):# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 判断是否使用GPUdevice = torch.device("cpu")model.to(device)best_val_accuracy = None  # 初始化最优验证集准确率# 创建保存目录(如果不存在)os.makedirs(os.path.dirname(model_save_path), exist_ok=True)# 训练几次模型for epoch in range(epochs):model.train()  # 开启训练模式,会更新参数for batch in train_loader:input_ids = batch['input_ids'].to(device)  # 直接通过键名访问'input_ids'attention_mask = batch['attention_mask'].to(device)  # 直接通过键名访问'attention_mask'labels = batch['label'].to(device)  # 直接通过键名访问'label'outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.loss  # 获取损失optimizer.zero_grad()  # 清零梯度loss.backward()  # 反向传播optimizer.step()  # 更新权重# 在每个epoch结束时评估模型性能model.eval()with torch.no_grad():val_loss = 0correct_predictions = 0total_samples = len(val_data)  # 计算总样本数,用于后续计算准确率for batch in val_loader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['label'].squeeze().to(device)# 计算logits而不是直接获取lossoutputs = model(input_ids=input_ids, attention_mask=attention_mask)# logits 是模型对输入数据计算出的未归一化的类别概率分布,通常是一个形状为 (batch_size, num_classes) 的张量logits = outputs.logits# 手动计算loss(假设labels已转换为one-hot编码或数值标签)# 交叉熵损失函数,多分类问题中常用的损失函数,特别适合于处理像BERT这样的预训练模型输出的logits,并且与one-hot编码的目标标签一起使用loss_fct = torch.nn.CrossEntropyLoss()# labels 是实际的类别标签,需要转换成一个形状为 (batch_size,) 的张量以匹配logits的展开维度。# view(-1, model.num_labels) 会将logits展平为 (batch_size * num_classes) 的向量,使得每个样本的每个类别都有一个单独的概率值对应loss = loss_fct(logits.view(-1, model.num_labels), labels.view(-1))# .item() 方法用于从损失张量提取标量值。val_loss += loss.item()# 找出每个样本的最大概率对应的类别索引,即模型预测的结果。# dim=1 时,表示在第二个维度上找到最大值_, preds = torch.max(logits, dim=1)correct_predictions += (preds == labels).sum().item()val_accuracy = correct_predictions / total_samplesprint(f'Epoch {epoch + 1}, Validation Loss: {val_loss / len(val_loader):.4f}, Accuracy: {val_accuracy * 100:.2f}%')# 如果当前验证集上的准确率优于之前保存的最佳模型,则保存当前模型if best_val_accuracy is None or val_accuracy > best_val_accuracy:best_val_accuracy = val_accuracytorch.save(model.state_dict(), model_save_path)  # 保存模型参数# 定义评估函数
def evaluate_model(model, data_loader):device = next(model.parameters()).devicemodel.eval()correct_predictions = 0total_samples = 0with torch.no_grad():for batch in data_loader:inputs = {key: batch[key].to(device) for key in ['input_ids', 'attention_mask']}labels = batch['label'].to(device)outputs = model(**inputs)_, preds = torch.max(outputs.logits, dim=1)correct_predictions += (preds == labels).sum().item()total_samples += len(labels)return correct_predictions / total_samplesdef collate_to_tensors(batch):input_ids = torch.tensor([example['input_ids'] for example in batch])attention_mask = torch.tensor([example['attention_mask'] for example in batch])labels = torch.tensor([example['label'] for example in batch])return {'input_ids': input_ids, 'attention_mask': attention_mask, 'label': labels}start = time.time()# 加载预训练的tokenizer和模型
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
with open('../data/class.txt', 'r', encoding='utf8') as f:class_labels = f.readlines()
model = BertForSequenceClassification.from_pretrained('../bert-base-chinese', num_labels=len(class_labels))  # 假设class_labels是一个包含所有类别的列表# 加载训练、验证和测试数据集
train_data = NewsTitleDataset('../data/train.txt', tokenizer)
val_data = NewsTitleDataset('../data/dev.txt', tokenizer)
test_data = NewsTitleDataset('../data/test.txt', tokenizer)# 创建DataLoader,用于批处理数据
# collate_to_tensors 调用函数,保证模型接受的数据参数类型一定为 pytorch 的张量类型
# shuffle=True 于防止模型过拟合和提高泛化性能至关重要,因为它确保了模型不会因为训练数据的顺序而产生依赖性。
# batch_size 示每次迭代从数据集中取出多少个样本作为一个批次(batch)进行训练。设置合理的批量大小有助于平衡计算效率和内存使用。
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, collate_fn=collate_to_tensors)
val_loader = DataLoader(val_data, batch_size=32, collate_fn=collate_to_tensors)
test_loader = DataLoader(test_data, batch_size=32, collate_fn=collate_to_tensors)# 设置优化器与学习率
# model.parameters():这是PyTorch中的一个方法,用于获取模型的所有可训练参数。
# lr代表学习率(Learning Rate),它是一个超参数,决定了在每个训练步骤中更新模型参数的幅度大小。给定值 2e-5 表示0.00002
optimizer = AdamW(model.parameters(), lr=2e-5)# 开始训练
train_model(model, train_loader, val_loader, optimizer, model_save_path='../output/best_bert_news_classifier.pth')# 测试模型(仅评估,不更新参数)
test_acc = evaluate_model(model, test_loader)
print(f'Test Accuracy: {test_acc * 100:.2f}%')
print(time.time() - start)

运行结果

 

预测

假如只想输入一个文本,直接得到疯了及结果,可以使用一下代码

import torch
from transformers import BertTokenizer
from transformers import BertForSequenceClassification# 假设 model_state_dict 是从文件加载的模型参数
with open('../data/class.txt', 'r', encoding='utf8') as f:class_labels = f.readlines()
model = BertForSequenceClassification.from_pretrained('../bert-base-chinese', num_labels=len(class_labels))  # 初始化模型结构,并指定分类类别数量# 假设 tokenizer 是您在训练时使用的 BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('../bert-base-chinese')
# 加载模型参数,训练好输出的模型参数
model.load_state_dict(torch.load('../output/best_bert_news_classifier.pth'))
model.eval()  # 设置模型为评估模式def predict_news_category(model, tokenizer, text):# 对文本进行预处理并编码inputs = tokenizer.encode_plus(text,add_special_tokens=True,max_length=128,  # 根据实际情况调整最大长度padding='max_length',truncation=True,return_tensors='pt')input_ids = inputs['input_ids'].to(model.device)attention_mask = inputs['attention_mask'].to(model.device)# 将数据传递给模型以获取logitswith torch.no_grad():outputs = model(input_ids=input_ids, attention_mask=attention_mask)# 获取分类结果logits = outputs.logits_, prediction = torch.max(logits, dim=1)# 返回预测类别索引,实际应用中可能需要将其映射回原始类别标签return prediction.item()# 示例:输入一条新闻标题并预测类别
text = "车载大模型是原子弹还是茶叶蛋?"
predicted_category = predict_news_category(model, tokenizer, text)
print(f"预测的新闻类别是:{predicted_category}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/674884.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录算法训练营Day51|309.最佳买卖股票时机含冷冻期、714.买卖股票的最佳时机含手续费、股票问题总结

目录 309.最佳买卖股票时机含冷冻期 前言 思路 算法实现 714.买卖股票的最佳时机含手续费 前言 思路 算法实现 股票问题总结 309.最佳买卖股票时机含冷冻期 题目链接 文章链接 前言 本题在买卖股票II的基础上增加了一个冷冻期,因此就不能简单分为持有股票和卖…

【JavaWeb】头条新闻纯JavaWeb项目实现 项目搭建 数据库工具类导入 跨域问题 Postman 第一期 (前端Vue3+Vite)

文章目录 一、项目简介1.1 微头条业务简介1.2 技术栈介绍 二、项目部署三、准备工具类3.1 异步响应规范格式类3.2 MD5加密工具类3.3 JDBCUtil连接池工具类3.4 JwtHelper工具类3.4 JSON转换的WEBUtil工具类 四、准备各层的接口和实现类4.1 准备实体类和VO对象4.2 DAO层接口和实现…

222. 完全二叉树的节点个数 - 力扣(LeetCode)

题目描述 给你一棵 完全二叉树 的根节点 root ,求出该树的节点个数。 完全二叉树 的定义如下:在完全二叉树中,除了最底层节点可能没填满外,其余每层节点数都达到最大值,并且最下面一层的节点都集中在该层最左边的若干…

[职场] 线束设计求职简历范文 #媒体#其他#笔记

线束设计求职简历范文 线束设计是指根据汽车电气系统的需求和规范,进行车载线束的布局、连接和组装的过程。下面是线束设计求职简历范文,供大家参考。 个人信息 姓名:蓝山 年龄:26岁 地址:东莞 工作经验&#xff…

C++面试宝典第27题:完全平方数之和

题目 给定正整数 n,找到若干个完全平方数(比如:1、4、9、16、...),使得它们的和等于n。你需要让组成和的完全平方数的个数最少。 示例1: 输入:n = 12 输出:3 解释:12 = 4 + 4 + 4。 示例2: 输入:n = 13 输出:2 解释:13 = 4 + 9。 解析 这道题主要考察应聘者对于…

MySQL- 运维-分库分表-Mycat

一、Mycat概述 1、安装 2、概念介绍 二、Mycat入门 启动服务 三、Mycat配置 1、schema.xml 2、rule.xml 3、server.xml 四、Mycat分片 1、垂直分库 2、水平分表 五、Mycat管理及监控 1、Mycat原理 2、Mycat管理工具 (1)、命令行 (2&#…

【RPA】浅谈RPA技术及其应用

摘要:随着信息技术的飞速发展,企业对于自动化、智能化的需求日益增强。RPA(Robotic Process Automation,机器人流程自动化)技术应运而生,为企业提供了全新的自动化解决方案。本文首先介绍了RPA技术的基本概…

visiontransformerVIT

虽然 Transformer 架构已成为自然语言处理任务的事实标准,但其在计算机视觉中的应用仍然有限。在视觉上,注意力要么与卷积网络结合使用,要么用于替换卷积网络的某些组件,同时保持其整体结构不变。我们表明,这种对 CNN …

软件应用实例分享,电玩计时计费怎么算,佳易王PS5游戏计时器系统程序教程

软件应用实例分享,电玩计时计费怎么算,佳易王PS5游戏计时器系统程序教程 一、前言 以下软件教程以 佳易王电玩计时计费管理系统软件V17.9为例说明 软件文件下载可以点击最下方官网卡片——软件下载——试用版软件下载 点击开始计时后,图片…

使用easyExcel 定义表头 字体 格式 颜色等,定义表内容,合计

HeadStyle 表头样式注解 HeadFontStyle 表头字体样式 HeadStyle(fillPatternType FillPatternTypeEnum.SOLID_FOREGROUND, fillForegroundColor 22) HeadFontStyle(fontHeightInPoints 12) 以下为实现效果

Unity BuffSystem buff系统

Unity BuffSystem buff系统 一、介绍二、buff系统架构三、架构讲解四、框架使用buff数据Json数据以及工具ShowTypeBuffTypeMountTypeBuffOverlapBuffShutDownTypeBuffCalculateType时间和层数这里也不过多说明了如何给生物添加buff 五、总结 一、介绍 现在基本做游戏都会需要些…

【Nicn的刷题日常】之有序序列合并

1.题目描述 描述 输入两个升序排列的序列,将两个序列合并为一个有序序列并输出。 数据范围: 1≤�,�≤1000 1≤n,m≤1000 , 序列中的值满足 0≤���≤30000 0≤val≤30000 输入描述…

C++基础知识点预览

一.绪论: 1.1 C简史: 与C的关系: 被设计为C语言的继任者,C语言是一种过程型语言,程序员使用它定义执行特定操作的函数,而C是一种面向对象的语言,实现了继承、抽象、多态和封装等概念。C支持类&…

【RPA】智能自动化的未来:AI + RPA

伴随着人工智能(AI)技术的迅猛进步,机器人流程自动化(RPA)正在经历一场翻天覆地的变革。AI为RPA注入了新的活力,尤其在处理复杂任务和制定决策方面。通过融合自然语言处理(NLP)、机器…

从0开始学Docker ---Docker安装教程

Docker安装教程 本安装教程参考Docker官方文档,地址如下: https://docs.docker.com/engine/install/centos/ 1.卸载旧版 首先如果系统中已经存在旧的Docker,则先卸载: yum remove docker \docker-client \docker-client-latest…

《学成在线》微服务实战项目实操笔记系列(P1~P83)【上】

史上最详细《学成在线》项目实操笔记系列【上】,跟视频的每一P对应,全系列12万字,涵盖详细步骤与问题的解决方案。如果你操作到某一步卡壳,参考这篇,相信会带给你极大启发。 一、前期准备 1.1 项目介绍 P2 To C面向…

jvm垃圾收集器之七种武器

目录 1.回收算法 1.1 标记-清除算法(Mark-Sweep) 1.2 复制算法(Copying) 1.3 标记-整理算法(Mark-Compact) 2.HotSpot虚拟机的垃圾收集器 2.1 新生代的收集器 Serial 收集器(复制算法) ParNew 收集器 (复制算法) Parallel Scavenge 收集器 (复制…

熔断机制解析:如何用Hystrix保障微服务的稳定性

微服务与系统的弹性设计 大家好,我是小黑,在讲Hystrix之前,咱们得先聊聊微服务架构。想象一下,你把一个大型应用拆成一堆小应用,每个都负责一部分功能,这就是微服务。这样做的好处是显而易见的,更新快,容错性强,每个服务可以独立部署,挺美的对吧?但是,问题也随之而…

Win10系统备份的几种方案,以后不重装系统,备份系统恢复Backup,系统映像备份

Win10系统备份的几种方案 其实都不想重装系统,每次都不愿意去安装各种软件,麻烦,其实win10有几种备份的方案,可以参考一下。 如果下次出问题,我就将系统恢复到这个状态即可,真的不想重装系统,还…

Stable Diffusion 模型下载:Schematics(原理图)

文章目录 模型介绍生成案例案例一案例二案例三案例四案例五案例六案例七案例八案例九案例十下载地址模型介绍 “Schematics”是一个非常个性化的LORA,我的目标是创建一个整体风格,但主要面向某些风格美学,因此它可以用于人物、物体、风景等。这次你会得到“连线”和“方案”…