TensorFlow 2.0 - Keras Pipeline、自定义Layer、Loss、Metric

文章目录

    • 1. Keras Sequential / Functional API
    • 2. 自定义 layer
    • 3. 自定义 loss
    • 4. 自定义 评估方法

学习于:简单粗暴 TensorFlow 2

1. Keras Sequential / Functional API

  • tf.keras.models.Sequential([layers...]),但是它不能表示更复杂的模型
mymodel = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(100, activation='relu'),tf.keras.layers.Dense(10),tf.keras.layers.Softmax()
])
  • Functional API 可以表示更复杂的模型
inp = tf.keras.Input(shape=(28, 28, 1))
x = tf.keras.layers.Flatten()(inp)
x = tf.keras.layers.Dense(units=100, activation=tf.nn.relu)(x)
x = tf.keras.layers.Dense(units=10)(x)
out = tf.keras.layers.Softmax()(x)
mymodel = tf.keras.Model(inputs=inp, outputs=out)
# 配置模型:优化器,损失函数,评估方法
mymodel.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate),loss=tf.keras.losses.sparse_categorical_crossentropy,metrics=[tf.keras.metrics.sparse_categorical_accuracy]
)# 训练:X,Y,batch_size, epochs
mymodel.fit(data_loader.train_data, data_loader.train_label,batch_size=batch_size,epochs=num_epochs)
# 测试
res = mymodel.evaluate(data_loader.test_data, data_loader.test_label)
print(res) # [loss, acc]

2. 自定义 layer

  • 继承 tf.keras.layers.Layer,重写 __init__buildcall 三个方法
import tensorflow as tf# 实现一个 线性layer
class myLayer(tf.keras.layers.Layer):def __init__(self, units):super().__init__()self.units = unitsdef build(self, input_shape):  # input_shape 是一个tensor# input_shape 是第一次运行 call() 时参数inputs的形状# 第一次使用该层的时候,调用buildself.w = self.add_weight(name='w',shape=[input_shape[-1], self.units],initializer=tf.zeros_initializer())self.b = self.add_weight(name='b',shape=[self.units],initializer=tf.zeros_initializer())def call(self, inputs):y_pred = tf.matmul(inputs, self.w) + self.breturn y_pred
  • 使用自定义的 layer
class LinearModel(tf.keras.Model):def __init__(self):super().__init__()self.dense = myLayer(units=1) # 使用def call(self, inputs):output = self.dense(inputs)return output
  • 简单的线性回归
import numpy as np# 原始数据
X_raw = np.array([0.0, 1., 2., 3., 4.], dtype=np.float32)
y_raw = np.array([0.01, 2., 4., 5.98, 8.], dtype=np.float32)X = np.expand_dims(X_raw, axis=-1)
y = np.expand_dims(y_raw, axis=-1)# 转成张量
X = tf.constant(X)
y = tf.constant(y)model = LinearModel()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),loss=tf.keras.losses.MeanSquaredError()
)
model.fit(X, y, batch_size=6, epochs=10000)
print(model.variables)X_test = tf.constant([[5.1], [6.1]])
res = model.predict(X_test)
print(res)

输出:

[<tf.Variable 'linear_model/my_layer/w:0' shape=(1, 1) dtype=float32, numpy=array([[1.9959974]], dtype=float32)>, 
<tf.Variable 'linear_model/my_layer/b:0' shape=(1,) dtype=float32, numpy=array([0.00600523], dtype=float32)>]
[[10.185592][12.181589]]

3. 自定义 loss

  • 继承 tf.keras.losses.Loss,重写 call 方法
class myError(tf.keras.losses.Loss):def call(self, y_true, y_pred):return tf.reduce_mean(tf.square(y_true - y_pred))model = LinearModel()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),loss=myError() # 使用 自定义的loss
)

4. 自定义 评估方法

  • 继承 tf.keras.metrics.Metric ,重写 __init__update_stateresult 三个方法
class myMetric(tf.keras.metrics.Metric):def __init__(self):super().__init__()self.total = self.add_weight(name='total',dtype=tf.int32,initializer=tf.zeros_initializer())self.count = self.add_weight(name='count',dtype=tf.int32,initializer=tf.zeros_initializer())def update_state(self, y_true, y_pred, sample_weight=None):values = tf.cast(tf.abs(y_true - y_pred) < 0.1, tf.int32)# 这里简单的判断误差 < 0.1, 算 trueself.total.assign_add(tf.shape(y_true)[0])self.count.assign_add(tf.reduce_sum(values))def result(self):return self.count / self.totalmodel = LinearModel()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),loss=myError(),metrics=[myMetric()] # 调用自定义的 metric
)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/473044.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python去重复元素_Python实现去除列表中重复元素的方法总结【7种方法】

这里首先给出来我很早之前写的一篇博客&#xff0c;Python实现去除列表中重复元素的方法小结【4种方法】&#xff0c;感兴趣的话可以去看看&#xff0c;今天是在实践过程中又积累了一些方法&#xff0c;这里一并总结放在这里。 由于内容很简单&#xff0c;就不再过多说明了&…

oracle取差值集合

Oracle Minus关键字 SQL中的MINUS关键字 SQL中有一个MINUS关键字&#xff0c;它运用在两个SQL语句上&#xff0c;它先找出第一条SQL语句所产生的结果&#xff0c;然后看这些结果有没有在第二个SQL语句的结果中。如果有的话&#xff0c;那这一笔记录就被去除&#xff0c;而不会…

TensorFlow 2.0 - Checkpoint 保存变量、TensorBoard 训练可视化

文章目录1. Checkpoint 保存变量2. TensorBoard 训练过程可视化学习于&#xff1a;简单粗暴 TensorFlow 2 1. Checkpoint 保存变量 tf.train.Checkpoint 可以保存 tf.keras.optimizer 、 tf.Variable 、 tf.keras.Layer 、 tf.keras.Model path "./checkp.ckpt" …

coturn的负载均衡特性_高性能负载均衡

单服务器无论如何优化&#xff0c;无论采用多好的硬件&#xff0c;总会有一个性能天花板&#xff0c;当单服务器的性能无法满足业务需求时&#xff0c;就需要设计高性能集群来提升系统整体的处理性能。高性能集群的本质很简单&#xff0c;通过增加更多的服务器来提升系统整体的…

LintCode MySQL 1928. 网课上课情况分析 I

文章目录1. 题目2. 解题1. 题目 online_class_situation 表展示了一些同学上网课的行为活动。 每行数据记录了一名同学在退出网课之前&#xff0c;当天使用同一台设备登录课程后听过的课程数目&#xff08;可能是0个&#xff09;。 写一条 SQL 语句&#xff0c;查询每位同学第…

poj1284:欧拉函数+原根

何为原根&#xff1f;由费马小定理可知 如果a于p互质 则有a^(p-1)≡1(mod p)对于任意的a是不是一定要到p-1次幂才会出现上述情况呢&#xff1f;显然不是&#xff0c;当第一次出现a^k≡1(mod p)时&#xff0c; 记为ep&#xff08;a&#xff09;k 当k(p-1)时&#xff0c;称a是p的…

python输入十个数输出最大值_python输入十个数如何输出最大值

python输入十个数输出最大值的方法&#xff1a;1、如果是整数的话&#xff0c;使用函数【a, b, c map(int, input().split())】&#xff1b;2、使用函数【Xinput().split()】。 相关免费学习推荐&#xff1a;python视频教程 python输入十个数输出最大值的方法&#xff1a; 第一…

SQL Server 2005远程连接连不上的解决办法收藏 Microsoft给的方法

SQL Server 2005远程连接连不上的解决办法收藏 Microsoft给的方法http://support.microsoft.com/kb/914277 是可以的,但我怕以后还会遇到这问题,干脆我也写到blog中来. 我的情况是别人怎么连也连不上我本地的DB,我装了2005的sp2也不行,后来发现关了防火墙就可以了,但我总不能什…

LintCode MySQL 1921. 从不充值的玩家(where not in)

文章目录1. 题目2. 解题1. 题目 描述 A game database contains two tables, player table and recharge table. Write a SQL query to find all players who never recharge. 样例 https://www.lintcode.com/problem/players-who-never-recharge/description 2. 解题 -- …

古风一棵桃花树简笔画_广东有个现实版的“桃花源”,藏于秘境之中,最适合情侣来度假!...

上学时&#xff0c;初闻“芳草鲜美&#xff0c;落英缤纷”&#xff0c;并没有多大感触。直到后来长大离家&#xff0c;每每为生活奔波劳累时&#xff0c;为工作琐碎忧心费神时&#xff0c;才骤然明了当年五柳先生所描绘的“桃花源”该是多少人的脑中所想、心中所向……原以为这…

关于NIOS ii烧写的几种方式(转)

源&#xff1a;http://www.cnblogs.com/bingoo/p/3450850.html 1. 方法一&#xff1a;.sof和.elf全部保存在FPGA内&#xff0c;程序加载和运行也是在FPGA内部。 把FPGA的配置文件.sof通过JTAG方式下载(其实是在线运行)进入FPGA本身&#xff0c;此时在NIOS II的界面中&#xff…

clob和blob是不是可以进行模糊查询_你知道什么是 MySQL 的模糊查询?

作者 | luanhz责编 | 郭芮本文对MySQL中几种常用的模糊搜索方式进行了介绍&#xff0c;包括LIKE通配符、RegExp正则匹配、内置字符串函数以及全文索引&#xff0c;最后给出了性能对比。引言MySQL根据不同的应用场景&#xff0c;支持的模糊搜索方式有多种&#xff0c;例如应用最…

一个长文档里,包括封面、不同的章节,如果我想封面不设置页眉页脚,每个章节的页眉都不同,请问应该如何设置页眉页脚?

问&#xff1a;在一个长文档里&#xff0c;包括封面、不同的章节&#xff0c;如果我想封面不设置页眉页脚&#xff0c;每个章节的页眉都不同&#xff0c;请问应该如何设置页眉页脚&#xff1f; 答&#xff1a;如果只需要首页不同&#xff0c;可选择“文件”菜单下的“页面设置”…

LintCode 1917. 切割剩余金属

文章目录1. 题目2. 解题1. 题目 描述 金属棒工厂的厂长拥有 n 根多余的金属棒。 当地的一个承包商提出&#xff0c;只要所有的棒材具有相同的长度&#xff08;用 saleLength 表示棒材的长度&#xff09;&#xff0c;就将金属棒工厂的剩余棒材全部购买。 厂长可以通过将每根棒…

太原理工电子信焦工程_电气工程及其自动化专业毕业后做什么工作?近几年就业和收入怎样...

本文内容为各大高校往届大学生真实的现身说法内容&#xff0c;但因为是往届&#xff0c;每年该专业的大学情况可能会发生略微变化&#xff0c;所以部分内容较今年&#xff0c;明年甚至以后几年&#xff0c;实际情况可能会略有不同但是对于本专业的相关信息还是非常有参考价值的…

js定时器和linux命令locate

js定时器如果带有参数&#xff0c;应该采用如下方式 setTimeout(function(){function(param)},1000); 匿名函数的方法。 linux locate基于数据库的查找方法。转载于:https://www.cnblogs.com/birdskyws/p/3974556.html

编程竞赛控制系统(PC2)使用说明书

编程竞赛控制系统(PC2)使用说明书 1. 系统简介 PC2是由美国加利福尼亚大学为国际大学生编程竞赛开发研制的竞赛控制系统。目前主要用于ACM/ICPC等国际编程竞赛。PC2最新的版本是8.5d&#xff0c;系统采用JAVA语言编写&#xff0c;可以运行在任何支持JAVA的平台(windows…

怎么查看linux日志里请求量最高的url访问最多的_实用的Linux高级命令,开发运维都要懂!...

在运维的坑里摸爬滚打好几年了&#xff0c;我还记得我刚开始的时候&#xff0c;我只会使用一些简单的命令&#xff0c;写脚本的时候&#xff0c;也是要多简单有多简单&#xff0c;所以有时候写出来的脚本又长又臭。像一些高级点的命令&#xff0c;比如说 Xargs 命令、管道命令、…

ggplot2箱式图两两比较_第十九章_使用ggplot2进行高级绘图

介绍ggplot2包使用形状、颜色和尺寸来对多元数据进行可视化用刻面图比较各组自定义ggplot2图19.1 R中的四种图形系统基础gridlatticeggplot2(用的较多)gghub需要的R包ggpolt2gridExtra(可以拼图)car19.2 ggplot2介绍library(ggplot2)ggplot(datamtcars, aes(xwt, ympg)) geom_p…

centos7 編譯 chmsee

安装libchm及相关的devel包&#xff0c;安装 xulrunner 及 devel 包&#xff01;否则后面make的时候会出错&#xff01; 到解压出来的chmsee/src目录下&#xff0c;找到与你系统对应的Makefile文件&#xff0c;我选的是Makefile.fedora&#xff0c;cp Makefile.fedora Makefile…