bert文本分类微调笔记

Bert实现文本分类微调Demo

import random
from collections import namedtuple'''
有四种文本需要做分类,请使用bert处理这个分类问题
'''# 使用namedtuple定义一个类别(Category),包含两个字段:名称(name)和样例(samples)
Category = namedtuple('Category', ['name', 'samples'])# 定义四个不同的类别及其对应的样例文本
categories = [Category('Weather Forecast', ['今天北京晴转多云,气温20-25度。', '明天上海有小雨,记得带伞。']),  # 天气预报类别的样例Category('Company Financial Report', ['本季度公司净利润增长20%。', '年度财务报告显示,成本控制良好。']),  # 公司财报类别的样例Category('Company Audit Materials', ['审计发现内部控制存在漏洞。', '审计确认财务报表无重大错报。']),  # 公司审计材料类别的样例Category('Product Marketing Ad', ['新口味可乐,清爽上市!', '买一送一,仅限今日。'])  # 产品营销广告类别的样例
]def generate_data(num_samples_per_category=50):''' 生成模拟数据集输入:- num_samples_per_category: 每个类别生成的样本数量,默认为50输出:- data: 包含文本样本及其对应类别的列表,每项为一个元组(text, label)'''data = []  # 初始化存储数据的列表for category in categories:  # 遍历所有类别for _ in range(num_samples_per_category):  # 对每个类别生成指定数量的样本sample = random.choice(category.samples)  # 从该类别的样例中随机选择一条文本data.append((sample, category.name))  # 将文本及其类别添加到data列表中return data# 调用generate_data函数生成模拟数据集
train_data = generate_data(100)  # 为每个类别生成100个训练样本
test_data = generate_data(6)     # 生成少量(6个)测试样本用于演示'''
train_data = 
[('明天上海有小雨,记得带伞。', 'Weather Forecast'),('明天上海有小雨,记得带伞。', 'Weather Forecast'),('今天北京晴转多云,气温20-25度。', 'Weather Forecast'),('今天北京晴转多云,气温20-25度。', 'Weather Forecast'),('今天北京晴转多云,气温20-25度。', 'Weather Forecast'),('明天上海有小雨,记得带伞。', 'Weather Forecast'),('明天上海有小雨,记得带伞。', 'Weather Forecast'),('明天上海有小雨,记得带伞。', 'Weather Forecast'),('今天北京晴转多云,气温20-25度。', 'Weather Forecast'),]
'''from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, TensorDataset
import torch
import torch.nn.functional as F# 步骤1: 定义类别到标签的映射
label_map = {category.name: index for index, category in enumerate(categories)}
num_labels = len(categories)  # 类别总数# 步骤2: 初始化BERT分词器和模型
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)# 步骤3: 准备数据集
def encode_texts(texts, labels):# 对文本进行编码,得到BERT模型需要的输入格式encodings = tokenizer(texts, truncation=True, padding=True, return_tensors='pt')# 将标签名称转换为对应的索引label_ids = torch.tensor([label_map[label] for label in labels])return encodings, label_idsdef prepare_data(data):texts, labels = zip(*data)  # 解压数据encodings, label_ids = encode_texts(texts, labels)  # 编码数据dataset = TensorDataset(encodings['input_ids'], encodings['attention_mask'], label_ids)  # 创建数据集return DataLoader(dataset, batch_size=8, shuffle=True)  # 创建数据加载器# 步骤4: 准备训练和测试数据
train_loader = prepare_data(train_data)
test_loader = prepare_data(test_data)# 步骤5: 定义训练和评估函数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)def train_epoch(model, data_loader, optimizer):model.train()total_loss = 0for batch in data_loader:optimizer.zero_grad()input_ids, attention_mask, labels = batchinput_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.losstotal_loss += loss.item()loss.backward()optimizer.step()return total_loss / len(data_loader)def evaluate(model, data_loader):model.eval()total_acc = 0total_count = 0with torch.no_grad():for batch in data_loader:input_ids, attention_mask, labels = batchinput_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)outputs = model(input_ids, attention_mask=attention_mask)predictions = torch.argmax(outputs.logits, dim=1)total_acc += (predictions == labels).sum().item()total_count += labels.size(0)return total_acc / total_count# 步骤6: 训练模型
optimizer = AdamW(model.parameters(), lr=2e-5)for epoch in range(3):  # 训练3个epochtrain_loss = train_epoch(model, train_loader, optimizer)acc = evaluate(model, test_loader)print(f'Epoch {epoch+1}, Train Loss: {train_loss}, Test Accuracy: {acc*100:.2f}%')# 步骤7: 使用微调后的模型进行预测
def predict(text):encodings = tokenizer(text, truncation=True, padding=True, return_tensors='pt')input_ids = encodings['input_ids'].to(device)attention_mask = encodings['attention_mask'].to(device)with torch.no_grad():outputs = model(input_ids, attention_mask=attention_mask)predicted_class_id = torch.argmax(outputs.logits).item()return categories[predicted_class_id].name# 预测一个新文本
new_text = ["明天的天气怎么样?"]  # 注意这里是一个列表
predicted_category = predict(new_text)
print(f'The predicted category for the new text is: {predicted_category}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/32156.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java 和 Kotlin 单例模式写法对比

目录 1、饿汉模式 Java 写法: Kotlin 写法: Kotlin 这段代码反编译&简化后如下: 2、懒汉模式,静态同步方法 Java 写法: Kotlin 写法: Kotlin 这段代码反编译&简化后如下: 3、懒…

Node.js 渲染三维模型并导出为图片

Node.js 渲染三维模型并导出为图片 1. 前言 本文将介绍如何在 Node.js 中使用 Three.js 进行 3D 模型渲染。通过结合 gl 和 canvas 这两个主要依赖库,我们能够在服务器端实现高效的 3D 渲染。这个方法解决了在服务器端生成和处理 3D 图形的需求,使得可…

【mysql】常用操作:维护用户/开启远程/忘记密码/常用命令

一、维护用户 1.1 创建用户 -- 语法 > CREATE USER [username][host] IDENTIFIED BY [password];-- 例子: -- 添加用户user007,密码123456,并且只能在本地可以登录 > CREATE USER user007localhost IDENTIFIED BY 123456; -- 添加用户…

一文搞懂Linux信号【下】

目录 🚩引言 🚩阻塞信号 🚩信号保存 🚩信号捕捉 🚩操作信号集 1.信号集操作函数 2.其它操作函数 🚩总结: 🚩引言 在观看本博客之前,建议大家先看一文搞懂Linux信…

Mysql 分表存储、多段存储

分表存储 分表存储是一种常用的数据库优化技术,特别是当单一表中的数据量非常大时。分表可以帮助提高查询性能、简化数据管理,并优化备份过程。以下是分表存储的一些常见策略和步骤: 1. 选择分表策略 分表可以基于多种策略,常见…

大火《与凤行》演员片酬曝光!网友:难以置信......

近日,大火的《与凤行》演员片酬被网友曝光:女主赵丽颖2000万,男主林更新1200万,而最让人意外的则是辣目洋子,个人片酬高达800万,但是其在剧中戏份比较低,不少网友感叹:难以置信&…

Star、Star求Star

本章是介绍博主自己的一个小工具的。使用的PythonPyQt5开发的。顺带来求一波star🌟🌟!!! 地址:https://gitee.com/qinganan_admin/PyCom Pycom是博主开发的串口工具,要是说对比其他串口工具&…

IOS Swift 从入门到精通: 可选项、展开和类型转换

文章目录 处理缺失数据展开可选值用保护装置解开强制展开隐式解包可选值零合并可选链式调用可选尝试可失败的初始化器类型转换总结 处理缺失数据 我们已经使用诸如 之类的类型Int来保存像 5 这样的值。但是如果您想存储age用户的属性,如果您不知道某人的年龄&#…

第8章:系统质量属性与架构评估

软件系统属性包括功能属性和质量属性,软件架构重点关注的是质量属性。架构的基本需求是在满足功能属性的前提下,关注软件系统质量属性。为了精确、定量地表达系统的质量属性,通常会采用质量属性场景的方式进行描述。   在确定软件系统架构&…

OpenGL3.3_C++_Windows(15)

理解glad: OpenGL只是一个标准/规范,具体的实现是由驱动开发商针对特定显卡实现的,由于OpenGL驱动版本众多,它大多数函数的位置都无法在编译时确定下来,需要在运行时查询,因此开发者需要在运行时获取函数…

Flutter GetX 状态管理 响应式编程(三)

在2021年4月初,我们在应用开发中大量使用了 GetX,目前看来效果还不错,于是我最近也出了一套GetX的从入门到源码原理的分析教程,欢迎大家关注更新。 【1 GetX 基本使用路由管理】【2 GetX 使用入门 程序计数器】 第一步 使用 GetM…

可灵王炸更新,图生视频、视频续写,最长可达3分钟!Runway 不香了 ...

现在视频大模型有多卷? Runway 刚在6月17号 发布Gen3 ,坐上王座没几天; 可灵就在6月21日中午,重新夺回了王座!发布了图生视频功能,视频续写功能! 一张图概括: 二师兄和团队老师第一…

实施高效冷却技术:确保滚珠丝杆稳定运行!

滚珠丝杆在运行过程中,由于摩擦、惯性力等因素,会产生一定的热量,当热量无法及时散发时,滚珠丝杆的温度就会升高,会直接影响滚珠丝杆的精度和稳定性,从而影响最终的产品质量。为了让滚珠丝杆保持应有的精度…

Redis源码学习:ziplist的数据结构和连锁更新问题

ziplist ziplist 是 Redis 中一种紧凑型的列表结构&#xff0c;专门用来存储元素数量少且每个元素较小的数据。它是一个双端链表&#xff0c; 可以在任意一端进行压入/弹出操作&#xff0c;并且该操作的时间复杂度为O(1)。 ziplist数据结构 <zlbytes><zltail>&l…

Linux基础指令(三)

目录 shell 权限指令&#xff1a; 文件的操作权限&#xff1a; 对文件进行操作的用户分类&#xff1a; 用户对文件进行的操作分类&#xff1a; 所有者、所属组、其他的访问权限&#xff1a; 创建用户 沾滞位 匹配查找指令&#xff1a; grep find shell shell&#x…

Ubuntu22.04开机后发现IP地址变成127.0.0.1

开机就是这个样子 解决办法 ip地址可能被释放&#xff0c;需要重新设置成自动分配 sudo dhclient -v可能网卡未加托管 查看方式: nmcli n若是enable就是已被托管,若是disabled&#xff0c;说明网卡未被托管 解决办法: nmcli n on搞定

DataWhale - 吃瓜教程学习笔记(二)

学习视频&#xff1a;第3章-一元线性回归_哔哩哔哩_bilibili 西瓜书对应章节&#xff1a; 3.1 - 3.2 一元线性回归 - 最小二乘法 - 极大似然估计 - 梯度 多元函数的一阶导数 - 海塞矩阵 多元函数的二阶导数 - 机器学习三要素

软件介绍—Fluent Reader (RSS阅读器)

软件介绍—Fluent Reader &#xff08;RSS阅读器&#xff09; 01 RSS介绍 RSS可翻译为简易信息聚合&#xff08;也叫聚合内容&#xff09;是一种基于XML的标准&#xff0c;在互联网上被广泛采用的内容包装和投递协议。简单来讲&#xff0c;就是可以“订阅”一些网站新发布的内…

【Android面试八股文】Kotlin内置标准函数also的原理是什么?

文章目录 原理解析应用场景为什么使用 `also`?also 是 Kotlin 标准库中的一个内置函数,其原理和应用场景可以通过源码和示例来解释。 原理解析 also 的定义如下: /*** Calls the specified function [block] with `this` value as its argument and returns `this` value.…

Android 开发Android Studio创建第一个Android应用

本文讲解如何Android Studio创建第一个Android应用。 启动Android Studio 或打开的项目的界面 点击File-New-New Project 选择“ Empty Views Activity”&#xff0c;点击Next 点击Next&#xff0c;项目创建完成如下&#xff1a; 创建项目完成&#xff0c;自带一个Activity。 …