TensorFlow 2.0 - 自定义模型、训练过程

文章目录

    • 1. 自定义模型
    • 2. 学习流程

学习于:简单粗暴 TensorFlow 2

1. 自定义模型

  • 重载 call() 方法,pytorch 是重载 forward() 方法
import tensorflow as tf
X = tf.constant([[1.0, 2.0, 3.0],[4.0, 5.0, 6.0]])
y = tf.constant([[10.0],[20.0]])class Linear(tf.keras.Model):def __init__(self):super().__init__()self.dense = tf.keras.layers.Dense(units=1,activation=None,kernel_initializer=tf.zeros_initializer(),bias_initializer=tf.zeros_initializer())def call(self, input): # 重载 call 方法output = self.dense(input)return outputmodel = Linear()# 优化器
optimizer = tf.keras.optimizers.SGD(learning_rate=0.001)for i in range(100):with tf.GradientTape() as tape: # 梯度记录器y_pred = model(X)loss = tf.reduce_mean(tf.square(y_pred-y)) # 损失grads = tape.gradient(loss, model.variables) # 求导# 更新参数optimizer.apply_gradients(grads_and_vars=zip(grads,model.variables))

2. 学习流程

  • 加载手写数字数据集
class MNistLoader():def __init__(self):data = tf.keras.datasets.mnist# 加载数据(self.train_data, self.train_label),(self.test_data, self.test_label) = data.load_data()# 扩展维度,灰度图1通道 [batch_size, 28, 28, chanels=1]self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)self.train_label = self.train_label.astype(np.int32)self.test_label = self.test_label.astype(np.int32)# 样本个数self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]def get_batch(self, batch_size):# 从训练集里随机取出 batch_size 个样本idx = np.random.randint(0, self.num_train_data, batch_size)return self.train_data[idx, :], self.train_label[idx]
  • 定义模型
# 自定义多层感知机模型
class MLPmodel(tf.keras.Model):def __init__(self):super().__init__()# 除第一维以外的维度展平self.flatten = tf.keras.layers.Flatten()self.dense1 = tf.keras.layers.Dense(units=100, activation='relu')self.dense2 = tf.keras.layers.Dense(units=10)def call(self, input):x = self.flatten(input)x = self.dense1(x)x = self.dense2(x)output = tf.nn.softmax(x)return output
  • 训练
# 参数
num_epochs = 5
batch_size = 50
learning_rate = 1e-4# 模型实例
mymodel = MLPmodel()
# 数据加载
data_loader = MNistLoader()
# adam 优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)num_batches = int(data_loader.num_train_data//batch_size * num_epochs)
# 训练
for idx in range(num_batches):# 取出数据X,y = data_loader.get_batch(batch_size)with tf.GradientTape() as tape: # 梯度记录y_pred = mymodel(X) # 预测# 计算交叉熵损失loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)loss = tf.reduce_mean(loss)print("batch {}, loss {}".format(idx, loss.numpy()))# 计算梯度grads = tape.gradient(loss, mymodel.variables)# 更新参数optimizer.apply_gradients(grads_and_vars=zip(grads, mymodel.variables))
  • 预测
# 评估标准
sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
num_batches = int(data_loader.num_test_data // batch_size)
# 预测
for idx in range(num_batches):# 数据区间start, end = idx*batch_size, (idx+1)*batch_size# 放入模型,预测y_pred = mymodel.predict(data_loader.test_data[start : end])# 统计更新 预测信息    sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start:end],y_pred=y_pred)
print("test 准确率:{}".format(sparse_categorical_accuracy.result()))
# test 准确率:0.9455000162124634

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/473081.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

谁动了我的产品

2014年3月中旬离开了自己奋斗三年的公司,这是一家海关政府公司,三年里无论是做项目需求分析、项目开发、项目测试、项目上线实施、项目上线跟踪、收集反馈、做项目版本修改,我和我的团队都在一个有非常明确目标、有非常明确思路的过程中&…

LeetCode 352. 将数据流变为多个不相交区间(map二分查找)

文章目录1. 题目2. 解题1. 题目 给定一个非负整数的数据流输入 a1,a2,…,an,…,将到目前为止看到的数字总结为不相交的区间列表。 例如,假设数据流中的整数为 1,3,7,2&…

windows键按了没反应_windows快捷键使用 - 小怜

1、总的参考图:2、ctrl的组合使用:1与shift键结合:2 ctrl shift del # 快速清除浏览器缓存记录3 ctrl shift N # 浏览器当中,快速打开无痕新窗口。chrome内核的应该都可以,chrome和新…

Python倒计时自动发微信(电脑版微信)

一、前言: Python倒计时自动发微信(电脑版微信登录状态) 二、主要思路及步骤: 1、先启动微信 2、定位到搜索框 3、搜索微信 4、进入聊天窗口 5、粘贴文本内容 6、发送 7、关闭微信窗口 三、代码: import …

Javascript日期函数使用需要注意地方

当我们需要获取未来某个时间的毫秒数时,可能第一时间想到的方法是 (new Date("2014-12-08 12:00:00")).getTime() 这个在方法在chrome下是会返回一个数字的。 但是在IE下返回的是一个NaN,是因为IE下认为 YYYY-mm-dd HH:ii:ss不是一个合理的日期…

python判断是工作日还是休息日

一、概述 最近在做数据分析,需要判断一个日期是否为工作日,节假日。 找到一个现成的插件,蛮好用的。 1.1、插件介绍 chinesecalendar PyPI 判断某年某月某一天是不是工作日/节假日。 1.2、安装 pip install chinese_calendar该模块常…

win10主题更换_还不升级? win10精简版不到10G,运行比win7还快,旧电脑的福音

即使现在win7系统已经停止了服务,但是还有许多人宁愿面对随时有可能出现问题的win7,还是不愿升级win10系统。至于原因,五花八门,比如win7兼容性和稳定性更好,比如win10经常更新,还有许多软件无法在win10环境…

LeetCode 1732. 找到最高海拔

文章目录1. 题目2. 解题1. 题目 有一个自行车手打算进行一场公路骑行,这条路线总共由 n 1 个不同海拔的点组成。 自行车手从海拔为 0 的点 0 开始骑行。 给你一个长度为 n 的整数数组 gain ,其中 gain[i] 是点 i 和点 i 1 的 净海拔高度差&#xff0…

小案例:利用python估算最外轮廓区域面积

一、需求: 给出一张图片,估算最外轮廓区域面积 二、步骤: 1、读取图片信息 2、利用open-cv,自适应分割图片 3、提取最外轮廓像素值 4、利用像素值标记轮廓 5、计算轮廓面积 三、代码: import cv2# 读取图片信息…

ApplicationContext容器的设计原理

1.在ApplicationContext容器中,我们以常用的FileSystemXmlApplicationContext的实现为例来说明ApplicationContext容器的设计原理。 2.在FileSystemXmlApplicationContext的设计中,我们看到ApplicationContext应用上下文的主要功能已经在FileSystemXmlAp…

使用c++查看linux服务器某个进程正在使用的内存_Linux 系统管理

1、进程管理介绍什么是进程程序是人使用计算机语言编写的,可以实现一定功能,并且可以执行的代码集合进程是正在执行当中的程序。程序在执行时,执行人的权限和属性、以及程序的代码都会被加载进内存,操作系统给这个进程分配一个 ID…

小案例:利用Python写个教师常用的点名软件

一、需求: 教师上课常用的点名软件 二、python库安装: openpyxl是Python中用于读写excel文件tkinter是Python中GUI编程非常好用的库,而且是标准库,不需要安装,导入即可使用random库是Python中用于实现随机功能的库&…

如何用DNS+GeoIP+Nginx+Varnish做世界级的CDN

如何用DNSGeoIPNginxVarnish做世界级的CDN 如何用BIND, GeoIP, Nginx, Varnish来创建你自己的高效的CDN网络?CDN,意思是Content Distrubtion Network,意思是内容分发网络,简单的说,就是全地域范围内的负载均衡&#xf…

python contains类似函数_01--实际工作中,python基础理念和数据处理

1.工作中遇到的python坑1.1 合并文件问题:正常将文件依次读取并append时,莫名出现很多空行。解决:在append前删除空行:data_tmp 1.2 重复数据行问题: append多日文件时,由于人工误操作,容易存在…

小案例:利用Python实现图片上下、左右翻转

一、前言需求: 对图片进行操作,使图片上下、左右翻转 二、函数库: 使用Pillow模块提供的transpose()方法可以让图像翻转,上下翻转,或者左右翻转 三、操作说明: 原图如下: 图片上下翻转代码…

LeetCode 1736. 替换隐藏数字得到的最晚时间

文章目录1. 题目2. 解题1. 题目 给你一个字符串 time ,格式为 hh:mm(小时:分钟),其中某几位数字被隐藏(用 ? 表示)。 有效的时间为 00:00 到 23:59 之间的所有时间,包括 00:00 和…

【SSH进阶之路】一步步重构MVC实现Struts框架——封装业务逻辑和跳转路径(四)...

目录: 【SSH进阶之路】Struts基本原理 实现简单登录(二) 【SSH进阶之路】一步步重构MVC实现Struts框架——从一个简单MVC开始(三) 【SSH进阶之路】一步步重构MVC实现Struts框架——封装业务逻辑和跳转路径&#xff08…

实用工具:推荐Pycharm常用的几款插件

相信对于不少的Python程序员们都是用Pycharm作为开发时候的IDE来使用的,今天来分享几个好用到爆的Pycharm插件,在安装上之后,你的编程效率、工作效率都能够得到极大地提升。 Pycharm插件安装教程 打开file---settings---plugins&#xff0c…

dataframe 添加一行_R语言Data Frame数据框常用操作

来源 | R友舍Data Frame一般被翻译为数据框,感觉就像是R中的表,由行和列组成,与Matrix不同的是,每个列可以是不同的数据类型,而Matrix是必须相同的。Data Frame每一列有列名,每一行也可以指定行名。如果不指…

LeetCode 1737. 满足三条件之一需改变的最少字符数(计数)

文章目录1. 题目2. 解题1. 题目 给你两个字符串 a 和 b ,二者均由小写字母组成。 一步操作中,你可以将 a 或 b 中的 任一字符 改变为 任一小写字母 。 操作的最终目标是满足下列三个条件 之一 : a 中的 每个字母 在字母表中 严格小于 b 中…