【深度学习_TensorFlow】自定义层实现cifar10

写在前面

尽管 tf.keras 提供了很多的常用网络层类,但深度学习可以使用的网络层远远不止这些。科研工作者一般是自行实现了较为新颖的网络层,经过大量实验验证有效后,深度学习框架才会跟进,内置对这些网络层的支持。因此掌握自定义网络层、网络的实现非常重要。


写在中间

1. 初始化方法

对于自定义的网络层,我们至少需要实现初始化__init__方法和前向传播逻辑call方法。我们以具体的自定义网络层为例,假设需要一个没有偏置向量的全连接层,即 bias 为0,同时固定激活函数为 ReLU 函数。我们通过实现这个“特别的”网络层类来阐述如何实现自定义网络层。

首先创建类,并继承自 Layer 基类。

  • 创建初始化方法,并调用母类的初始化函数,由于是全连接层,因此需要设置两个参数:输入特征的长度 inp_dim 和输出特征的长度outp_dim,并通过 self.add_variable(name, shape)创建 shape 大小,名字为 name 的张量𝑾,并设置为需要优化。
class MyDense(layers.Layer):def __init__(self, inp_dim, outp_dim):super(MyDense, self).__init__()# 添加权重,用于线性变换self.kernel = self.add_weight('w', [inp_dim, outp_dim])# 添加偏置项self.bias = self.add_weight('b', [outp_dim])
  • 完成自定义类的初始化工作后,我们来设计自定义类的前向运算逻辑,对于这个例子,只需要完成𝑶 = 𝑿@𝑾矩阵运算,并通过固定的 ReLU 激活函数即可;

  • 自定义类的前向运算逻辑实现在 call(inputs, training=None)函数中,其中 inputs代表输入,由用户在调用时传入;training 参数用于指定模型的状态:training 为 True 时执行训练模式,training 为 False 时执行测试模式,默认参数为 None,即测试模式。

 def call(self, inputs, training=None): # 实现自定义类的前向计算逻辑 # X@W out = inputs @ self.kernel # 执行激活函数运算 out = tf.nn.relu(out) return out  

2. 自定义网络

Sequential 容器适合于数据按序从第一层传播到第二层,再从第二层传播到第三层,以此规律传播的网络模型。对于复杂的网络结构,例如第三层的输入不仅是第二层的输出,还有第一层的输出,此时使用自定义网络更加灵活。下面我们来创建自定义网络类

  • 首先创建类,并继承自 Model 基类:

  • 接着在类里创建对应的网络层对象。

  • 然后实现自定义网络的前向运算逻辑。

class MyModel(tf.keras.Model):def __init__(self):super(MyModel, self).__init__()# 添加自定义的全连接层作为网络的组成部分self.fc1 = MyDense(28 * 28, 256)  # 输入维度为28*28,输出维度为256self.fc2 = MyDense(256, 128)  # 输入维度为256,输出维度为128self.fc3 = MyDense(128, 64)  # 输入维度为128,输出维度为64self.fc4 = MyDense(64, 32)  # 输入维度为64,输出维度为32self.fc5 = MyDense(32, 10)  # 输入维度为32,输出维度为10def call(self, inputs, training=None):"""前向传播函数Args:inputs (tf.Tensor): 输入张量training (bool, 可选): 是否处于训练模式Returns:tf.Tensor: 输出张量"""x = self.fc1(inputs)  # 经过第一层全连接层x = tf.nn.relu(x)  # ReLU激活函数处理x = self.fc2(x)  # 经过第二层全连接层x = tf.nn.relu(x)  x = self.fc3(x)  # 经过第三层全连接层x = tf.nn.relu(x)x = self.fc4(x)  # 经过第四层全连接层x = tf.nn.relu(x)x = self.fc5(x)  # 经过第五层全连接层,无激活函数处理,直接输出结果return x

只学理论可不行,我们还要学会实际应用,接下来我们就使用cifar10数据集来实战

3. cifar10实战

CIFAR-10是一个包含60000张32x32 RGB彩色图片的数据集,被用于物体识别。数据集包含10个类别的图片,每个类别有6000张。这些类别包括飞机、汽车、鸟类、猫、鹿、狗、蛙类、马、船和卡车。其中,50000张图片用于训练集,10000张图片用于测试集。与MNIST数据集相比,CIFAR-10数据集的图像具有更高的色彩分辨率和更复杂的物体形状,因此对物体识别的挑战更大。

注意:由于我们的图像的尺寸很小,图案本身就不清楚,加之网络简单,所以训练出的模型准确率不高是正常现象。

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics# 图像预处理函数
def preprocess(x, y):# 一张一张的传入图像数据# 将图片数值范围调整到[-1, 1]x = 2 * tf.cast(x, dtype=tf.float32) / 255. - 1.# 标签数据y = tf.squeeze(y)   # 删除大小为 1 的维度y = tf.cast(y, dtype=tf.int32)    # 将标签转换为int32类型y = tf.one_hot(y, depth=10)    # 将标签转为one-hot编码return x, y# 加载CIFAR10数据集,分为训练数据和测试数据
(x, y), (x_test, y_test) = datasets.cifar10.load_data()# 打印数据集信息
print('datasets:', x.shape, y.shape, x_test.shape, y_test.shape)
print('图片像素范围:', x.min(), x.max())# 处理训练集
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(preprocess).shuffle(10000).batch(128)# 处理测试集
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).batch(128)# 取一个batch并打印shape
sample = next(iter(train_db))
print('batch:', sample[0].shape, sample[1].shape)# 自定义全连接层
class MyDense(layers.Layer):# 初始化函数def __init__(self, inp_dim, outp_dim):super(MyDense, self).__init__()# 添加一个kernel变量,作为自定义层的权重矩阵# inp_dim为输入特征维度,outp_dim为输出特征维度self.kernel = self.add_weight('w', [inp_dim, outp_dim])# 前向计算函数def call(self, inputs, training=None):# inputs为输入张量# self.kernel为自定义层的权重矩阵# 使用矩阵乘法计算线性变换x = inputs @ self.kernel# 返回计算结果return x# 自定义模型类,继承自tf.keras.Model
class MyNetwork(tf.keras.Model):# 初始化函数def __init__(self):super(MyNetwork, self).__init__()# 定义全连接层,层间节点数逐步减小self.fc1 = MyDense(32 * 32 * 3, 256)self.fc2 = MyDense(256, 128)self.fc3 = MyDense(128, 64)self.fc4 = MyDense(64, 32)self.fc5 = MyDense(32, 10)# 前向计算函数def call(self, inputs, training=None):# 将输入reshape为一维向量x = tf.reshape(inputs, [-1, 32 * 32 * 3])# 通过自定义的全连接层计算x = self.fc1(x)x = tf.nn.relu(x)x = self.fc2(x)x = tf.nn.relu(x)x = self.fc3(x)x = tf.nn.relu(x)x = self.fc4(x)x = tf.nn.relu(x)x = self.fc5(x)# 返回最终结果return x# 构建网络
network = MyNetwork()
network.compile(optimizer=optimizers.Adam(learning_rate=1e-3),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])# 训练
network.fit(train_db, epochs=15, validation_data=test_db, validation_freq=1)# 评估
print('测试集评估...')
network.evaluate(test_db)
# 保存权重
network.save_weights('ckpt/weights.ckpt')
print('保存权重...')# 恢复权重
network = MyNetwork()
network.compile(optimizer=optimizers.Adam(lr=1e-3),loss=tf.losses.CategoricalCrossentropy(from_logits=True),metrics=['accuracy'])network.load_weights('ckpt/weights.ckpt')
print('读取权重,重新评估...')
network.evaluate(test_db)

写在最后

👍🏻点赞,你的认可是我创作的动力!
⭐收藏,你的青睐是我努力的方向!
✏️评论,你的意见是我进步的财富!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/36998.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

学习笔记整理-面向对象-01-认识对象

一、认识对象 1. 对象 对象(object)是键值对的集合,表示属性和值的映射关系。 对象的语法 k和v之间用冒号分割,每组k:v之间用逗号分割,最后一个k:v对后可以不书写逗号。 属性是否加引号 如果对象的属性键名不符合命名规范,则这…

数组slice、splice字符串substr、split

一、定义 这篇文章主要对数组操作的两种方法进行介绍和使用,包括:slice、splice。对字符串操作的两种方法进行介绍和使用,包括:substr、split (一)、数组 slice:可以操作的数据类型有:数组字符串 splice:数组 操作数组…

一个基础但全面的Vue的表单范例,很基础,但是很容易,也很全面。

下面这个案例,路人朋友们可以直接粘贴到html文件类型中运行,注意引入Vuejs的路径即可,不会改的可以参考我第一篇Vue入门,同时建议同志们手打,真的前端都不能熟能生巧,既不要编程了, 可以详细看注…

计算机网络-物理层(一)物理层的概念与传输媒体

计算机网络-物理层(一)物理层的概念与传输媒体 物理层相关概念 物理层的作用用来解决在各种传输媒体上传输比特0和1的问题,进而为数据链路层提供透明(看不见)传输比特流的服务物理层为数据链路层屏蔽了各种传输媒体的差异,使数据…

最新Kali Linux安装教程:从零开始打造网络安全之旅

Kali Linux,全称为Kali Linux Distribution,是一个操作系统(2013-03-13诞生),是一款基于Debian的Linux发行版,基于包含了约600个安全工具,省去了繁琐的安装、编译、配置、更新步骤,为所有工具运行提供了一个…

[低端局][cx32L003] 移植U8G2

文章目录 一、简介(1)U8g2(2)U8x8 二、配置要求三、移植步骤(1)文件准备和添加(2)实现回调接口(I2C的读写函数)①软件I2C②硬件I2C (3)功能裁剪① u8g2_d_set…

Gof23设计模式之模板方法模式

1.定义 定义一个操作中的算法骨架,而将算法的一些步骤延迟到子类中,使得子类可以不改变该算法结构的情况下重定义该算法的某些特定步骤。 2.结构 模板方法(Template Method)模式包含以下主要角色: 抽象类&#xff0…

Kerberos 重新认识 From Oracle安全

参考 https://docs.oracle.com/cd/E24847_01/html/819-7061/seamtm-1.html#scrolltoc Kerberos服务 Kerberos服务是一种网络身份认证协议,由麻省理工学院(MIT)开发。它提供了强大的身份验证功能,用于在计算机网络中验证用户和服务…

买爱心气球(nim博弈)

链接:登录—专业IT笔试面试备考平台_牛客网 来源:牛客网 Alice 和 Bob 是一对竞技编程选手,他们路过了一家气球店,发现有 m 个大爱心气球和 n个小爱心气球。他们决定玩一个游戏,游戏规则如下: Alice先手拿…

Python Selenium 设置带账号密码的socks5代理,启动浏览器

selenium添加带有账密的socks5代理 我们都知道在使用selenium开发爬虫的时候不可避免的会使用socks5高匿名代理。一般情况下我们使用方法如下(开发语言为python): from selenium import webdriver chrome_options webdriver.ChromeOptions() chrome_options.add_…

Java并发之ReentrantLock

AQS AQS(AbstractQueuedSynchronizer):抽象队列同步器,是一种用来构建锁和同步器的框架。在是JUC下一个重要的并发类,例如:ReentrantLock、Semaphore、CountDownLatch、LimitLatch等并发都是由AQS衍生出来…

React Native Expo项目,复制文本到剪切板

装包: npx expo install expo-clipboard import * as Clipboard from expo-clipboardconst handleCopy async (text) > {await Clipboard.setStringAsync(text)Toast.show(复制成功, {duration: 3000,position: Toast.positions.CENTER,})} 参考链接&#xff1a…

3.文件目录

第四章 文件管理 3.文件目录 ​   对于D盘这个根目录来说它对应的目录文件就是图中的样子,其实就是用一个所谓的目录表来表示这个目录下面存放了哪些东西。在D盘中的每一个文件,每一个文件夹都会对应这个目录表中的一个表项,所以其实这些一…

如何写简历?

写程序员简历时,可以从以下几个方面入手: 1. 个人信息:在简历的开头,包含个人基本信息如姓名、联系方式、地址等。 2. 求职目标/职业目标:明确自己希望得到的职位或行业,并简要描述为什么适合该职位。 3…

Autoware感知02—欧氏聚类(lidar_euclidean_cluster_detect)源码解析

文章目录 引言一、点云回调函数:二、预处理(1)裁剪距离雷达过于近的点云,消除车身的影响(2)点云降采样(体素滤波,默认也是不需要的)(3)裁剪雷达高…

【概念篇】文件概述

✅作者简介:大家好,我是小杨 📃个人主页:「小杨」的csdn博客 🐳希望大家多多支持🥰一起进步呀! 文件概述 1,文件的概念 狭义上的文件是计算机系统中用于存储和组织数据的一种数据存…

React源码解析18(5)------ 实现函数组件【修改beginWork和completeWork】

摘要 经过之前的几篇文章,我们实现了基本的jsx,在页面渲染的过程。但是如果是通过函数组件写出来的组件,还是不能渲染到页面上的。 所以这一篇,主要是对之前写得方法进行修改,从而能够显示函数组件,所以现…

【深度学习】NLP中的对抗训练

在NLP中,对抗训练往往都是针对嵌入层(包括词嵌入,位置嵌入,segment嵌入等等)开展的,思想很简单,即针对嵌入层添加干扰,从而提高模型的鲁棒性和泛化能力,下面结合具体代码…

Spark 学习记录

基础 SparkContext是什么?有什么作用? https://blog.csdn.net/Shockang/article/details/118344357 SparkContext 是什么? SparkContext 是通往 Spark 集群的唯一入口,可以用来在 Spark 集群中创建 RDDs 、累加和广播变量( Br…

【数据库基础】Mysql下载安装及配置

下载 下载地址:https://downloads.mysql.com/archives/community/ 当前最新版本为 8.0版本,可以在Product Version中选择指定版本,在Operating System中选择安装平台,如下 安装 MySQL安装文件分两种 .msi和.zip [外链图片转存失…