深度学习笔记(七)——基于Iris/MNIST数据集构建基础的分类网络算法实战

文中程序以Tensorflow-2.6.0为例
部分概念包含笔者个人理解,如有遗漏或错误,欢迎评论或私信指正。
截图和程序部分引用自北京大学机器学习公开课

认识网络的构建结构

在神经网络的构建过程中,都避不开以下几个步骤:

  1. 导入网络和依赖模块
  2. 原始数据处理和清洗
  3. 加载训练和测试数据
  4. 构建网络结构,确定网络优化方法
  5. 将数据送入网络进行训练,同时判断预测效果
  6. 保存模型
  7. 部署算法,使用新的数据进行预测推理

使用Keras快速构建网络的必要API

在tensorflow2版本中将很多基础函数进行了二次封装,进一步急速了算法初期的构建实现。通过keras提供的很多高级API可以在较短的代码体量上实现网络功能。同时通过搭配tf中的基础功能函数可以实现各种不同类型的卷积和组合操作。正是这中高级API和底层元素及的操作大幅度的提升了tensorflow的自由程度和易用性。

常用网络

全连接层
tf.keras.layers.Dense(units=3, activation=tf.keras.activations.softmax, kernel_regularizer=tf.keras.regularizers.L2())

units:维数(神经元个数)
activation:激活函数,可选:relu softmax sigmoid tanh,这里记不住的话可以用tf.keras.activations.逐个查看
kernel_regularizer:正则化函数,同样的可以使用tf.keras.regularizers.逐个查看
全连接层是标准的神经元组成,更多被用在网络的后端或解码端(Decoder)用来输出预测数据。

拉伸层(维度展平)
tf.keras.layers.Flatten()

这个函数默认不需要输入参数,直接使用,它会将多维的数据按照每一行依次排开首尾连接变成一个一维的张量。通常在数据输入到全连接层之前使用。

卷积层
tf.keras.layers.Conv2D(filters=3, kernel_size=3, strides=1, padding='valid')

filters:卷积核个数
kernel_size:卷积核尺寸
strides:卷积核步长,卷积核是在原始数据上滑动遍历完成数据计算。
padding:可填 ‘valid’ ‘same’,是否使用全零填充,影响最后卷积结果的大小。
卷积一般被用来提取数据的数据特征。卷积最关键的就是卷积核个数和卷积核尺寸。假设输入一个1nn大小的张量,经过x个卷积核+步长为2+尺寸可以整除n的卷积层之后会输出一个x*(n/2)*(n/2)大小的张量。可以理解为卷积步长和卷积核大小影响输出张量的长宽,卷积核的大小影响输出张量的深度。

构建网络

使用Sequential构建简单网络,或者构建网络模块。列表中顺序包含网络的各个层。

tf.keras.models.Sequential([ ])

使用独立的class构建,这里定义一个类继承自 tensorflow.keras.Model 后面基本是标准结构>初始化相关参数>定义网络层>重写call函数定义前向传播层的连接顺序。后续随着使用的深入可以进一步的添加更多函数来实现不同类型的网络。

class mynnModel(Model):    # 继承from tensorflow.keras import Model 作为父类def __init__(self):super(IrisModel, self).__init__()   # 初始化父类的参数self.d1 = layers.Dense(units=3, activation=tf.keras.activations.softmax, kernel_regularizer=tf.keras.regularizers.L2())def call(self, input):  # 重写前向传播函数y = self.d1(input)return ymodel = IrisModel()

训练及其参数设置

设置训练参数
tensorflow.keras.Model.compile(optimizer=参数更新优化器,loss=损失函数metrics=准确率计算方式,即输出数据类型和标签数据类型如何对应)

具体参数可以看下面的内容:

optimizer:参数优化器 SGD:        tf.keras.optimizers.SGD(learning_rate=0.1,momentum=动量参数) learning_rate学习率,momentum动量参数AdaGrad:    tf.keras.optimizers.Adagrad(learning_rate=学习率)Adam:       tf.keras.optimizers.Adam(learning_rate=学习率 , beta_1=0.9, beta_2=0.999)
loss:损失函数MSE:        tf.keras.losses.MeanSquaredError()交叉熵损失: tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False) from_logits=true时输出值经过一次softmax概率归一化
metrics:准确率计算方式,就是输出数据类型和标签数据类型如何对应数值型(两个都是序列值):    'accuracy'都是独热码:    'categorical_accuracy'标签是数值,输出是独热码: 'sparse_categorical_accuracy'
训练
tensorflow.keras.Model.model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)

网络传入参数含义如下:

输入的数据依次为:输入训练特征数据,标签数据,单次输入数据量,迭代次数
validation_split=从训练集划分多少比例数据用来测试 /  validation_data=(测试特征数据,测试标签数据) 这两个参数智能二选一
validation_freq=多少次epoch测试一次
输出网络信息
tensorflow.keras.Model.model.summary()

上面这个函数可以在训练结束或者训练开始之前输出一次网络的结构信息用于确认。

实际应用展示

环境

软件环境的配置可以查看环境配置流程说明

cuda = 11.8	# CUDA也可以使用11.2版本
python=3.7
numpy==1.19.5
matplotlib== 3.5.3
notebook==6.4.12
scikit-learn==1.2.0
tensorflow==2.6.0
keras==2.6.0
使用iris数据集构建基础的分类网络
import tensorflow as tf
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)model = tf.keras.models.Sequential([ tf.keras.layers.Dense(3, activation='softmax',kernel_regularizer=tf.keras.regularizers.l2())])
model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary( )

通过上面这样几行简单的代码,我们实现了对iris数据的分类训练。在上面的代码中使用了Sequential函数来构建网络。

使用MNIST数据集设计分类网络

在开始下面的代码之前,要先下载对应的数据 https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 复制这段网址在浏览器打开会直接下载数据,然后将下载好的mnist.npz复制到一个新的路径下,然后在tf.keras.datasets.mnist.load_data(path=‘you file path ’)代码中的这行里修改为你的路径,注意要使用绝对路径

import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras import layers
from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(path='E:\Tensorflow\data\mnist.npz') # 注意替换自己的使用绝对路径
x_train, x_test = x_train/255.0, x_test/255.0	# 图像数据归一化
print('训练集样本的大小:', x_train.shape)
print('训练集标签的大小:', y_train.shape)
print('测试集样本的大小:', x_test.shape)
print('测试集标签的大小:', y_test.shape)
#可视化样本,下面是输出了训练集中前20个样本
fig, ax = plt.subplots(nrows=4,ncols=5,sharex='all',sharey='all')
ax = ax.flatten()
for i in range(20):img = x_train[i].reshape(28, 28)ax[i].imshow(img,cmap='Greys')
ax[0].set_xticks([])
ax[0].set_yticks([])
plt.tight_layout()
plt.show()
# 定义网络结构
class mnisModel(Model):def __init__(self, *args, **kwargs):super(mnisModel, self).__init__(*args, **kwargs)self.flatten1=layers.Flatten()self.d1=layers.Dense(128, activation=tf.keras.activations.relu)self.d2=layers.Dense(10, activation=tf.keras.activations.softmax)def call(self, input):x = self.flatten1(input)x = self.d1(x)x = self.d2(x)return(x)
model = mnisModel()
#设置训练参数
model.compile(optimizer='adam',     # 'adam'  tf.keras.optimizers.Adam(learning_rate=0.4 , beta_1=0.9, beta_2=0.999)loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])
# 训练
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data = (x_test, y_test), validation_freq=1)
model.summary()

运行后会先显示数据集中的前二十个数字
在这里插入图片描述
关闭数字展示窗口后开始训练,并看到训练的过程
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/625890.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Android+物联网】Android封装MQTT连接阿里云物联网平台

前言: 亲测可行,本文实现Android封装MQTT连接阿里云物联网平台。将MQTT协议和连接阿里云平台的操作通过Android studio写入APP中,并简单设计UI。实现手机APP远程控制单片机LED灯亮灭的功能。 关于《Android软件开发》,见如下专栏…

手拉手Vue3生命周期实战应用

每个 Vue 组件实例在创建时都需要经历一系列的初始化步骤,比如设置好数据侦听,编译模板,挂载实例到 DOM,以及在数据改变时更新 DOM。在此过程中,它也会运行被称为生命周期钩子的函数,让开发者有机会在特定阶…

SpringAMQP的使用

1. 简介: SpringAMQP是基于RabbitMQ封装的一套模板,并且还利用SpringBoot对其实现了自动装配,使用起来非常方便。 SpringAmqp的官方地址:https://spring.io/projects/spring-amqp SpringAMQP提供了三个功能: 自动声…

【linux】查看Debian应用程序图标对应的可执行命令

在Debian系统中,应用程序图标通常与.desktop文件关联。您可以通过查看.desktop文件来找到对应的可执行命令。这些文件通常位于/usr/share/applications/或~/.local/share/applications/目录下。这里是如何查找的步骤: 1. 打开文件管理器或终端。 2. 导…

20240115如何在线识别俄语字幕?

20240115如何在线识别俄语字幕? 2024/1/15 21:25 百度搜索:俄罗斯语 音频 在线识别 字幕 Bilibili:俄语AI字幕识别 音视频转文字 字幕小工具V1.2 BING:音视频转文字 字幕小工具V1.2 https://www.bilibili.com/video/BV1d34y1F7…

<Linux> 进程间通信

目录 前言: 一、进程间通信 (一)进程间通信目的 (二)进程通信的要求 (三)进程间通信分类 二、管道 (一)什么是管道 (二)基本原理 &#…

会声会影2024什么时间发布呢?会声会影2024会有那些新功能

近年来,随着科技的不断进步,各种软件的功能越来越强大,其中最为常用的莫过于视频编辑软件。而会声会影作为一款颇受欢迎的视频编辑软件,备受用户关注。那么,会声会影2024什么时间发布呢? 首先,我…

Java 使用 EasyExcel 爬取数据

一、爬取数据的基本思路 分析要爬取数据的来源 1. 查找数据来源:浏览器按 F12 或右键单击“检查”打开开发者工具查看数据获取时的请求地址 2. 查看接口信息:复制请求地址直接到浏览器地址栏输入看能不能取到数据 3. 推荐安装插件:FeHelper&a…

搭建知识付费小程序平台:如何避免被坑,选择最佳方案?

随着知识经济的兴起,知识付费已经成为一种趋势。越来越多的人开始将自己的知识和技能进行变现,而知识付费小程序平台则成为了一个重要的渠道。然而,市面上的知识付费小程序平台琳琅满目,其中不乏一些不良平台,让老实人…

高可用架构去中心化重要?

1 背景 在互联网高可用架构设计中,应该避免将所有的控制权都集中到一个中心服务,即便这个中心服务是多副本模式。 对某个中心服务(组件)的过渡强依赖,那等同于把命脉掌握在依赖方手里,依赖方的任何问题都可…

个性化定制的知识付费小程序,为用户提供个性化的知识服务

明理信息科技知识付费saas租户平台 随着知识经济的兴起,越来越多的人开始重视知识付费,并希望通过打造自己的知识付费平台来实现自己的知识变现。本文将介绍如何打造自己的知识付费平台,并从定位、内容制作、渠道推广、运营维护四个方面进行…

如何保证Kafka不丢失消息

丢失消息有 3 种不同的情况,针对每一种情况有不同的解决方案。 生产者丢失消息的情况消费者丢失消息的情况Kafka 弄丢了消息 生产者丢失消息的情况 生产者(Producer) 调用send方法发送消息之后,消息可能因为网络问题并没有发送过去。所以,我们…

@Controller层自定义注解拦截request请求校验

一、背景 笔者工作中遇到一个需求,需要开发一个注解,放在controller层的类或者方法上,用以校验请求参数中(不管是url还是body体内,都要检查,有token参数,且符合校验规则就放行)是否传了一个token的参数&am…

从车联网到智慧城市:智慧交通的革新之路

一、引言 1、智慧城市的概念和发展背景 智慧城市(Smart City)是指以信息技术为基础,运用信息与通信等手段,对城市各个核心系统各项关键数据进行感测、分析、整合和利用,实现对城市生活环境的感知、资源的调控&#x…

Linux下的HTTPS配置:从证书到安全连接

在当今的互联网环境中,数据传输的安全性越来越受到重视。HTTPS,作为HTTP的安全版本,通过使用SSL/TLS协议来加密数据传输,确保了数据在传输过程中的安全。在Linux环境下,配置HTTPS需要从证书的生成到服务器的配置进行一…

用 YAML 文件配置 CI/CD 管道

MSBuild 参数: 在使用 MSBuild 命令行生成打包项目(就像在 Visual Studio 中使用向导生成项目一样)之前,生成过程可以通过编辑 Package.appxmanifest 文件中 Package 元素的 Version 属性,来对生成的 MSIX 包进行版本控制。 在 Azure Pipelines 中,可以使用某个表达式来…

内网yum仓库 ftp;http方式

ftp方式 服务端 客户端 vim /etc/yum.repos.d/ftp.repo http方式 服务端 yum install httpd -y systemctl start httpd cd /var/www/html/ mkdir centos7 mount /dev/sr0 /var/www/html/centos7 客户端

pip与pip3的区别

pip 和 pip3 都是 Python 的包管理工具,用于安装第三方库。它们的区别在于: pip 是 Python 2 和 Python 3 通用的包管理工具,它可以安装适用于 Python 2 和 Python 3 的库。pip3 是专门用于 Python 3 的包管理工具,它只能安装适用…

CAN-位填充

位填充定义(Bit Stuffing) 当CAN节点发送 逻辑电平(显性dominant或隐性recessive)为持续相同的5位时,它必须添加一位反向电平。 CAN接收 节点会自动删除这个新增的额外电平位。 位填充作用 1---位填充是为了防止突发…

【2024-01-15】某安居客验证码分析-滑块验证码

声明:该专栏涉及的所有案例均为学习使用,严禁用于商业用途和非法用途,否则由此产生的一切后果均与作者无关!如有侵权,请私信联系本人删帖! 文章目录 一、抓包分析二、参数分析1.请求getInfoTp2.校验checkInfoTp一、抓包分析 网址: aHR0cHM6Ly9hcGkuYW5qdWtlLmNvbS93ZWI…