TensorFlow2实战-系列教程11:RNN文本分类3

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

6、构建训练数据

  • 所有的输入样本必须都是相同shape(文本长度,词向量维度等)
  • tf.data.Dataset.from_tensor_slices(tensor):将tensor沿其第一个维度切片,返回一个含有N个样本的数据集,这样做的问题就是需要将整个数据集整体传入,然后切片建立数据集类对象,比较占内存。
  • tf.data.Dataset.from_generator(data_generator,output_data_type,output_data_shape):从一个生成器中不断读取样本
def data_generator(f_path, params):with open(f_path,encoding='utf-8') as f:print('Reading', f_path)for line in f:line = line.rstrip()label, text = line.split('\t')text = text.split(' ')x = [params['word2idx'].get(w, len(word2idx)) for w in text]#得到当前词所对应的IDif len(x) >= params['max_len']:#截断操作x = x[:params['max_len']]else:x += [0] * (params['max_len'] - len(x))#补齐操作y = int(label)yield x, y
  1. 定义一个生成器函数,传进来读数据的路径、和一些有限制的参数,params 在是一个字典,它包含了最大序列长度(max_len)、词到索引的映射(word2idx)等关键信息
  2. 打开文件
  3. 打印文件路径
  4. 遍历每行数据
  5. 获取标签和文本
  6. 文本按照空格分离出单词
  7. 获取当前句子的所有词对应的索引,for w in text取出这个句子的每一个单词,[params[‘word2idx’]取出params中对应的word2idx字典,.get(w, len(word2idx))从word2idx字典中取出该单词对应的索引,如果有这个索引则返回这个索引,如果没有则返回len(word2idx)作为索引,这个索引表示unknow
  8. 如果当前句子大于预设的最大句子长度
  9. 进行截断操作
  10. 如果小于
  11. 补充0
  12. 标签从str转换为int类型
  13. yield 关键字:用于从一个函数返回一个生成器(generator)。与 return 不同,yield 不会退出函数,而是将函数暂时挂起,保存当前的状态,当生成器再次被调用时,函数会从上次 yield 的地方继续执行,使用 yield 的函数可以在处理大数据集时节省内存,因为它允许逐个生成和处理数据,而不是一次性加载整个数据集到内存中

也就是说yield 会从上一次取得地方再接着去取数据,而return却不会

def dataset(is_training, params):_shapes = ([params['max_len']], ())_types = (tf.int32, tf.int32)if is_training:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['train_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.shuffle(params['num_samples'])ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)else:ds = tf.data.Dataset.from_generator(lambda: data_generator(params['test_path'], params),output_shapes = _shapes,output_types = _types,)ds = ds.batch(params['batch_size'])ds = ds.prefetch(tf.data.experimental.AUTOTUNE)return ds
  1. 定义一个制作数据集的函数,is_training表示是否是训练,这个函数在验证和测试也会使用,训练的时候设置为True,验证和测试为False
  2. 当前shape值
  3. 1
  4. 是否在训练,如果是:
  5. 构建一个Dataset
  6. 传进我们刚刚定义的生成器函数,并且传进实际的路径和配置参数
  7. 输出的shape值
  8. 输出的类型
  9. 指定shuffle
  10. 指定 batch_size
  11. 设置缓存序列,根据可用的CPU动态设置并行调用的数量,说白了就是加速
  12. 如果不是在训练,则:
  13. 验证和测试不同的就是路径不同,以及没有shuffle操作,其他都一样
  14. 最后把做好的Datasets返回回去

7、自定义网络模型

在这里插入图片描述
一条文本变成一组向量/矩阵的基本流程:

  1. 拿到一个英文句子
  2. 通过查语料表将句子变成一组索引
  3. 通过词嵌入表结合索引将每个单词都变成一组向量,一条句子就变成了一个矩阵,这就是特征了

在这里插入图片描述
在这里插入图片描述
BiLSTM即双向LSTM,就是在原本的LSTM增加了一个从后往前走的模块,这样前向和反向两个方向都各自生成了一组特征,把两个特征拼接起来得到一组新的特征,得到翻倍的特征。其他前面和后续的处理操作都是一样的。

class Model(tf.keras.Model):def __init__(self, params):super().__init__()self.embedding = tf.Variable(np.load('./vocab/word.npy'), dtype=tf.float32, name='pretrained_embedding', trainable=False,)self.drop1 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop2 = tf.keras.layers.Dropout(params['dropout_rate'])self.drop3 = tf.keras.layers.Dropout(params['dropout_rate'])self.rnn1 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn2 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=True))self.rnn3 = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(params['rnn_units'], return_sequences=False))self.drop_fc = tf.keras.layers.Dropout(params['dropout_rate'])self.fc = tf.keras.layers.Dense(2*params['rnn_units'], tf.nn.elu)self.out_linear = tf.keras.layers.Dense(2)  def call(self, inputs, training=False):if inputs.dtype != tf.int32:inputs = tf.cast(inputs, tf.int32)batch_sz = tf.shape(inputs)[0]rnn_units = 2*params['rnn_units']x = tf.nn.embedding_lookup(self.embedding, inputs)        x = self.drop1(x, training=training)x = self.rnn1(x)x = self.drop2(x, training=training)x = self.rnn2(x)x = self.drop3(x, training=training)x = self.rnn3(x)x = self.drop_fc(x, training=training)x = self.fc(x)x = self.out_linear(x)return x
  1. 自定义一个模型,继承tf.keras.Model模块
  2. 初始化函数
  3. 初始化
  4. 词嵌入,把之前保存好的词嵌入文件向量读进来
  5. 定义一层dropout1
  6. 定义一层dropout2
  7. 定义一层dropout3
  8. 定义一个rnn1,rnn_units表示得到多少维的特征,return_sequences表示是返回一个序列还是最后一个输出
  9. 定义一个rnn2,最后一层的rnn肯定只需要最后一个输出,前后两个rnn的堆叠肯定需要返回一个序列
  10. 定义一个rnn3 ,tf.keras.layers.LSTM()直接就可以定义一个LSTM,在外面再封装一层API:tf.keras.layers.Bidirectional就实现了双向LSTM
  11. 定义全连接层的dropout
  12. 定义一个全连接层,因为是双向的,这里就需要把参数乘以2
  13. 定义最后输出的全连接层,只需要得到是正例还是负例,所以是2
  14. 定义前向传播函数,传进来一个batch的数据和是否是在训练
  15. 如果输入数据不是tf.int32类型
  16. 转换成tf.int32类型
  17. 取出batch_size
  18. 设置LSTM神经元个数,双向乘以2
  19. 使用 TensorFlow 的 embedding_lookup 函数将输入的整数索引转换为词向量
  20. 数据通过第1个 Dropout 层
  21. 数据通过第1个rnn
  22. 数据通过第2个 Dropout 层
  23. 数据通过第2个rnn
  24. 数据通过第3个 Dropout 层
  25. 数据通过第3个rnn
  26. 经过全连接层对应的Dropout
  27. 数据通过一个全连接层
  28. 最后,数据通过一个输出层
  29. 返回最终的模型输出

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/657225.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

视网膜长尾数据

视网膜长尾数据 问题:视网膜疾病分类,解法:深度学习模型问题:数据复杂性处理,解法:多任务框架(同时处理多种疾病)和少量样本学习(提高对罕见疾病的识别)问题&…

【Delphi】IDE 工具栏错乱恢复

由于经常会在4K和2K显示器上切换Delphi开发环境(IDE),导致IDE工具栏错乱,咋样设置都无法恢复,后来看到红鱼儿的博客,说是通过操作注册表的方法,能解决,试了一下,果真好用,非常感谢分…

Java面试架构篇【一览众山小】

文章目录 🚡 简介☀️ Spring🐥 体系结构🐠 生命周期 🍁 SpringMVC🌰 执行流程 🌜 SpringBoot🌍 核心组件🎍 自动装配🎑 3.0升级 🔅 spring Cloud Alibaba&am…

获取依赖aar包的两种方式-在android studio里引入 如:glide

背景:我需要获取aar依赖到内网开发,内网几乎代表没网。 一、 如何需要获取依赖aar包 方式一:在官方的github中下载,耗时不建议 要从开发者网站、GitHub 存储库或其他来源获取 ‘com.github.bumptech.glide:glide:4.12.0’ AAR 包&#xff…

vue使用antv-x6 绘制流程图DAG图(二)

代码&#xff1a; <template><div class"graph-wrap" click.stop"hideFn"><Toobar :graph"graph"></Toobar><!-- 小地图 --><div id"minimap" class"mini-map-container"></div>…

基于SSM的游泳会员管理系统

末尾获取源码作者介绍&#xff1a;大家好&#xff0c;我是墨韵&#xff0c;本人4年开发经验&#xff0c;专注定制项目开发 更多项目&#xff1a;CSDN主页YAML 学如逆水行舟&#xff0c;不进则退。学习如赶路&#xff0c;不能慢一步。 目录 一、项目简介 二、开发技术与环境配…

【云原生】consul自动注册,实现负载均衡器与节点服务应用解耦,批量管理容器

目录 一、consul解决了什么问题&#xff1f; 二、consul的模式 三、consul的工作原理 四、实操consul连接负载均衡与容器 步骤一&#xff1a;完成consul的部署 步骤二&#xff1a;完成gliderlabs/registrator:latest镜像的拉取&#xff0c;并完成启动 步骤三&#xff1a;…

重写Sylar基于协程的服务器(1、日志模块的架构)

重写Sylar基于协程的服务器&#xff08;1、日志模块的架构&#xff09; 重写Sylar基于协程的服务器系列&#xff1a; 重写Sylar基于协程的服务器&#xff08;0、搭建开发环境以及项目框架 || 下载编译简化版Sylar&#xff09; 重写Sylar基于协程的服务器&#xff08;1、日志模…

15EG在PL端使用外部时钟驱动led灯

创建PL端的vivado工程在 pl_only_led中已经介绍过了&#xff0c;这里直接从创建好设计文件开始 打开设计文件修改里面代码&#xff0c;代码我也提供了&#xff0c;在工程文件夹下的file文件夹中 15eg这块板子有俩个外部晶振&#xff0c;分别时200M和74.25 M我们可以在原理图中找…

16- OpenCV:轮廓的发现和轮廓绘制、凸包

目录 一、轮廓发现 1、轮廓发现(find contour in your image) 的含义 2、相关的API 以及代码演示 二、凸包 1、凸包&#xff08;Convex Hull&#xff09;的含义 2、Graham扫描算法- 概念介绍 3、cv::convexHull 以及代码演示 三、轮廓周围绘制矩形和圆形框 一、轮廓发现…

【图文详解】阿里云服务器放行高防IP加入安全组

打开阿里云的云服务器配置面板&#xff0c;在要操作实例的操作列找到更多 > 网络和安全组 > 安全组配置。 对已有安全组配置规则&#xff0c;或者直接添加安全组规则。 根据需要放通高防IP在内的IP段相应协议类型的端口访问。

QEMU - e1000全虚拟化前端与TAP/TUN后端流程简析

目录 1. Host -> Guest 2.Guest ->Host 3. 如何修改以支持TUN设备的后端&#xff1f; 4. 相关 QEMU 源码 5. 实验 1. Host -> Guest 2.Guest ->Host 3. 如何修改以支持TUN设备的后端&#xff1f; 1. 简单通过后端网卡名字来判断是TUN还是TAP。 2. 需要前端全…

Adobe Camera Raw forMac/win:掌控原始之美的秘密武器

Adobe Camera Raw&#xff0c;这款由Adobe开发的插件&#xff0c;已经成为摄影师和设计师们的必备工具。对于那些追求完美、渴望探索更多创意可能性的专业人士来说&#xff0c;它不仅仅是一个插件&#xff0c;更是一个能够释放无尽创造力的平台。 在数字摄影时代&#xff0c;R…

使用Win32API实现贪吃蛇小游戏

目录 C语言贪吃蛇项目 基本功能 需要的基础内容 Win32API 介绍 控制台程序部分指令 设置控制台窗口的长宽 设置控制台的名字 控制台在屏幕上的坐标位置结构体COORD 检索指定标准设备的句柄&#xff08;标准输入、标准输出或标准错误&#xff09; 光标信息结构体类型CONSOLE_CUR…

洛谷P8599 [蓝桥杯 2013 省 B] 带分数

[蓝桥杯 2013 省 B] 带分数 题目描述 100 100 100 可以表示为带分数的形式&#xff1a; 100 3 69258 714 100 3 \frac{69258}{714} 100371469258​。 还可以表示为&#xff1a; 100 82 3546 197 100 82 \frac{3546}{197} 100821973546​。 注意特征&#xff1a;带分…

条款32:确定你的public继承塑模出 is-a 关系

如果你编写类D(“派生类”)public继承类B(“基类”)&#xff0c;就是在告诉C编译器(以及代码的读者)每个类型D的对象都是类型B的对象&#xff0c;但反之则不然。 class Person {...}; class Student: public Person {...}; void eat(const Person& p); // 素有的Person都…

云计算HCIE备考经验分享

大家好&#xff0c;我是来自深圳信息职业技术学院22级鲲鹏3-1班的刘同学&#xff0c;在2023年9月19日成功通过了华为云计算HCIE认证&#xff0c;并且取得了A的成绩。下面把我的考证经验分享给大家。 转专业进鲲鹏班考HCIE 大一上学期的时候&#xff0c;在上Linux课程的时候&…

ES Serverless让日志检索更加便捷

前言 在项目中,或者开发过程中,出现bug或者其他线上问题,开发人员可以通过查看日志记录来定位问题。通过日志定位 bug 是一种常见的软件开发和运维技巧,只有观察日志才能追踪到具体代码。在软件开发过程中,开发人员会在代码中添加日志记录,以记录程序的运行情况和异常信…

发现了一款宝藏学习项目,包含了Web全栈的知识体系,JS、Vue、React知识就靠它了!

前言 在当今互联网时代&#xff0c;一切以页面、UI为主要呈现方式&#xff0c;web全栈开发工程师的需求越来越大。 然而&#xff0c;市场上大多数工程师只会使用api而不了解其原理&#xff0c;这种情况使得他们变得可替代。 因此&#xff0c;成为一个高级开发工程师需要具备…

用React给XXL-JOB开发一个新皮肤(四):实现用户管理模块

目录 一. 简述二. 模块规划 2.1. 页面规划2.2. 模型实体定义 三. 模块实现 3.1. 用户分页搜索3.2. Modal 配置3.3. 创建用户表单3.4. 修改用户表单3.5. 删除 四. 结束语 一. 简述 上一篇文章我们实现登录页面和管理页面的 Layout 骨架&#xff0c;并对接登录和登出接口。这篇…