Tensorflow2.0笔记 - 使用卷积神经网络层做CIFA100数据集训练(类VGG13)

        本笔记记录CNN做CIFAR100数据集的训练相关内容,代码中使用了类似VGG13的网络结构,做了两个Sequetial(CNN和全连接层),没有用Flatten层而是用reshape操作做CNN和全连接层的中转操作。由于网络层次较深,参数量相比之前的网络多了不少,因此只做了10次epoch(RTX4090),没有继续跑了,最终准确率大概在33.8%左右。

import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Inputos.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__#如果下载很慢,可以使用迅雷下载到本地,迅雷的链接也可以直接用官网URL:
#      https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
#下载好后,将cifar-100.python.tar.gz放到 .keras\datasets 目录下(我的环境是C:\Users\Administrator\.keras\datasets)
# 参考:https://blog.csdn.net/zy_like_study/article/details/104219259
(x_train,y_train), (x_test, y_test) = datasets.cifar100.load_data()
print("Train data shape:", x_train.shape)
print("Train label shape:", y_train.shape)
print("Test data shape:", x_test.shape)
print("Test label shape:", y_test.shape)def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)return x,yy_train = tf.squeeze(y_train, axis=1)
y_test = tf.squeeze(y_test, axis=1)batch_size = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(1000).map(preprocess).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)sample = next(iter(train_db))
print("Train data sample:", sample[0].shape, sample[1].shape, tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))#创建CNN网络,总共4个unit,每个unit主要是两个卷积层和Max Pooling池化层
cnn_layers = [#unit 1layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(64, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),#unit 2layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(128, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),#unit 3layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(256, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),#unit 4layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),#unit 5layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),layers.Conv2D(512, kernel_size=[3,3], padding='same', activation='relu'),#layers.MaxPool2D(pool_size=[2,2], strides=2, padding='same'),layers.MaxPool2D(pool_size=[2,2], strides=2),
]def main():#[b, 32, 32, 3] => [b, 1, 1, 512]cnn_net = Sequential(cnn_layers)cnn_net.build(input_shape=[None, 32, 32, 3])#测试一下卷积层的输出#x = tf.random.normal([4, 32, 32, 3])#out = cnn_net(x)#print(out.shape)#创建全连接层, 输出为100分类fc_net = Sequential([layers.Dense(256, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(100, activation=None),])fc_net.build(input_shape=[None, 512])#设置优化器optimizer = optimizers.Adam(learning_rate=1e-4)#记录cnn层和全连接层所有可训练参数, 实现的效果类似list拼接,比如# [1, 2] + [3, 4] => [1, 2, 3, 4]variables = cnn_net.trainable_variables + fc_net.trainable_variables#进行训练num_epoches = 10for epoch in range(num_epoches):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:#[b, 32, 32, 3] => [b, 1, 1, 512]out = cnn_net(x)#flatten打平 => [b, 512]out = tf.reshape(out, [-1, 512])#使用全连接层做100分类logits输出#[b, 512] => [b, 100]logits = fc_net(out)#标签做one_hot encodingy_onehot = tf.one_hot(y, depth=100)#计算损失loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)loss = tf.reduce_mean(loss)#计算梯度grads = tape.gradient(loss, variables)#更新参数optimizer.apply_gradients(zip(grads, variables))if (step % 100 == 0):print("Epoch[", epoch + 1, "/", num_epoches, "]: step-", step, " loss:", float(loss))#进行验证total_samples = 0total_correct = 0for x,y in test_db:out = cnn_net(x)out = tf.reshape(out, [-1, 512])logits = fc_net(out)prob = tf.nn.softmax(logits, axis=1)pred = tf.argmax(prob, axis=1)pred = tf.cast(pred, dtype=tf.int32)correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)correct = tf.reduce_sum(correct)total_samples += x.shape[0]total_correct += int(correct)#统计准确率acc = total_correct / total_samplesprint("Epoch[", epoch + 1, "/", num_epoches, "]: accuracy:", acc)
if __name__ == '__main__':main()

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/842.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

centos搭建yum源

目录 1.createrepo简介 2.repo搭建思路 3.安装 4.使用 1.createrepo简介 createrepo 是一个用于创建 RPM 包的工具,它可以帮助你创建一个本地的 YUM 仓库。createrepo 并不是用于运行 YUM 仓库服务的软件,而是用来生成仓库的元数据,使得…

区分软件成熟度模型集成的五个等级

概念讲解 软件成熟度模型集成(CMMI,Capability Maturity Model Integration)是一种评估和改进软件开发过程的模型。CMMI的五个成熟度等级分别是: 初始级(Level 1 - Initial):在这个等级&#x…

【Taro3踩坑日记】找不到sass的类型定义文件

问题截图如下:找不到sass的类型定义文件 解决办法: 1、npm i types/sass1.43.1 2、然后配置 TypeScript 编译选项:确保 TypeScript 编译器能够识别 Sass 文件,并正确处理它们。

在一个态势感知复杂网络系统中,存在着态、势、感、知四种损失函数和梯度变化...

反向传播是一种用于训练神经网络的常用技术,它通过计算损失函数对网络参数的梯度,然后利用梯度下降等优化算法来更新参数,从而使网络逐步优化以减少损失函数的值。 在反向传播中,损失函数的选择非常重要,通常采用的损失…

PyTorch的核心概念

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「stormsha的主页」…

AWS账号注册以及Claude 3 模型使用教程!

哈喽哈喽大家好呀,伙伴们!你听说了吗?最近AWS托管了大热模型:Claude 3 Opus!想要一探究竟吗?那就赶紧来注册AWS账号吧!别担心,现在注册还免费呢!而且在AWS上还有更多的大…

【Linux】认识文件(一):文件标识符

【Linux】认识文件(一):文件标识符 一.什么是文件?1.文件的本质2.文件的分类 二.访问文件操作1.C语言中的访问文件接口i.fopenii.fcloseiii.fwrite 2.系统访问文件接口i.openii.closeiii.write 三.文件管理1.对所有打开文件的管理…

FlinkCDC基础篇章2-数据源 SqlServerCDC写入到ES中

接着 上期FlinkCDC基础篇章1-安装使用 下载 Flink 和所需要的依赖包 # 下载 Flink 1.17.0 并将其解压至目录 flink-1.17.0 下载下面列出的依赖包,并将它们放到目录 flink-1.17.0/lib/ 下: 下载链接只对已发布的版本有效, SNAPSHOT 版本需要本地编译 …

大华相机C#学习之IStream类

获取方式 IDevice.StreamGrabber 设备实例对象的StreamGrabber属性可以返回IStream对象。 常用属性 IsStart 判断是否开启捕获码流&#xff0c;是&#xff0c;返回true&#xff1b;否&#xff0c;返回false。 private void test_Click(object sender, EventArgs e) {List<…

【分治】Leetcode 数组中的第K个最大元素

题目讲解 数组中的第K个最大元素 算法讲解 堆排序&#xff1a;1. 寻找最后一个节点的父亲&#xff0c;依次向上遍历&#xff0c;完成小堆的建立&#xff1b;2. 从最后一个元素开始&#xff0c;和堆顶的数据做交换&#xff0c;此时最小的数据在对后面&#xff0c;然后对剩下的…

sql知识总结二(接上)

2.updatexml报错注入 &#xff08;1&#xff09;判断字符型/数字型 如果是字符型再判断闭合方式&#xff08;备注&#xff1a;结尾加--看是否闭合&#xff0c;若页面正常执行则闭合完成&#xff09; ?id1") and 1updatexml(1,concat(0x7e,(select group_concat(table_…

Go语言中栈和堆对数据密集型应用程序性能的影响

在 Go 中,变量可以被分配在栈上或堆上。这两种类型的内存在根本上是不相同的,它们可以显著影响数据密集型应用程序的性能。 1. 栈 vs 堆 首先,让我们讨论一下栈和堆的区别。栈是默认内存;它是一种后进先出(LIFO)的数据结构,用于存储特定 goroutine 的所有局部变量。当一…

部署轻量级Gitea替代GitLab进行版本控制(一)

Gitea 是一款使用 Golang 编写的可自运营的代码管理工具。 Gitea Official Website gitea: Gitea的首要目标是创建一个极易安装&#xff0c;运行非常快速&#xff0c;安装和使用体验良好的自建 Git 服务。我们采用Go作为后端语言&#xff0c;这使我们只要生成一个可执行程序即…

【React】Sigma.js框架网络图-入门篇

一、介绍 Sigma.js是一个专门用于图形绘制的JavaScript库。 它使在Web页面上发布网络变得容易&#xff0c;并允许开发人员将网络探索集成到丰富的Web应用程序中。 Sigma.js提供了许多内置功能&#xff0c;例如Canvas和WebGL渲染器或鼠标和触摸支持&#xff0c;以使用户在网页上…

Echarts-丝带图

Echarts-丝带图 demo地址 打开CodePen 什么是丝带图&#xff1f; 丝带图是Power BI中独有额可视化视觉对象&#xff0c;它的工具提示能展示指标当期与下期的数据以及排名。需求&#xff1a;使用丝带图展示"2022年点播订单表"不同月份不同点播套餐对应订单数据。 …

vim之一键替换

Vim的substitute命令是一个非常强大的文本替换工具&#xff0c;它允许用户在整个文件或指定范围内执行文本替换操作。 命令格式 substitute命令的基本格式如下&#xff1a; :[range]s[ubstitute]/{pattern}/{string}/[flags] 其中&#xff1a; [range] 指定替换操作的范围…

搭建HBase2.x完全分布式集群(CentOS 9 + Hadoop3.x)

Apache HBase™是一个分布式、可扩展、大数据存储的Hadoop数据库。 当我们需要对大数据进行随机、实时的读/写访问时&#xff0c;可以使用HBase。这个项目的目标是在通用硬件集群上托管非常大的表——数十亿行X数百万列。Apache HBase是一个开源、分布式、版本化的非关系数据库…

宝塔手动安装grafana

1.下载 # 进入目标目录 cd /data/prometheus/ # 下载 wget https://dl.grafana.com/oss/release/grafana-8.0.4-1.x86_64.rpm # 安装 sudo yum install grafana-8.0.4-1.x86_64.rpm 2.运行项目 # 启动 /etc/init.d/grafana-server start 3.修改配置文件全局搜索 defaults.i…

【AIGC调研系列】llama 3与GPT4相比的优劣点

Llama 3与GPT-4相比&#xff0c;各有其优劣点。以下是基于我搜索到的资料的详细分析&#xff1a; Llama 3的优点&#xff1a; 更大的数据集和参数规模&#xff1a;Llama 3基于超过15T token的训练&#xff0c;这相当于Llama 2数据集的7倍还多[1][3]。此外&#xff0c;它拥有4…

Ceph学习 -11.块存储RBD接口

文章目录 RBD接口1.基础知识1.1 基础知识1.2 简单实践1.3 小结 2.镜像管理2.1 基础知识2.2 简单实践2.3 小结 3.镜像实践3.1 基础知识3.2 简单实践3.3 小结 4.容量管理4.1 基础知识4.2 简单实践4.3 小结 5.快照管理5.1 基础知识5.2 简单实践5.3 小结 6.快照分层6.1 基础知识6.2…