Tensorflow2.0笔记 - metrics做损失和准确度信息度量

        本笔记主要记录metrics相关的内容,详细内容请参考代码注释,代码本身只使用了Accuracy和Mean。本节的代码基于上篇笔记FashionMnist的代码经过简单修改而来,上篇笔记链接如下:

Tensorflow2.0笔记 - FashionMnist数据集训练-CSDN博客文章浏览阅读339次。本笔记使用FashionMnist数据集,搭建一个5层的神经网络进行训练,并统计测试集的精度。本笔记中FashionMnist数据集是直接下载到本地加载的方式,不涉及用梯子。关于FashionMnist的介绍,请自行百度。https://blog.csdn.net/vivo01/article/details/136921592?spm=1001.2014.3001.5502

#Fashion Mnist数据集本地下载和加载(不用梯子)
#https://blog.csdn.net/scar2016/article/details/115361245 (百度网盘)
#https://blog.csdn.net/weixin_43272781/article/details/110006990 (github)
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metricstf.__version__#加载fashion mnist数据集
def load_mnist(path, kind='train'):import osimport gzipimport numpy as np"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels#预处理数据
def preprocess(x, y):x = tf.cast(x, dtype=tf.float32)x = tf.convert_to_tensor(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)y = tf.convert_to_tensor(y, dtype=tf.int32)return x, y
#训练数据
train_data, train_labels = load_mnist("./datasets")
print(train_data.shape, train_labels.shape)
#测试数据
test_data, test_labels = load_mnist("./datasets", "t10k")
print(test_data.shape, test_labels.shape)batch_size = 128train_db = tf.data.Dataset.from_tensor_slices((train_data, train_labels))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((test_data, test_labels))
test_db = test_db.map(preprocess).batch(batch_size)train_db_iter = iter(train_db)
sample = next(train_db_iter)
print('Batch:', sample[0].shape, sample[1].shape)#定义网络模型
model = Sequential([#Layer 1: [b, 784] => [b, 256]layers.Dense(256, activation=tf.nn.relu),#Layer 2: [b, 256] => [b, 128]layers.Dense(128, activation=tf.nn.relu),#Layer 3: [b, 128] => [b, 64]layers.Dense(64, activation=tf.nn.relu),#Layer 4: [b, 64] => [b, 32]layers.Dense(32, activation=tf.nn.relu),#Layer 5: [b, 32] => [b, 10], 输出类别结果layers.Dense(10)
])#编译网络
model.build(input_shape=[None, 28*28])
model.summary()#进行训练
total_epoches = 5
learn_rate = 0.01#Metrics统计
#参考资料:https://zhuanlan.zhihu.com/p/42438077
#1. 新建meter
#acc_meter = metrics.Accuracy()
#loss_meter = metrics.Mean()
#2. 更新状态, update_state()
#loss_meter.update_state(loss)
#acc_meter.update_state(y, pred)
#3.获取结果, result()
#print(step, 'loss:', loss_meter.result().numpy())
#print(step, 'Evaluate Acc:', total_correct/total, acc_meter.result().numpy())
#4.清除度量信息,reset_states()
#loss_meter.reset_states()
#acc_meter.reset_states()#新建准确度和loss度量对象
acc_meter = metrics.Accuracy()
loss_meter = metrics.Mean()optimizer = optimizers.Adam(learning_rate = learn_rate)
for epoch in range(total_epoches):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:logits = model(x)y_onehot = tf.one_hot(y, depth=10)#使用交叉熵作为lossloss_ce = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True))#调用update_state更新loss度量信息loss_meter.update_state(loss_ce)#计算梯度grads = tape.gradient(loss_ce, model.trainable_variables)#更新梯度optimizer.apply_gradients(zip(grads, model.trainable_variables))if step % 100 == 0:print("Epoch[", epoch, "]: step-", step, "\tloss: ", loss_meter.result().numpy())loss_meter.reset_states()#使用测试集进行验证total_correct = 0total_num = 0#清除准确度的统计信息acc_meter.reset_states()for x,y in test_db:logits = model(x)#使用softmax得到各个类别的概率prob = tf.nn.softmax(logits, axis=1)#求出概率最大的结果参数位置,作为预测的分类结果pred = tf.cast(tf.argmax(prob, axis=1), dtype=tf.int32)#比较结果correct = tf.equal(pred, y)correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))#计算精度total_correct += int(correct)total_num += x.shape[0]#使用metircs的update_state进行更新acc_meter.update_state(y, pred)acc = total_correct / total_numprint("Epoch[", epoch, "] Manual Accuracy:", acc, " Metrics Accuracy:", acc_meter.result().numpy())

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/778527.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SAP BTP云上一个JVM与DB Connection纠缠的案例

前言 最近在CF (Cloud Foundry) 云平台上遇到一个比较经典的案例。因为牵扯到JVM (app进程)与数据库连接两大块,稍有不慎,很容易引起不快。 在云环境下,有时候相互扯皮的事蛮多。如果是DB的问题,就会找DB…

DDos系列攻击原理与防御原理

七层防御体系 静态过滤 命中黑名单 对确定是攻击的流量直接加入黑名单(源地址命中黑名单直接丢弃,缺乏机动性和扩展性) 畸形报文过滤 畸形报文攻击 TCP包含多个标记位,排列组合有规律 • 现象:TCP标记位全为1 …

docker:在ubuntu中运行docker容器

前言 1 本笔记本电脑运行的ubuntu20.04系统 2 docker运行在ubuntu20.04系统 3 docker镜像使用的是ubuntu18.04,这样拉的 docker pull ubuntu:18.04 4 docker容器中运行的是ubuntu18.04的系统,嗯就是严谨 5 这纯粹是学习笔记,实际上没啥价值。…

指纹浏览器是什么?有哪些好用的推荐?

在网络世界中,保护您的在线隐私和安全非常重要。反检测浏览器是专门为此诞生的工具,旨在通过更改浏览器指纹来帮助您做到这一点,它们使网站、广告商和其他人很难跟踪您的在线行为。 一、什么是反检测浏览器? 您是否想过网站如何检…

C++万物起源:类与对象(二)

一、类的6个默认成员函数 如果一个类中什么成员都没有,简称为空类。 空类中真的什么都没有吗? 并不是,任何类在什么都不写时,编译器会自动生成以下6个默认成员 函数。 默认成员函数:用户没有显式实现,…

篮球论坛系统的设计与实现|Springboot+ Mysql+Java+ B/S结构(可运行源码+数据库+设计文档)

本项目包含可运行源码数据库LW,文末可获取本项目的所有资料。 推荐阅读100套最新项目持续更新中..... 2024年计算机毕业论文(设计)学生选题参考合集推荐收藏(包含Springboot、jsp、ssmvue等技术项目合集) 目录 1. …

Linux根据时间删除文件或目录

《liunx根据时间删除文件》和 《Linux 根据时间删除文件或者目录》已经讲述了根据时间删除文件或目录的方法。 下面我做一些补充,讲述一个具体例子。以删除/home目录下的文件为例。 首先通过命令: ls -l --time-style"%Y-%m-%d %H:%M:%S"…

Redis、Mysql双写情况下,如何保证数据一致

Redis、Mysql双写情况下,如何保证数据一致 场景谈谈数据一致性三个经典的缓存模式Cache-Aside Pattern读流程写流程 Read-Through/Write-Through(读写穿透)Write behind (异步缓存写入) 操作缓存的时候,删除…

【tensorflow框架神经网络实现鸢尾花分类】

文章目录 1、数据获取2、数据集构建3、模型的训练验证可视化训练过程 1、数据获取 从sklearn中获取鸢尾花数据,并合并处理 from sklearn.datasets import load_iris import pandas as pdx_data load_iris().data y_data load_iris().targetx_data pd.DataFrame…

ros2相关代码记录

1.ros2概述 ROS2(Robot Operating System 2)是一个用于机器人应用程序的开源软件框架。它是ROS(Robot Operating System)的下一代版本,旨在改进和扩展原始ROS的特性,以适应更广泛的机器人应用场景和需求。…

Unity 实现鼠标左键进行射击

发射脚本实现思路 分析 确定用户交互方式:通过鼠标左键点击发射子弹。确定子弹发射逻辑:每次点击后有一定时间间隔才能再次发射。确定子弹发射源和方向:子弹从枪口(Transform)位置发射,沿枪口方向前进。 变…

Qt扫盲-QAssisant 集成其他qch帮助文档

QAssisant 集成其他qch帮助文档 一、概述二、Cmake qch例子1. 下载 Cmake.qch2. 添加qch1. 直接放置于Qt 帮助的目录下2. 在 QAssisant中添加 一、概述 QAssisant是一个很好的帮助文档,他提供了供我们在外部添加新的 qch帮助文档的功能接口,一般有两中添…

八大技术趋势案例(虚拟现实增强现实)

科技巨变,未来已来,八大技术趋势引领数字化时代。信息技术的迅猛发展,深刻改变了我们的生活、工作和生产方式。人工智能、物联网、云计算、大数据、虚拟现实、增强现实、区块链、量子计算等新兴技术在各行各业得到广泛应用,为各个领域带来了新的活力和变革。 为了更好地了解…

QT QInputDialog弹出消息框用法

使用QInputDialog类的静态方法来弹出对话框获取用户输入,缺点是不能自定义按钮的文字,默认为OK和Cancel: int main(int argc, char *argv[]) {QApplication a(argc, argv);bool isOK;QString text QInputDialog::getText(NULL, "Input …

李宏毅【生成式AI导论 2024】第6讲 大型语言模型修炼_第一阶段_ 自我学习累积实力

背景知识:机器怎么学会做文字接龙 详见:https://blog.csdn.net/qq_26557761/article/details/136986922?spm=1001.2014.3001.5501 在语言模型的修炼中,我们需要训练资料来找出数十亿个未知参数,这个过程叫做训练或学习。找到参数后,我们可以使用函数来进行文字接龙,拿…

【数据分析面试】3.编写数据选取函数(Python)

题目 给定了一个名为 students_df 的学生数据表格 nameagefavorite_colorgradeTim Voss19red91Nicole Johnson20yellow95Elsa Williams21green82John James20blue75Catherine Jones23green93 编写一个名为 grades_colors 的函数,以选择仅当学生喜欢的颜色是绿色或…

2024最新Guitar Pro 8.1中文版永久许可证激活

Guitar Pro是一款非常受欢迎的音乐制作软件,它可以帮助用户创建和编辑各种音乐曲谱。从其诞生以来就送专门为了编写吉他谱而研发迭代的。 尽管这款产品可能已经成为全球最受欢迎的吉他打谱软件,在编写吉他六线谱和乐队总谱中始终处于行业领先地位&#x…

ESCTF-密码赛题WP

*小学生的爱情* Base64解码获得flag *中学生的爱情* 社会主义核心价值观在线解码得到flag http://www.atoolbox.net/Tool.php?Id850 *高中生的爱情* U2FsdG开头为rabbit密码,又提示你密钥为love。本地toolfx密码工具箱解密。不知道为什么在线解密不行。 *大学生的爱情* …

jira安装与配置

1. 环境准备 环境要求 1) JDK1.8以上环境配置 2) Mysql数据库5.7.13 3) Jira版本7及破解包 1.1 JDK1.8安装配置 1) 首先下载 JDK1.8, - 网址:https://www.oracle.com/cn/java/technologies/javase/javase-jdk8-downloads.html - windows64 版&am…

机器学习优化算法(深度学习)

目录 预备知识 梯度 Hessian 矩阵(海森矩阵,或者黑塞矩阵) 拉格朗日中值定理 柯西中值定理 泰勒公式 黑塞矩阵(Hessian矩阵) Jacobi 矩阵 优化方法 梯度下降法(Gradient Descent) 随机…