RNN文本分类任务实战

  1. 递归神经网络 (RNN):
    定义:RNN 是一类专为顺序数据处理而设计的人工神经网络。
    顺序处理:RNN 保持一个隐藏状态,该状态捕获有关序列中先前输入的信息,使其适用于涉及顺序依赖关系的任务。
  2. 词嵌入:
    定义:词嵌入是捕获语义关系的词的密集向量表示。
    重要性:它们允许神经网络学习上下文信息和单词之间的关系。
    实现:使用预先训练的词嵌入(Word2Vec、GloVe)或在模型中包含嵌入层。
  3. 文本标记化和填充:
    代币化:将文本分解为单个单词或子单词。
    填充:通过添加零或截断来确保所有序列具有相同的长度。
  4. Keras 中的顺序模型:
    实现:利用 Keras 库中的 Sequential 模型创建线性层堆栈。
  5. 嵌入层:
    实现:向模型添加嵌入层,将单词转换为密集向量。
    配置:指定输入维度、输出维度(嵌入大小)和输入长度。
  6. 循环层(LSTM 或 GRU):
    LSTM 和 GRU:长短期记忆 (LSTM) 和门控循环单元 (GRU) 层有助于捕获长期依赖关系。
    实现:将一个或多个 LSTM 或 GRU 层添加到模型中。
  7. 致密层:
    目的:密集层用于最终分类输出。
    实现:添加一个或多个具有适当激活函数的密集层。
  8. 激活功能:
    选择:ReLU(整流线性单元)或tanh是隐藏层中激活函数的常见选择。
  9. 损失函数和优化器:
    损失函数:稀疏分类交叉熵通常用于文本分类任务。
    优化:Adam 或 RMSprop 是常用的优化器。
  10. 批处理和排序:
    批处理:在批量输入序列上训练模型。
    处理不同长度的物料:使用填充来处理不同长度的序列。
  11. 培训流程:
    汇编:使用所选的损失函数、优化器和指标编译模型。
    训练:将模型拟合到训练数据,在单独的集合上进行验证。
  12. 防止过拟合:
    技术:实现 dropout 或 recurrent dropout 层以防止过拟合。
    正规化:如果需要,请考虑 L1 或 L2 正则化。
  13. 超参数调优:
    参数:根据验证性能调整超参数,例如学习率、批量大小和循环单元数。
  14. 评估指标:
    指标:选择适当的指标,如准确率、精确率、召回率或 F1 分数进行评估。

# 文本分类任务实战
# 数据集构建:影评数据集进行情感分析
# 词向量模型:加载训练好的词向量或者自己训练
# 序列网络模型:训练好RNN模型进行识别import  os
import warnings
warnings.filterwarnings('ignore')
import tensorflow as tf
import numpy as np
import pprint
import logging
import time
from collections import  Counterfrom pathlib import Path
from tqdm import tqdm#加载影评数据集,可以自动下载放到对应位置
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.imdb.load_data()
# a=x_train.shape
# print(a)
# 读进来的数据是已经转换成ID映射的,一般的数据读进来都是词语,都需要手动转换成ID映射的_word2idx = tf.keras.datasets.imdb.get_word_index()
word2idx = {w: i+3 for w, i in _word2idx.items()}
word2idx['<pad>'] = 0
word2idx['<start>'] = 1
word2idx['<unk>'] = 2
idx2word = {i: w for w, i in word2idx.items()}# 按文本长度大小进行排序def sort_by_len(x, y):x, y = np.asarray(x), np.asarray(y)idx = sorted(range(len(x)), key=lambda i: len(x[i]))return x[idx], y[idx]# 将中间结果保存到本地,万一程序崩了还得重玩,保存的是文本数据,不是IDx_train, y_train = sort_by_len(x_train, y_train)
x_test, y_test = sort_by_len(x_test, y_test)def write_file(f_path, xs, ys):with open(f_path, 'w',encoding='utf-8') as f:for x, y in zip(xs, ys):f.write(str(y)+'\t'+' '.join([idx2word[i] for i in x][1:])+'\n')write_file('./data/train.txt', x_train, y_train)
write_file('./data/test.txt', x_test, y_test)# 构建语料表,基于词频来进行统计counter = Counter()
with open('./data/train.txt',encoding='utf-8') as f:for line in f:line = line.rstrip()label, words = line.split('\t')words = words.split(' ')counter.update(words)words = ['<pad>'] + [w for w, freq in counter.most_common() if freq >= 10]
print('Vocab Size:', len(words))Path('./vocab').mkdir(exist_ok=True)with open('./vocab/word.txt', 'w',encoding='utf-8') as f:for w in words:f.write(w+'\n')# 得到新的word2id映射表word2idx = {}
with open('./vocab/word.txt',encoding='utf-8') as f:for i, line in enumerate(f):line = line.rstrip()word2idx[line] = i# embedding层
# 可以基于网络来训练,也可以直接加载别人训练好的,一般都是加载预训练模型
# 这里有一些常用的:https://nlp.stanford.edu/projects/glove/#做了一个大表,里面有20598个不同的词,【20599*50】
embedding = np.zeros((len(word2idx)+1, 50)) # + 1 表示如果不在语料表中,就都是unknowwith open('./data/glove.6B.50d.txt',encoding='utf-8') as f: #下载好的count = 0for i, line in enumerate(f):if i % 100000 == 0:print('- At line {}'.format(i)) #打印处理了多少数据line = line.rstrip()0sp = line.split(' ')word, vec = sp[0], sp[1:]if word in word2idx:count += 1embedding[word2idx[word]] = np.asarray(vec, dtype='float32') #将词转换成对应的向量# 现在已经得到每个词索引所对应的向量print("[%d / %d] words have found pre-trained values"%(count, len(word2idx)))
np.save('./vocab/word.npy', embedding)
print('Saved ./vocab/word.npy')# 构建训练数据
# 注意所有的输入样本必须都是相同shape(文本长度,词向量维度等)
# 数据生成器
# tf.data.Dataset.from_tensor_slices(tensor):将tensor沿其第一个维度切片,返回一个含有N个样本的数据集,这样做的问题就是需要将整个数据集整体传入,然后切片建立数据集类对象,比较占内存。
#
# tf.data.Dataset.from_generator(data_generator,output_data_type,output_data_shape):从一个生成器中不断读取样本def data_generator(f_path, params):with open(f_path,encoding='utf-8') as f:print('Reading', f_path)for line in f:line = line.rstrip()label, text = line.split('\t')text = text.split(' ')x = [params['word2idx'].get(w, len(word2idx)) for w in text]#得到当前词所对应的IDif len(x) >= params['max_len']:#截断操作x = x[:params['max_len']]else:x += [0] * (params['max_len'] - len(x))#补齐操作y = int(label)yield x, ydef dataset(is_training, params):_shapes = ([params['max_len']], ())_types = (tf.int32, tf.int32)if is_training:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['train_path'], params),output_shapes=_shapes,output_types=_types, )ds = ds.shuffle(params['num_samples'])ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)  # 设置缓存序列,根据可用的CPU动态设置并行调用的数量,说白了就是加速else:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['test_path'], params),output_shapes=_shapes,output_types=_types, )ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)return ds# 自定义网络模型class Model(tf.keras.Model):def __init__(self, params):super().__init__()self.embedding = tf.Variable(np.load('./vocab/word.npy'),dtype=tf.float32,name='pretrained_embedding',trainable=False, )self.drop1 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop2 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop3 = tf.keras.layers.Dropout(params['dropout_rate'])self.rnn1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn3 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.drop_fc = tf.keras.layers.Dropout(params['dropout_rate'])self.fc = tf.keras.layers.Dense(2 * params['rnn_units'], tf.nn.elu)self.out_linear = tf.keras.layers.Dense(2)def call(self, inputs, training=False):if inputs.dtype != tf.int32:inputs = tf.cast(inputs, tf.int32)batch_sz = tf.shape(inputs)[0]rnn_units = 2 * params['rnn_units']x = tf.nn.embedding_lookup(self.embedding, inputs)x = tf.reshape(x, (batch_sz * 10 * 10, 10, 50))x = self.drop1(x, training=training)x = self.rnn1(x)x = tf.reduce_max(x, 1)x = tf.reshape(x, (batch_sz * 10, 10, rnn_units))x = self.drop2(x, training=training)x = self.rnn2(x)x = tf.reduce_max(x, 1)x = tf.reshape(x, (batch_sz, 10, rnn_units))x = self.drop3(x, training=training)x = self.rnn3(x)x = tf.reduce_max(x, 1)x = self.drop_fc(x, training=training)x = self.fc(x)x = self.out_linear(x)return x# 设置参数params = {'vocab_path': './vocab/word.txt','train_path': './data/train.txt','test_path': './data/test.txt','num_samples': 25000,'num_labels': 2,'batch_size': 32,'max_len': 1000,'rnn_units': 200,'dropout_rate': 0.2,'clip_norm': 10.,'num_patience': 3,'lr': 3e-4,
}def is_descending(history: list):history = history[-(params['num_patience']+1):]for i in range(1, len(history)):if history[i-1] <= history[i]:return Falsereturn Trueword2idx = {}
with open(params['vocab_path'],encoding='utf-8') as f:for i, line in enumerate(f):line = line.rstrip()word2idx[line] = i
params['word2idx'] = word2idx
params['vocab_size'] = len(word2idx) + 1model = Model(params)
model.build(input_shape=(None, None))#设置输入的大小,或者fit时候也能自动找到
#pprint.pprint([(v.name, v.shape) for v in model.trainable_variables])#链接:https://tensorflow.google.cn/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay?version=stable
#return initial_learning_rate * decay_rate ^ (step / decay_steps)
decay_lr = tf.optimizers.schedules.ExponentialDecay(params['lr'], 1000, 0.95)#相当于加了一个指数衰减函数
optim = tf.optimizers.Adam(params['lr'])
global_step = 0history_acc = []
best_acc = .0t0 = time.time()
logger = logging.getLogger('tensorflow')
logger.setLevel(logging.INFO)while True:# 训练模型for texts, labels in dataset(is_training=True, params=params):with tf.GradientTape() as tape:  # 梯度带,记录所有在上下文中的操作,并且通过调用.gradient()获得任何上下文中计算得出的张量的梯度logits = model(texts, training=True)loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)loss = tf.reduce_mean(loss)optim.lr.assign(decay_lr(global_step))grads = tape.gradient(loss, model.trainable_variables)grads, _ = tf.clip_by_global_norm(grads, params['clip_norm'])  # 将梯度限制一下,有的时候回更新太猛,防止过拟合optim.apply_gradients(zip(grads, model.trainable_variables))  # 更新梯度if global_step % 50 == 0:logger.info("Step {} | Loss: {:.4f} | Spent: {:.1f} secs | LR: {:.6f}".format(global_step, loss.numpy().item(), time.time() - t0, optim.lr.numpy().item()))t0 = time.time()global_step += 1# 验证集效果m = tf.keras.metrics.Accuracy()for texts, labels in dataset(is_training=False, params=params):logits = model(texts, training=False)y_pred = tf.argmax(logits, axis=-1)m.update_state(y_true=labels, y_pred=y_pred)acc = m.result().numpy()logger.info("Evaluation: Testing Accuracy: {:.3f}".format(acc))history_acc.append(acc)if acc > best_acc:best_acc = acclogger.info("Best Accuracy: {:.3f}".format(best_acc))if len(history_acc) > params['num_patience'] and is_descending(history_acc):logger.info("Testing Accuracy not improved over {} epochs, Early Stop".format(params['num_patience']))break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/601692.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

linux 安装minio

下载minio安装包&#xff0c;linux下minio为单文件 创建minio安装目录&#xff0c;将minio安装包复制到minio安装目录下 如&#xff1a; 安装目录&#xff1a;/usr/local/software/minio 创建日志文件&#xff1a;touch minio.log 创建启动脚本: touch minio-start.sh 编…

Supershell反溯源配置

简介 项目地址&#xff1a;https://github.com/tdragon6/Supershell Supershell是一个集成了reverse_ssh服务的WEB管理平台&#xff0c;使用docker一键部署&#xff08;快速构建&#xff09;&#xff0c;支持团队协作进行C2远程控制&#xff0c;通过在目标主机上建立反向SSH隧…

【Java EE初阶六】多线程案例(单例模式)

1. 单例模式 单例模式是一种设计模式&#xff0c;设计模式是我们必须要掌握的一个技能&#xff1b; 1.1 关于框架和设计模式 设计模式是软性的规定&#xff0c;且框架是硬性的规定&#xff0c;这些都是技术大佬已经设计好的&#xff1b; 一般来说设计模式有很多种&#xff0c;…

谈谈Mongodb insertMany的一些坑

概述 Mongodb提供了多种方法向集合中插入数据 插入一条数据 db.collection.insertOne() 插入多个文档 db.collection.insertMany() 更新集合中不存在的文档数据&#xff0c;指定{upsert: true}时插入数据 db.collection.updateOne() db.collection.updateMany() db.coll…

Go语言之父:开源14年,Go不止是编程语言,究竟做对了哪些?

提及编程语言&#xff0c;2023 年&#xff0c;除了老牌的 C 和新晋之秀 Rust 热度最高之外&#xff0c;就要数 Go 了。 从 2009 年由 C 语言获取灵感而发布&#xff0c;到如今风靡已久的高性能语言&#xff0c;Go 已经走过了 14 个年头。 “Go是一个项目&#xff0c;不只是一门…

基于ssm的智慧社区电子商务系统+vue论文

目 录 目 录 I 摘 要 III ABSTRACT IV 1 绪论 1 1.1 课题背景 1 1.2 研究现状 1 1.3 研究内容 2 2 系统开发环境 3 2.1 vue技术 3 2.2 JAVA技术 3 2.3 MYSQL数据库 3 2.4 B/S结构 4 2.5 SSM框架技术 4 3 系统分析 5 3.1 可行性分析 5 3.1.1 技术可行性 5 3.1.2 操作可行性 5 3…

Go语言中的init函数的执行时机

init函数的执行时机 这个涉及到 init 函数的作用和执行顺序相同个文件和不同文件中以及在不同的包中init的执行顺序go文件初始化的顺序 一、init 函数的作用和执行顺序 作用 init 函数是用于程序执行前做包的初始化的函数&#xff0c;比如初始化包里面的一些变量等等通常在…

HTML5大作业-精致版个人博客空间模板源码

文章目录 1.设计来源1.1 博客主页界面1.2 博主信息界面1.3 我的文章界面1.4 我的相册界面1.5 我的工具界面1.6 我的源码界面1.7 我的日记界面1.8 我的留言板界面1.9 联系博主界面 2.演示效果和结构及源码2.1 效果演示2.2 目录结构2.3 源代码 源码下载 作者&#xff1a;xcLeigh …

从零学Java 包装类

Java 包装类 文章目录 Java 包装类1 什么是包装类?2 为什么需要包装类?3 包装类对应4 包装类的基本操作4.1 装箱4.2 拆箱4.3 自动装箱/拆箱面试题整数缓冲区 5 包装类的类型转换5.1 字符串转基本数据类型5.2 基本数据类型转字符串5.2.1 方法一: 号字符串拼接5.2.2 方法二: 使…

在MS中基于perl脚本实现氢键统计

氢原子与电负性大的原子X以共价键结合&#xff0c;若与电负性大、半径小的原子Y&#xff08;O F N等&#xff09;接近&#xff0c;在X与Y之间以氢为媒介&#xff0c;生成X-H…Y形式的一种特殊的分子间或分子内相互作用&#xff0c;称为氢键。 氢键通常是物质在液态时形成的&…

【华为OD真题 Python】解密犯罪时间

文章目录 题目描述输入描述输出描述示例1输入输出说明示例1输入输出说明示例2输入输出说明备注代码实现题目描述 警察在侦破一个案件时,得到了线人给出的可能

第1章 线性回归

一、基本概念 1、线性模型 2、线性模型可以看成&#xff1a;单层的神经网络 输入维度&#xff1a;d 输出维度&#xff1a;1 每个箭头代表权重 一个输入层&#xff0c;一个输出层 单层神经网络&#xff1a;带权重的层为1&#xff08;将权重和输入层放在一起&#xff09; 3、…

数据库设计——DML

D M L \huge{DML} DML DML&#xff1a;数据库操作语言&#xff0c;用来对数据库中的数据进行增删改查。 增&#xff08;INSERT&#xff09; 使用insert来向数据库中增加数据。 示例&#xff1a; -- DML : 数据操作语言 -- DML : 插入数据 - insert -- 1. 为 tb_emp 表的 us…

Kubernetes二进制部署 单节点

一、环境准备 k8s集群master1&#xff1a;192.168.229.90 kube-apiserver kube-controller-manager kube-scheduler etcd k8s集群node1: 192.168.229.80 kubelet kube-proxy docker flannel k8s集群node2: 192.168.229.70 kubelet kube-proxy docker flannel 至少2C2G 常见的k…

Flutter3.X基础入门教程(2024完整版)

Flutter介绍&#xff1a; Flutter是谷歌公司开发的一款开源、免费的UI框架&#xff0c;可以让我们快速的在Android和iOS上构建高质量App。它最大的特点就是跨平台、以及高性能。 目前Flutter已经支持 iOS、Android、Web、Windows、macOS、Linux的跨平台开发。 教程所讲内容支持…

独立式键盘控制步进电机实验

#include<reg51.h> //包含51单片机寄存器定义的头文件 sbit S1P1^4; //将S1位定义为P1.4引脚 sbit S2P1^5; //将S2位定义为P1.5引脚 sbit S3P1^6; //将S3位定义为P1.6引脚 unsigned char keyval; //储存按键值 unsigned char ID; …

bat批处理文件_命令汇总(2)

文章目录 1、换行2、返回上一级目录cd..3、隐藏指令回显echo off4、开启指令回显echo on5、用关闭echo off指令本身的回显6、echo提示信息 1、换行 cd.. echo. echo. echo. pause2、返回上一级目录cd… 3、隐藏指令回显echo off echo off echo hello1 echo hello2 pause4、开…

tomcat session cookie值设置逻辑

tomcat session cookie 值设置&#xff0c;tomcat jsessionid设置 ##调用request.getSession() Controller RequestMapping("/cookie") public class CookieController {RequestMapping("/tomcatRequest")ResponseBodypublic String tomcatRequest(HttpS…

软件测试|什么是Python构造方法,构造方法如何使用?

构造方法&#xff08;Constructor&#xff09;是面向对象编程中的重要概念&#xff0c;它在创建对象时用于初始化对象的实例变量。在Python中&#xff0c;构造方法是通过特殊的名称__init__()来定义的。本文将介绍Python构造方法的基本概念、语法和用法。 什么是构造方法&…

轻松获取CHATGPT API:免费、无验证、带实例

免费获取和使用ChatGPT API的方法 快速开始&#xff1a;视频教程 章节一&#xff1a;GPT-API-Free开源项目介绍 GPT-API-Free 是一个开源项目&#xff0c;它提供了一个中转API KEY&#xff0c;使用户能够调用多个GPT模型&#xff0c;包括gpt-3.5-turbo、embedding和gpt-4。这…