基于BERT的情感分析

基于BERT的情感分析

1. 项目背景

情感分析(Sentiment Analysis)是自然语言处理的重要应用之一,用于判断文本的情感倾向,如正面、负面或中性。随着深度学习的发展,预训练语言模型如BERT在各种自然语言处理任务中取得了显著的效果。本项目利用预训练语言模型BERT,构建一个能够对文本进行情感分类的模型。


2. 项目结构

sentiment-analysis/
├── data/
│   ├── train.csv        # 训练数据集
│   ├── test.csv         # 测试数据集
├── src/
│   ├── preprocess.py    # 数据预处理模块
│   ├── train.py         # 模型训练脚本
│   ├── evaluate.py      # 模型评估脚本
│   ├── inference.py     # 模型推理脚本
│   ├── utils.py         # 工具函数(可选)
├── models/
│   ├── bert_model.pt    # 保存的模型权重
├── logs/
│   ├── training.log     # 训练日志(可选)
├── README.md            # 项目说明文档
├── requirements.txt     # 依赖包列表
└── run.sh               # 一键运行脚本

3. 环境准备

3.1 系统要求

  • Python 3.6 或以上版本
  • GPU(可选,但建议使用以加速训练)

3.2 安装依赖

建议在虚拟环境中运行。安装所需的依赖包:

pip install -r requirements.txt

requirements.txt内容:

torch>=1.7.0
transformers>=4.0.0
pandas
scikit-learn
tqdm

4. 数据准备

4.1 数据格式

数据文件train.csvtest.csv的格式如下:

textlabel
I love this product.1
This is a bad movie.0
  • text:输入文本
  • label:目标标签,1为正面情感,0为负面情感

将数据文件保存至data/目录下。

4.2 数据集划分

可以使用train_test_split将数据划分为训练集和测试集。


5. 代码实现

5.1 数据预处理 (src/preprocess.py)

import pandas as pd
from transformers import BertTokenizer
from torch.utils.data import Dataset
import torchclass SentimentDataset(Dataset):"""自定义的用于情感分析的Dataset。"""def __init__(self, data_path, tokenizer, max_len=128):"""初始化Dataset。Args:data_path (str): 数据文件的路径。tokenizer (BertTokenizer): BERT的分词器。max_len (int): 最大序列长度。"""self.data = pd.read_csv(data_path)self.tokenizer = tokenizerself.max_len = max_lendef __len__(self):"""返回数据集的大小。"""return len(self.data)def __getitem__(self, idx):"""根据索引返回一条数据。Args:idx (int): 数据索引。Returns:dict: 包含input_ids、attention_mask和label的字典。"""text = str(self.data.iloc[idx]['text'])label = int(self.data.iloc[idx]['label'])encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_len, return_tensors="pt")return {'input_ids': encoding['input_ids'].squeeze(0),  # shape: [seq_len]'attention_mask': encoding['attention_mask'].squeeze(0),  # shape: [seq_len]'label': torch.tensor(label, dtype=torch.long)  # shape: []}

5.2 模型训练 (src/train.py)

import torch
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, AdamW, BertTokenizer, get_linear_schedule_with_warmup
from preprocess import SentimentDataset
import argparse
import os
from tqdm import tqdmdef train_model(data_path, model_save_path, batch_size=16, epochs=3, lr=2e-5, max_len=128):"""训练BERT情感分析模型。Args:data_path (str): 训练数据的路径。model_save_path (str): 模型保存的路径。batch_size (int): 批次大小。epochs (int): 训练轮数。lr (float): 学习率。max_len (int): 最大序列长度。"""# 初始化分词器和数据集tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')dataset = SentimentDataset(data_path, tokenizer, max_len=max_len)# 划分训练集和验证集train_size = int(0.8 * len(dataset))val_size = len(dataset) - train_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])# 数据加载器train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=batch_size)# 初始化模型model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)# 优化器和学习率调度器optimizer = AdamW(model.parameters(), lr=lr)total_steps = len(train_loader) * epochsscheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)# 设备设置device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)# 训练循环for epoch in range(epochs):model.train()total_loss = 0progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")for batch in progress_bar:optimizer.zero_grad()input_ids = batch['input_ids'].to(device)  # shape: [batch_size, seq_len]attention_mask = batch['attention_mask'].to(device)  # shape: [batch_size, seq_len]labels = batch['label'].to(device)  # shape: [batch_size]outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.lossloss.backward()optimizer.step()scheduler.step()total_loss += loss.item()progress_bar.set_postfix(loss=loss.item())avg_train_loss = total_loss / len(train_loader)print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {avg_train_loss:.4f}")# 验证模型model.eval()val_loss = 0correct = 0total = 0with torch.no_grad():for batch in val_loader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['label'].to(device)outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.losslogits = outputs.logitsval_loss += loss.item()preds = torch.argmax(logits, dim=1)correct += (preds == labels).sum().item()total += labels.size(0)avg_val_loss = val_loss / len(val_loader)val_accuracy = correct / totalprint(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}")# 保存模型os.makedirs(os.path.dirname(model_save_path), exist_ok=True)torch.save(model.state_dict(), model_save_path)print(f"Model saved to {model_save_path}")if __name__ == "__main__":parser = argparse.ArgumentParser(description="Train BERT model for sentiment analysis")parser.add_argument('--data_path', type=str, default='data/train.csv', help='Path to training data')parser.add_argument('--model_save_path', type=str, default='models/bert_model.pt', help='Path to save the trained model')parser.add_argument('--batch_size', type=int, default=16, help='Batch size')parser.add_argument('--epochs', type=int, default=3, help='Number of training epochs')parser.add_argument('--lr', type=float, default=2e-5, help='Learning rate')parser.add_argument('--max_len', type=int, default=128, help='Maximum sequence length')args = parser.parse_args()train_model(data_path=args.data_path,model_save_path=args.model_save_path,batch_size=args.batch_size,epochs=args.epochs,lr=args.lr,max_len=args.max_len)

5.3 模型评估 (src/evaluate.py)

import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from preprocess import SentimentDataset
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, BertTokenizer
import argparse
from tqdm import tqdmdef evaluate_model(data_path, model_path, batch_size=16, max_len=128):"""评估BERT情感分析模型。Args:data_path (str): 测试数据的路径。model_path (str): 训练好的模型的路径。batch_size (int): 批次大小。max_len (int): 最大序列长度。"""# 初始化分词器和数据集tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')dataset = SentimentDataset(data_path, tokenizer, max_len=max_len)loader = DataLoader(dataset, batch_size=batch_size)# 加载模型model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))model.eval()# 设备设置device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)all_preds = []all_labels = []with torch.no_grad():for batch in tqdm(loader, desc="Evaluating"):input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['label'].to(device)outputs = model(input_ids, attention_mask=attention_mask)logits = outputs.logitspreds = torch.argmax(logits, dim=1)all_preds.extend(preds.cpu().numpy())all_labels.extend(labels.cpu().numpy())accuracy = accuracy_score(all_labels, all_preds)precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')print(f"Accuracy: {accuracy:.4f}")print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")if __name__ == "__main__":parser = argparse.ArgumentParser(description="Evaluate BERT model for sentiment analysis")parser.add_argument('--data_path', type=str, default='data/test.csv', help='Path to test data')parser.add_argument('--model_path', type=str, default='models/bert_model.pt', help='Path to the trained model')parser.add_argument('--batch_size', type=int, default=16, help='Batch size')parser.add_argument('--max_len', type=int, default=128, help='Maximum sequence length')args = parser.parse_args()evaluate_model(data_path=args.data_path,model_path=args.model_path,batch_size=args.batch_size,max_len=args.max_len)

5.4 推理 (src/inference.py)

import torch
from transformers import BertTokenizer, BertForSequenceClassification
import argparsedef predict_sentiment(text, model_path, max_len=128):"""对输入的文本进行情感预测。Args:text (str): 输入的文本。model_path (str): 训练好的模型的路径。max_len (int): 最大序列长度。Returns:str: 预测的情感类别。"""# 初始化分词器和模型tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))model.eval()# 设备设置device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)# 数据预处理inputs = tokenizer(text, return_tensors="pt", truncation=True, padding='max_length', max_length=max_len)inputs = {key: value.to(device) for key, value in inputs.items()}# 模型推理with torch.no_grad():outputs = model(**inputs)logits = outputs.logitsprediction = torch.argmax(logits, dim=1).item()sentiment = "Positive" if prediction == 1 else "Negative"return sentimentif __name__ == "__main__":parser = argparse.ArgumentParser(description="Inference script for sentiment analysis")parser.add_argument('--text', type=str, required=True, help='Input text for sentiment prediction')parser.add_argument('--model_path', type=str, default='models/bert_model.pt', help='Path to the trained model')parser.add_argument('--max_len', type=int, default=128, help='Maximum sequence length')args = parser.parse_args()sentiment = predict_sentiment(text=args.text,model_path=args.model_path,max_len=args.max_len)print(f"Input Text: {args.text}")print(f"Predicted Sentiment: {sentiment}")

6. 项目运行

6.1 一键运行脚本 (run.sh)

#!/bin/bash# 训练模型
python src/train.py --data_path=data/train.csv --model_save_path=models/bert_model.pt# 评估模型
python src/evaluate.py --data_path=data/test.csv --model_path=models/bert_model.pt# 推理示例
python src/inference.py --text="I love this movie!" --model_path=models/bert_model.pt

6.2 单独运行

6.2.1 训练模型
python src/train.py --data_path=data/train.csv --model_save_path=models/bert_model.pt --epochs=3 --batch_size=16
6.2.2 评估模型
python src/evaluate.py --data_path=data/test.csv --model_path=models/bert_model.pt
6.2.3 模型推理
python src/inference.py --text="This product is great!" --model_path=models/bert_model.pt

7. 结果展示

7.1 训练结果

  • 损失下降曲线:可以使用matplotlibtensorboard绘制训练过程中的损失变化。
  • 训练日志:在logs/training.log中记录训练过程。

7.2 模型评估

  • 准确率(Accuracy):模型在测试集上的准确率。
  • 精确率、召回率、F1-score:更全面地评估模型性能。

7.3 推理示例

示例:

python src/inference.py --text="I absolutely love this!" --model_path=models/bert_model.pt

输出:

Input Text: I absolutely love this!
Predicted Sentiment: Positive

8. 注意事项

  • 模型保存与加载:确保模型保存和加载时的路径正确,特别是在使用相对路径时。
  • 设备兼容性:代码中已考虑CPU和GPU的兼容性,确保设备上安装了相应的PyTorch版本。
  • 依赖版本:依赖的库版本可能会影响代码运行,建议使用requirements.txt中指定的版本。

9. 参考资料

  • BERT论文
  • Hugging Face Transformers文档
  • PyTorch官方文档

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/60996.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

37.超级简易的计算器 C语言

超级简单&#xff0c;简单到甚至这个计算器输入都比较反人类 但是足够简单 有输入功能有Switch语句支持四种运算还能检查除数是不是0还能打印出完整的式子 #define _CRT_SECURE_NO_WARNINGS// 禁用安全警告 #include <stdio.h>int main() {double num1, num2;// 声明两…

【tokenization分词】WordPiece, Byte-Pair Encoding(BPE), Byte-level BPE(BBPE)的原理和代码

目录 前言 1、word (词粒度) 2、char (字符粒度) 3、subword (子词粒度) WordPiece Byte-Pair Encoding (BPE) Byte-level BPE(BBPE) 总结 前言 Tokenization&#xff08;分词&#xff09; 在自然语言处理(NLP)的任务中是最基本的一步&#xff0c;将文本处理成一串tok…

深入解析 MySQL 数据库:数据库时区问题

在 MySQL 数据库中&#xff0c;时区管理是一个重要且复杂的主题&#xff0c;尤其是在全球化的应用程序中。以下是关于 MySQL 数据库时区问题的深入解析&#xff1a; 1. 时区的概念 时区是指地球表面被分为若干个区域&#xff0c;每个区域的标准时间相对协调世界时 (UTC) 有所…

技术周总结 11.11~11.17 周日(Js JVM XML)

文章目录 一、11.11 周一1.1&#xff09;问题01&#xff1a;js中的prompt弹窗区分出来用户点击的是 确认还是取消进一步示例 1.2&#xff09;问题02&#xff1a;在 prompt弹窗弹出时默认给弹窗中写入一些内容 二、11.12 周二2.1) 问题02: 详解JVM中的本地方法栈本地方法栈的主要…

模拟实现STL中的list

目录 1.设计list的结点 2.设计list的迭代器 3.list类的设计总览 4.list类的迭代器操作 5.list类的四个特殊的默认成员函数 无参的默认构造函数 拷贝构造函数 赋值运算符重载函数 析构函数 6.list类的插入操作 7.list类的删除操作 8.list.hpp源代码 1.设计list的结点…

spark 设置hive.exec.max.dynamic.partition不生效

spark脚本和程序中设置ive.exec.max.dynamic.partition不生效 正确写法&#xff1a; spark-submit \ --master yarn \ --deploy-mode client \ --driver-memory 1G \ --executor-memory 12G \ --num-executors 8 \ --executor-cores 4 \--conf spark.hadoop.hive.exec.max.dyna…

.NET SDK 各操作系统开发环境搭建

一、Win10&#xff08;推荐&#xff09; 1、VS 2022 社区版 # 下载地址 https://visualstudio.microsoft.com/zh-hans/downloads/ 2、.NET 6 SDK # 下载地址 https://dotnet.microsoft.com/zh-cn/download/dotnet/6.0 3、Hello World 如果需要使用旧程序样式时&#xff0c;则…

centos7.9安装mysql社区版

文章目录 场景安装 场景 今天把家里闲置的笔记本安装了centos&#xff0c;设置内网穿透做个人服务器用&#xff0c; 这里记录下安装mysql的过程 安装 安装mysql源 sudo yum install -y wget wget https://dev.mysql.com/get/mysql80-community-release-el7-5.noarch.rpm sudo r…

IDEA怎么定位java类所用maven依赖版本及引用位置

在实际开发中&#xff0c;我们可能会遇到需要搞清楚代码所用依赖版本号及引用位置的场景&#xff0c;便于排查问题&#xff0c;怎么通过IDEA实现呢&#xff1f; 可以在IDEA中打开项目&#xff0c;右键点击maven的pom.xml文件&#xff0c;或者在maven窗口下选中项目&#xff0c;…

CommandLineRunner、ApplicationRunner和@PostConstruct

在Spring Boot中&#xff0c;CommandLineRunner、ApplicationRunner 和 PostConstruct 都是常用的生命周期管理接口或注解&#xff0c;它们有不同的用途和执行时机&#xff0c;帮助开发者在Spring应用启动过程中进行一些初始化操作或执行特定任务。下面分别介绍它们的特点、使用…

【Golang】——Gin 框架中的模板渲染详解

Gin 框架支持动态网页开发&#xff0c;能够通过模板渲染结合数据生成动态页面。在这篇文章中&#xff0c;我们将一步步学习如何在 Gin 框架中配置模板、渲染动态数据&#xff0c;并结合静态资源文件创建一个功能完整的动态网站。 文章目录 1. 什么是模板渲染&#xff1f;1.1 概…

力扣 LeetCode 144. 二叉树的前序遍历(Day6:二叉树)

解题思路&#xff1a; 方法一&#xff1a;递归&#xff08;中左右&#xff09; class Solution {List<Integer> res new ArrayList<>();public List<Integer> preorderTraversal(TreeNode root) {recur(root);return res;}public void recur(TreeNode roo…

高级 SQL 技巧讲解

​ 大家好&#xff0c;我是程序员小羊&#xff01; 前言&#xff1a; SQL&#xff08;结构化查询语言&#xff09;是管理和操作数据库的核心工具。从基本的查询语句到复杂的数据处理&#xff0c;掌握高级 SQL 技巧不仅能显著提高数据分析的效率&#xff0c;还能解决业务中的复…

Java Server Pages (JSP):动态网页开发的基石

在Web开发的广阔领域中&#xff0c;Java Server Pages (JSP) 作为一种将Java代码与HTML内容相结合的服务器端技术&#xff0c;始终占据着举足轻重的地位。作为Java Enterprise Edition (Java EE) 的核心组成部分&#xff0c;JSP不仅为开发者提供了强大的动态网页生成能力&#…

pom中无法下载下来的类外部引用只给一个jar的时候

比如jar在桌面上放着,操作步骤如下&#xff1a; 选择桌面&#xff0c;输入cmd ,执行mvn install:install-file -DgroupIdcom -DartifactIdaspose-words -Dversion15.8.0 -Dpackagingjar -Dclassifierjdk11 -Dfilejar包名称 即可把jar包引入成功。

c语言学习21数组

1.1数组介绍 概念&#xff1a;数组就是 相同数据类型 的一组数据的集合 数组中每个数据 元素 用一个名字来命名这个集合 数组名 用编号区分 下表&#xff08;从0开始自动标号&#xff09; 当处理大量…

No Module named pytorchvideo.losses问题解决

问题描述 最近在跑X3D的源码时发现&#xff0c;在conda powershell prompt中安装了pytorchvideo&#xff0c;但是仍然报错&#xff1a;No Module named pytorchvideo.losses 解决方案&#xff1a; 直接去https://gitcode.com/gh_mirrors/py/pytorchvideo/overview?utm_sour…

LLM学习笔记(1).env文件与Anaconda虚拟环境有何区别?

刚入门&#xff0c;我打算跟着这个随笔走&#xff1a; 前言 LLM 应用开发实践笔记 会先从ChatGPT的中文文档开始&#xff1a; 入门 | OpenAI 官方帮助文档中文版 以上笔记和文档里的内容就不会在我的学习笔记里重复了&#xff0c;最后会把注意力转移到Hugging Face和论文上…

【软件测试】设计测试用例的万能公式

文章目录 概念设计测试用例的万能公式常规思考逆向思维发散性思维万能公式水杯测试弱网测试如何进行弱网测试 安装卸载测试 概念 什么是测试用例&#xff1f; 测试⽤例&#xff08;Test Case&#xff09;是为了实施测试⽽向被测试的系统提供的⼀组集合&#xff0c;这组集合包…