第T2周:彩色图片分类

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

👉 要求:

  • 学习如何编写一个完整的深度学习程序
  • 了解分类彩色图片会灰度图片有什么区别
  • 测试集accuracy到达72%

🦾我的环境:

  • 语言环境:Python3.8
  • 编译器:Jupyter Lab
  • 深度学习环境:
    • TensorFlow2

一、 前期准备

1.1. 设置GPU

  • 如果设备上支持GPU就使用GPU,否则使用CPU
  • Mac上的GPU使用mps
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")gpu0
PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')

1.2. 导入数据

使用dataset下载MNIST数据集,并划分好训练集与测试集

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
A local file was found, but it seems to be incomplete or outdated because the auto file hash does not match the original value of 6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce so we will re-download the data.
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 8500s 50us/step

1.3. 归一化

数据归一化作用

● 使不同量纲的特征处于同一数值量级,减少方差大的特征的影响,使模型更准确。
● 加快学习算法的收敛速度。

更详解的介绍请参考文章:🔗归一化与标准化

# 将像素的值标准化至0到1的区间内。(对于灰度图片来说,每个像素最大值是255,每个像素最小值是0,也就是直接除以255就可以完成归一化。)
train_images, test_images = train_images / 255.0, test_images / 255.0# 查看数据维数信息
train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
((50000, 32, 32, 3), (10000, 32, 32, 3), (50000, 1), (10000, 1))

1.4. 可视化图片

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']plt.figure(figsize=(20,10))
for i in range(20):plt.subplot(5,10,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i][0]])
plt.show()

在这里插入图片描述

二、构建简单的CNN网络

⭐池化层

池化层对提取到的特征信息进行降维,一方面使特征图变小,简化网络计算复杂度;另一方面进行特征压缩,提取主要特征,增加平移不变性,减少过拟合风险。但其实池化更多程度上是一种计算性能的一个妥协,强硬地压缩特征的同时也损失了一部分信息,所以现在的网络比较少用池化层或者使用优化后的如SoftPool。

池化层包括最大池化层(MaxPooling)和平均池化层(AveragePooling),均值池化对背景保留更好,最大池化对纹理提取更好)。同卷积计算,池化层计算窗口内的平均值或者最大值。例如通过一个 2*2 的最大池化层,其计算方式如下:
在这里插入图片描述

我们即将构建模型的结构图,我以分别二维和三维的形式展示出来方便大家理解。

  • 平面结构图
    在这里插入图片描述

  • 立体结构图
    在这里插入图片描述

model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), #卷积层1,卷积核3*3layers.MaxPooling2D((2, 2)),                   #池化层1,2*2采样layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层2,卷积核3*3layers.MaxPooling2D((2, 2)),                   #池化层2,2*2采样layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层3,卷积核3*3layers.Flatten(),                      #Flatten层,连接卷积层与全连接层layers.Dense(64, activation='relu'),   #全连接层,特征进一步提取layers.Dense(10)                       #输出层,输出预期结果
])model.summary()  # 打印网络结构
Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================conv2d (Conv2D)             (None, 30, 30, 32)        896       max_pooling2d (MaxPooling2  (None, 15, 15, 32)        0         D)                                                              conv2d_1 (Conv2D)           (None, 13, 13, 64)        18496     max_pooling2d_1 (MaxPoolin  (None, 6, 6, 64)          0         g2D)                                                            conv2d_2 (Conv2D)           (None, 4, 4, 64)          36928     flatten (Flatten)           (None, 1024)              0         dense (Dense)               (None, 64)                65600     dense_1 (Dense)             (None, 10)                650       =================================================================
Total params: 122570 (478.79 KB)
Trainable params: 122570 (478.79 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________2024-06-23 22:16:01.054779: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2
2024-06-23 22:16:01.054802: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2024-06-23 22:16:01.054811: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2024-06-23 22:16:01.054984: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-06-23 22:16:01.055316: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)

三、编译模型

model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

四、训练模型

history = model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))
Epoch 1/102024-06-23 22:16:41.825293: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.1563/1563 [==============================] - ETA: 0s - loss: 1.5781 - accuracy: 0.42422024-06-23 22:16:54.304550: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.1563/1563 [==============================] - 13s 8ms/step - loss: 1.5781 - accuracy: 0.4242 - val_loss: 1.3528 - val_accuracy: 0.5133
Epoch 2/10
1563/1563 [==============================] - 12s 8ms/step - loss: 1.2892 - accuracy: 0.5464 - val_loss: 1.2880 - val_accuracy: 0.5617
Epoch 3/10
1563/1563 [==============================] - 12s 8ms/step - loss: 1.3585 - accuracy: 0.5521 - val_loss: 1.6484 - val_accuracy: 0.5155
Epoch 4/10
1563/1563 [==============================] - 12s 8ms/step - loss: 2.0448 - accuracy: 0.5044 - val_loss: 3.0545 - val_accuracy: 0.4380
Epoch 5/10
1563/1563 [==============================] - 12s 8ms/step - loss: 5.7139 - accuracy: 0.4563 - val_loss: 20.7035 - val_accuracy: 0.2908
Epoch 6/10
1563/1563 [==============================] - 12s 8ms/step - loss: 45.9029 - accuracy: 0.3672 - val_loss: 109.2576 - val_accuracy: 0.3624
Epoch 7/10
1563/1563 [==============================] - 12s 8ms/step - loss: 504.0281 - accuracy: 0.2838 - val_loss: 1375.9681 - val_accuracy: 0.2399
Epoch 8/10
1563/1563 [==============================] - 12s 8ms/step - loss: 3719.2263 - accuracy: 0.2359 - val_loss: 6212.4688 - val_accuracy: 0.2268
Epoch 9/10
1563/1563 [==============================] - 12s 8ms/step - loss: 11472.0957 - accuracy: 0.2238 - val_loss: 20005.8828 - val_accuracy: 0.1773
Epoch 10/10
1563/1563 [==============================] - 12s 8ms/step - loss: 25618.4004 - accuracy: 0.2182 - val_loss: 31095.4336 - val_accuracy: 0.2160

五、预测

通过模型进行预测得到的是每一个类别的概率,数字越大该图片为该类别的可能性越大

plt.imshow(test_images[1])

在这里插入图片描述

输出测试集中第一张图片的预测结果

import numpy as nppre = model.predict(test_images)
print(class_names[np.argmax(pre[1])])
 75/313 [======>.......................] - ETA: 0s2024-06-23 22:20:12.257425: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.313/313 [==============================] - 1s 3ms/step
ship

六、模型评估

import matplotlib.pyplot as pltplt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)

在这里插入图片描述

print(test_acc)
0.6845156432345124

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/32897.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

测试测量-DMM直流精度

测试测量-DMM直流精度 最近去面试&#xff0c;发现了自己许多不足&#xff0c;比如我从未考虑过万用表准或者不准&#xff0c;或者万用表有多准&#xff1f; 在过去的实验室中&#xff0c;常用的DMM有KEYSIGHT 34401A以及 KEITHLEY THD2015&#xff0c;就以这两台为例&#x…

线程C++

#include <thread> #include <chrono> #include <cmath> #include <mutex> #include <iostream> using namespace std;mutex mtx; void threadCommunicat() {int ans 0;while (ans<3){mtx.lock();//上锁cout << "ans" <…

预制直埋聚氨酯保温管

&#x1f600;宝子们&#xff0c;今天来给大家介绍一下预制聚氨酯直埋保温管&#x1f389;&#xff01; 它可是个好东西&#x1f44d;&#xff0c;具有超强的保温性能&#x1f9ca;&#xff0c;能够有效保持管道内的温度&#x1f60e;。 而且它还很耐用&#x1f4aa;&#xff0…

用 idea 启动多个实例

在学习负载均衡的时候&#xff0c;要模拟多个实例均提供一个服务&#xff0c;我们要如何用 idea 启动多个实例呢&#xff1f; 如下图&#xff0c;我们已经启动了一个 ProductService 服务&#xff0c;现在想再启动两个相同的服务 1. 选中要启动的服务,右键选择 Copy Configura…

用Java获取键盘输入数的个十百位数

这段Java代码是一个简单的程序&#xff0c;用于接收用户输入的一个三位数&#xff0c;并将其分解为个位、十位和百位数字&#xff0c;然后分别打印出来。下面是代码的详细解释&#xff1a; 导入所需类库: import java.util.Scanner;&#xff1a;导入Scanner类&#xff0c;用于从…

opencv学习笔记(3)

绘制直线 line(img, 开始点&#xff0c;结束点&#xff0c;颜色&#xff0c;线宽&#xff0c;线型(默认为8)) import cv2 import numpy as npimg np.zeros((640, 480, 3), np.uint8)#画线&#xff0c;坐标点为(x, y) cv2.line(img, (10, 20), (10, 220), (0, 0, 255), 5, 4)…

UsersGUI.java用户界面

完成效果图&#xff1a; 点击阅读按钮&#xff1a; 点击删除按钮&#xff1a; 点击新建按钮&#xff1a; Code /* This GUI application allows users to manage their diaries: ​ Read: Users can read existing diaries. Create: Users can create new diaries. Delete: Us…

ARC学习(3)基本编程模型认识(三)

笔者来介绍arc的编程模型的中断流程和异常流程 1、中断介绍 主要介绍一下中断进入的流程&#xff0c;包括需要配置的寄存器等信息。 中断号&#xff1a;16-255&#xff0c;总共240个中断。触发类型&#xff1a;脉冲或者电平触发中断优先级&#xff1a;16个&#xff0c;0最大&…

【git1】指令,commit,免密

文章目录 1.常用指令&#xff1a;git branch查看本地分支&#xff0c; -r查看远程分支&#xff0c; -a查看本地和远程&#xff0c;-v查看各分支最后一次提交, -D删除分支2.commit规范&#xff1a;git commit进入vi界面&#xff08;进入前要git config core.editor vim设一下vi模…

DVWA-XSS(Stored)-httponly分析

拿DVWA的XSS为例子 httponly是微软对cookie做的扩展。这个主要是解决用户的cookie可能被盗用的问题。 接DVWA的分析&#xff0c;发现其实Impossible的cookie都是设置的httponly1&#xff0c;samesite1. 这两个参数的意思参考Set-Cookie HttpOnly:阻止 JavaScript 通过 Documen…

Java项目:基于SSM框架实现的精品酒销售管理系统分前后台【ssm+B/S架构+源码+数据库+毕业论文】

一、项目简介 本项目是一套基于SSM框架实现的精品酒销售管理系统 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目附带全部源码可作为毕设使用。 项目都经过严格调试&#xff0c;eclipse或者idea 确保可以运行&#xff01; 该系统功能完善、界面美观、操作简单、功…

微信公众号 H5授权登录实现(最详细)

一、微信公众号 &#xff08;一&#xff09;基础信息 微信授权类型 自己的网站、APP等第三方&#xff0c;要实现接入微信授权登录&#xff0c;有多种方式&#xff1a;微信公众号&#xff08;网页&#xff09;、微信小程序、微信开放平台&#xff08;APP&#xff09;等等。 【…

面试:关于word2vec的相关知识点Hierarchical Softmax和NegativeSampling

1、为什么需要Hierarchical Softmax和Negative Sampling 从输入层到隐含层需要一个维度为NK的权重矩阵&#xff0c;从隐含层到输出层又需要一个维度为KN的权重矩阵&#xff0c;学习权重可以用反向传播算法实现&#xff0c;每次迭代时将权重沿梯度更优的方向进行一小步更新。但…

详细解析MATLAB和Simulink中的文件格式:mat, mdl, mexw32, 和 m 文件

matlab 探索MATLAB和Simulink中的文件格式&#xff1a;MAT, MDL, MEXW32, 和 M 文件**MAT 文件 (.mat)****MDL 文件 (.mdl)****MEX 文件 (.mexw32/.mexw64)****M 文件 (.m)****总结** 探索MATLAB和Simulink中的文件格式&#xff1a;MAT, MDL, MEXW32, 和 M 文件 当你开始使用M…

Python 虚拟环境 requirements.txt 文件生成 ;pipenv导出pip安装文件

搜索关键词: Python 虚拟环境Pipenv requirements.txt 文件生成;Pipenv 导出 pip requirements.txt安装文件 本文基于python版本 >3.9 文章内容有效日期2023年01月开始(因为此方法从这个时间开始是完全ok的) 上述为pipenv的演示版本 使用以下命令可精准生成requirement…

Java8 --- Gradle7.4整合IDEA

目录 一、Gradle整合IDEA 1.1、Groovy安装 1.1.1、配置环境变量 ​编辑 1.2、创建项目 ​编辑 1.3、Groovy基本语法 1.3.1、基本语法 1.3.2、引号 1.3.3、语句结构 1.3.4、数据类型 1.3.5、集合操作 1.4、使用Gradle创建普通Java工程 1.5、使用Gradle创建Java ss…

使用 axios 进行 HTTP 请求

使用 axios 进行 HTTP 请求 文章目录 使用 axios 进行 HTTP 请求1、介绍2、安装和引入3、axios 基本使用4、axios 发送 GET 请求5、axios 发送 POST 请求6、高级使用7、总结 1、介绍 什么是 axios axios 是一个基于 promise 的 HTTP 库&#xff0c;可以用于浏览器和 Node.js 中…

计算机组成入门知识

前言&#x1f440;~ 数据库的知识点先暂且分享到这&#xff0c;接下来开始接触计算机组成以及计算机网络相关的知识点&#xff0c;这一章先介绍一些基础的计算机组成知识 一台计算机如何组成的&#xff1f; 存储器 CPU cpu的工作流程 主频 如何衡量CPU好坏呢&#xff1f…

我的常见问题记录

1,maven在idea工具可以正常使用,在命令窗口执行出现问题 代码: E:\test-hello\simple-test>mvn clean compile [INFO] Scanning for projects... [WARNING] [WARNING] Some problems were encountered while building the effective model for org.consola:simple-test:jar…