Tensorflow2.0笔记 - ResNet实践

        本笔记记录使用ResNet18网络结构,进行CIFAR100数据集的训练和验证。由于参数较多,训练时间会比较长,因此只跑了10个epoch,准确率还没有提升上去。

import os
import time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, Inputos.environ['TF_CPP_MIN_LOG_LEVEL']='2'
#tf.random.set_seed(12345)
tf.__version__#关于ResNet的描述,可以参考如下链接:
#https://blog.csdn.net/qq_39770163/article/details/126169080
#代码基于ResNet18结构,有少许不一样
class BasicBlock(layers.Layer):def __init__(self, filter_num, strides = 1):super(BasicBlock, self).__init__()#卷积层1self.conv1 = layers.Conv2D(filter_num, (3,3), strides = strides, padding='same')#BN层self.bn1 = layers.BatchNormalization()#Relu层self.relu = layers.Activation('relu')#卷积层2,BN层2,self.conv2 = layers.Conv2D(filter_num, (3,3), strides = 1, padding='same')self.bn2 = layers.BatchNormalization()#Shortcutif strides != 1:#如果strides不为1,需要下采样self.downsample = Sequential()self.downsample.add(layers.Conv2D(filter_num, (1,1), strides=strides))else:#strides为1, 直接返回原始值即可self.downsample = lambda x:xdef call(self, inputs, training = None):#经过第一个卷积层,BN和Reluout = self.conv1(inputs)out = self.bn1(out)out = self.relu(out)#经过第二个卷积层out = self.conv2(out)out = self.bn2(out)#Shortt处理,out和输入相加identity = self.downsample(inputs)output = layers.add([out, identity])#再经过一个reluoutput = tf.nn.relu(output)return outputclass ResNet(keras.Model):#layer_dims表示对应位置的ResBlock包含了几个BasicBlock#比如[2,2,2,2] => 总共4个ResBlock,每个ResBlock包含两个BasicBlock#num_classes表示输出的类别的个数def __init__(self, layer_dims, num_classes=100):super(ResNet, self).__init__()#预处理单元self.stem = Sequential([layers.Conv2D(64, (3,3), strides=(1,1)),layers.BatchNormalization(),layers.Activation('relu'),layers.MaxPool2D(pool_size=(2,2), strides=(1,1), padding='same')])#创建中间ResBlock层self.layer1 = self.buildResBlock(64, layer_dims[0])self.layer2 = self.buildResBlock(128, layer_dims[1], strides=2)self.layer3 = self.buildResBlock(256, layer_dims[2], strides=2)self.layer4 = self.buildResBlock(512, layer_dims[3], strides=2)#自适应输出层self.avgpool = layers.GlobalAveragePooling2D()#全连接层self.fc = layers.Dense(num_classes)def call(self, inputs, training = None):x = self.stem(inputs)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)#经过avgpool => [b, 512]x = self.avgpool(x)#经过Dense => [b, 100]x = self.fc(x)return xdef buildResBlock(self, filter_num, blocks, strides = 1):resBlocks = Sequential()resBlocks.add(BasicBlock(filter_num, strides))#后续的resBlock的strides都设置为1for _ in range(1, blocks):resBlocks.add(BasicBlock(filter_num))return resBlocks;def ResNet18():return ResNet([2, 2, 2 ,2]);def ResNet34():return ResNet([3, 4, 6, 3])#加载CIFAR100数据集
#如果下载很慢,可以使用迅雷下载到本地,迅雷的链接也可以直接用官网URL:
#      https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
#下载好后,将cifar-100.python.tar.gz放到 .keras\datasets 目录下(我的环境是C:\Users\Administrator\.keras\datasets)
# 参考:https://blog.csdn.net/zy_like_study/article/details/104219259
(x_train,y_train), (x_test, y_test) = datasets.cifar100.load_data()
print("Train data shape:", x_train.shape)
print("Train label shape:", y_train.shape)
print("Test data shape:", x_test.shape)
print("Test label shape:", y_test.shape)def preprocess(x, y):x = tf.cast(x, dtype=tf.float32) / 255.y = tf.cast(y, dtype=tf.int32)return x,yy_train = tf.squeeze(y_train, axis=1)
y_test = tf.squeeze(y_test, axis=1)batch_size = 128
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.shuffle(1000).map(preprocess).batch(batch_size)test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(batch_size)sample = next(iter(train_db))
print("Train data sample:", sample[0].shape, sample[1].shape, tf.reduce_min(sample[0]), tf.reduce_max(sample[0]))def main():#创建ResNetresNet = ResNet18()resNet.build(input_shape=[None, 32, 32, 3])resNet.summary()#设置优化器optimizer = optimizers.Adam(learning_rate=1e-3)#进行训练num_epoches = 10for epoch in range(num_epoches):for step, (x,y) in enumerate(train_db):with tf.GradientTape() as tape:#[b, 32, 32, 3] => [b, 100]logits = resNet(x)#标签做one_hot encodingy_onehot = tf.one_hot(y, depth=100)#计算损失loss = tf.losses.categorical_crossentropy(y_onehot, logits, from_logits=True)loss = tf.reduce_mean(loss)#计算梯度grads = tape.gradient(loss, resNet.trainable_variables)#更新参数optimizer.apply_gradients(zip(grads, resNet.trainable_variables))if (step % 100 == 0):print("Epoch[", epoch + 1, "/", num_epoches, "]: step - ", step, " loss:", float(loss))#进行验证total_samples = 0total_correct = 0for x,y in test_db:logits = resNet(x)prob = tf.nn.softmax(logits, axis=1)pred = tf.argmax(prob, axis=1)pred = tf.cast(pred, dtype=tf.int32)correct = tf.cast(tf.equal(pred, y), dtype=tf.int32)correct = tf.reduce_sum(correct)total_samples += x.shape[0]total_correct += int(correct)#统计准确率acc = total_correct / total_samplesprint("Epoch[", epoch + 1, "/", num_epoches, "]: accuracy:", acc)if __name__ == '__main__':main()

运行结果:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/5792.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

附录6-5 黑马优购项目-我的与后端本地化

目录 1 我的 2 后端本地化 1 我的 tarbar我的 只有这两个页面 其中未登录页面中只有一键登录有用,其他都是写死的,一键登录的功能仅仅是切换到登录的页面 目前微信小程序和微信用户的信息是脱钩的(之前的wx.getUserProfile与wx.getUs…

企业气候风险披露、报表词频、文本分析数据集合(2007-2022年)

01、数据介绍 企业气候风险披露是指企业通过一定的方式,将气候变化对其影响、自身采取的应对措施等信息披露出来。这有助于投资者更准确地评估企业价值,发现投资机会,规避投资风险。解企业在气候风险方面的关注度和披露情况。 可以帮助利益…

Django后台项目开发实战七

为后台管理系统换风格 第七阶段 安装皮肤包 pip install django-grappelli 在 setting.py 注册 INSTALLED_APPS [grappelli,django.contrib.admin,django.contrib.auth,django.contrib.contenttypes,django.contrib.sessions,django.contrib.messages,django.contrib.stat…

【yolov8】yolov8剪枝训练流程

yolov8剪枝训练流程 流程: 约束剪枝微调 一、正常训练 yolo train model./weights/yolov8s.pt datayolo_bvn.yaml epochs100 ampFalse projectprun nametrain二、约束训练 2.1 修改YOLOv8代码: ultralytics/yolo/engine/trainer.py 添加内容&#…

R语言4版本安装mvstats(纯新手)

首先下载mvstats.R文件 下载mvstats.R文件点此链接:https://download.csdn.net/download/m0_62110645/89251535 第一种方法 找到mvstats.R的文件安装位置(R语言的工作路径) getwd() 将mvstats.R保存到工作路径 在R中输入命令 source(&qu…

ctf web-部分

** web基础知识 ** *一.反序列化 在PHP中,反序列化通常是指将序列化后的字节转换回原始的PHP对象或数据结构的过程。PHP中的序列化和反序列化通过serialize()和unserialize()函数实现。 1.序列化serialize() 序列化说通俗点就是把一个对象变成可以传输的字符串…

创新指南|如何通过用户研究打造更好的人工智能产品

每个人都对人工智能感到兴奋,但对错过机会 (FOMO) 的恐惧正在驱使公司将人工智能嵌入到每个产品功能中。这可能会导致以技术为中心的方法,从而掩盖产品开发的基本目标:创建真正解决用户问题并满足他们需求的解决方案。本文将介绍通过用户研究…

HawkEye—高效、细粒度的大页管理算法

文章目录 HawkEye—高效、细粒度的大页管理算法1.作者简介2.文章简介与摘要3.简介(1).当时的SOTA系统概述LinuxFreeBSDIngensHawkEye 4.动机(1).地址翻译开销与内存膨胀(2).缺页中断延迟与缺页中断次数(3).多处理器大页面分配(4).如何测算地址翻译开销? 5.设计与实现…

大长案例 - 通用的三方接口调用方案设计

文章目录 引言身份验证防止重复提交数据完整性和加密回调地址安全事件响应可用性 设计方案概述1. API密钥生成2. 接口鉴权3. 回调地址设置4. 接口API设计 权限划分权限划分概述1. 应用ID(AppID)2. 应用公钥(AppKey)【(…

安装VMware Tools报错处理(SP1)

一、添加共享文件 因为没有VMware Tools,所以补丁只能通过共享文件夹进行传输了。直接在虚拟机的浏览器下载的话,自带的IE浏览器太老了,网站打不开,共享文件夹会方便一点,大家也可以用自己的方法,能顺利上…

【Go语言快速上手(六)】管道, 网络编程,反射,用法讲解

💓博主CSDN主页:杭电码农-NEO💓   ⏩专栏分类:Go语言专栏⏪   🚚代码仓库:NEO的学习日记🚚   🌹关注我🫵带你学习更多Go语言知识   🔝🔝 GO快速上手 1. 前言2. 初识管道3. 管…

清新优雅、功能强大的后台管理模板 | 开源日报 No.238

soybeanjs/soybean-admin Stars: 7.0k License: MIT soybean-admin 是一个基于 Vue3、Vite5、TypeScript、Pinia、NaiveUI 和 UnoCSS 的清新优雅且功能强大的后台管理模板。 使用最新流行的技术栈,如 Vue3、Vite5 和 TypeScript。采用清晰的项目架构,易…

Mac M2 本地下载 Xinference

想要在Mac M2 上部署一个本地的模型。看到了Xinference 这个工具 一、Xorbits Inference 是什么 Xorbits Inference(Xinference)是一个性能强大且功能全面的分布式推理框架。可用于大语言模型(LLM),语音识别模型&…

Kubernetes 弃用Docker后 Kubelet切换到Containerd

containerd 是一个高级容器运行时,又名 容器管理器。简单来说,它是一个守护进程,在单个主机上管理完整的容器生命周期:创建、启动、停止容器、拉取和存储镜像、配置挂载、网络等。 containerd 旨在轻松嵌入到更大的系统中。Docke…

screen服务使用解析

一、为什么要使用screen服务 当我们在进行一些常见的远程操作时,通常首先会先进行远程ssh登录 或者telnet连接到远程服务器上,然后执行相关操作,或程序启动等。 1、程序所需的执行时间过长,可能需要挂载几天的那种,可…

Linux(ubuntu)—— 用户管理user 用户组group

一、用户 1.1、查看所有用户 cat /etc/passwd 1.2、新增用户 useradd 命令,我这里用的是2.4的命令。 然后,需要设置密码 passwd student 只有root用户才能用passwd命令设置其他用户的密码,普通用户只能够设置自己的密码 二、组 2.1查看…

基于ROS从零开始构建自主移动机器人:仿真和硬件

书籍:Build Autonomous Mobile Robot from Scratch using ROS:Simulation and Hardware 作者:Rajesh Subramanian 出版:Apress 书籍下载-《基于ROS从零开始构建自主移动机器人:仿真和硬件》您将开始理解自主机器人发…

aic8800 linux

编译方法参考 http://t.csdnimg.cn/epR89 aic8800 源码在 github 里。同样需要 cfg80211 和 mac80211 aic_load_fw/aic_load_fw.ko aic8800_fdrv/aic8800_fdrv.ko都放到放 .ko 的地方 src/USB/driver_fw/drivers/aic8800 就是源码,没有蓝牙的型号不需要aic_btusb …

ip地址与硬件地址的区别是什么

在数字世界的浩瀚海洋中,每一台联网的设备都需要一个独特的标识来确保信息的准确传输。这些标识,我们通常称之为IP地址和硬件地址。虽然它们都是用来识别网络设备的,但各自扮演的角色和所处的层次却大相径庭。虎观代理小二将带您深入了解IP地…

6.k8s中的secrets资源

一、Secret secrets资源,类似于configmap资源,只是secrets资源是用来传递重要的信息的; secret资源就是将value的值使用base64编译后传输,当pod引用secret后,k8s会自动将其base64的编码,反编译回正常的字符…