04 卷积神经网络搭建

一、数据集

MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的[参考]。

MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的灰度图像,每张图像包含一个手写数字。
 

 1.1 准备数据

将数据集分为训练集、验证集和测试集 

#训练集有60000张图片,前5000张图片作为验证集,后55000作为训练集

1. `x_train_all` 和 `y_train_all`:
   - `x_train_all` 包含了完整的训练数据集的图像数据,这些图像用于训练深度学习模型。
   - `y_train_all` 包含了完整的训练数据集的标签,即与 `x_train_all` 中的图像相对应的类别标签。

2. `x_test` 和 `y_test`:
   - `x_test` 包含了测试数据集的图像数据,这些图像用于评估深度学习模型的性能。
   - `y_test` 包含了测试数据集的标签,即与 `x_test` 中的图像相对应的类别标签。

from tensorflow import keras
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as pltfashion_mnist = keras.datasets.fashion_mnist
(x_train_all, y_train_all), (x_test, y_test) = fashion_mnist.load_data()
x_valid, x_train = x_train_all[:5000], x_train_all[5000:]
y_valid, y_train = y_train_all[:5000], y_train_all[5000:]print(x_valid.shape, y_valid.shape)
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)

1.2 标准化

# 标准化
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28, 1)
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28, 1)x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1, 1)).reshape(-1, 28, 28, 1)

1.3 数据集

def make_dataset(data, target, epochs, batch_size, shuffle=True):dataset = tf.data.Dataset.from_tensor_slices((data, target))if shuffle:dataset = dataset.shuffle(10000)dataset = dataset.repeat(epochs).batch(batch_size).prefetch(50)return datasetbatch_size = 64
epochs = 20
train_dataset = make_dataset(x_train_scaled, y_train, epochs, batch_size)

1.4 搭建模型

model = keras.models.Sequential()
# 卷积
model.add(keras.layers.Conv2D(filters = 32, kernel_size = 3, padding = 'same',activation='relu',# batch_size, height, width, channelsinput_shape=(28, 28, 1))) # (28, 28, 32)model.add(keras.layers.Conv2D(filters = 32, kernel_size = 3, padding = 'same',activation='relu'))
# 池化
model.add(keras.layers.MaxPool2D()) # (14, 14, 32)model.add(keras.layers.Conv2D(filters = 64, kernel_size = 3, padding = 'same',activation='relu')) # (14, 14, 64)
model.add(keras.layers.Conv2D(filters = 64, kernel_size = 3, padding = 'same',activation='relu'))
# 池化
model.add(keras.layers.MaxPool2D()) # (7, 7, 64)
model.add(keras.layers.Conv2D(filters = 128, kernel_size = 3, padding = 'same',activation='relu')) # (7, 7, 128)
model.add(keras.layers.Conv2D(filters = 128, kernel_size = 3, padding = 'same',activation='relu')) # (7, 7, 128)
# 池化, 向下取整
model.add(keras.layers.MaxPooling2D()) # (3, 3, 128)model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(512, activation='relu'))
model.add(keras.layers.Dense(256, activation='relu'))
model.add(keras.layers.Dense(10, activation='softmax'))model.compile(loss='sparse_categorical_crossentropy',optimizer='adam',metrics=['accuracy'])print(model.summary())

keras.models.Sequential()`是使用 Keras 构建神经网络模型的开始。`keras.models.Sequential()` 创建了一个 Sequential 模型对象,这是 Keras 中的一种常见模型类型。Sequential 模型是一个线性的、层叠的神经网络模型,适用于顺序层的堆叠,其中每一层都是依次添加到模型中

一旦创建了 Sequential 模型,你可以使用 `.add()` 方法来逐层添加神经网络层,从输入层到输出层。每个层都可以通过实例化 Keras 中的层类来创建,例如

`keras.layers.Dense` 用于全连接层,

`keras.layers.Conv2D` 用于卷积层等。

以下是一个简单的例子,演示如何使用 Sequential 模型创建一个简单的前馈神经网络

from tensorflow import keras# 创建一个 Sequential 模型
model = keras.models.Sequential()# 添加输入层和第一个隐藏层
model.add(keras.layers.Input(shape=(input_shape,)))  # 输入层,input_shape 根据你的数据维度定义
model.add(keras.layers.Dense(units=128, activation='relu'))  # 隐藏层1# 添加第二个隐藏层
model.add(keras.layers.Dense(units=64, activation='relu'))  # 隐藏层2# 添加输出层
model.add(keras.layers.Dense(units=num_classes, activation='softmax'))  # 输出层,num_classes 是输出类别的数量# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

上面的代码创建了一个具有两个隐藏层和一个输出层的前馈神经网络模型,并使用了 ReLU 激活函数和 softmax 激活函数。这只是一个简单的示例,你可以根据你的任务和数据来构建更复杂的模型。一旦模型构建完成,你可以使用 `.fit()` 方法来训练模型。

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

是 Keras 中用于编译深度学习模型的代码,它设置了模型的损失函数、优化器和评估指标。让我为你解释每个参数的含义:

- `loss='sparse_categorical_crossentropy'`: 这里设置了模型的损失函数。`sparse_categorical_crossentropy` 是一种用于多类别分类问题的损失函数。它适用于目标变量是整数形式(类别标签)的情况,而不需要将目标变量进行独热编码(one-hot encoding)。模型的目标是最小化这个损失函数,从而使预测结果尽可能接近真实标签。

- `optimizer='adam'`: 这里设置了优化器,用于模型的参数更新。`'adam'` 是一种常用的优化算法,它基于梯度下降的方法,具有自适应学习率和动量的特性,通常在深度学习中表现良好。优化器的作用是最小化损失函数,从而调整模型的权重和参数以使模型更好地拟合数据。

- `metrics=['accuracy']`: 这里设置了评估指标,用于在模型训练期间监测模型性能。

`['accuracy']` 表示模型在训练期间将计算并输出准确度(accuracy),即正确分类的样本数与总样本数的比率。准确度通常用于分类问题的性能评估。

一旦模型编译完成,你可以使用 `model.fit()` 方法来训练模型,该方法会使用上述设置的损失函数、优化器和评估指标来进行训练。例如:

```python
model.fit(x_train, y_train, epochs=10, validation_data=(x_valid, y_valid))
```

1.5 train

eval_dataset = make_dataset(x_valid_scaled, y_valid, epochs=1, batch_size=32, shuffle=False)history = model.fit(train_dataset, steps_per_epoch=x_train_scaled.shape[0] // batch_size,epochs=10,validation_data=eval_dataset)

 

1.6 模型评估

model.evaluate(eval_dataset)

 

1.7 可视化

def plot_learning_curves(history):pd.DataFrame(history.history).plot(figsize=(8, 5))plt.grid(True)plt.gca().set_ylim(0, 1)plt.show()plot_learning_curves(history)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/72094.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

deepstream6.2部署yolov5详细教程与代码解读

文章目录 引言一.环境安装1、yolov5环境安装2、deepstream环境安装 二、源码文件说明三.wts与cfg生成1、获得wts与cfg2、修改wts 四.libnvdsinfer_custom_impl_Yolo.so库生成五.修改配置文件六.运行demo 引言 DeepStream 是使用开源 GStreamer 框架构建的优化图形架构&#xf…

cesium创建基本的实体、点、线、多边形(vue)

1.通过viewer实例的entities对象实现 实现代码&#xff1a; <template><div id"container"></div> </template><script> import * as Cesium from cesium/Cesium import "cesium/Widgets/widgets.css" export default {mo…

LeetCode刷题笔记【25】:贪心算法专题-3(K次取反后最大化的数组和、加油站、分发糖果)

文章目录 前置知识1005.K次取反后最大化的数组和题目描述分情况讨论贪心算法 134. 加油站题目描述暴力解法贪心算法 135. 分发糖果题目描述暴力解法贪心算法 总结 前置知识 参考前文 参考文章&#xff1a; LeetCode刷题笔记【23】&#xff1a;贪心算法专题-1&#xff08;分发饼…

gRPC远程进程调用

gRPC远程进程调用 rpc简介golang实现rpc方法一net/rpc库golang实现rpc方法二jsonrpc库grpc和protobuf在一起第一个grpc应用grpc服务的定义和服务的种类grpc stream实例1-服务端单向流grpc stream实例2-客户端单向流grpc stream实例3-双向流grpc整合gin

【2023高教社杯】C题 蔬菜类商品的自动定价与补货决策 问题分析、数学模型及python代码实现

【2023高教社杯】C题 蔬菜类商品的自动定价与补货决策 1 题目 C题蔬菜类商品的自动定价与补货决策 在生鲜商超中&#xff0c;一般蔬菜类商品的保鲜期都比较短&#xff0c;且品相随销售时间的增加而变差&#xff0c; 大部分品种如当日未售出&#xff0c;隔日就无法再售。因此&…

已经2023年了,你还不会手撕轮播图?

目录 一、前言二、动画基础1. 定时器2. left与offsetLeft3. 封装函数3.1 物体3.2 目标点3.3 回调函数 4.封装 三、基础结构3.1 焦点图3.2 按钮3.3 小圆点3.4 总结 四、按钮显示五、圆点5.1 生成5.2 属性5.3 移动 六、按钮6.1 准备6.2 出错6.2.1 小圆点跟随6.2.2 图片返回 6.3 b…

BLE架构与开源协议栈

BLE架构&#xff1a; 简单来说&#xff0c;BLE协议栈可以分成三个部分&#xff0c;主机(host)程序&#xff0c;控制器(controller)程序&#xff0c;主机控制器接口(HCI)。如果再加上底层射频硬件和顶层用户程序&#xff0c;则构成了完整的BLE协议&#xff0c;如下图所示&#…

ModuleNotFoundError: No module named ‘lavis‘解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

c语言实训心得3篇集合

c语言实训心得体会一&#xff1a; 在这个星期里&#xff0c;我们专业的学生在专业老师的带领下进行了c语言程序实践学习。在这之前&#xff0c;我们已经对c语言这门课程学习了一个学期&#xff0c;对其有了一定的了解&#xff0c;但是也仅仅是停留在了解的范围&#xff0c;对里…

第十八课、Qt 下载、安装与配置

功能描述&#xff1a;介绍了 Qt 的下载、安装和配置的全部过程&#xff0c;并对关键页面选项进行了详细说明 一、Qt 的下载 Qt 官方下载地址&#xff1a;https://www.qt.io/zh-cn/downloadhttps://download.qt.io/https://download.qt.io/https://www.qt.io/zh-cn/download进入…

GptFuck—开源Gpt4分享

这个项目不错&#xff0c;分享给大家 项目地址传送门

深入探索KVM虚拟化技术:全面掌握虚拟机的创建与管理

文章目录 安装KVM开启cpu虚拟化安装KVM检查环境是否正常 KVM图形化创建虚拟机上传ISO创建虚拟机加载镜像配置内存添加磁盘能否手工指定存储路径呢&#xff1f;创建成功安装完成查看虚拟机 KVM命令行创建虚拟机创建磁盘通过命令行创建虚拟机手动安装虚拟机 KVM命令行创建虚拟机-…

数据集笔记:GeoLife GPS 数据 (user guide)

数据链接&#xff1a;https://www.microsoft.com/en-us/download/details.aspx?id52367 1 数据基本信息 1.1 数据介绍 182名用户在超过三年的时间内&#xff08;从2007年4月到2012年8月&#xff09;在&#xff08;微软亚洲研究院&#xff09;Geolife项目中收集的。该数据集…

使用SpringCloud Eureka 搭建EurekaServer 集群- 实现负载均衡故障容错【上】

&#x1f600;前言 本篇博文是关于使用SpringCloud Eureka 搭建EurekaServer 集群- 实现负载均衡&故障容错&#xff0c;希望你能够喜欢 &#x1f3e0;个人主页&#xff1a;晨犀主页 &#x1f9d1;个人简介&#xff1a;大家好&#xff0c;我是晨犀&#xff0c;希望我的文章可…

865. 具有所有最深节点的最小子树(javascript)865. Smallest Subtree with all the Deepest Nodes

给定一个根为 root 的二叉树&#xff0c;每个节点的深度是 该节点到根的最短距离 。 返回包含原始树中所有 最深节点 的 最小子树 。 如果一个节点在 整个树 的任意节点之间具有最大的深度&#xff0c;则该节点是 最深的 。 一个节点的 子树 是该节点加上它的所有后代的集合…

【计算机网络】 IP协议格式以及以太网帧结构

文章目录 IP协议格式以太网帧结构 IP协议格式 IP工作在网络层 IP头分为两部分&#xff0c;固定部分和可变部分&#xff0c;固定部分就是一定要带这些数据&#xff0c;正常存储应该是连续的&#xff0c;并不是像图中这样会换行&#xff0c;图中只是为了方便观察。 首先是一个版…

山洪、地质灾害监测利器-泥石流、山体滑坡AI视觉仪

1、设备介绍 AI视觉仪通过AI算法智能化摄像机&#xff0c;能够及时、全面的把握边坡潜在安全风险&#xff0c;有效防范自然灾害。支持全天候运行&#xff0c;在恶劣环境及气候条件下仍能正常进行监测数据采集。自动识别监控区域内是否有泥石流、山体滑坡等&#xff0c;一旦检测…

uni-app:重置表单数据

效果 代码 <template><form><input type"text" v-model"inputValue" placeholder"请输入信息"/><input type"text" v-model"inputValue1" placeholder"请输入信息"/><input type&quo…

排序之插入排序

文章目录 前言一、直接插入排序1、基本思想2、直接插入排序的代码实现3、直接插入排序总结 二、希尔排序1、希尔排序基本思想2、希尔排序的代码实现3、希尔排序时间复杂度 前言 排序&#xff1a;所谓排序&#xff0c;就是使一串记录&#xff0c;按照其中的某个或某些关键字的大…

IDEA找不到Maven窗口

有时候导入项目或者创建项目时候Maven窗口找不到了 然后指定项目的pom.xml文件