TensorFlow2实战-系列教程15:Resnet实战3

🧡💛💚TensorFlow2实战-系列教程 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传

Resnet实战1
Resnet实战2
Resnet实战3

7、训练脚本train.py解读------配置训练参数

# create modelmodel = get_model()# define loss and optimizerloss_object = tf.keras.losses.SparseCategoricalCrossentropy()optimizer = tf.keras.optimizers.Adam(lr=0.001)train_loss = tf.keras.metrics.Mean(name='train_loss')train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')valid_loss = tf.keras.metrics.Mean(name='valid_loss')valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy')
  1. 前面已经解析了模型的构建
  2. loss_object,多元交叉熵损失函数
  3. optimizer ,Adam优化器,学习率为0.001
  4. train_loss ,返回的是batch的平均损失
  5. train_accuracy ,loss计算方法对应的准确率计算方法
  6. valid_loss 和valid_accuracy 是验证集的平均损失和准确率计算方法

8、训练脚本train.py解读------模型训练

@tf.function
def train_step(images, labels):with tf.GradientTape() as tape:predictions = model(images, training=True)loss = loss_object(y_true=labels, y_pred=predictions)gradients = tape.gradient(loss, model.trainable_variables)optimizer.apply_gradients(grads_and_vars=zip(gradients, model.trainable_variables))train_loss(loss)train_accuracy(labels, predictions)@tf.function
def valid_step(images, labels):predictions = model(images, training=False)v_loss = loss_object(labels, predictions)valid_loss(v_loss)valid_accuracy(labels, predictions)

train_step(images, labels) 函数:

  • 装饰器 @tf.function 将这个函数转换为 TensorFlow 图,这可以提高执行效率。
  • with tf.GradientTape() as tape:这是一个自动微分的上下文管理器,用于记录在其内部执行的所有操作,以便于后续计算梯度
  • predictions = model(images, training=True):通过模型传递输入图像,得到预测结果。training=True 表示模型在训练模式下运行
  • loss = loss_object(y_true=labels, y_pred=predictions):计算真实标签和预测标签之间的损失。
  • gradients = tape.gradient(loss, model.trainable_variables):计算损失相对于模型可训练变量的梯度。
  • optimizer.apply_gradients(grads_and_vars=zip(gradients, model.trainable_variables)):应用梯度下降算法来更新模型的权重。
  • train_loss(loss)train_accuracy(labels, predictions):更新训练损失和准确率的指标。

valid_step(images, labels) 函数,不需要计算梯度,其他都一样

    for epoch in range(config.EPOCHS):train_loss.reset_states()train_accuracy.reset_states()valid_loss.reset_states()valid_accuracy.reset_states()step = 0for images, labels in train_dataset:step += 1train_step(images, labels)print("Epoch: {}/{}, step: {}/{}, loss: {:.5f}, accuracy: {:.5f}".format(epoch + 1, config.EPOCHS, step, math.ceil(train_count / config.BATCH_SIZE), train_loss.result(), train_accuracy.result()))for valid_images, valid_labels in valid_dataset:valid_step(valid_images, valid_labels)print("Epoch: {}/{}, train loss: {:.5f}, train accuracy: {:.5f}, ""valid loss: {:.5f}, valid accuracy: {:.5f}".format(epoch + 1, config.EPOCHS, train_loss.result(), train_accuracy.result(), valid_loss.result(), valid_accuracy.result()))model.save_weights(filepath=config.save_model_dir, save_format='tf')
  1. 逐个epoch执行训练
  2. 重置训练和验证的损失及准确率计算
  3. step 归0
  4. 训练集一个batch一个batch取数据
  5. step +1
  6. 调用train_step()函数训练当前batch数据
  7. 打印当前batch训练信息
  8. 验证集集一个batch一个batch取数据
  9. 调用valid_step()函数验证当前batch数据
  10. 在训练完成后,保存模型的权重

Resnet实战1
Resnet实战2
Resnet实战3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/657568.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

解读4篇混合类型文件Polyglot相关的论文

0. 引入 Polyglot文件指的是混合类型文件,关于混合类型文件的基础,请参考文末给出的第一个链接(参考1)。 1. Toward the Detection of Polyglot Files 1.1 主题 这篇2022年的论文,提出了Polyglot文件的检测方法。虽…

C++核心编程:文件操作 笔记

5.文件操作 程序运行时产生的数据都属于临时数据&#xff0c;程序一旦允许结束都会被释放。通过文件可以将数据持久化 C中对文件操作需要包含头文件<fstream> 文件类型分为两种&#xff1a; 文本文件 - 文件以文本的ASCII码形式存储在计算机中二进制文件 - 文件以文本…

openssl3.2 - .pod文件的查看方法

文章目录 .pod文件的查看方法概述笔记初步的解决方法备注 - pod2html.bat的详细用法好像Perl就自带这个BATEND .pod文件的查看方法 概述 看到openssl源码目录下有很多.pod文件, 软件发布的帮助内容都在里面. 当make install后, 大部分的.pod都会转成html文件, 但是有一部分…

【Java程序设计】【C00215】基于SSM的勤工助学管理系统(论文+PPT)

基于SSM的勤工助学管理系统&#xff08;论文PPT&#xff09; 项目简介项目获取开发环境项目技术运行截图 项目简介 这个一个基于SSM的勤工助学管理系统&#xff0c;本系统共分为三种权限&#xff1a;管理员、教师和学生 管理员&#xff1a;首页、个人中心、教师管理、学生管理…

逆置字符串

将字符串逆序,比如输入abcd,返回dcba void reverse(char*left,char *right) { while (right>left) { char temp *left; *left *right; *right temp; right--; left; } } int main() { char arr[100] { 0 };//定义…

gdp调试—Linux

目录 介绍 使用 介绍 代码分为debug模式和release模式 如果一份代码要被调试&#xff0c;这份代码必须是debug Linux下编译代码默认是是release模式 如果你想代码是debug模式 必须加上 - g 小提&#xff1a; vim默认&#xff1a;命令模式 gcc默认&#xff1a;releas…

操作系统--进程、线程基础知识

一、进程 我们编写的代码只是一个存储在硬盘的静态文件&#xff0c;通过编译后就会生成二进制可执行文件&#xff0c;当我们运行这个可执行文件后&#xff0c;它会被装载到内存中&#xff0c;接着 CPU 会执行程序中的每一条指令&#xff0c;那么这个运行中的程序&#xff0c;就…

ModelArts加速识别,助力新零售电商业务功能的实现

前言 如果说为客户提供最好的商品是产品眼中零售的本质&#xff0c;那么用户的思维是什么呢&#xff1f; 在用户眼中&#xff0c;极致的服务体验与优质的商品同等重要。 企业想要满足上面两项服务&#xff0c;关键在于提升效率&#xff0c;也就是需要有更高效率的零售&#…

C++ //练习 3.8 分别用while循环和传统的for循环重写第一题的程序,你觉得哪种形式更好呢?为什么?

C Primer&#xff08;第5版&#xff09; 练习 3.8 练习 3.8 分别用while循环和传统的for循环重写第一题的程序&#xff0c;你觉得哪种形式更好呢&#xff1f;为什么? 环境&#xff1a;Linux Ubuntu&#xff08;云服务器&#xff09; 工具&#xff1a;vim 代码块 /********…

【三】【C++】类与对象(二)

类的六个默认成员函数 在C中&#xff0c;有六个默认成员函数&#xff0c;它们是编译器在需要的情况下自动生成的成员函数&#xff0c;如果你不显式地定义它们&#xff0c;编译器会自动提供默认实现。这些默认成员函数包括&#xff1a; 默认构造函数 (Default Constructor)&…

C++ 数论相关题目 博弈论:拆分-Nim游戏

给定 n 堆石子&#xff0c;两位玩家轮流操作&#xff0c;每次操作可以取走其中的一堆石子&#xff0c;然后放入两堆规模更小的石子&#xff08;新堆规模可以为 0 &#xff0c;且两个新堆的石子总数可以大于取走的那堆石子数&#xff09;&#xff0c;最后无法进行操作的人视为失…

PMP中的数据收集工具:打开项目成功的钥匙

在项目管理中&#xff0c;数据收集是关键的一环。准确、及时的数据能够为项目决策提供可靠的依据&#xff0c;帮助项目经理更好地监控项目进展、识别潜在风险&#xff0c;并制定有效的应对措施。本文将深入探讨PMP&#xff08;项目管理专业&#xff09;中常用的数据收集工具&am…

二层环路和三层环路

环路的原因&#xff1a;二层环路是由于物理拓扑出现环路&#xff0c;如3台交换机3角形连接。 三层环路一般物理拓扑有环路&#xff0c;并且设备之间路由表形成互指。(物理拓扑不成环&#xff0c;2台设备使用静态路由互指也可能成环&#xff0c;这种特殊情况除外)。 二层设备和三…

力扣题目训练(6)

2024年1月30日力扣题目训练 2024年1月30日力扣题目训练367. 有效的完全平方数374. 猜数字大小383. 赎金信99. 恢复二叉搜索树105. 从前序与中序遍历序列构造二叉树51. N 皇后 2024年1月30日力扣题目训练 2024年1月30日第六天编程训练&#xff0c;今天主要是进行一些题训练&…

I2C 设备驱动

V5.10 参考文档&#xff1a;Documentation/i2c/writing-clients.rst static struct i2c_device_id foo_idtable[] {{ "foo", my_id_for_foo },{ "bar", my_id_for_bar },{ }};MODULE_DEVICE_TABLE(i2c, foo_idtable);static struct i2c_driver foo_drive…

在ubuntu上在安装Squid代理服务器

Squid 是一个代理和缓存服务器&#xff0c;它将请求转发到所需的目的地&#xff0c;同时保存请求的内容&#xff0c;当你再次请求相同内容时&#xff0c;他可以向你提供缓冲内容&#xff0c;从而提高访问速度。Squid代理服务器目前支持的协议有&#xff1a;http、SSL、DNS、FTP…

App测试中ios和Android有哪些区别呢?

App测试中&#xff0c;大家最常问到的问题就是&#xff1a;ios和 Android有什么区别呢&#xff1f; 在Android端&#xff0c;我们经常会使用 JavaScript、 HTML、 CSS等技术来编写一些简单的 UI界面。而 iOS端&#xff0c;我们经常会使用到 UI设计、界面布局、代码结构、 API等…

C++——特殊类

特殊类 文章目录 特殊类一、请设计一个类&#xff0c;不能被拷贝二、请设计一个类&#xff0c;只能在堆上创建对象方案一&#xff1a;析构函数私有化方案二&#xff1a;构造函数私有化 三、请设计一个类&#xff0c;只能在栈上创建对象四、请设计一个类&#xff0c;不能被继承五…

Linux系统MySQL重置root密码

MySQL是一种开源的关系型数据库管理系统&#xff0c;广泛用于Web应用程序的后台数据存储。在MySQL中&#xff0c;root是默认的超级用户&#xff0c;具有最高权限。然而&#xff0c;有时候我们可能会遇到忘记root密码的情况&#xff0c;或者需要重置root密码以增加数据库的安全性…

微软Office Plus与WPS Office的较量:办公软件市场将迎来巨变?

微软Office Plus在功能表现上远超WPS Office&#xff1f; 微软出品的Office套件实力强劲&#xff0c;其不仅在办公场景中扮演着不可或缺的角色&#xff0c;为用户带来高效便捷的体验&#xff0c;而且在娱乐生活管理等多元领域中同样展现出了卓越的应用价值 作为中国本土办公软…