模型构建
class Encoder(layers.Layer):def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):super(Encoder, self).__init__(name=name, **kwargs)'''w_init = tf.random_normal_initializer()self.w = tf.Variable(initial_value=w_init(shape=(input_dim, units), dtype="float32"),trainable=True)b_init = tf.zeros_initializer()self.b = tf.Variable(initial_value=b_init(shape=(units,), dtype="float32"), trainable=True)'''# 简洁写法self.w = self.add_weight(shape=(input_dim, units), initializer="random_normal", trainable=True)self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)# 可具有不可训练权重self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)# 可以延迟权重创建在得知输出形状后:https://www.tensorflow.org/guide/keras/custom_layers_and_modelsdef call(self, inputs):# ...class Decoder(layers.Layer):def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):super(Decoder, self).__init__(name=name, **kwargs)self.dense_proj = layers.Dense(intermediate_dim, activation="relu")self.dense_output = layers.Dense(original_dim, activation="sigmoid")def call(self, inputs):x = self.dense_proj(inputs)return self.dense_output(x)class VariationalAutoEncoder(keras.Model):def __init__(self,original_dim,intermediate_dim=64,latent_dim=32,name="autoencoder",**kwargs):super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)self.original_dim = original_dimself.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)def call(self, inputs):z_mean, z_log_var, z = self.encoder(inputs)reconstructed = self.decoder(z)# Add KL divergence regularization loss.kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)self.add_loss(kl_loss)return reconstructed
模型训练
# 数据集加载
(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)# 模型初始化
model = VariationalAutoEncoder(784, 64, 32)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.MeanSquaredError()# 模型训练
for epoch in range(3):for x_batch_train in train_dataset:with tf.GradientTape() as tape:reconstructed = model(x_batch_train)loss = loss_fn(x_batch_train, reconstructed) # Compute reconstruction lossloss += sum(model.losses) # Add KLD regularization lossgrads = tape.gradient(loss, model.trainable_weights)optimizer.apply_gradients(zip(grads, model.trainable_weights))print("step %d: mean loss = %.4f" % (epoch, loss.numpy()))# 由于模型是 Model 子类化的结果,它具有内置的训练循环。因此,您也可以用以下方式训练它:
model.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
model.fit(x_train, x_train, epochs=2, batch_size=64)
模型保存和加载
# 模型保存
model.save('path/to/location')# 模型加载
model = keras.models.load_model('path/to/location')# 其他详细内容:https://www.tensorflow.org/guide/keras/save_and_serialize
案例二
# 自定义一个Layer
class Linear(keras.layers.Layer):def __init__(self, units=32, input_dim=32):super(Linear, self).__init__()# ...def call(self, inputs):# ...# 层递归组合
class MLPBlock(keras.layers.Model):def __init__(self):super(MLPBlock, self).__init__()self.linear_1 = Linear(64, 32)self.linear_2 = Linear(32, 16)self.linear_3 = Linear(16, 1)def call(self, inputs):x = self.linear_1(inputs)x = tf.nn.relu(x)x = self.linear_2(x)x = tf.nn.relu(x)return self.linear_3(x)# 自定义损失函数和评估方法 add_loss()/add_metric():https://www.tensorflow.org/guide/keras/custom_layers_and_modelsd_optimizer = keras.optimizers.Adam(learning_rate=0.001)
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()
model = MLPBlock()@tf.function
def train_step(x, y):with tf.GradientTape() as tape:predictions = model(x, training=True)loss_value = loss_fn(y, predictions)grads = tape.gradient(loss_value, model.trainable_weights)d_optimizer.apply_gradients(zip(grads, model.trainable_weights))@tf.function
def test_step(x, y):predictions = model(x, training=False)val_acc_metric.update_state(y, predictions)