TensorFlow2实战-系列教程12:RNN文本分类4

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

8、压缩版本网络模型

class Model(tf.keras.Model):def __init__(self, params):super().__init__()self.embedding = tf.Variable(np.load('./vocab/word.npy'),dtype=tf.float32,name='pretrained_embedding',trainable=False,)self.drop1 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop2 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop3 = tf.keras.layers.Dropout(params['dropout_rate'])self.rnn1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn3 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.drop_fc = tf.keras.layers.Dropout(params['dropout_rate'])self.fc = tf.keras.layers.Dense(2*params['rnn_units'], tf.nn.elu)self.out_linear = tf.keras.layers.Dense(2)def call(self, inputs, training=False):if inputs.dtype != tf.int32:inputs = tf.cast(inputs, tf.int32)batch_sz = tf.shape(inputs)[0]rnn_units = 2*params['rnn_units']x = tf.nn.embedding_lookup(self.embedding, inputs)x = tf.reshape(x, (batch_sz*10*10, 10, 50))x = self.drop1(x, training=training)x = self.rnn1(x)x = tf.reduce_max(x, 1)x = tf.reshape(x, (batch_sz*10, 10, rnn_units))x = self.drop2(x, training=training)x = self.rnn2(x)x = tf.reduce_max(x, 1)x = tf.reshape(x, (batch_sz, 10, rnn_units))x = self.drop3(x, training=training)x = self.rnn3(x)x = tf.reduce_max(x, 1)x = self.drop_fc(x, training=training)x = self.fc(x)x = self.out_linear(x)return x

这是另外一个版本的自定义网络,网络定义部分是一样的,只是在前向传播的过程中,对每一个rnn的输出做了特征压缩,每次只取10个特征中值最大的,特征数量因此变为了1/10,所以这个版本的训练速度回更快

9、模型训练参数

params = {'vocab_path': './vocab/word.txt','train_path': './data/train.txt','test_path': './data/test.txt','num_samples': 25000,'num_labels': 2,'batch_size': 32,'max_len': 1000,'rnn_units': 200,'dropout_rate': 0.2,'clip_norm': 10.,'num_patience': 3,'lr': 3e-4,
}

语料表路径、训练数据路径、验证数据路径
句子数量、标签输出值个数、batch_size
句子最大长度、rnn_units隐层神经元个数、dropout比例
梯度截断(避免梯度剧烈变化,控制过拟合)、多少次损失没下降停止训练、学习率

def is_descending(history: list):history = history[-(params['num_patience']+1):]for i in range(1, len(history)):if history[i-1] <= history[i]:return Falsereturn True  

根据损失值、准确率来判断有没有提升效果,如果num_patience次数都没提升,就停止训练

word2idx = {}
with open(params['vocab_path'],encoding='utf-8') as f:for i, line in enumerate(f):line = line.rstrip()word2idx[line] = i
params['word2idx'] = word2idx
params['vocab_size'] = len(word2idx) + 1

读进语料表进行id映射

model = Model(params)
model.build(input_shape=(None, None))
decay_lr = tf.optimizers.schedules.ExponentialDecay(params['lr'], 1000, 0.95)#相当于加了一个指数衰减函数
optim = tf.optimizers.Adam(params['lr'])
global_step = 0
history_acc = []
best_acc = .0t0 = time.time()
logger = logging.getLogger('tensorflow')
logger.setLevel(logging.INFO)
  1. 构建模型
  2. 设置输入的大小,或者fit时候也能自动找到
  3. 学习率衰减
  4. 优化器
  5. 迭代次数计数变量
  6. 保存历史准确率
  7. 最佳准确率
  8. 获取当前时间
  9. 打印日志的设置参数

10、模型训练

while True:# 训练模型for texts, labels in dataset(is_training=True, params=params):with tf.GradientTape() as tape:logits = model(texts, training=True)loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits)loss = tf.reduce_mean(loss)optim.lr.assign(decay_lr(global_step))grads = tape.gradient(loss, model.trainable_variables)grads, _ = tf.clip_by_global_norm(grads, params['clip_norm']) optim.apply_gradients(zip(grads, model.trainable_variables))if global_step % 50 == 0:logger.info("Step {} | Loss: {:.4f} | Spent: {:.1f} secs | LR: {:.6f}".format(global_step, loss.numpy().item(), time.time()-t0, optim.lr.numpy().item()))t0 = time.time()global_step += 1# 验证集效果m = tf.keras.metrics.Accuracy()for texts, labels in dataset(is_training=False, params=params):logits = model(texts, training=False)y_pred = tf.argmax(logits, axis=-1)m.update_state(y_true=labels, y_pred=y_pred)acc = m.result().numpy()logger.info("Evaluation: Testing Accuracy: {:.3f}".format(acc))history_acc.append(acc)if acc > best_acc:best_acc = acclogger.info("Best Accuracy: {:.3f}".format(best_acc))if len(history_acc) > params['num_patience'] and is_descending(history_acc):logger.info("Testing Accuracy not improved over {} epochs, Early Stop".format(params['num_patience']))break
  1. 按照batch取数据
  2. 梯度带,记录所有在上下文中的操作,并且通过调用.gradient()获得任何上下文中计算得出的张量的梯度
  3. 当前输入经过模型的输出结果
  4. 计算损失
  5. 计算平均损失
  6. 根据自定义的学习率更新策略 更新学习率
  7. 根据梯度带计算梯度值
  8. 将梯度限制一下,有的时候回更新太猛,防止过拟合
  9. 更新梯度
  10. 每隔50次打印一下当前训练的结果
  11. 使用当前训练的网络对验证集的数据进行测试
  12. 3次没有提升准确率就停止训练
  13. 如果准确率超过阈值后停止训练

部分训练过程日志:
Reading ./data/train.txt
INFO:tensorflow:Step 0 | Loss: 0.6997 | Spent: 7.5 secs | LR: 0.000300

INFO:tensorflow:Evaluation: Testing Accuracy: 0.872
INFO:tensorflow:Best Accuracy: 0.879
Reading ./data/train.txt
INFO:tensorflow:Step 10200 | Loss: 0.2801 | Spent: 640.2 secs | LR: 0.000178
INFO:tensorflow:Step 10250 | Loss: 0.1747 | Spent: 77.9 secs | LR: 0.000177
INFO:tensorflow:Step 10300 | Loss: 0.2829 | Spent: 77.7 secs | LR: 0.000177

INFO:tensorflow:Step 10900 | Loss: 0.2204 | Spent: 77.7 secs | LR: 0.000172
Reading ./data/test.txt
INFO:tensorflow:Evaluation: Testing Accuracy: 0.863
INFO:tensorflow:Best Accuracy: 0.879
INFO:tensorflow:Testing Accuracy not improved over 3 epochs, Early Stop

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/655539.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

选矿二厂电气系统维管工作项目招标公告

选矿二厂电气系统维管工作项目招标公告 (招标编号&#xff1a;JDCRY-ZB2024-05) 项目所在地区&#xff1a;河南省,洛阳市,汝阳县 一、招标条件 本选矿二厂电气系统维管工作项目已由项目审批/核准/备案机关批准&#xff0c;项目资金来源为自筹资金438万元&#xff0c;招标人…

QT C++语言格式化输出wchar_t * 中文乱码

在 Qt 中&#xff0c;如果你使用 wprintf 或 wcout 进行宽字符输出&#xff0c;而且你的字符串包含中文字符&#xff0c;确保使用 Unicode 字符集&#xff0c;并将字符串编码为 UTF-16。此外&#xff0c;确保你的输出流和终端都能正确地处理宽字符。 下面是一个简单的例子&…

[GN] 设计模式—— 创建型模式

文章目录 创建型模式单例模式 -- 确保对象唯一性饿汉式懒汉式优缺点使用场景 简单工厂模式例子&#xff1a;优化优缺点适用场景 工厂方法模式--多态工厂的实现例子优缺点适用场景 创建型模式 单例模式 – 确保对象唯一性 用TaskManager类。通过以下三步进行重构 为了确保Ta…

[足式机器人]Part3 机构运动学与动力学分析与建模 Ch01-2 完整定常系统——杆组RRR

本文仅供学习使用,总结很多本现有讲述运动学或动力学书籍后的总结,从矢量的角度进行分析,方法比较传统,但更易理解,并且现有的看似抽象方法,两者本质上并无不同。 2024年底本人学位论文发表后方可摘抄 若有帮助请引用 本文参考: 《空间机构的分析与综合(上册)》-张启先…

redis-4 集群

应用场景 为什么需要redis集群&#xff1f; 当主备复制场景&#xff0c;无法满足主机的单点故障时&#xff0c;需要引入集群配置。 一般数据库要处理的读请求远大于写请求 &#xff0c;针对这种情况&#xff0c;我们优化数据库可以采用读写分离的策略。我们可以部 署一台主服…

【网络基础】mac地址

以太网 以太网" 不是一种具体的网络, 而是一种技术标准; 既包含了数据链路层的内容, 也包含了一些物理层的内容. 例如: 规定了网络拓扑结构, 访问控制方式, 传输速率等; 例如以太网中的网线必须使用双绞线; 传输速率有10M, 100M, 1000M等; 以太网是当前应用最广泛的局域…

docker 的常用命令

随着容器技术的快速发展&#xff0c;Docker 已经成为了一种流行的容器化工具。它能够使开发者将应用程序及其依赖项打包到一个可移植的容器中&#xff0c;并在不同的环境中快速部署。以下是 Docker 的常用命令&#xff0c;帮助你更好地管理和使用容器。 安装 Docker 在安装 D…

Win10无法完成更新正在撤销更改的解决方法

在Win10电脑操作过程中&#xff0c;用户看到了“无法完成更新正在撤销更改”的错误提示&#xff0c;这样系统就不能成功完成更新&#xff0c;不知道如何操作才能解决此问题&#xff1f;以下小编分享最简单的解决方法&#xff0c;帮助大家轻松解决Win10电脑无法完成更新正在撤销…

Python实现时间序列分析AR定阶自回归模型(ar_select_order算法)项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档视频讲解&#xff09;&#xff0c;如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 时间序列分析中&#xff0c;AR定阶自回归模型&#xff08;AR order selection&#xff09;是指确定自回…

Vue学习之使用开发工具创建项目、gitcode管理项目

Vue学习之使用开发工具创建项目、gitcode管理项目 翻阅与学习了vue的开发工具&#xff0c;通过对比最终采用HBuilderX作为开发工具&#xff0c;以下章节对HBuilder安装与基础使用介绍 1. HBuilder 下载 从HbuildX官网&#xff08;http://www.dcloud.io/hbuilderx.html&#…

SpringBoot 实现定时任务

在项目我们会有很多需要在某一特定时刻自动触发某一时间的需求&#xff0c;例如我们提交订单但未支付的超过一定时间后需要自动取消订单。 定时任务实现的几种方式&#xff1a; Timer&#xff1a;java自带的java.util.Timer类&#xff0c;使用这种方式允许你调度一个java.util…

「HDLBits题解」Finite State Machines

本专栏的目的是分享可以通过HDLBits仿真的Verilog代码 以提供参考 各位可同时参考我的代码和官方题解代码 或许会有所收益 题目链接&#xff1a;Fsm1 - HDLBits module top_module(input clk,input areset, // Asynchronous reset to state Binput in,output out);// para…

制作OpenSSH 9.6 for openEuler 22.03 LTS的rpm升级包

OpenSSH作为操作系统底层管理平台软件&#xff0c;需要保持更新以免遭受安全攻击&#xff0c;编译生成rpm包是生产环境中批量升级的最佳途径。本文在国产openEuler 22.03 LTS系统上完成OpenSSH 9.6的编译工作。 一、编译环境 1、准备环境 基于vmware workstation发布的x86虚…

知识笔记(一百)———什么是okhttp?

OkHttp简介&#xff1a; OkHttp 是一个开源的、高效的 HTTP 客户端库&#xff0c;由 Square 公司开发和维护。它为 Android 和 Java 应用程序提供了简单、强大、灵活的 HTTP 请求和响应的处理方式。OkHttp 的设计目标是使网络请求变得更加简单、快速、高效&#xff0c;并且支持…

gitee仓库项目迁移到gitlab仓库

背景 之前一直使用gitee代码仓库提交代码&#xff0c;现在需要将gitee仓库中的代码迁移到gitlab中&#xff0c;并保留原有的提交记录。 前提 配置好了本地git&#xff0c;并本地与gitlab仓库已连接。 我这里使用 ssh方式拉去代码&#xff0c;因此需要配置ssh密钥 步骤 也可以直…

剑指offer面试题13 在O(1)时间删除链表结点

考察点 链表知识点 链表的删除正常情况下需要O(n)的时间&#xff0c;因为需要找到待删除结点的前置结点题目 分析 我们都知道链表删除往往需要O(n)遍历链表&#xff0c;找到待删除结点的前置结点&#xff0c;把前置结点的next指针指向待删除结点的后置结点。现在要求O(1)时间…

23种设计模式-结构型模式

1.代理模式 在软件开发中,由于一些原因,客户端不想或不能直接访问一个对象,此时可以通过一个称为"代理"的第三者来实现间接访问.该方案对应的设计模式被称为代理模式. 代理模式(Proxy Design Pattern ) 原始定义是&#xff1a;让你能够提供对象的替代品或其占位符。…

复杂链表的复制

作者简介&#xff1a;大家好&#xff0c;我是smart哥&#xff0c;前中兴通讯、美团架构师&#xff0c;现某互联网公司CTO 联系qq&#xff1a;184480602&#xff0c;加我进群&#xff0c;大家一起学习&#xff0c;一起进步&#xff0c;一起对抗互联网寒冬 学习必须往深处挖&…

在Ubuntu Linux上安装Chrome浏览器的最佳方法

我们可以使用GUI和命令行方法在Ubuntu Linux上安装Google Chrome浏览器&#xff0c;但是&#xff0c;终端是配置Chrome浏览器的最佳方式。在这里&#xff0c;我们讨论如何使用它。 有数十种浏览器&#xff0c;甚至Linux系统如Ubuntu也带有自己的默认浏览器Mozilla Firefox。然…

电磁波的波长与频率是什么关系?

摘要: 电磁波的波长&#xff08;λ&#xff09;与频率&#xff08;f&#xff09;之间的关系可以通过以下公式来表示&#xff1a; f c/λ cλf 其中&#xff1a; c 是光速&#xff0c;即电磁波在真空中的传播速度&#xff0c;约为 3 x 10⁸ 米/秒&#xff08;m/s&#xff09;λ…