使用 BERT 和逻辑回归进行文本分类及示例验证

使用 BERT 和逻辑回归进行文本分类及示例验证

一、引言

在自然语言处理领域中,文本分类是一项至关重要的任务。本文将详细介绍如何结合 BERT 模型与逻辑回归算法来实现文本分类,并通过实际示例进行验证。

二、环境准备

为了运行本文中的代码,你需要安装以下库:

  • pandas:用于数据处理。
  • sklearn:包含机器学习算法。
  • torch:用于深度学习任务。
  • transformers:用于加载预训练语言模型。

三、代码实现

(一)读取数据集

首先,从 CSV 文件中读取数据集。假设该数据集包含两列,分别是content(文本内容)和labels(文本标签)。

import pandas as pd# 从 CSV 文件读取数据集
print("正在读取数据集...")
df = pd.read_csv('training_data.csv', encoding='utf-8-sig')
print("数据集读取完成,共包含 {} 条数据.".format(len(df)))

(二)分割数据集

接着,提取特征和目标,并将数据集分割为训练集和测试集。

# 提取特征和目标
X = df['content']
y = df['labels']# 分割数据集
print("正在分割数据集...")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("训练集大小: {}, 测试集大小: {}".format(len(X_train), len(X_test)))

(三)加载 BERT 模型和分词器

然后,加载 BERT 模型和分词器,以便将文本转化为特征向量。

import torch
from transformers import BertTokenizer, BertModel# 加载 BERT 模型和分词器
print("加载 BERT 模型和分词器...")
tokenizer = BertTokenizer.from_pretrained('D:\\bert-base-chinese')
model = BertModel.from_pretrained('D:\\bert-base-chinese')

(四)文本转化为特征向量

定义一个函数get_embeddings,用于将文本转化为特征向量。该函数利用 BERT 模型对文本进行编码,然后获取[CLS]标记的输出作为文本的特征向量。

# 文本转化为特征向量
def get_embeddings(texts):print("正在生成文本特征向量...")inputs = tokenizer(texts.tolist(), padding=True, truncation=True, return_tensors='pt')with torch.no_grad():outputs = model(**inputs)# 获取[CLS]标记的输出作为文本的特征向量return outputs.last_hidden_state[:, 0, :].numpy()

(五)训练分类模型

使用逻辑回归算法作为分类模型。先将训练集转化为 BERT 特征,然后训练分类模型。

from sklearn.linear_model import LogisticRegression# 转换训练集和测试集为 BERT 特征
X_train_bert = get_embeddings(X_train)
X_test_bert = get_embeddings(X_test)# 训练分类模型
print("正在训练分类模型...")
classifier = LogisticRegression(max_iter=1000)  # 使用逻辑回归
classifier.fit(X_train_bert, y_train)
print("模型训练完成.")

(六)预测

使用训练好的分类模型对测试集进行预测,并打印预测结果。

# 预测
print("正在进行预测...")
predictions = classifier.predict(X_test_bert)# 打印预测结果
print("预测结果:", predictions)

(七)示例数据验证

最后,添加一些示例数据进行验证。将示例数据转化为 BERT 特征,然后使用分类模型进行预测,并打印预测结果。

# 添加示例数据进行验证
sample_texts = ["音乐有助力放松大脑,心情愉悦。","热爱生活,享受人生",
]# 将示例数据转换为 BERT 特征
print("正在对示例数据进行预测...")
sample_embeddings = get_embeddings(pd.Series(sample_texts))
sample_predictions = classifier.predict(sample_embeddings)# 打印示例数据预测结果
for text, prediction in zip(sample_texts, sample_predictions):print(f"文本: \"{text}\" 预测标签: {prediction}")

四、总结

本文介绍了如何运用 BERT 和逻辑回归进行文本分类,并通过示例数据进行了验证。借助 BERT 模型学习到的文本上下文信息,能够显著提高文本分类的准确性。同时,逻辑回归算法的快速性使得我们可以高效地对大量文本进行分类。

五、完整代码

text_categorize_and_tag.py

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import torch
from transformers import BertTokenizer, BertModel# 从CSV文件读取数据集
print("正在读取数据集...")
df = pd.read_csv('training_data.csv', encoding='utf-8-sig')
print("数据集读取完成,共包含 {} 条数据.".format(len(df)))# 提取特征和目标
X = df['content']
y = df['labels']# 分割数据集
print("正在分割数据集...")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("训练集大小: {}, 测试集大小: {}".format(len(X_train), len(X_test)))# 加载BERT模型和分词器
print("加载BERT模型和分词器...")
tokenizer = BertTokenizer.from_pretrained('D:\\bert-base-chinese')
model = BertModel.from_pretrained('D:\\bert-base-chinese')# 文本转化为特征向量
def get_embeddings(texts):print("正在生成文本特征向量...")inputs = tokenizer(texts.tolist(), padding=True, truncation=True, return_tensors='pt')with torch.no_grad():outputs = model(**inputs)# 获取[CLS]标记的输出作为文本的特征向量return outputs.last_hidden_state[:, 0, :].numpy()# 转换训练集和测试集为BERT特征
X_train_bert = get_embeddings(X_train)
X_test_bert = get_embeddings(X_test)# 训练分类模型
print("正在训练分类模型...")
classifier = LogisticRegression(max_iter=1000)  # 使用逻辑回归
classifier.fit(X_train_bert, y_train)
print("模型训练完成.")# 预测
print("正在进行预测...")
predictions = classifier.predict(X_test_bert)# 打印预测结果
print("预测结果:", predictions)# 添加示例数据进行验证
sample_texts = ["音乐有助力放松大脑,心情愉悦。","热爱生活,享受人生",
]# 将示例数据转换为BERT特征
print("正在对示例数据进行预测...")
sample_embeddings = get_embeddings(pd.Series(sample_texts))
sample_predictions = classifier.predict(sample_embeddings)# 打印示例数据预测结果
for text, prediction in zip(sample_texts, sample_predictions):print(f"文本: \"{text}\" 预测标签: {prediction}")

training_data.csv

content,labels
"Python 是一种广泛使用的高级编程语言。","编程"
"自然语言处理是人工智能领域的重要研究方向。","NLP"
"机器学习是分析数据的重要工具。","机器学习"
"数据科学结合了统计学和计算机科学。","数据科学"
"人工智能正在改变我们的生活方式。","人工智能"
"深度学习能够处理复杂的数据集。","机器学习"
"很多企业开始应用人工智能技术以提高效率。","人工智能"
"数据分析是理解客户行为的重要工具。","数据科学"
"编程不仅是技术,更是一种思维方式。","编程"
"算法在大数据时代发挥着重要作用。","数据科学"
"音乐可以影响人的情绪和认知。","音乐"
"学习音乐可以提高学生的创造力。","教育"
"现场音乐会可以提供独特的视听体验。","娱乐"
"教育科技正在变革传统的学习方式。","教育"
"学习一门乐器有助于提升专注力。","音乐"
"电影和电视节目是现代娱乐的重要部分。","娱乐"
"音乐治疗被广泛应用于心理健康。","音乐"
"在线教育平台为学习者提供灵活的选择。","教育"
"综艺节目为观众提供了丰富的娱乐内容。","娱乐"
"这是一篇关于机器学习的文章。","科技"
"我喜欢户外活动和旅游。","生活"
"COVID-19疫情对全球经济产生了深远的影响。","财经"
"人工智能正在改变我们的生活方式。","科技"
"旅游是一种能让人开阔视野的活动。","生活"
"金融科技让我们的投资变得更加智能。","财经"
"环境保护对我们的未来至关重要。","环保"

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/58496.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

第8天:数据存储-补充材料——‘User.kt‘和‘UserDao‘解读

下面是对“第8天:数据存储”该文学习的更深层次的补充材料,对 ‘User.kt’和’UserDao’ 文件的理解。 下面对’User.kt’文件中每一行进行详细解释: 这段代码定义了一个数据类User,它用于与Room数据库中的表进行交互。下面是逐句…

网络安全包含哪些方面?如何加强网络安全建设?

系统安全、应用安全、物理安全、管理安全等都属于网络安全。 从大的角度,如系统安全来看,可以理解为在系统生命周期内应用系统安全工程和系统安全管理方法,辨识系统中的隐患,并采取有效的控制措施使其危险性最小。这包括操作系统的…

隐私保护下的数据提取策略

在隐私保护下进行数据提取,需要采取一系列策略来确保个人隐私得到妥善保护,同时满足数据使用的需求。以下是一些关键的策略和方法: 一、数据最小化原则 定义:仅收集和提取必要的数据,避免收集过多的个人信息或不相关…

qt QStackedLayout详解

QStackedLayout类提供了一种布局方式,使得在同一时间内只有一个子部件(或称为页面)是可见的。这些子部件被维护在一个堆栈中,用户可以通过切换来显示不同的子部件,适合用在需要动态显示不同界面的场景,如向…

【Web前端】JavaScript 对象原型与继承机制

JavaScript 是一种动态类型的编程语言,其核心特性之一就是对象和原型链。理解原型及其工作机制对于掌握 JavaScript 的继承和对象关系非常重要。 什么是原型 每个对象都有一个内部属性 ​​[[Prototype]]​​​,这个属性指向创建该对象的构造函数的原型…

请详细介绍python三大神器:迭代器、生成器、装饰器

Python三大神器分别是迭代器、生成器和装饰器,它们都是Python高级特性,可以提高程序的效率和灵活性。 迭代器(Iterators): 迭代器是一个对象,它允许逐个访问容器中的元素,而不需要提前把容器中的…

基于YOLO11/v10/v8/v5深度学习的危险驾驶行为检测识别系统设计与实现【python源码+Pyqt5界面+数据集+训练代码】

《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…

python实战(五)——构建自己的大模型助手

一、任务目标 本文将利用大语言模型强大的对话能力,搭建一个PC端问答助手。具体来说,我们将使用API来调用我们想要的大模型,并结合Prompt让大模型根据任务类型生成对应的输出。为了更方便地调用大模型助手,我们将结合python第三方…

FPGA技术优势

在当今数字化时代,现场可编程门阵列(FPGA)因其高灵活性和强大的处理能力而广泛应用于各种行业。FPGA允许用户在单个芯片上实现大量数字逻辑,以相对较高的速度并且无需依赖传统顺序程序。这种独特的能力使得FPGA能够在许多复杂应用中脱颖而出。 FPGA的基本特性 FPGA是一种…

Android OpenGL ES详解——裁剪Scissor

目录 一、概念 二、如何使用 1、开启裁剪测试 2、关闭裁剪测试 3、指定裁剪窗口(位置和大小) 4、裁剪应用举例 三、窗口、视⼝和裁剪区域三者区别 四、源码下载 一、概念 定义1: 裁剪是OpenGL中提⾼渲染的⼀种方式,只刷新…

江协科技STM32学习- P22 实验-ADC单通道/ADC多通道

🚀write in front🚀 🔎大家好,我是黄桃罐头,希望你看完之后,能对你有所帮助,不足请指正!共同学习交流 🎁欢迎各位→点赞👍 收藏⭐️ 留言📝​…

PostgreSQL和MySQL在MVCC

PostgreSQL和MySQL在MVCC(多版本并发控制)机制上的不同主要体现在以下几个方面: MVCC实现方式 : PostgreSQL将数据记录的多个版本保存在数据库中,当这些版本不再需要时,垃圾收集器会回收这些记录。MySQL/…

【数据结构】树-二叉树-堆(下)

🍃 如果觉得本系列文章内容还不错,欢迎订阅🚩 🎊个人主页:小编的个人主页 🎀 🎉欢迎大家点赞👍收藏⭐文章 ✌️ 🤞 🤟 🤘 🤙 👈 &…

D365 FO开发参考

FO数据库表参考 InventTable in Main - Common Data Model - Common Data Model | Microsoft Learn FO编程Query参考 Query.addDataSource Method | Microsoft Learn FO核心Enums查询 BatchStatus Enum (Microsoft.Dynamics.Ax.Xpp) | Microsoft Learn FO Function 开发参…

前端笔面试查漏补缺

面试笔试的知识总结&#xff0c;查漏补缺 一、CSS样式隔离 CSS样式隔离用于确保组件或元素的样式不会与其他组件或元素的样式发生冲突。 1.scoped css -- <style scoped> 构建工具&#xff08;vue-loader&#xff09;会在编译阶段对css特殊处理&#xff0c;给当前组…

-XSS-

链接 https://github.com/do0dl3/xss-labs 搭建过程非常容易的 搭建好之后&#xff0c;就可以点击图片开始闯关了 第一关--JS弹窗函数alert() 显示payload的长度是4 level1.php?nametest level1.php?nametest1 发现只要改变name的值就显示什么在页面上 没有什么过滤的 …

【数据结构与算法】《Java 算法宝典:探秘从排序到回溯的奇妙世界》

目录 标题&#xff1a;《Java 算法宝典&#xff1a;探秘从排序到回溯的奇妙世界》一、排序算法1、冒泡排序2、选择排序3、插入排序4、快速排序5、归并排序 二、查找算法1、线性查找2、二分查找 三、递归算法四、动态规划五、图算法1. 深度优先搜索&#xff08;DFS&#xff09;2…

sqlserver、达梦、mysql的差异

差异项sqlserver达梦mysql单行注释---- 1、-- &#xff0c;--后面带个空格 2、# 包裹对象名称&#xff0c;如表、表字段等 [tableName] "tableName"tableName表字段自增IDENTITY(1, 1)IDENTITY(1, 1)AUTO_INCREMENT二进制数据类型IMAGEIMAGE、BLOBBLOB 存储一个汉字需…

transformControls THREE.Object3D.add: object not an instance of THREE.Object3D.

把scene.add(transformControls);改为scene.add(transformControls.getHelper());

python删除文件夹下面的所有文件示例

要删除指定路径下的文件夹和文件&#xff0c;可以使用 Python 的 shutil 模块。这将允许你递归地删除文件夹及其内容。以下是一个示例&#xff1a; import shutil from pathlib import Path# 指定要删除的目录路径 directory Path(path/to/your/directory)# 检查目录是否存在…