一、创建Dataset
# 可以接收一个numpy.ndarray、tuple、dict
dataset = tf.data.Dataset.from_tensor_slices(np.arange(10).reshape((5,2)))
dataset = tf.data.Dataset.from_tensor_slices(([1,2,3,4,5,6],[10,20,30,40,50,60]))
dataset = tf.data.Dataset.from_tensor_slices({"x":[1,2,3,4,5,6],"y":[10,20,30,40,50,60]})dataset = dataset.batch(3)
for batch in dataset:print(batch)
分别输出:
tf.Tensor(
[[0 1][2 3][4 5]], shape=(3, 2), dtype=int32)
tf.Tensor(
[[6 7][8 9]], shape=(2, 2), dtype=int32)
#------------------------------------------------------------------------(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3])>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([10, 20, 30])>)
(<tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6])>, <tf.Tensor: shape=(3,), dtype=int32, numpy=array([40, 50, 60])>)
#------------------------------------------------------------------------
{'x': <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3])>, 'y': <tf.Tensor: shape=(3,), dtype=int32, numpy=array([10, 20, 30])>}
{'x': <tf.Tensor: shape=(3,), dtype=int32, numpy=array([4, 5, 6])>, 'y': <tf.Tensor: shape=(3,), dtype=int32, numpy=array([40, 50, 60])>}
二、数据预处理
1、map
def func(x, y):x = x/1y = y/10return x, ytrain_data = [1,2,3,4,5,6]
train_label = [10,20,30,40,50,60]
dataset = tf.data.Dataset.from_tensor_slices((train_data, train_label))
dataset = dataset .map(func)for x, y in dataset :print(x,y)
输出
tf.Tensor(1.0, shape=(), dtype=float64) tf.Tensor(1.0, shape=(), dtype=float64)
tf.Tensor(2.0, shape=(), dtype=float64) tf.Tensor(2.0, shape=(), dtype=float64)
tf.Tensor(3.0, shape=(), dtype=float64) tf.Tensor(3.0, shape=(), dtype=float64)
tf.Tensor(4.0, shape=(), dtype=float64) tf.Tensor(4.0, shape=(), dtype=float64)
tf.Tensor(5.0, shape=(), dtype=float64) tf.Tensor(5.0, shape=(), dtype=float64)
tf.Tensor(6.0, shape=(), dtype=float64) tf.Tensor(6.0, shape=(), dtype=float64)
dataset = dataset .map(map_func=func, num_parallel_calls=tf.data.experimental.AUTOTUNE)
num_parallel_calls:将数据加载与变换过程并行到多个CPU线程上
tf.data.experimental.AUTOTUNE:自动设置为最大的可用线程数
2、shuffle 和 batch
dataset = tf.data.Dataset.from_tensor_slices([1,2,3,4,5])
dataset = dataset.batch(2).shuffle(2) # 对batch进行shuffle,batch内部不shuffle
dataset = dataset.shuffle(2).batch(2) # 先将数据进行shuffle,再进行batch划分
for d in dataset:print(d)print("----------------------------")
结果1
tf.Tensor([1 2], shape=(2,), dtype=int32)
----------------------------
tf.Tensor([5], shape=(1,), dtype=int32)
----------------------------
tf.Tensor([3 4], shape=(2,), dtype=int32)
----------------------------
结果2
tf.Tensor([1 3], shape=(2,), dtype=int32)
----------------------------
tf.Tensor([2 4], shape=(2,), dtype=int32)
----------------------------
tf.Tensor([5], shape=(1,), dtype=int32)
----------------------------
3、repeat
dataset = tf.data.Dataset.from_tensor_slices([1,2])
dataset = dataset.repeat(2)
for d in dataset:print(d)
结果
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
三、并行化策略
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
Dataset.prefetch() :让数据集对象 Dataset 在训练时预取出若干个元素,使得在 GPU 训练的同时 CPU 可以准备数据,从而提升训练流程的效率。
四、模型使用数据集
Keras 支持使用 tf.data.Dataset 直接作为输入。当调用 tf.keras.Model 的 fit() 和 evaluate() 方法时,可以将参数中的输入数据 x 指定为一个元素格式为 (输入数据, 标签数据) 的 Dataset ,并忽略掉参数中的标签数据 y 。
常规的 Keras 训练方式:
model.fit(x=train_data, y=train_label, epochs=num_epochs, batch_size=batch_size)
使用 tf.data.Dataset 训练方式:
model.fit(dataset, epochs=num_epochs)