TensorFlow(2)-训练数据载入

tensorflow 训练数据载入

  • 1. tf.data.Dataset
  • 2. dataset 创建数据集的方式
    • 2.1 tf.data.Dataset.from_tensor_slices()
    • 2.2 tf.data.TextLineDataset()
    • 2.3 tf.data.FixedLengthRecordDataset()
    • 2.4 tf.data.TFRecordDataset()
  • 3. dateset 迭代操作iterator
    • 3.1 make_one_shot_iterator()
    • 3.2 make_initializable_iterator()
    • 3.3 reinitializable iterator()
    • 3.4 feedable iterator()
  • 4. dataset的map、batch、shuffle、repeat操作
  • 5. 非eager/eager 模式
    • 5.1 非eager模式demo
    • 5.2 eager模式demo

1. tf.data.Dataset

参考Google官方给出的Dataset API中的类图,Dataset 务于数据读取,构建输入数据的pipeline。在这里插入图片描述
Dataset可以看作是相同类型“元素”的有序列表,可使用Iterator迭代获取Dataset中的元素。

2. dataset 创建数据集的方式

2.1 tf.data.Dataset.from_tensor_slices()

从tensor中创建数据集,数据集元素以tensor第一维度为划分。

import tensorflow as tf
import numpy as np
# 切分传入Tensor的第一个维度,生成相应的dataset。
dataset1 = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) 
# 如果传入字典,那切分结果就是字典按值切分,元素型如{"a":[1],"b":[x,x]}
dataset2 = tf.data.Dataset.from_tensor_slices({"a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),                                       "b": np.random.uniform(size=(5, 2))}
)

2.2 tf.data.TextLineDataset()

读取文件数据创建数据集,数据集元素为文件的每一行

2.3 tf.data.FixedLengthRecordDataset()

从一个文件列表和record_bytes中创建数据集,数据集元素是文件中固定字节数record_bytes的内容。

2.4 tf.data.TFRecordDataset()

读TFRecord文件创建数据集,数据集中的一条数据是一个TFExample。

dataset = tf.data.TFRecordDataset(filenames = [tfrecord_file_name]) # [tfrecord_file_name] tfrecord 文件列表

frecord 文件中的特征一般都经过tf.train.Example 序列化,在使用前需要先解码tf.train.Example.FromString()

raw_example = next(iter(dataset))
parsed = tf.train.Example.FromString(raw_example.numpy())

3. dateset 迭代操作iterator

iterator是从Dataset对象中创建出来的,用于迭代取数据集中的元素。

3.1 make_one_shot_iterator()

dataset.make_one_shot_iterator()–只能从头到尾读取一次dataset。如果一个dataset中元素被读取完了再sess.run()的话,会抛出tf.errors.OutOfRangeError异常。因此可以在外界捕捉这个异常以判断数据是否读取完。

import tensorflow as tf
import numpy as np
# 切分传入Tensor的第一个维度,生成相应的dataset。如果传入字典,那切分结果就是字典按值切分
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) 
iterator = dataset.make_one_shot_iterator()    # 只能从头到尾读取一次
one_element = iterator.get_next()              # 从iterator里取出一个元素。
# 处于非Eager模式,所以one_element只是一个Tensor,并不是一个实际的值。调用sess.run(one_element)后,才能真正地取出一个值。
with tf.Session() as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")   

3.2 make_initializable_iterator()

dataset.make_initializable_iterator()–支持placeholder dataset 的迭代操作,这可以方便通过参数快速定义新的Iterator。

 # limit相当于一个参数,它规定了Dataset中数的上限, 使用make_initializable_iterator
limit = tf.placeholder(dtype=tf.int32, shape=[])
dataset = tf.data.Dataset.from_tensor_slices(tf.range(start=0, limit=limit))
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()  
with tf.Session() as sess:sess.run(iterator.initializer, feed_dict={limit: 10})for i in range(10):value = sess.run(next_element)assert i == value

sess.run(next_element) 每run一次, 数据迭代器指针就会往下移动一个。TF官网学习(9)–使用iterator注意事项

如果在dataset的构建时,一次性读入了所有的数据,会导致计算图变得很大,给传输、保存带来不便。make_initializable_iterator()支持placeholder 操作,仅在需要传输数据时再取数据。

# 从硬盘中读入两个Numpy数组
with np.load("/var/data/training_data.npy") as data:features = data["features"]labels = data["labels"]features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels})

3.3 reinitializable iterator()

dataset.reinitializable iterator() --待补

3.4 feedable iterator()

dataset.feedable iterator()–待补

4. dataset的map、batch、shuffle、repeat操作

map–接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset。

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
dataset = dataset.map(lambda x: x + 1) # 2.0, 3.0, 4.0, 5.0, 6.0

batch–将多个元素组合成一个batch

dataset = dataset.batch(16)    # 将数据集划分为batch size为16的小批次

shuffle– 打乱dataset中的元素,参数buffersize。打乱的实现机理:从buffer_size 大小的部buffer中随机抽取元素,组成打乱后的数据集。buffer中被抽走的元素由原数据集中的后续元素补位置。 重复‘抽取-补充’这个过程,直至buffer为空。
会在batch之间打乱数据–疑问多tfrecord 文件是一次性构建数据集还是一条一条的构建

buffer_size 的大小详见tf.data.Dataset.shuffle(buffer_size)中buffer_size的理解

dataset = dataset.shuffle(buffer_size=10000)

repeat– 将整个序列重复多次,用来处理机器学习中的epoch,假设原始数据是一个epoch,使用repeat(5)就可以将之变成5个epoch

dataset = dataset.repeat(5)

5. 非eager/eager 模式

5.1 非eager模式demo

在非Eager模式下,Dataset中读出的一个元素一般对应一个batch的Tensor,我们可以使用这个Tensor在计算图中构建模型。

import tensorflow as tf
import numpy as np
# 切分传入Tensor的第一个维度,生成相应的dataset。如果传入字典,那切分结果就是字典按值切分
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0])) 
iterator = dataset.make_one_shot_iterator()    # 只能从头到尾读取一次
one_element = iterator.get_next()              # 从iterator里取出一个元素。
# 处于非Eager模式,所以one_element只是一个Tensor,并不是一个实际的值。调用sess.run(one_element)后,才能真正地取出一个值。
with tf.Session() as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")   

5.2 eager模式demo

在Eager模式下,Dataset建立Iterator的方式有所不同,此时通过读出的数据就是含有值的Tensor,方便调试。

import tensorflow.contrib.eager as tfe
tfe.enable_eager_execution()
dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
for one_element in tfe.Iterator(dataset):print(one_element)             # 可直接读取数据

参考文献:TensorFlow全新的数据读取方式:Dataset API入门教程


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/444643.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

leetcode14. 最长公共前缀

编写一个函数来查找字符串数组中的最长公共前缀。 如果不存在公共前缀,返回空字符串 ""。 示例 1: 输入: ["flower","flow","flight"] 输出: "fl" 示例 2: 输入: ["dog","racecar",&quo…

Android在子线程里使用Toast报错Can't toast on a thread that has not called Looper.prepare()

在接android SDK的时候有时候为了方便debug调试查看,通过Toast输出相关信息, 实际上这个是在子线程中输出的,在logcat里查看有如下报错java.lang.RuntimeException: Cant toast on a thread that has not called Looper.prepare()。 解决办法…

虚拟机安装windows2012和虚拟机安装国产系统deepin

虚拟机安装windows2012和虚拟机安装国产系统deepin 一.安装windows20121.安装VMWare虚拟机2.1.注意点一:VMWare虚拟网卡2.2.注意点二:配置虚拟网络编辑器3.安装配置Windows Server 2012 R2 二.虚拟机安装deepin1.deepin官网下载ios镜像2.deepin下载合适的…

leetcode876 链表中间的结点

给定一个带有头结点 head 的非空单链表,返回链表的中间结点。 如果有两个中间结点,则返回第二个中间结点。 示例 1: 输入:[1,2,3,4,5] 输出:此列表中的结点 3 (序列化形式:[3,4,5]) 返回的结点值为 3 。 …

PlayFab(二)如何通过Demo应用来进一步熟悉Playfab

有时候刚开始接触新的平台会两眼一麻黑,不过这个文章希望能给读者一些启示,Playfab默认会给开发者提供一个应用,这里我暂且叫他”我的游戏“; 我通过官网提供的DEMO测试地址: https://www.vanguardoutrider.com/#/ 来为该应用配置服务器。 如果你是第一次进入这个页面想为…

leetcode718 最长重复子数组

给两个整数数组 A 和 B &#xff0c;返回两个数组中公共的、长度最长的子数组的长度。 示例 1: 输入: A: [1,2,3,2,1] B: [3,2,1,4,7] 输出: 3 解释: 长度最长的公共子数组是 [3, 2, 1]。 说明: 1 < len(A), len(B) < 1000 0 < A[i], B[i] < 100 思路&#xf…

leetcode108 将有序数组转换为二叉搜索树

将一个按照升序排列的有序数组&#xff0c;转换为一棵高度平衡二叉搜索树。 本题中&#xff0c;一个高度平衡二叉树是指一个二叉树每个节点 的左右两个子树的高度差的绝对值不超过 1。 示例: 给定有序数组: [-10,-3,0,5,9], 一个可能的答案是&#xff1a;[0,-3,9,-10,null,…

MachineLearning(12)- RNN-LSTM-tf.nn.rnn_cell

RNN-LSTM1.RNN2.LSTM3. tensorflow 中的RNN-LSTM3.1 tf.nn.rnn_cell.BasicRNNCell()3.2 tf.nn.rnn_cell.BasicLSTMCell()3.3 tf.nn.dynamic_rnn()--多步执行循环神经网络1.RNN RNN-Recurrent Neural Network-循环神经网络 RNN用来处理序列数据。多层感知机MLP层间节点全联接&…

AWS的VPC使用经验(二)

上文说了如何创建自定义VPC网络的EC2实例&#xff0c;这节说如何在多个VPC之间创建对等连接。 这里分别填写自己的VPC和对方的VPC的ID信息&#xff0c;然后在对方的VPC里就能看到有连接请求&#xff0c;在对方的连接请求里选择 “操作”->接受。 到这里已经快要收尾了&…

ubuntu nginx配置负载均衡篇(一)

Nginx 代理服务的配置说明 1、设置 404 页面导向地址 error_page 404 https://www.runnob.com; #错误页 proxy_intercept_errors on; #如果被代理服务器返回的状态码为400或者大于400,设置的error_page配置起作用。默认为off。 2、如果我们的代理只允许接受get,post请求…

坦克大战

效果 map.js var map4 [[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],[0,2,2,0,0,2,2,0,0,0,2,2,0,0,2,2,0,0,2,2,0,2,2,0],[0,2,2,0,0,2,2,0,0,0,2,2,3,3,2,2…

windows安装 MySQL5.7服务端

1,安装https://cdn.mysql.com//Downloads/MySQL-5.7/mysql-5.7.30-winx64.zip mysql安装包。 2. 自己配置my.ini [client] port=3306 [mysql] default-character-set=utf8[mysqld] character-set-server=utf8 port=3306 basedir="D:/mysql-5.7.30-winx64/" datad…

screen命令使用说明

有些程序写的很操蛋&#xff0c;比如放到后台执行&#xff0c;但后边还需要再切回前台来重新执行&#xff0c;这个时候我们选择screen工具&#xff1a; screen -d -m -S LoginServer[6001] ./run_login_server.sh 具体的screen命令包含哪些参数&#xff0c;可以参考scree…

看这玩意复习你还会挂科?《数据结构篇》

一&#xff0e;绪论 1.何谓程序设计&#xff1f; 程序 算法 数据结构 2.数据结构的定义 是相互之间存在一种或多种特定关系的数据元素的集合 3.数据、数据元素、数据对象的概念 数据&#xff08;data&#xff09;&#xff1a;对客观事物的符号表示&#xff0c;含义很广&am…

苹果订阅服务器端开发

有时候我们想做一个苹果订阅功能,需要在苹果开发者后台添加订阅商品productid/ 订阅需要增加一个参数: password: 秘钥, 就可以了, 但是官方文档说秘钥仅仅用在自动续订上面 大家叫后台加个验证,如果苹果验证返回21004的话(21004 你提供的共享密钥和账户的共享密钥不一致)…

nginx代理配置根据ip地址来转发到不同的地址端口

最近我们在开发的某SLG游戏的某业务要做如下场景: 要求在全球各个区域访问离他最近的服务器节点:用户通过访问域名A,在服务器端解析用户来源,根据ip地址来源来转发到对应的最近的服务器节点。 由于我们之前的业务一些设计很难调整,所以我将通过代码层面来进行做转发处理,…

做了nginx反向代理之后常见问题汇总

1.客户端无缘无故的主动断开和服务器的连接&#xff0c;如图&#xff1a; 服务器端收到了FIN包&#xff0c;查看了nginx 的配置有个选项&#xff1a;proxy_timeout选项 设置为30s。 注意&#xff1a;“proxy_timeout”这个参数可以写在stream节点下&#xff0c;所有server都生效…

在GoogPlay上发布的包Facebook登录失败提示签名问题

在googplay提审的包发布后,发现Facebook登录功能异常,提示如下: 意识到可能是hashkey出问题了,但是之前测试都是好的,原来是上传包到googlePlay后有个二次签名,会修改hashkey的,所以需要在Facebook后台添加下重新签名的hashkey。 基本签名信息在Google Play 上都能查看…

根据当前docker容器生成镜像提交到远端服务器

docker commit 4d6883e5fa21 gaoke/koa_ios docker push gaoke/koa_ios 然后在远端可看到

2019我做成的事情

1、ccpc河北金 这个省赛可能是退役赛了&#xff0c;因为下半年写项目&#xff0c;明年实习&#xff0c;没机会参加省赛、区预赛了。 2019.5大二的时候参加的&#xff0c;记得敲了个区间dp&#xff0c;大模拟&#xff0c;队友数学没搞出来&#xff0c;有一个搜索也是胆子不够大…