基于LSTM的文本多分类任务

概述:

LSTM(Long Short-Term Memory,长短时记忆)模型是一种特殊的循环神经网络(RNN)架构,由Hochreiter和Schmidhuber于1997年提出。LSTM被设计来解决标准RNN在处理序列数据时遇到的长期依赖问题,即难以学习时间序列中相隔较远的事件之间的关联。

LSTM模型的核心是它的细胞(cell)状态和三个控制门结构:遗忘门(forget gate)、输入门(input gate)和输出门(output gate)。

以下是对LSTM模型关键组成部分的简述:

细胞状态(Cell State):细胞状态是LSTM的核心,它贯穿于整个LSTM单元,可以传输信息到网络的遥远部分。细胞状态可以看作是信息流动的“高速公路”,它允许信息在序列的不同部分之间长期传递。

遗忘门(Forget Gate):遗忘门决定了哪些信息应该从细胞状态中丢弃。它通过一个称为sigmoid的激活函数查看上一个隐藏状态(( h_{t-1} ))和当前输入(( x_t )),并输出一个介于0到1之间的数值给每个在细胞状态中的数字。1表示“完全保留这个信息”,而0表示“完全丢弃这个信息”。

输入门(Input Gate):输入门负责更新细胞状态。首先,一个sigmoid函数决定哪些值我们将要更新,然后一个tanh函数创建一个新的候选值向量,( \tilde{C}_t ),它可以被加到状态中。在遗忘门忘记旧状态的信息后,我们将这个候选值与sigmoid门的输出相乘,决定实际要更新的状态部分。

输出门(Output Gate):最后,我们需要决定输出值。输出值是基于细胞状态的,但会是一个过滤后的版本。首先,我们运行一个sigmoid函数来决定细胞状态的哪些部分将输出。然后,我们将细胞状态通过tanh(得到一个介于-1到1之间的值)并乘以sigmoid门的输出,以决定最终的输出。

代码案例

数据采用推特上对于新冠病毒的评级

代码详情如下

加载数据与依赖

import numpy as np
import pandas as pd
import os 
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from wordcloud import WordCloud
import re
from nltk.corpus import stopwords #模块包含了英语和其他语言的停用词列表。停用词是指在语言中非常常见的单词,
#加载数据
os.chdir('E:\python code\文本分类')train_Data = pd.read_excel('Corona_NLP_train.xlsx')
test_Data = pd.read_excel('Corona_NLP_test.xlsx')#train_data = pd.read_csv(train_path, encoding="ISO-8859-1") 

数据处理


"""
----------------------------------------------------------------------------
###################       数据处理               ###################
----------------------------------------------------------------------------
"""
print(train_Data.head())
print(train_Data.columns)
print(train_Data['Sentiment'].value_counts())
print(train_Data.shape)
print(test_Data.shape)
print(train_Data.info())#查看详情
for i in range(3):print(i)print(train_Data['OriginalTweet'][i].lower())#lower 转小写train_Data['OriginalTweet'] = train_Data['OriginalTweet'].astype(str)train_Data=train_Data.dropna(subset=['Location'])
test_Data=test_Data.dropna(subset=['Location'])#调整标签
def change_sen(sentiment):if sentiment == "Extremely Positive":return 'positive'elif sentiment == "Extremely Negative":return 'negative'elif sentiment == "Positive":return 'positive'elif sentiment == "Negative":return 'negative'else:return 'netural'train_Data['Sentiment'] = train_Data['Sentiment'].apply(lambda x: change_sen(x))
test_Data['Sentiment'] = test_Data['Sentiment'].apply(lambda x: change_sen(x))

EDA

----------------------------------------------------------------------------
###################       EDA               ###################
----------------------------------------------------------------------------
"""# 筛选前20的地区
top_20 = train_Data['Location'].value_counts().head(20)# 标记颜色
colors = ['#FF6347', '#FF7F50', '#FFD700', '#ADFF2F', '#00CED1', '#8A2BE2', '#A52A2A', '#5F9EA0', '#D2691E', '#FF1493', '#00BFFF', '#696969', '#008080', '#FFD700', '#9ACD32', '#FF4500', '#2E8B57', '#8B0000', '#B8860B', '#B0E0E6']# 构建柱形图
top_20.plot(kind='bar', color=colors, rot=45, figsize=(12, 6))# Add title and labels
plt.title("Top 20 Tweet Locations by Frequency")
plt.ylabel('Frequency')
plt.xlabel('Location')
plt.show()# 查看标签的分布
plt.figure(figsize=(8, 6))
sns.countplot(x='Sentiment', data=train_Data, color='#422e9e')
plt.title("Sentiment Distribution")
plt.xlabel("Sentiment")
plt.ylabel("Count")
plt.show()#查看内容的分布
#isinstance() 是一个内置函数,用来检查一个对象是否是一个特定类或继承自该类的实例。
text = ' '.join(tweet for tweet in train_Data['OriginalTweet'] if isinstance(tweet,str))Wordcloud = WordCloud(width=800 , height= 400,background_color='white').generate(text)plt.figure(figsize=(10,5))
plt.imshow(Wordcloud,interpolation='bilinear')
plt.axis('off')
plt.show()#查看文本的平均长度
text_len = [len(i) for i in train_Data['OriginalTweet']]
# 绘制箱型图
plt.boxplot(text_len)  # 设置vert=False让箱型图水平显示
plt.title('Boxplot of String Lengths')
plt.xlabel('Length of Strings')
plt.xticks([])  # 不显示x轴的刻度
plt.show()# 绘制柱形图
sns.histplot(text_len, bins=30, kde=True, color="#eb4034")
plt.title("Tweet Length Distribution")
plt.show()

前20地区的分布
在这里插入图片描述
类别分布
在这里插入图片描述
中间出现的词汇频率
在这里插入图片描述

特征工程


"""
----------------------------------------------------------------------------
###################      特征工程              ###################
----------------------------------------------------------------------------
"""
X = train_Data['OriginalTweet'].copy()y = train_Data['Sentiment'].copy()def data_cleaner(tweet):# 删除 http#sub 是re模块中的一个函数,用于替换字符串中符合正则表达式的部分。#\S+ 匹配一个或多个非空白字符# 删除 http 开头的连续的字符直到第一个空格tweet = re.sub(r'http\S+', ' ', tweet)#test = re.sub(r'http\S+', ' ', 'http:www.baidu.com test')#print(test)# 去除<>#.*? 是一个非贪婪匹配,.匹配除了换行符之外的任何单个字符,* 表示“零个或多个”的意思,? 使得.*变成非贪婪模式,意味着它会匹配尽可能少的字符。#*? 无线的匹配,如果精确的匹配加 .tweet = re.sub(r'<.*?>',' ', tweet)#test = re.sub(r'--*?', ' ', '<a---> test')#print(test)# 删除数字#\d 匹配任何数字字符(0-9)#+ 表示匹配前面的字符(在这里是\d)一次或多次。tweet = re.sub(r'\d+',' ', tweet)#test = re.sub(r'\d+',' ', '<a-123--> test')#print(test)# 删除一些和字符组合在一起的脏数据 # tweet = re.sub(r'#\w+',' ', tweet)#test = re.sub(r'#\w+',' ', 'Hello #world, this --s a #test tweet')#print(test)# 删除和字母组合在一起的脏数据 @tweet = re.sub(r'@\w+',' ', tweet)#添加停止测tweet = tweet.split()tweet = " ".join([word for word in tweet if not word in stop_words])return tweetstop_words = stopwords.words('english')
#调整字符
X_cleaned = X.apply(data_cleaner)
#查看数据
X_cleaned.head()

token 转化

from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences# 加载token
tokenizer = Tokenizer()
tokenizer.fit_on_texts(X_cleaned)#转换
X = tokenizer.texts_to_sequences(X_cleaned)# 向量表
vocab_size = len(tokenizer.word_index) + 1
print(f"向量表: {vocab_size}")# 查看赌赢数据的详情
print(f"\nSentence:\n{X_cleaned[6]}")
print(f"\nAfter tokenizing:\n{X[6]}")#对数据长度和截断和填充 默认最大长度 ,从尾部填充
# X_padded = pad_sequences(X, maxlen=5, padding='post')
X = pad_sequences(X, padding='post')
print(f"\nAfter padding:\n{X[6]}")

调整标签


"""
----------------------------------------------------------------------------
###################      调整标签              ###################
----------------------------------------------------------------------------
"""text = {"netural":0, "positive":1,"negative":2}
train_Data['Sentiment'] = train_Data['Sentiment'].map(text)y.replace(text, inplace=True)print(y.shape)

模型训练


import tensorflow as tf
from tensorflow.keras import layers as L
from tensorflow.keras.losses import SparseCategoricalCrossentropy #适用于稀疏标签数据的交叉熵损失函数# Hyperparameters
EPOCHS = 10
BATCH_SIZE = 32
embedding_dim = 16
units = 256# Define the model
model = tf.keras.Sequential([# 用于将输入的整数序列转换为密集的向量表示。vocab_size应该被替换为词汇表的大小。L.Embedding(vocab_size, embedding_dim),  #一个双向的LSTM层,它能够处理序列数据并且提供前向和后向的上下文信息。#units是LSTM层中单元的数量。return_sequences=True表示LSTM层的每个时间步都会返回一个输出,#这在后面接GlobalMaxPool1D层时是必需的L.Bidirectional(L.LSTM(units, return_sequences=True)),#全局最大池化层,它会沿着时间维度对序列进行最大值池化,从而减少输出的维度。L.GlobalMaxPool1D(),L.Dropout(0.4),#层:一个全连接层,这里用于实现非线性变换,activation="relu"指定了Rectified Linear Unit激活函数。L.Dense(64, activation="relu"),L.Dropout(0.4),L.Dense(3)  #最后输出3个结果
])# Compile the model
model.compile(#定义损失函数损失函数是SparseCategoricalCrossentropy,它适用于整数标签的稀疏分类问题#并且设置from_logits=True表示输入的是未经激活的logitsloss=SparseCategoricalCrossentropy(from_logits=True),optimizer='adam',metrics=['accuracy']
)# 清除之前的TensorFlow会话,释放资源,并确保后续的模型训练不受之前会话的影响。
tf.keras.backend.clear_session()history = model.fit(X, y, epochs=EPOCHS, validation_split=0.12, batch_size=BATCH_SIZE)

结果如下:
Epoch 1/10
896/896 [] - 78s 82ms/step - loss: 0.7185 - accuracy: 0.6824 - val_loss: 0.4261 - val_accuracy: 0.8526
Epoch 2/10
896/896 [
] - 57s 64ms/step - loss: 0.3591 - accuracy: 0.8832 - val_loss: 0.3745 - val_accuracy: 0.8741
Epoch 3/10
896/896 [] - 68s 76ms/step - loss: 0.2382 - accuracy: 0.9257 - val_loss: 0.4173 - val_accuracy: 0.8677
Epoch 4/10
896/896 [
] - 73s 81ms/step - loss: 0.1755 - accuracy: 0.9465 - val_loss: 0.4795 - val_accuracy: 0.8529
Epoch 5/10
896/896 [] - 73s 82ms/step - loss: 0.1394 - accuracy: 0.9556 - val_loss: 0.5664 - val_accuracy: 0.8450
Epoch 6/10
896/896 [
] - 79s 88ms/step - loss: 0.1119 - accuracy: 0.9642 - val_loss: 0.6328 - val_accuracy: 0.8401
Epoch 7/10
896/896 [] - 58s 64ms/step - loss: 0.0923 - accuracy: 0.9699 - val_loss: 0.7140 - val_accuracy: 0.8281
Epoch 8/10
896/896 [
] - 80s 89ms/step - loss: 0.0731 - accuracy: 0.9760 - val_loss: 0.7973 - val_accuracy: 0.8191
Epoch 9/10
896/896 [] - 74s 83ms/step - loss: 0.0566 - accuracy: 0.9822 - val_loss: 0.9219 - val_accuracy: 0.8133
Epoch 10/10
896/896 [
] - 52s 58ms/step - loss: 0.0472 - accuracy: 0.9851 - val_loss: 1.0420 - val_accuracy: 0.8140

模型验证


"""
----------------------------------------------------------------------------
###################      模型验证             ###################
----------------------------------------------------------------------------
"""plt.figure(figsize=(10, 6))
plt.plot(history.history['accuracy'], label='Training Accuracy', color='blue')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy', color='orange')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()#测试处理
X_test = test_Data['OriginalTweet'].copy()
y_test = test_Data['Sentiment'].copy()X_test = X_test.apply(data_cleaner)X_test = tokenizer.texts_to_sequences(X_test)X_test = pad_sequences(X_test, padding='post')y_test.replace(text, inplace=True)loss, acc = model.evaluate(X_test,y_test,verbose=0)
print('测试集损失: {}'.format(loss))
print('测试集准确率: {}'.format(acc))pred = model.predict(X_test).argmax(axis=1)
#混淆矩阵
print("Unique values in y_test:", y_test.unique())
print("Unique values in pred:", np.unique(pred))pred = pred.astype(int)from sklearn.metrics import confusion_matrix
conf = confusion_matrix(y_test, pred)labels = ['neutral', 'positive', 'negative']
cm = pd.DataFrame(conf, index=labels, columns=labels)import matplotlib.pyplot as plt
import seaborn as snsplt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/62701.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

frp软件实现网络穿透

1. 名词 1.1. 网络穿透 网络穿透是一种技术&#xff0c;用于解决内网设备或服务无法直接被外部网络访问的问题。通常&#xff0c;内网设备位于路由器后面&#xff0c;并没有公网 IP 地址&#xff0c;因此外部用户不能直接连接到这些设备。网络穿透通过一些特定的技术手段&…

leetcode3250. 单调数组对的数目 I,仅需1s

题目&#xff1a; https://leetcode.cn/problems/find-the-count-of-monotonic-pairs-i/description/ 不为别的&#xff0c;只是记录下这个超过100%&#xff0c;而且比原先最快的快了一个量级 不知道咋分析&#xff0c;反正得出结论就是&#xff0c;变大不变&#xff0c;变小…

使用docker-compose部署搜索引擎ElasticSearch6.8.10

背景 Elasticsearch 是一个开源的分布式搜索和分析引擎&#xff0c;基于 Apache Lucene 构建。它被广泛用于实时数据搜索、日志分析、全文检索等应用场景。 Elasticsearch 支持高效的全文搜索&#xff0c;并提供了强大的聚合功能&#xff0c;可以处理大规模的数据集并进行快速…

Zabbix 模板翻译自动化教程

在企业 IT 运维管理中&#xff0c;Zabbix 作为一款强大的开源监控平台被广泛应用。而 Zabbix 模板作为监控配置的重要组成部分&#xff0c;用来定义监控项、触发器、图形等。随着国际化的需求增加&#xff0c;Zabbix 模板的翻译工作变得日益重要&#xff0c;特别是在需要为不同…

Springboot小知识(1):启动类与配置

一、启动类&#xff08;引导类&#xff09; 在通常情况下&#xff0c;你创建的Spring应用项目都会为你自动生成一个启动类&#xff0c;它是这个应用的起点。 在Spring Boot中&#xff0c;引导类&#xff08;也称为启动类&#xff0c;通常是main方法所在的类&#xff09;是整个…

数据集-目标检测系列- 海边漫步锻炼人检测数据集 person >> DataBall

数据集-目标检测系列- 海边漫步锻炼人检测数据集 person >> DataBall DataBall 助力快速掌握数据集的信息和使用方式&#xff0c;会员享有 百种数据集&#xff0c;持续增加中。 需要更多数据资源和技术解决方案&#xff0c;知识星球&#xff1a; “DataBall - X 数据球…

NLP信息抽取大总结:三大任务(带Prompt模板)

信息抽取大总结 1.NLP的信息抽取的本质&#xff1f;2.信息抽取三大任务&#xff1f;3.开放域VS限定域4.信息抽取三大范式&#xff1f;范式一&#xff1a;基于自定义规则抽取&#xff08;2018年前&#xff09;范式二&#xff1a;基于Bert下游任务建模抽取&#xff08;2018年后&a…

手机中的核心SOC是什么?

大家好&#xff0c;我是山羊君Goat。 常常听说CPU&#xff0c;中央处理器等等的&#xff0c;它是一个电脑或单片机系统的核心&#xff0c;但是对于SOC可能相比于CPU了解的人没有那么广泛。 所以SOC是什么&#xff1f; SOC全称是System on Chip&#xff0c;就是片上系统&#…

网络--socket编程--基础

1、网络字节序 已知:内存中的很多数据都有大小端之分,在网络这,网络数据流也是有大小端之分的。 TCP/IP协议规定:网络数据流采用大端字节序(即低地址处放高位字节)。 因此,小端机器发送网络数据流之前,必须转为大端(一般的机器会自动转换): 在网络-本地字节序转换…

Transformers在计算机视觉领域中的应用【第1篇:ViT——Transformer杀入CV界之开山之作】

目录 1 模型结构2 模型的前向过程3 思考4 结论 论文&#xff1a; AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE 代码&#xff1a;https://github.com/google-research/vision_transformer Huggingface&#xff1a;https://github.com/huggingf…

<数据集>路面坑洼识别数据集<目标检测>

数据集格式&#xff1a;VOCYOLO格式 图片数量&#xff1a;665张 标注数量(xml文件个数)&#xff1a;665 标注数量(txt文件个数)&#xff1a;665 标注类别数&#xff1a;1 标注类别名称&#xff1a;[pothole] 序号类别名称图片数框数1pothole6651740 使用标注工具&#x…

PySide6 QSS(Qt Style Sheets) Reference: PySide6 QSS参考指南

Qt官网参考资料&#xff1a; QSS介绍&#xff1a; Styling the Widgets Application - Qt for Pythonhttps://doc.qt.io/qtforpython-6/tutorials/basictutorial/widgetstyling.html#tutorial-widgetstyling QSS 参考手册&#xff1a; Qt Style Sheets Reference | Qt Widge…

07.ES11 08.ES12

7.1、Promise.allSettled 调用 allsettled 方法&#xff0c;返回的结果始终是成功的&#xff0c;返回的是promise结果值 <script>//声明两个promise对象const p1 new Promise((resolve, reject) > {setTimeout(() > {resolve("商品数据 - 1");}, 1000)…

qt QGraphicsRotation详解

1、概述 QGraphicsRotation 是 Qt 框架中 QGraphicsTransform 的一个子类&#xff0c;它专门用于处理图形项的旋转变换。通过 QGraphicsRotation&#xff0c;你可以对 QGraphicsItem&#xff08;如形状、图片等&#xff09;进行旋转操作&#xff0c;从而创建动态和吸引人的视觉…

Unity Plane API解释

构造函数解释&#xff0c;d的解释为&#xff1a;距离是沿着平面法线从平面到原点的距离。注意&#xff0c;这意味着为正值的distance值将导致平面朝向原点。负的距离值会导致平面朝向远离原点。 试验&#xff1a; GetSide方法检测点是否位于平面的正向侧&#xff0c;结果显示…

通讯专题4.1——CAN通信之计算机网络与现场总线

从通讯专题4开始&#xff0c;来学习CAN总线的内容。 为了更好的学习CAN&#xff0c;先从计算机网络与现场总线开始了解。 1 计算机网络体系的结构 在我们生活当中&#xff0c;有许多的网络&#xff0c;如交通网&#xff08;铁路、公路等&#xff09;、通信网&#xff08;电信、…

深度学习模型:LSTM (Long Short-Term Memory) - 长短时记忆网络详解

一、引言 在深度学习领域&#xff0c;循环神经网络&#xff08;RNN&#xff09;在处理序列数据方面具有独特的优势&#xff0c;例如语音识别、自然语言处理等任务。然而&#xff0c;传统的 RNN 在处理长序列数据时面临着严重的梯度消失问题&#xff0c;这使得网络难以学习到长…

算法笔记:力扣24. 两两交换链表中的节点

思路&#xff1a; 本题最简单的就是通过递归的形式去实现 class Solution {public ListNode swapPairs(ListNode head) {if(head null || head.next null){return head;}ListNode next head.next;head.next swapPairs(next.next);next.next head;return next;} } 对于链…

ehr系统建设方案,人力资源功能模块主要分为哪些,hrm平台实际案例源码,springboot人力资源系统,vue,JAVA语言hr系统(源码)

eHR人力资源管理系统&#xff1a;功能强大的人力资源管理工具 随着企业规模的不断扩大和业务需求的多样化&#xff0c;传统的人力资源管理模式已无法满足现代企业的需求。eHR人力资源管理系统作为一种先进的管理工具&#xff0c;能够为企业提供高效、准确、实时的人力资源管理。…

【Android】从事件分发开始:原理解析如何解决滑动冲突

【Android】从事件分发开始&#xff1a;原理解析如何解决滑动冲突 文章目录 【Android】从事件分发开始&#xff1a;原理解析如何解决滑动冲突Activity层级结构浅析Activity的setContentView源码浅析AppCompatActivity的setContentView源码 触控三分显纷争&#xff0c;滑动冲突…