Tensorflow2.0笔记 - metrics做损失和准确度信息度量

        本笔记主要记录metrics相关的内容,详细内容请参考代码注释,代码本身只使用了Accuracy和Mean。本节的代码基于上篇笔记FashionMnist的代码经过简单修改而来,上篇笔记链接如下:

Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博客文章浏览阅读339次。本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。关于FashionMnist的介绍,请自行百度。https://blog.csdn.net/vivo01/article/details/136921592?spm=1001.2014.3001.5502

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__#加载fashion mnist数据集
def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels#预处理数据
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32)x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)y = tf.convert_to_tensor(y, dtype=tf.int32)return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)batch_size = 128train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)#定义网络模型
model = Sequential([#Layer 1: [b, 784] => [b, 256]layers.Dense(256, activation=tf.nn.relu),#Layer 2: [b, 256] => [b, 128]layers.Dense(128, activation=tf.nn.relu),#Layer 3: [b, 128] => [b, 64]layers.Dense(64, activation=tf.nn.relu),#Layer 4: [b, 64] => [b, 32]layers.Dense(32, activation=tf.nn.relu),#Layer 5: [b, 32] => [b, 10], 输出类别结果layers.Dense(10)
])#编译网络
model.build(input_shape=[None, 28*28])
model.summary()#进行训练
total_epoches = 5
learn_rate = 0.01#Metrics统计
#参考资料:https://zhuanlan.zhihu.com/p/42438077
#1. 新建meter
#acc_meter = metrics.Accuracy()
#loss_meter = metrics.Mean()
#2. 更新状态, update_state()
#loss_meter.update_state(loss)
#acc_meter.update_state(y, pred)
#3.获取结果, result()
#print(step, 'loss:', loss_meter.result().numpy())
#print(step, 'Evaluate Acc:', total_correct/total, acc_meter.result().numpy())
#4.清除度量信息,reset_states()
#loss_meter.reset_states()
#acc_meter.reset_states()#新建准确度和loss度量对象
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()optimizer = optimizers.Adam(learning_rate = learn_rate)
for epoch in range(total_epoches):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:logits = model(x)y_onehot = tf.one_hot(y, depth=10)#使用交叉熵作为lossloss_ce = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True))#调用update_state更新loss度量信息loss_meter.update_state(loss_ce)#计算梯度grads = tape.gradient(loss_ce, model.trainable_variables)#更新梯度optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print("Epoch[", epoch, "]: step-", step, "\tloss: ", loss_meter.result().numpy())loss_meter.reset_states()#使用测试集进行验证total_correct = 0total_num = 0#清除准确度的统计信息acc_meter.reset_states()for x,y in test_db:logits = model(x)#使用softmax得到各个类别的概率prob = tf.nn.softmax(logits, axis=1)#求出概率最大的结果参数位置,作为预测的分类结果pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)#比较结果correct = tf.equal(pred, y)correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))#计算精度total_correct += int(correct)total_num += x.shape[0]#使用metircs的update_state进行更新acc_meter.update_state(y, pred)acc = total_correct / total_numprint("Epoch[", epoch, "] Manual Accuracy:", acc, " Metrics Accuracy:", acc_meter.result().numpy())

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/778527.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SAP BTP云上一个JVM与DB Connection纠缠的案例

前言 最近在CF (Cloud Foundry) 云平台上遇到一个比较经典的案例。因为牵扯到JVM (app进程)与数据库连接两大块,稍有不慎,很容易引起不快。 在云环境下,有时候相互扯皮的事蛮多。如果是DB的问题,就会找DB…

DDos系列攻击原理与防御原理

七层防御体系 静态过滤 命中黑名单 对确定是攻击的流量直接加入黑名单(源地址命中黑名单直接丢弃,缺乏机动性和扩展性) 畸形报文过滤 畸形报文攻击 TCP包含多个标记位,排列组合有规律 • 现象:TCP标记位全为1 …

docker:在ubuntu中运行docker容器

前言 1 本笔记本电脑运行的ubuntu20.04系统 2 docker运行在ubuntu20.04系统 3 docker镜像使用的是ubuntu18.04,这样拉的 docker pull ubuntu:18.04 4 docker容器中运行的是ubuntu18.04的系统,嗯就是严谨 5 这纯粹是学习笔记,实际上没啥价值。…

文件的输入与输出(史上最全)

文件的输入与输出 一个文件是一个存储在磁盘中带有指定名称和目录路径的数据集合。当打开文件进行读写时,它变成一个流。 从根本上说,流是通过通信路径传递的字节序列。有两个主要的流:输入流 和 输出流。输入流用于从文件读取数据&#xf…

指纹浏览器是什么?有哪些好用的推荐?

在网络世界中,保护您的在线隐私和安全非常重要。反检测浏览器是专门为此诞生的工具,旨在通过更改浏览器指纹来帮助您做到这一点,它们使网站、广告商和其他人很难跟踪您的在线行为。 一、什么是反检测浏览器? 您是否想过网站如何检…

C++万物起源:类与对象(二)

一、类的6个默认成员函数 如果一个类中什么成员都没有,简称为空类。 空类中真的什么都没有吗? 并不是,任何类在什么都不写时,编译器会自动生成以下6个默认成员 函数。 默认成员函数:用户没有显式实现,…

动态规划算法及Java实例

动态规划算法的基本概念 动态规划算法是一种解决复杂问题的有效方法,它通过将大问题分解为小问题,然后逐个解决这些小问题,最终通过组合小问题的解来得到大问题的解。这种方法的特点是充分利用了问题的重叠子问题和最优子结构的特性&#xf…

篮球论坛系统的设计与实现|Springboot+ Mysql+Java+ B/S结构(可运行源码+数据库+设计文档)

本项目包含可运行源码数据库LW,文末可获取本项目的所有资料。 推荐阅读100套最新项目持续更新中..... 2024年计算机毕业论文(设计)学生选题参考合集推荐收藏(包含Springboot、jsp、ssmvue等技术项目合集) 目录 1. …

Linux根据时间删除文件或目录

《liunx根据时间删除文件》和 《Linux 根据时间删除文件或者目录》已经讲述了根据时间删除文件或目录的方法。 下面我做一些补充,讲述一个具体例子。以删除/home目录下的文件为例。 首先通过命令: ls -l --time-style"%Y-%m-%d %H:%M:%S"…

Redis、Mysql双写情况下,如何保证数据一致

Redis、Mysql双写情况下,如何保证数据一致 场景谈谈数据一致性三个经典的缓存模式Cache-Aside Pattern读流程写流程 Read-Through/Write-Through(读写穿透)Write behind (异步缓存写入) 操作缓存的时候,删除…

【tensorflow框架神经网络实现鸢尾花分类】

文章目录 1、数据获取2、数据集构建3、模型的训练验证可视化训练过程 1、数据获取 从sklearn中获取鸢尾花数据,并合并处理 from sklearn.datasets import load_iris import pandas as pdx_data load_iris().data y_data load_iris().targetx_data pd.DataFrame…

git泄露

git泄露 CTFHub技能树-Web-信息泄露-备份文件下载 当前大量开发人员使用git进行版本控制,对站点自动部署。如果配置不当,可能会将.git文件夹直接部署到线上环境。这就引起了git泄露漏洞。 工具GitHack使用:python2 GitHack.py URL地址/.git/ git命令…

怎样才能把重建大师的空三导进去CC?

导出空三文件xml两者都是通用的,cc和photoscan都可以兼容。 重建大师是一款专为超大规模实景三维数据生产而设计的集群并行处理软件,输入倾斜照片,激光点云,POS信息及像控点,输出高精度彩色网格模型,可一键…

ros2相关代码记录

1.ros2概述 ROS2(Robot Operating System 2)是一个用于机器人应用程序的开源软件框架。它是ROS(Robot Operating System)的下一代版本,旨在改进和扩展原始ROS的特性,以适应更广泛的机器人应用场景和需求。…

Unity 实现鼠标左键进行射击

发射脚本实现思路 分析 确定用户交互方式:通过鼠标左键点击发射子弹。确定子弹发射逻辑:每次点击后有一定时间间隔才能再次发射。确定子弹发射源和方向:子弹从枪口(Transform)位置发射,沿枪口方向前进。 变…

Qt扫盲-QAssisant 集成其他qch帮助文档

QAssisant 集成其他qch帮助文档 一、概述二、Cmake qch例子1. 下载 Cmake.qch2. 添加qch1. 直接放置于Qt 帮助的目录下2. 在 QAssisant中添加 一、概述 QAssisant是一个很好的帮助文档,他提供了供我们在外部添加新的 qch帮助文档的功能接口,一般有两中添…

八大技术趋势案例(虚拟现实增强现实)

科技巨变,未来已来,八大技术趋势引领数字化时代。信息技术的迅猛发展,深刻改变了我们的生活、工作和生产方式。人工智能、物联网、云计算、大数据、虚拟现实、增强现实、区块链、量子计算等新兴技术在各行各业得到广泛应用,为各个领域带来了新的活力和变革。 为了更好地了解…

Composer常见错误解决

Composer 是 PHP 社区广泛使用的一个依赖管理工具,它帮助开发者定义、管理和安装项目所需的库。在使用 Composer 的过程中,可能会遇到各种错误和问题。以下是一些常见的 Composer 错误及其解决方法: 1. 内存限制错误 错误信息:P…

QT QInputDialog弹出消息框用法

使用QInputDialog类的静态方法来弹出对话框获取用户输入,缺点是不能自定义按钮的文字,默认为OK和Cancel: int main(int argc, char *argv[]) {QApplication a(argc, argv);bool isOK;QString text QInputDialog::getText(NULL, "Input …

李宏毅【生成式AI导论 2024】第6讲 大型语言模型修炼_第一阶段_ 自我学习累积实力

背景知识:机器怎么学会做文字接龙 详见:https://blog.csdn.net/qq_26557761/article/details/136986922?spm=1001.2014.3001.5501 在语言模型的修炼中,我们需要训练资料来找出数十亿个未知参数,这个过程叫做训练或学习。找到参数后,我们可以使用函数来进行文字接龙,拿…