NLP 算法实战项目:使用 BERT 进行模型微调,进行文本情感分析

本篇我们使用公开的微博数据集(weibo_senti_100k)进行训练,此数据集已经进行标注,0: 负面情绪,1:正面情绪。数据集共计82718条(包含标题)。如下图:

图片

下面我们使用bert-base-chinese预训练模型进行微调并进行测试。 技术交流,文末获取。

1. 导入必要的库

import torch
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, Dataset, random_split
import pandas as pd
from tqdm import tqdm
import random

2. 加载数据集和预训练模型

# 读取训练数据集
df = pd.read_csv("weibo_senti_100k.csv")  # 替换为你的训练数据集路径
# 加载预训练的BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertForSequenceClassification.from_pretrained('bert-base-chinese')

3. 对数据集进行预处理

注意:此处需要打乱数据行,为了快速训练展示,下面程序只加载了1500条数据。

# 设置随机种子以确保可重复性
random.seed(42)
# 随机打乱数据行
df = df.sample(frac=1).reset_index(drop=True)
# 数据集中1为正面,0为反面
class SentimentDataset(Dataset):def __init__(self, dataframe, tokenizer, max_length=128):self.dataframe = dataframeself.tokenizer = tokenizerself.max_length = max_lengthdef __len__(self):return len(self.dataframe)def __getitem__(self, idx):text = self.dataframe.iloc[idx]['review']label = self.dataframe.iloc[idx]['label']encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_length, return_tensors='pt')return {'input_ids': encoding['input_ids'].flatten(),'attention_mask': encoding['attention_mask'].flatten(),'labels': torch.tensor(label, dtype=torch.long)}# 创建数据集对象
dataset = SentimentDataset(df[:1500], tokenizer)

4. 将数据集分为训练集、验证集

# 创建数据集对象
dataset = SentimentDataset(df[:1500], tokenizer)# 划分训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

5. 设置训练参数

# 设置训练参数
optimizer = AdamW(model.parameters(), lr=5e-5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

6. 训练模型

# 训练模型
model.train()
for epoch in range(3):  # 3个epoch作为示例for batch in tqdm(train_loader, desc="Epoch {}".format(epoch + 1)):input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)optimizer.zero_grad()outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.lossloss.backward()optimizer.step()
# 输出
Epoch 1: 100%|██████████| 150/150 [00:28<00:00,  5.28it/s]
Epoch 2: 100%|██████████| 150/150 [00:29<00:00,  5.15it/s]
Epoch 3: 100%|██████████| 150/150 [00:27<00:00,  5.36it/s]

7. 评估模型

# 评估模型
model.eval()
total_eval_accuracy = 0
for batch in tqdm(val_loader, desc="Evaluating"):input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['labels'].to(device)with torch.no_grad():outputs = model(input_ids, attention_mask=attention_mask)logits = outputs.logitspreds = torch.argmax(logits, dim=1)accuracy = (preds == labels).float().mean()total_eval_accuracy += accuracy.item()average_eval_accuracy = total_eval_accuracy / len(val_loader)
print("Validation Accuracy:", average_eval_accuracy)
# 输出
Evaluating: 100%|██████████| 38/38 [00:02<00:00, 16.57it/s]Validation Accuracy: 0.9407894736842105

8. 进行预测

# 使用微调后的模型进行预测
def predict_sentiment(sentence):inputs = tokenizer(sentence, padding='max_length', truncation=True, max_length=128, return_tensors='pt').to(device)with torch.no_grad():outputs = model(**inputs)logits = outputs.logitsprobs = torch.softmax(logits, dim=1)positive_prob = probs[0][1].item()  # 1表示正面print("Positive Probability:", positive_prob)# 测试一个句子
predict_sentiment("我要发火了")
# 输出
Positive Probability: 0.19748596847057343

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了NLP技术与面试交流群, 想要进交流群、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2060,备注:技术交流

用通俗易懂方式讲解系列

  • 用通俗易懂的方式讲解:自然语言处理初学者指南(附1000页的PPT讲解)
  • 用通俗易懂的方式讲解:1.6万字全面掌握 BERT
  • 用通俗易懂的方式讲解:NLP 这样学习才是正确路线
  • 用通俗易懂的方式讲解:28张图全解深度学习知识!
  • 用通俗易懂的方式讲解:不用再找了,这就是 NLP 方向最全面试题库
  • 用通俗易懂的方式讲解:实体关系抽取入门教程
  • 用通俗易懂的方式讲解:灵魂 20 问帮你彻底搞定Transformer
  • 用通俗易懂的方式讲解:图解 Transformer 架构
  • 用通俗易懂的方式讲解:大模型算法面经指南(附答案)
  • 用通俗易懂的方式讲解:十分钟部署清华 ChatGLM-6B,实测效果超预期
  • 用通俗易懂的方式讲解:内容讲解+代码案例,轻松掌握大模型应用框架 LangChain
  • 用通俗易懂的方式讲解:如何用大语言模型构建一个知识问答系统
  • 用通俗易懂的方式讲解:最全的大模型 RAG 技术概览
  • 用通俗易懂的方式讲解:利用 LangChain 和 Neo4j 向量索引,构建一个RAG应用程序
  • 用通俗易懂的方式讲解:使用 Neo4j 和 LangChain 集成非结构化知识图增强 QA
  • 用通俗易懂的方式讲解:面了 5 家知名企业的NLP算法岗(大模型方向),被考倒了。。。。。
  • 用通俗易懂的方式讲解:NLP 算法实习岗,对我后续找工作太重要了!。
  • 用通俗易懂的方式讲解:理想汽车大模型算法工程师面试,被问的瑟瑟发抖。。。。
  • 用通俗易懂的方式讲解:基于 Langchain-Chatchat,我搭建了一个本地知识库问答系统
  • 用通俗易懂的方式讲解:面试字节大模型算法岗(实习)
  • 用通俗易懂的方式讲解:大模型算法岗(含实习)最走心的总结
  • 用通俗易懂的方式讲解:大模型微调方法汇总

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/734500.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

STC89C52串口通信详解

目录 前言 一.通信基本原理 1.1串行通信与并行通信 1.2同步通信和异步通信 1.2.1异步通信 1.2.2同步通信 1.3单工、半双工与全双工通信 1.4通信速率 二.串口通信简介 2.1接口标准 2.2串口内部结构 2.3串口相关寄存器 三.串口工作方式 四.波特率计算 五.串口初始化步骤 六.实验…

万界星空科技MES系统中的车间管理的作用

在了解mes生产管理系统的作用包括哪些方面之前&#xff0c;我们先来了解一下作为生产管理信息化的关键部分&#xff0c;车间管理系统包含哪几个部分&#xff1a;一、mes系统中的车间管理通常包含以下部分&#xff1a; 1、设备管理&#xff1a;用于监控车间内的设备状态&#xf…

新规正式发布 | 百度深度参编《生成式人工智能服务安全基本要求》

2024年2月29日&#xff0c;全国网络安全标准化技术委员会&#xff08; TC260 &#xff09;正式发布《生成式人工智能服务安全基本要求》&#xff08;以下简称《基本要求》&#xff09;。《基本要求》规定了生成式人工智能服务在安全方面的基本要求&#xff0c;包括语料安全、模…

springboot整合shiro的实战教程(二)

文章目录 整合思路1.创建springboot项目2.引入依赖3.创建Shiro Filter0.创建配置类1.配置shiroFilterFactoryBean2.配置WebSecurityManager3.创建自定义Relm4.配置自定义realm5.编写控制器跳转至index.html6.加入资源的权限控制7. 常见过滤器 登录认证实现登录界面开发controll…

目标网站屏蔽右键检查(使用开发者工具)

问题&#xff1a; 通过网络触手中想要获取某网站的数据出现&#xff1a;鼠标右击&#xff0c;或按ctrl F10 键 无反应&#xff08;也就是打不开类似谷歌的开发工具&#xff09; 问题同等与&#xff1a; 解决网页屏蔽F12或右键打开审查元素 引用&#xff1a; 作者&#xff…

C/C++ BM19 寻找峰值

文章目录 前言题目解决方案一1.1 思路阐述1.2 源码 解决方案二2.1 思路阐述2.2 源码 总结 前言 这道题第一遍做的时候题目条件没有好好的审阅&#xff0c;导致在判断边界问题的时候出了不少岔子。 我的方法是时间复杂度为O(N)的&#xff0c;官方的logN可能更好一些。我的就是简…

启发式算法:遗传算法

文章目录 遗传算法-引例交叉变异遗传算法遗传算法流程遗传算法应用遗传算法-引例 在一代代演化过程中,父母扇贝的基因组合产生新扇贝,所以遗传算法会选择两个原有的扇贝,然后对这两个扇贝的染色体进行随机交叉形成新的扇贝。迭代演化也会造成基因突变,遗传算法让新产生扇贝…

Mysql索引底层数据结构

Mysql索引底层数据结构 一、数据结构1.1.索引的本质1.2.MySQl特点1.3.索引数据结构1.4.B-Tree结构1.5.BTree结构1.6.查看mysql文件页大小&#xff08;16K&#xff09;1.7.为什么mysql页文件默认16K&#xff1f;1.8.Hash结构 二、存储引擎2.1.InnoDB2.1.1.聚集索引2.1.2.为什么建…

力扣:数组篇

1、数组理论基础 数组是存放在连续内存空间上的相同类型数据的集合。 需要两点注意的是 数组下标都是从0开始的。数组内存空间的地址是连续的 因为数组的在内存空间的地址是连续的&#xff0c;所以我们在删除或者增添元素的时候&#xff0c;就难免要移动其他元素的地址。 …

【你也能从零基础学会网站开发】Web建站之javascript入门篇 JavaScript中的表达式、运算符、位运算、递增递减

&#x1f680; 个人主页 极客小俊 ✍&#x1f3fb; 作者简介&#xff1a;web开发者、设计师、技术分享 &#x1f40b; 希望大家多多支持, 我们一起学习和进步&#xff01; &#x1f3c5; 欢迎评论 ❤️点赞&#x1f4ac;评论 &#x1f4c2;收藏 &#x1f4c2;加关注 JavaScript…

kali当中不同的python版本切换(超简单)

kali当中本身就是自带两个python版本的 配置 update-alternatives --install /usr/bin/python python /usr/bin/python2 100 update-alternatives --install /usr/bin/python python /usr/bin/python3 150 切换版本 update-alternatives --config python 0 1 2编号选择一个即可…

【MySQL篇】 MySQL基础学习

文章目录 前言基础数据类型DDL数据库操作查询数据库创建数据库删除数据库使用数据库 DDL表操作创建表查询表修改表删除 DML-增删改添加数据更改数据删除数据 DQL-查询基础查询条件查询聚合函数分组查询排序查询分页查询编写顺序 DML-用户及权限用户管理权限控制 函数字符串函数…

挑战杯 基于设深度学习的人脸性别年龄识别系统

文章目录 0 前言1 课题描述2 实现效果3 算法实现原理3.1 数据集3.2 深度学习识别算法3.3 特征提取主干网络3.4 总体实现流程 4 具体实现4.1 预训练数据格式4.2 部分实现代码 5 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 基于深度学习机器视觉的…

浅谈2024 年 AI 辅助研发趋势!

目录 ​编辑 引言 一、AI辅助研发现状 1. 技术发展 2. 工具集成 3. 应用场景 二、AI辅助研发趋势 1. 更高的自动化程度 2. 更高的智能化程度 3. 更多的领域应用 4. 更高的重视度 三、结论 四. 完结散花 悟已往之不谏&#xff0c;知来者犹可追 创作不易&#xff…

(南京观海微电子)——I3C协议介绍

特点 两线制总线&#xff1a;I2C仅使用两条线——串行数据线&#xff08;SDA&#xff09;和串行时钟线&#xff08;SCL&#xff09;进行通信&#xff0c;有效降低了连接复杂性。多主多从设备支持&#xff1a;I2C支持多个主设备和多个从设备连接到同一总线上。每个设备都有唯一…

017-$route、$router

$route、$router 1、$route2、$router 1、$route $route 对象表示当前的路由信息&#xff0c;包含了当前 URL 解析得到的信息。包含当前的路径&#xff0c;参数&#xff0c;query对象等。 使用场景&#xff1a; 获取路由传参&#xff1a;this.$route.query、this.$route.par…

【布局:1688,阿里海外的新筹码?】1688重新布局跨境海外市场:第一步开放1688API数据采集接口

2023年底&#xff0c;阿里巴巴“古早”业务1688突然成为“重头戏”&#xff0c;尤其宣布正式布局跨境业务的消息&#xff0c;一度引发电商圈讨论。1688重新布局跨境海外市场&#xff1a;第一步开放1688API数据采集接口 2023年11月中旬&#xff0c;阿里财报分析师电话会上&…

VUE——v-cloak指令

VUE——v-cloak指令 属性选择器&#xff0c;可以控制vue实例化完成前的dom样式 功能&#xff1a;利用vue实例化后v-cloak属性会消失&#xff0c;设置其样式 官网介绍 没用前效果&#xff1a;当vue没渲染完前&#xff0c;界面效果会看到{{aboutCloak}}字符&#xff0c;影响用户…

UDP与TCP:了解这两种网络协议的不同之处

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

红包题第一弹

下载附件&#xff0c;发现有86个压缩包 现每个压缩包里面都有图片&#xff0c;010打开图片末尾都有base64部分&#xff0c;并且每个压缩包里面图片末尾的base64长度一样&#xff0c;刚好每一张的base64长度为100。猜测需要拼接起来然后解码 写个python脚本 import os import …