Tensorflow2.0模型构建与训练

模型构建


class Encoder(layers.Layer):def __init__(self, latent_dim=32, intermediate_dim=64, name="encoder", **kwargs):super(Encoder, self).__init__(name=name, **kwargs)'''w_init = tf.random_normal_initializer()self.w = tf.Variable(initial_value=w_init(shape=(input_dim, units), dtype="float32"),trainable=True)b_init = tf.zeros_initializer()self.b = tf.Variable(initial_value=b_init(shape=(units,), dtype="float32"), trainable=True)'''# 简洁写法self.w = self.add_weight(shape=(input_dim, units), initializer="random_normal", trainable=True)self.b = self.add_weight(shape=(units,), initializer="zeros", trainable=True)# 可具有不可训练权重self.total = tf.Variable(initial_value=tf.zeros((input_dim,)), trainable=False)# 可以延迟权重创建在得知输出形状后:https://www.tensorflow.org/guide/keras/custom_layers_and_modelsdef call(self, inputs):# ...class Decoder(layers.Layer):def __init__(self, original_dim, intermediate_dim=64, name="decoder", **kwargs):super(Decoder, self).__init__(name=name, **kwargs)self.dense_proj = layers.Dense(intermediate_dim, activation="relu")self.dense_output = layers.Dense(original_dim, activation="sigmoid")def call(self, inputs):x = self.dense_proj(inputs)return self.dense_output(x)class VariationalAutoEncoder(keras.Model):def __init__(self,original_dim,intermediate_dim=64,latent_dim=32,name="autoencoder",**kwargs):super(VariationalAutoEncoder, self).__init__(name=name, **kwargs)self.original_dim = original_dimself.encoder = Encoder(latent_dim=latent_dim, intermediate_dim=intermediate_dim)self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)def call(self, inputs):z_mean, z_log_var, z = self.encoder(inputs)reconstructed = self.decoder(z)# Add KL divergence regularization loss.kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)self.add_loss(kl_loss)return reconstructed

模型训练

# 数据集加载
(x_train, _), _ = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(60000, 784).astype("float32") / 255
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)# 模型初始化
model = VariationalAutoEncoder(784, 64, 32)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = tf.keras.losses.MeanSquaredError()# 模型训练
for epoch in range(3):for x_batch_train in train_dataset:with tf.GradientTape() as tape:reconstructed = model(x_batch_train)loss = loss_fn(x_batch_train, reconstructed) # Compute reconstruction lossloss += sum(model.losses)  # Add KLD regularization lossgrads = tape.gradient(loss, model.trainable_weights)optimizer.apply_gradients(zip(grads, model.trainable_weights))print("step %d: mean loss = %.4f" % (epoch, loss.numpy()))# 由于模型是 Model 子类化的结果,它具有内置的训练循环。因此,您也可以用以下方式训练它:
model.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
model.fit(x_train, x_train, epochs=2, batch_size=64)

模型保存和加载

# 模型保存
model.save('path/to/location')# 模型加载
model = keras.models.load_model('path/to/location')# 其他详细内容:https://www.tensorflow.org/guide/keras/save_and_serialize

案例二

# 自定义一个Layer
class Linear(keras.layers.Layer):def __init__(self, units=32, input_dim=32):super(Linear, self).__init__()# ...def call(self, inputs):# ...# 层递归组合
class MLPBlock(keras.layers.Model):def __init__(self):super(MLPBlock, self).__init__()self.linear_1 = Linear(64, 32)self.linear_2 = Linear(32, 16)self.linear_3 = Linear(16, 1)def call(self, inputs):x = self.linear_1(inputs)x = tf.nn.relu(x)x = self.linear_2(x)x = tf.nn.relu(x)return self.linear_3(x)# 自定义损失函数和评估方法 add_loss()/add_metric():https://www.tensorflow.org/guide/keras/custom_layers_and_modelsd_optimizer = keras.optimizers.Adam(learning_rate=0.001)
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()
model = MLPBlock()@tf.function
def train_step(x, y):with tf.GradientTape() as tape:predictions = model(x, training=True)loss_value = loss_fn(y, predictions)grads = tape.gradient(loss_value, model.trainable_weights)d_optimizer.apply_gradients(zip(grads, model.trainable_weights))@tf.function
def test_step(x, y):predictions = model(x, training=False)val_acc_metric.update_state(y, predictions)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/479490.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从ReentrantLock的实现看AQS的原理及应用

前言 Java中的大部分同步类(Lock、Semaphore、ReentrantLock等)都是基于AbstractQueuedSynchronizer(简称为AQS)实现的。AQS是一种提供了原子式管理同步状态、阻塞和唤醒线程功能以及队列模型的简单框架。本文会从应用层逐渐深入到…

论文浅尝 | 利用知识-意识阅读器改进的不完整知识图谱问答方法

论文笔记整理:谭亦鸣,东南大学博士生,研究方向为知识库问答。来源:ACL2019链接:https://www.aclweb.org/anthology/P19-1417/本文提出了一种融合不完整知识图谱与文档集信息的end2end问答模型,旨在利用结构…

机器学习十大经典算法之岭回归和LASSO回归

机器学习十大经典算法之岭回归和LASSO回归(学习笔记整理:https://blog.csdn.net/weixin_43374551/article/details/83688913

MVP模式在Android中的应用(附UML高清大图,使用RecyclerView举例)

传了一张图,图比较大,请移步下载:http://download.csdn.net/detail/u011064099/9266245 在看代码之前,首先简单看一下什么是MVP模式:http://www.cnblogs.com/end/archive/2011/06/02/2068512.html MVP最核心就是将界面…

Facebook大公开:解决NLG模型落地难题!工业界的新一波春天?

文 | 小喂老师编 | 小轶作为NLP领域的“三高”用户(高产、高能、高钞),FaceBook最近(2020年11月)又发表了一篇高水准文章,目前已被COLING-2020接收,号称解决了自然语言生成(NLG&…

Tensorflow2.0 tf.function和AutoGraph模式

一个简单记录,后续慢慢补充。。。。。 一、函数 # 类似一个tensorflow操作 tf.function def add(a, b):return ab # 即使传入数字,函数运算也是python基本运算,发返回值的类型也会变成tensor。print(add(1,2)) # tf.Tensor(3, shape(),…

论文浅尝 | 如何利用外部知识提高预训练模型在阅读理解任务中的性能

论文笔记整理:吴桐桐,东南大学博士生,研究方向为自然语言处理。链接:https://www.aclweb.org/anthology/P19-1226/近年来,机器阅读理解已经逐渐发展为自然语言理解方向的主流任务之一。最近,预训练模型尤其…

美团外卖前端容器化演进实践

背景 提单页的位置 提单页是美团外卖交易链路中非常关键的一个页面。外卖下单的所有入口,包括首页商家列表、订单列表页再来一单、二级频道页的今日推荐等,最终都会进入提单页,在确认各项信息之后,点击提交订单按钮,完…

LeetCode 807. 保持城市天际线

文章目录1. 题目2. 解题1. 题目 在二维数组grid中,grid[i][j]代表位于某处的建筑物的高度。 我们被允许增加任何数量(不同建筑物的数量可能不同)的建筑物的高度。 高度 0 也被认为是建筑物。 最后,从新数组的所有四个方向&#…

机器学习数据集汇总(附下载地址)

大学公开数据集(Stanford)69G大规模无人机(校园)图像数据集【Stanford】http://cvgl.stanford.edu/projects/uav_data/人脸素描数据集【CUHK】http://mmlab.ie.cuhk.edu.hk/archive/facesketch.html自然语言推理(文本蕴含标记)数据集【NYU】https://www.nyu.edu/projects/bowma…

提供一个Android原生的Progress——SwipeToRefreshLayout下拉刷新时的等待动画

先来上个图看看效果: 这里我为什么要单独把这个拿出来呢,因为最近才开始接触Android最新的东西,也就是5.0以上的东西,发现Android提供的SwipeToRefreshLayout是没有上拉加载更多的,在网上找了不少第三方提供加载更多的…

tensorflow2.0 Dataset创建和使用

一、创建Dataset # 可以接收一个numpy.ndarray、tuple、dict dataset tf.data.Dataset.from_tensor_slices(np.arange(10).reshape((5,2))) dataset tf.data.Dataset.from_tensor_slices(([1,2,3,4,5,6],[10,20,30,40,50,60])) dataset tf.data.Dataset.from_tensor_slices…

导师实验室对学生影响有多大?

读博士导师非常重要,比你们想象得还要更重要。一个优秀的导师不仅在科研帮上很多忙,而且让你懂得怎么做科研,更重要的他教会你怎么做一个合格的学者。 跟这种导师工作,你会发现科研其实是一件非常有趣的事情,它带来的乐…

论文浅尝 | 使用孪生BERT网络生成句子的嵌入表示

论文笔记整理:吴杨,浙江大学计算机学院,知识图谱、NLP方向。https://www.ctolib.com/https://arxiv.org/abs/1908.10084动机谷歌的 BERT 预训练模型,已经能够在两个句子的语义相似度匹配等需要输入一对句子的任务上取得了非常好的…

美团点评效果广告实验配置平台的设计与实现

一. 背景 效果广告的主要特点之一是可量化,即广告系统的所有业务指标都是可以计算并通过数字进行展示的。因此,可以通过业务指标来表示广告系统的迭代效果。那如何在全量上线前确认迭代的结果呢?通用的方法是采用AB实验(如图1&…

LeetCode 832. 翻转图像(异或^)

文章目录1. 题目2. 解题1. 题目 给定一个二进制矩阵 A,我们想先水平翻转图像,然后反转图像并返回结果。 水平翻转图片就是将图片的每一行都进行翻转,即逆序。例如,水平翻转 [1, 1, 0] 的结果是 [0, 1, 1]。 反转图片的意思是图…

MVP模式在Android中的应用之图片展示选择功能的框架设计

前言:虽然安卓出现的时间比其它平台软件比较晚,但是在我们的安卓开发中,一样可以使用我们所熟知的设计模式来给它一个合理、完善的结构,这样,才可以使我们在平常开发的时候减少冗余代码的发生,真正的提高效…

2020年8个效率最高的爬虫框架

一些较为高效的Python爬虫框架。分享给大家。 1.Scrapy Scrapy是一个为了爬取网站数据,提取结构性数据而编写的应用框架。 可以应用在包括数据挖掘,信息处理或存储历史数据等一系列的程序中。。用这个框架可以轻松爬下来如亚马逊商品信息之类的数据。 …

抑制过拟合之正则化与Dropout

避免过拟合: 1、增大数据集合 – 使用更多的数据,噪声点比减少(减少数据扰动所造成的影响) 2、减少数据特征 – 减少数据维度,高维空间密度小(减少模型复杂度) 3、正则化 / dropout / 数据增强…

谈谈神经网络的大规模训练优化

文 | 立交桥跳水冠军源 | 知乎大规模神经网络训练一般会涉及到几百个分布式节点同时工作,模型的参数量以及运算量往往很大,作者认为在这个task下当前的工作主要归结为以下三种:对通信本身的优化,神经网络训练通信的优化&#xff0…