CNN文本分类(tensorflow实现)

前言
  • 实现步骤
    • 1.安装tensorflow
    • 2.导入所需要的tensorflow库和其它相关模块
    • 3.设置随机种子
    • 4.定义模型相关超参数
    • 5.加载需要的数据集
    • 6.对加载的文本内容进行填充和截断
    • 7.构建自己模型
    • 8.训练构建的模型
    • 9.评估完成的模型
  • CNN(卷积神经网络)在文本分类任务中具有良好的特征提取能力、位置不变性、参数共享和处理大规模数据的优势,能够有效地学习文本的局部和全局特征,提高模型性能和泛化能力,所以本文将以CNN实现文本分类。
    CNN对文本分类的支持主要提现在:

    特征提取:CNN能够有效地提取文本中的局部特征。卷积层通过应用多个卷积核来捕获不同大小的n-gram特征,从而能够识别关键词、短语和句子结构等重要信息。

    位置不变性:对于文本分类任务,特征的位置通常是不重要的。CNN中的池化层(如全局最大池化)能够保留特征的最显著信息,同时忽略其具体位置,这对于处理可变长度的文本输入非常有帮助。

    参数共享:CNN中的卷积核在整个输入上共享参数,这意味着相同的特征可以在不同位置进行识别。这种参数共享能够极大地减少模型的参数量,降低过拟合的风险,并加快模型的训练速度。

    处理大规模数据:CNN可以高效地处理大规模的文本数据。由于卷积和池化操作的局部性质,CNN在处理文本序列时具有较小的计算复杂度和内存消耗,使得它能够适应大规模的文本分类任务。

    上下文建模:通过使用多个卷积核和不同的大小,CNN可以捕捉不同尺度的上下文信息。这有助于提高模型对文本的理解能力,并能够捕捉更长范围的依赖关系。

    实现步骤之前首先安装完成tensorflow
  • 使用这个代码安装的前提是你的深度学习已经环境存在
  • 例如:conda、pytorch、cuda、cudnn等环境
  • conda create -n tf python=3.8
    conda activate tf
    #tensorflow的安装
    pip install tensorflow-gpu -i https://pypi.douban.com/simple
    

    一. 测试tensorflow是否安装成功

  • 有三种方法

  • 方法一:

import tensorflow as tf 
print(tf.__version__)
#输出'2.0.0-alpha0'
print(tf.test.is_gpu_available())
#会输出True,则证明安装成功
#新版本的tf把tf.test.is_gpu_available()换成如下命令
import tensorflow as tf 
tf.config.list_physical_devices('GPU')
  • 方法二:
  • import tensorflow as tf 
    with tf.device('/GPU:0'):a = tf.constant(3)
    

    方法三:

  • #输入python,进入python环境
    import tensorflow as tf
    #查看tensorflow版本
    print(tf.__version__)
    #输出'2.0.0-alpha0'
    #测试GPU能否调用,先查看显卡使用情况
    import os 
    os.system("nvidia-smi")
    #调用显卡
    @tf.function
    def f():pass
    f()
    #这时会打印好多日志
    #再次查询显卡
    os.system("nvidia-smi")
    可以对比两次使用情况
    

    二、打开pycharm倒入你创建的tf环境,新建py文件开始构建代码

  • 1.导入所需的库和模块:

import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Embedding, Conv1D, GlobalMaxPooling1D

其中提前安装TensorFlow来用于构建和训练模型,以及Keras中的各种层和模型类

2.设置随机种子:

np.random.seed(42)

在CNN(卷积神经网络)中设置随机种子主要是为了保证实验的可重复性。由于深度学习模型中涉及大量的随机性,如权重的初始化、数据的打乱(shuffle)等,设置随机种子可以使得每次实验的随机过程都保持一致,从而使得实验结果可以复现

3.定义模型超参数:

max_features = 5000  # 词汇表大小
max_length = 100  # 文本最大长度
embedding_dims = 50  # 词嵌入维度
filters = 250  # 卷积核数量
kernel_size = 3  # 卷积核大小
hidden_dims = 250  # 全连接层神经元数量
batch_size = 32  # 批处理大小
epochs = 5  # 训练迭代次数

超参数影响模型的结构和训练过程,可自行调整。

4.加载数据集:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)

示例中,使用的IMDB电影评论数据集,其中包含以数字表示的评论文本和相应的情感标签(正面或负面),使用tf.keras.datasets.imdb.load_data函数可以方便地加载数据集,并指定num_words参数来限制词汇表的大小。

5.对文本进行填充和截断:

x_train = sequence.pad_sequences(x_train, maxlen=max_length)
x_test = sequence.pad_sequences(x_test, maxlen=max_length)

由于每条评论的长度可能不同,需要将它们统一到相同的长度。sequence.pad_sequences函数用于在文本序列前后进行填充或截断,使它们具有相同的长度。

6.构建模型:

model = Sequential()
model.add(Embedding(max_features, embedding_dims, input_length=max_length))
model.add(Dropout(0.2))
model.add(Conv1D(filters, kernel_size, padding='valid', activation='relu', strides=1))
model.add(GlobalMaxPooling1D())
model.add(Dense(hidden_dims, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(1, activation='sigmoid'))

这个模型使用Sequential模型类构建,依次添加了嵌入层(Embedding)、卷积层(Conv1D)、全局最大池化层(GlobalMaxPooling1D)和两个全连接层(Dense)。嵌入层将输入的整数序列转换为固定维度的词嵌入表示,卷积层通过应用多个卷积核来提取特征,全局最大池化层获取每个特征通道的最大值,而两个全连接层用于分类任务。

7.编译模型:

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

在编译模型之前,需要指定损失函数、优化器和评估指标。使用二元交叉熵作为损失函数,Adam优化器进行参数优化,并使用准确率作为评估指标。

8.训练模型:

model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test))

使用fit函数对模型进行训练。需要传入训练数据、标签,批处理大小、训练迭代次数,并可以指定验证集进行模型性能评估。

9.评估模型:

scores = model.evaluate(x_test, y_test, verbose=0)
print("Test accuracy:", scores[1])

使用evaluate函数评估模型在测试集上的性能,计算并打印出测试准确率。

完整代码

import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Embedding, Conv1D, GlobalMaxPooling1D# 设置随机种子
np.random.seed(42)# 定义模型超参数
max_features = 5000  # 词汇表大小
max_length = 100  # 文本最大长度
embedding_dims = 50  # 词嵌入维度
filters = 250  # 卷积核数量
kernel_size = 3  # 卷积核大小
hidden_dims = 250  # 全连接层神经元数量
batch_size = 32  # 批处理大小
epochs = 5  # 训练迭代次数# 加载数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data(num_words=max_features)# 对文本进行填充和截断,使其具有相同的长度
x_train = sequence.pad_sequences(x_train, maxlen=max_length)
x_test = sequence.pad_sequences(x_test, maxlen=max_length)# 构建模型
model = Sequential()
model.add(Embedding(max_features, embedding_dims, input_length=max_length))
model.add(Dropout(0.2))
model.add(Conv1D(filters, kernel_size, padding='valid', activation='relu', strides=1))
model.add(GlobalMaxPooling1D())
model.add(Dense(hidden_dims, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(1, activation='sigmoid'))# 编译模型
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])# 训练模型
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test))# 评估模型
scores = model.evaluate(x_test, y_test, verbose=0)
print("Test accuracy:", scores[1])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/715932.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【GPU驱动开发】-mesa简介

前言 不必害怕未知,无需恐惧犯错,做一个Creator! 一、mesa介绍 Mesa 是一个开源的3D图形库,它实现了多种图形API,包括 OpenGL、Vulkan 和 OpenCL。Mesa 的目标是提供一个开源、跨平台的图形库,使得开发者…

ABAP - SALV教程08 列设置热点及绑定点击事件

实现思路:将列设置成热点,热点列是可点击的,再给SALV实例对象注册点击事件即可,一般作用于点击单号跳转到前台等功能 "设置热点方法METHODS:set_hotspot CHANGING co_alv TYPE REF TO cl_salv_table...."事件处理方法M…

SMART原则

在软件研发领域,项目管理和目标设定尤为关键。一个成功的软件项目不仅需要先进的技术支持,还需要一个清晰、明确且可实现的目标。SMART原则,作为一种高效的目标设定和管理方法,为软件研发提供了有力的指导。SMART是五个英文单词首…

合宙esp32-c3 进入深度睡眠无法唤醒解决一例

手贱,昨天收到了嘉立创最新的esp32 s3,想测试一下电流功耗,于是顺便测试了一下以前的合宙esp32 c3 无串口芯片的版本 打算对比一下c3和s3的功耗相差多少,结果把自己玩死了: void setup() {esp_deep_sleep_start();// esp_light_s…

oppo手机备忘录记录怎么转移到华为手机?

oppo手机备忘录记录怎么转移到华为手机?使用oppo手机已经有三四年了,因为平时习惯,在手机系统的备忘录中记录了很多重要的笔记,比如工作会议的要点、读书笔记、购物清单、朋友的生日提醒等。这些记录对我来说非常重要,我可以通过…

STM32 HAL库 串口使用问题记录

文章目录 STM32 HAL库 串口使用问题记录情况一:串口导致程序假死机情况二:其它程序正常运行,串口不再接收数据 STM32 HAL库 串口使用问题记录 情况一:串口导致程序假死机 多数应该出现在未开启DMA模式使用中断方式接收数据的情况…

钾是人体内重要的电解质之一

钾是人体内重要的电解质之一,是维持细胞生理活动的主要阳离子,在保持机体的正常渗透压及酸碱平衡,维持内环境的稳定性,参与糖及蛋白质代谢,保证神经肌肉的正常功能,在兴奋性等方面具有重要的作用。人体内的…

2000-2021年300+地级市进出口总额数据

2000-2021年300地级市进出口总额数据 1、时间:2000-2021年 2、指标:进出口总额 3、单位:万美元 4、来源:城市年鉴、各省年鉴、城市公报、2021年为城市统计年鉴中进口额出口额加总之后换算成万美元,已尽最大可能进行…

20240303

1.在优势、劣势、机会与威胁(SWOT)的分析期间,团队发现另一个项目通过与该团队合作可能从规模经济中获益。两个项目的成本都可能大幅降低,并可能实现公司的利益,项目经理应该怎么做? A.在风险登记册中记录该发现 B.询问项目发起人的意见 …

1.亿级积分数据分库分表:总体方案设计

项目背景 以一个积分系统为例,积分系统最核心的有积分账户表和积分明细表: 积分账户表:每个用户在一个品牌下有一个积分账户记录,记录了用户的积分余额,数据量在千万级积分明细表:用户每次积分发放、积分扣…

数据结构——Top-k问题

Top-k问题 方法一:堆排序(升序)(时间复杂度O(N*logN))向上调整建堆(时间复杂度:O(N * logN) )向下调整建堆(时间复杂度:O(N) )堆排序代码 方法二&…

LeetCode---386周赛

题目列表 3046. 分割数组 3047. 求交集区域内的最大正方形面积 3048. 标记所有下标的最早秒数 I 3049. 标记所有下标的最早秒数 II 一、分割数组 这题简单的思维题,要想将数组分为两个数组,且分出的两个数组中数字不会重复,很显然一个数…

Redis 的哨兵模式配置

1.配置 vim sentinel.conf# mymaster 给主机起的名字 # 192.168.205.128 主机的ip地址 # 6379 端口号 # 2 当几个哨兵发现主观宕机,则判定为客观宕机。 原则上是大于一半。比如三个哨兵,则设置为 2 sentinel monitor mymaster 192.168.205.128 63…

【动态规划入门】01背包问题

每日一道算法题之01背包问题 一、题目描述二、思路三、C++代码四、结语一、题目描述 题目来源:Acwing 有N件物品和一个容量是 V的背包。每件物品只能使用一次。第 i件物品的体积是 vi,价值是 wi。 求解将哪些物品装入背包,可使这些物品的总体积不超过背包容量,且总价值最大…

LeetCode题练习与总结:合并K个升序链表

一、题目 给你一个链表数组,每个链表都已经按升序排列。 请你将所有链表合并到一个升序链表中,返回合并后的链表。 二、解题思路 创建一个最小堆(优先队列)来存储所有链表的头节点。这样我们可以始终取出当前所有链表中值最小…

人工智能指数报告2023

人工智能指数报告2023 主要要点第 1 章 研究与开发第 2 章 技术性能第 3 章 人工智能技术伦理第 4 章 经济第 5 章 教育第 6 章 政策与治理第 7 章 多样性第 8 章 舆论 人工智能指数是斯坦福大学以人为本的人工智能研究所(HAI)的一项独立倡议&#xff0c…

Java 石头剪刀布小游戏

一、任务 编写一个剪刀石头布游戏的程序。程序启动后会随机生成1~3的随机数,分别代表剪刀、石头和布,玩家通过键盘输入剪刀、石头和布与电脑进行5轮的游戏,赢的次数多的一方为赢家。若五局皆为平局,则最终结果判为平局。 二、实…

redis 为什么会阻塞

目录 前言 客户端交换时的阻塞 redis 磁盘交换的阻塞 主从节点交互的阻塞 切片集群交互时的阻塞 异步执行的演变 redis 异步执行如何实现的 前言 大家对redis 比较熟悉吧,只要做项目都会用到redis,提高系统的吞吐。小米商城抢购高峰18k的qps&…

KubeSphere平台安装系列之三【Linux多节点部署KubeSphere】(3/3)

**《KubeSphere平台安装系列》** 【Kubernetes上安装KubeSphere(亲测–实操完整版)】(1/3) 【Linux单节点部署KubeSphere】(2/3) 【Linux多节点部署KubeSphere】(3/3) **《KubeS…

一句话讲清楚数据库中事务的隔离级别(通俗易懂版)

为什么我只说通俗易懂版不说严谨版? 因为严谨版遍地都是, 但是他们却有一个缺点就是让人看得云里雾里, 所以这就是我写通俗易懂版的初衷! 但是既然是通俗易懂版就必然有缺陷, 只为了各位在开发过程中头脑更加清晰, 如有错误还望兄弟们不吝赐教! 在MySQL数据库中,事务一共有4…