tf.data详解

转自https://www.cnblogs.com/hellcat/p/8569651.html

Dataset有两个重要的类:Dataset和Iterator。

Dataset可以看作是相同类型“元素”的有序列表。在实际使用时,单个“元素”可以是向量,也可以是字符串、图片,甚至是tuple或者dict。

迭代器对象实例化(非Eager模式下):

iterator = dataset.make_one_shot_iterator()

one_element = iterator.get_next()

综合起来效果如下:

import tensorflow as tf
import numpy as npdataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))
iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session() as sess:for i in range(5):print(sess.run(one_element))

输出:1.0  2.0  3.0  4.0  5.0

读取结束异常:

如果一个dataset中元素被读取完了,再尝试sess.run(one_element)的话,就会抛出tf.errors.OutOfRangeError异常,这个行为与使用队列方式读取数据的行为是一致的。

在实际程序中,可以在外界捕捉这个异常以判断数据是否读取完,综合以上三点请参考下面的代码:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session(config=config) as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")

输出:1.0  2.0  3.0  4.0  5.0 end!

高维数据集使用

tf.data.Dataset.from_tensor_slices真正作用是切分传入Tensor的第一个维度,生成相应的dataset,即第一维表明数据集中数据的数量,之后切分batch等操作都以第一维为基础。

dataset = tf.data.Dataset.from_tensor_slices(np.random.uniform(size=(5, 2)))

传入的数值是一个矩阵,它的形状为(5, 2),tf.data.Dataset.from_tensor_slices就会切分它形状上的第一个维度,最后生成的dataset中一个含有5个元素,每个元素的形状是(2, ),即每个元素是矩阵的一行。

dataset = tf.data.Dataset.from_tensor_slices(np.random.uniform(size=(5, 2)))iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session(config=config) as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")
[0.09787406 0.71672957]
[0.25681324 0.81974072]
[0.35186046 0.39362398]
[0.75228199 0.6534702 ]
[0.39695169 0.9341708 ]
end!

字典使用

在实际使用中,我们可能还希望Dataset中的每个元素具有更复杂的形式,如每个元素是一个Python中的元组,或是Python中的词典。例如,在图像识别问题中,一个元素可以是{“image”: image_tensor, “label”: label_tensor}的形式,这样处理起来更方便,

注意,image_tensor、label_tensor和上面的高维向量一致,第一维表示数据集中数据的数量。相较之下,字典中每一个key值可以看做数据的一个属性,value则存储了所有数据的该属性值。

dataset = tf.data.Dataset.from_tensor_slices({"a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),                                      "b": np.random.uniform(size=(5, 2))})iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session(config=config) as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")
{'a': 1.0, 'b': array([0.31721037, 0.33378767])}
{'a': 2.0, 'b': array([0.99221946, 0.65894961])}
{'a': 3.0, 'b': array([0.98405468, 0.11478854])}
{'a': 4.0, 'b': array([0.95311317, 0.57432678])}
{'a': 5.0, 'b': array([0.46067428, 0.19716722])}
end!

 

复杂的tuple组合数据

类似的,可以使用组合的特征进行拼接:

dataset = tf.data.Dataset.from_tensor_slices((np.array([1.0, 2.0, 3.0, 4.0, 5.0]), np.random.uniform(size=(5, 2)))
)iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session(config=config) as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")
(1.0, array([6.55877282e-04, 6.63244735e-01]))
(2.0, array([0.04756927, 0.44968581]))
(3.0, array([0.97841076, 0.06465231]))
(4.0, array([0.46639246, 0.39146086]))
(5.0, array([0.61085016, 0.61609538]))
end!

四、数据集处理方法

Dataset支持一类特殊的操作:Transformation。一个Dataset通过Transformation变成一个新的Dataset。通常我们可以通过Transformation完成数据变换,打乱,组成batch,生成epoch等一系列操作。

常用的Transformation有:

  • map
  • batch
  • shuffle
  • repeat

map

和python中的map类似,map接收一个函数,Dataset中的每个元素都会被当作这个函数的输入,并将函数返回值作为新的Dataset:

dataset = tf.data.Dataset.from_tensor_slices(np.array([1.0, 2.0, 3.0, 4.0, 5.0]))dataset = dataset.map(lambda x: x + 1) # <-----iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session(config=config) as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")

输出:2.0  3.0  4.0  5.0  6.0  end!

注意map函数可以使用num_parallel_calls参数加速(第五部分有介绍)。

batch

batch就是将多个元素组合成batch,如上所说,按照输入元素第一个维度:

dataset = tf.data.Dataset.from_tensor_slices({"a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),                                      "b": np.random.uniform(size=(5, 2))})dataset = dataset.batch(2) # <-----iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session(config=config) as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")
{'a': array([1., 2.]), 'b': array([[0.87466134, 0.21519021], [0.6123372 , 0.95722733]])}
{'a': array([3., 4.]), 'b': array([[0.76964374, 0.22445015], [0.08313089, 0.60531841]])}
{'a': array([5.]), 'b': array([[0.37901654, 0.3955096 ]])}
end!

shuffle

shuffle的功能为打乱dataset中的元素,它有一个参数buffersize,表示打乱时使用的buffer的大小,建议舍的不要太小,一般是1000:

dataset = tf.data.Dataset.from_tensor_slices({"a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),                                      "b": np.random.uniform(size=(5, 2))})dataset = dataset.shuffle(buffer_size=5) # <-----iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session(config=config) as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")
{'a': 3.0, 'b': array([0.82048268, 0.39821839])}
{'a': 4.0, 'b': array([0.42775421, 0.36749283])}
{'a': 1.0, 'b': array([0.09588742, 0.01954797])}
{'a': 2.0, 'b': array([0.10992948, 0.24416772])}
{'a': 5.0, 'b': array([0.15447616, 0.09005545])}
end!

repeat

repeat的功能就是将整个序列重复多次,主要用来处理机器学习中的epoch,假设原先的数据是一个epoch,使用repeat(2)就可以将之变成2个epoch:

dataset = tf.data.Dataset.from_tensor_slices({"a": np.array([1.0, 2.0, 3.0, 4.0, 5.0]),                                      "b": np.random.uniform(size=(5, 2))})dataset = dataset.repeat(2) # <-----iterator = dataset.make_one_shot_iterator()
one_element = iterator.get_next()
with tf.Session(config=config) as sess:try:while True:print(sess.run(one_element))except tf.errors.OutOfRangeError:print("end!")
{'a': 1.0, 'b': array([0.85180201, 0.1703507 ])}
{'a': 2.0, 'b': array([0.37874819, 0.81303628])}
{'a': 3.0, 'b': array([0.99560094, 0.56446562])}
{'a': 4.0, 'b': array([0.86341794, 0.69984075])}
{'a': 5.0, 'b': array([0.85026424, 0.74761098])}
{'a': 1.0, 'b': array([0.85180201, 0.1703507 ])}
{'a': 2.0, 'b': array([0.37874819, 0.81303628])}
{'a': 3.0, 'b': array([0.99560094, 0.56446562])}
{'a': 4.0, 'b': array([0.86341794, 0.69984075])}
{'a': 5.0, 'b': array([0.85026424, 0.74761098])}
end!

注意,如果直接调用repeat()的话,生成的序列就会无限重复下去,没有结束,因此也不会抛出tf.errors.OutOfRangeError异常。

更多的Dataset创建方法

除了tf.data.Dataset.from_tensor_slices外,目前Dataset API还提供了另外三种创建Dataset的方式:

  • tf.data.TextLineDataset():这个函数的输入是一个文件的列表,输出是一个dataset。dataset中的每一个元素就对应了文件中的一行。可以使用这个函数来读入CSV文件。
  • tf.data.FixedLengthRecordDataset():这个函数的输入是一个文件的列表和一个record_bytes,之后dataset的每一个元素就是文件中固定字节数record_bytes的内容。通常用来读取以二进制形式保存的文件,如CIFAR10数据集就是这种形式。
  • tf.data.TFRecordDataset():顾名思义,这个函数是用来读TFRecord文件的,dataset中的每一个元素就是一个TFExample。

更多的Iterator创建方法

在非Eager模式下,最简单的创建Iterator的方法就是通过dataset.make_one_shot_iterator()来创建一个one shot iterator。

除了这种one shot iterator外,还有三个更复杂的Iterator,即:

  • initializable iterator
  • reinitializable iterator
  • feedable iterator

initializable iterator方法要在使用前通过sess.run()来初始化,使用initializable iterator,可以将placeholder代入Iterator中,实现更为灵活的数据载入,实际上占位符引入了dataset对象创建中,我们可以通过feed来控制数据集合的实际情况。

limit = tf.placeholder(dtype=tf.int32, shape=[])dataset = tf.data.Dataset.from_tensor_slices(tf.range(start=0, limit=limit))iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()with tf.Session() as sess:sess.run(iterator.initializer, feed_dict={limit: 10})for i in range(10):value = sess.run(next_element)print(value)assert i == value

输出:0 1 2 3 4 5 6 7 8 9

initializable iterator还有一个功能:读入较大的数组。

在使用tf.data.Dataset.from_tensor_slices(array)时,实际上发生的事情是将array作为一个tf.constants保存到了计算图中。当array很大时,会导致计算图变得很大,给传输、保存带来不便。这时,我们可以用一个placeholder取代这里的array,并使用initializable iterator,只在需要时将array传进去,这样就可以避免把大数组保存在图里,示例代码为(来自官方例程):

# 从硬盘中读入两个Numpy数组
with np.load("/var/data/training_data.npy") as data:features = data["features"]labels = data["labels"]features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,labels_placeholder: labels})

可见,在上面程序中,feed也遵循着类似字典一样的规则,创建两个占位符(keys),给data_holder去feed数据文件,给label_holder去feed标签文件。

reinitializable iterator和feedable iterator相比initializable iterator更复杂,也更加少用,如果想要了解它们的功能,可以参阅官方介绍,这里就不再赘述了。

总结

在非Eager模式下,Dataset中读出的一个元素一般对应一个batch的Tensor,我们可以使用这个Tensor在计算图中构建模型。

在Eager模式下,Dataset建立Iterator的方式有所不同,此时通过读出的数据就是含有值的Tensor,方便调试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/491641.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

nodejs设置x-xss-protection解决xss问题

在Node.js中设置X-XSS-Protection可以通过使用helmet库来完成。 首先&#xff0c;确保已经安装了helmet库。如果没有安装&#xff0c;可以运行以下命令进行安装&#xff1a; npm install helmet --save 然后&#xff0c;在你的Node.js应用程序中引入并配置helmet库&#xff…

OpenCV与图像处理学习三——线段、矩形、圆、椭圆、多边形的绘制以及文字的添加

OpenCV与图像处理学习三——线段、矩形、圆、椭圆、多边形的绘制以及文字的添加一、OpenCV中的绘图函数1.1 线段绘制1.2 矩形绘制1.3 圆绘制1.4 椭圆的绘制1.5 多边形绘制1.6 添加文字上两次笔记主要知识点回顾&#xff1a; 数字图像基本概念图像的读取、显示与保存图像直方图…

第二百一十九天 how can I 坚持

今天好冷&#xff0c;白天在家待了一天&#xff0c;晚上&#xff0c;老贾生日&#xff0c;生日快乐&#xff0c;去海底捞吃了个火锅&#xff0c;没感觉呢。 今天还发现了个好游戏&#xff0c;纪念碑谷&#xff0c;挺新颖&#xff0c;就是难度有点大了。 好累。睡觉&#xff0c;…

AI英雄 | 论人工智能与自由意志,请看尤瓦尔与李飞飞的这场“激辩”

来源&#xff1a;Towards Data Science尤瓦尔赫拉利和李飞飞在斯坦福大学展开了一场别开生面的对话&#xff0c;他们所提出的问题已经远远超出了我们可以解答的范围。《连线》杂志主编尼古拉斯•汤普森在座无虚席的纪念礼堂主持了这场90分钟的谈话。赫拉利&#xff08;Harari&a…

OpenCV与图像处理学习四——图像几何变换:平移、缩放、旋转、仿射变换与透视变换

OpenCV与图像处理学习四——图像几何变换&#xff1a;平移、缩放、旋转、仿射变换与透视变换二、图像的几何变换2.1 图像平移2.2 图像缩放&#xff08;上采样与下采样&#xff09;2.3 图像旋转2.4 仿射变换2.5 透视变化2.6 几何变化小结续上次的笔记&#xff1a;OpenCV与图像处…

python打乱顺序的洗牌函数

numpy.random.shuffle&#xff08;x&#xff09; x&#xff1a;序列或者数组 对于多维数组&#xff0c;只对第一维进行洗牌&#xff0c;子数组的顺序改变了&#xff0c;但是它们的内容保持不变。 >>> arr np.arange(10) >>> np.random.shuffle(arr) >…

课后作业和动手动脑

一&#xff0c;运行TestInherits.java 通过super调用基类构造方法&#xff0c;必是子类构造方法中的第一个语句。 二.为什么子类的构造方法在运行之前&#xff0c;必须调用父类的构造方法&#xff1f;能不能反过来&#xff1f;为什么不能反过来&#xff1f; 构造函数的主要作用…

OpenCV与图像处理学习五——图像滤波与增强:线性、非线性滤波、直方图均衡化与Gamma变换

OpenCV与图像处理学习五——图像滤波与增强&#xff1a;线性、非线性滤波、直方图均衡化与Gamma变换三、图像滤波与增强3.1 线性滤波3.1.1 方框滤波3.1.2 均值滤波3.1.3 高斯滤波3.1.4 一般卷积滤波3.2 非线性滤波3.2.1 中值滤波3.2.2 双边滤波3.3 图像直方图均衡化3.3.1 单通道…

张钹院士:人工智能技术已进入第三代

来源&#xff1a;经济观察报近日&#xff0c;中科院院士、清华大学人工智能研究院院长张钹教授接受记者采访时认为&#xff0c;目前基于深度学习的人工智能在技术上已经触及天花板。从长远来看&#xff0c;必须得走人类智能这条路&#xff0c;最终要发展人机协同&#xff0c;人…

软件工程作业

典型用户1 名字老陈性别&#xff0c;年龄男&#xff0c;40岁职业教师收入两万/年知识层次和能力本科&#xff0c;熟练计算机操作生活、工作情况教书&#xff0c;辅导孩子完成作业动机&#xff0c;目的&#xff0c;困难希望节省辅导孩子的时间&#xff0c;用于自己的业务工作用户…

numpy.ndarray索引/切片方式

注意&#xff1a;获得多维数组的前三个子数组不能用array[0,1,2]&#xff0c;应该用 array[0:3]&#xff0c;如下例子&#xff1a; a np.random.random([85, 7794, 64]) b a[0:3] print(np.shape(b)) # (3, 7794, 64)

OpenCV与图像处理学习六——图像形态学操作:腐蚀、膨胀、开、闭运算、形态学梯度、顶帽和黑帽

OpenCV与图像处理学习六——图像形态学操作&#xff1a;腐蚀、膨胀、开、闭运算、形态学梯度、顶帽和黑帽四、图像形态学操作4.1 腐蚀和膨胀4.1.1 图像腐蚀4.1.2 图像膨胀4.2 开运算与闭运算4.2.1 开运算4.2.2 闭运算4.3 形态学梯度&#xff08;Gradient&#xff09;4.4顶帽和黑…

python 求复数的模

abs()即可求绝对值&#xff0c;也可以求复数的模 import numpy as np a 1-2j print(abs(a)) #2.23606797749979 print(np.sqrt(5)) #2.23606797749979

RS学习笔记(二)

1、OSPF&#xff1a;路由条目1万多条。收敛时间1s&#xff1b;ISIS:路由条目可以达2万多条&#xff0c;收敛时间50ms()。ISIS在链路层上面&#xff0c;不依赖IP这层&#xff0c;这样给了它很多可能。比如IPv4, IPv6路由的混合承载&#xff0c;给运营商网络平滑迁移提供了便捷。…

解密硅谷大骗局

来源&#xff1a;硅谷封面企鹅号、腾讯科技在许多为人称道的科技创业故事中&#xff0c;总不乏硅谷的名字。从英特尔、IBM到微软、苹果&#xff0c;从雅虎、谷歌到Twitter、Facebook&#xff0c;这里诞生了很多知名科技企业。对于全球的创业者来说&#xff0c;硅谷就是梦想中的…

OpenCV与图像处理学习七——传统图像分割之阈值法(固定阈值、自适应阈值、大津阈值)

OpenCV与图像处理学习七——传统图像分割之阈值法&#xff08;固定阈值、自适应阈值、大津阈值&#xff09;一、固定阈值图像分割1.1 直方图双峰法1.2 OpenCV中的固定阈值分割二、自动阈值图像分割2.1 自适应阈值法2.2 迭代法阈值分割2.3 Otsu大津阈值法前面的笔记介绍了一些Op…

linux删除文件夹和文件

rm -rf 删除文件夹实例&#xff1a; rm -rf /var/log/httpd 将会强制删除httpd这个文件夹 删除文件使用实例&#xff1a; rm -f /var/log/httpd/access.log 将会强制删除/var/log/httpd/access.log这个文件

Foxmail 绑定企业邮箱

转载于:https://www.cnblogs.com/wu628/p/4955017.html

边缘计算将吞掉云计算!

来源&#xff1a;CSDN以下为译文&#xff1a;边缘计算已成为物联网的重要趋势。高德纳咨询公司认为边缘计算是2019年的一项技术趋势。各个物联网公司发现在将数据发送到云之前&#xff0c;通过边缘计算处理数据有很大的好处。最近Micron/Forrester的调查证实了这一趋势&#xf…

OpenCV与图像处理学习八——图像边缘提取(Canny检测代码)

OpenCV与图像处理学习八——图像边缘提取&#xff08;Canny检测代码&#xff09;一、图像梯度1.1 梯度1.2 图像梯度二、梯度图与梯度算子2.1模板卷积2.2 梯度图2.3 梯度算子2.3.1 Roberts交叉算子2.3.2 Prewitt算子2.3.3 Sobel算子三、Canny边缘检测算法&#xff08;代码实现&a…