第T2周:彩色图片分类

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

👉 要求:

  • 学习如何编写一个完整的深度学习程序
  • 了解分类彩色图片会灰度图片有什么区别
  • 测试集accuracy到达72%

🦾我的环境:

  • 语言环境:Python3.8
  • 编译器:Jupyter Lab
  • 深度学习环境:
    • TensorFlow2

一、 前期准备

1.1. 设置GPU

  • 如果设备上支持GPU就使用GPU,否则使用CPU
  • Mac上的GPU使用mps
import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")gpu0
PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')

1.2. 导入数据

使用dataset下载MNIST数据集,并划分好训练集与测试集

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
A local file was found, but it seems to be incomplete or outdated because the auto file hash does not match the original value of 6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce so we will re-download the data.
Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
170498071/170498071 [==============================] - 8500s 50us/step

1.3. 归一化

数据归一化作用

● 使不同量纲的特征处于同一数值量级,减少方差大的特征的影响,使模型更准确。
● 加快学习算法的收敛速度。

更详解的介绍请参考文章:🔗归一化与标准化

# 将像素的值标准化至0到1的区间内。(对于灰度图片来说,每个像素最大值是255,每个像素最小值是0,也就是直接除以255就可以完成归一化。)
train_images, test_images = train_images / 255.0, test_images / 255.0# 查看数据维数信息
train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
((50000, 32, 32, 3), (10000, 32, 32, 3), (50000, 1), (10000, 1))

1.4. 可视化图片

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']plt.figure(figsize=(20,10))
for i in range(20):plt.subplot(5,10,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i][0]])
plt.show()

在这里插入图片描述

二、构建简单的CNN网络

⭐池化层

池化层对提取到的特征信息进行降维,一方面使特征图变小,简化网络计算复杂度;另一方面进行特征压缩,提取主要特征,增加平移不变性,减少过拟合风险。但其实池化更多程度上是一种计算性能的一个妥协,强硬地压缩特征的同时也损失了一部分信息,所以现在的网络比较少用池化层或者使用优化后的如SoftPool。

池化层包括最大池化层(MaxPooling)和平均池化层(AveragePooling),均值池化对背景保留更好,最大池化对纹理提取更好)。同卷积计算,池化层计算窗口内的平均值或者最大值。例如通过一个 2*2 的最大池化层,其计算方式如下:
在这里插入图片描述

我们即将构建模型的结构图,我以分别二维和三维的形式展示出来方便大家理解。

  • 平面结构图
    在这里插入图片描述

  • 立体结构图
    在这里插入图片描述

model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), #卷积层1,卷积核3*3layers.MaxPooling2D((2, 2)),                   #池化层1,2*2采样layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层2,卷积核3*3layers.MaxPooling2D((2, 2)),                   #池化层2,2*2采样layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层3,卷积核3*3layers.Flatten(),                      #Flatten层,连接卷积层与全连接层layers.Dense(64, activation='relu'),   #全连接层,特征进一步提取layers.Dense(10)                       #输出层,输出预期结果
])model.summary()  # 打印网络结构
Model: "sequential"
_________________________________________________________________Layer (type)                Output Shape              Param #   
=================================================================conv2d (Conv2D)             (None, 30, 30, 32)        896       max_pooling2d (MaxPooling2  (None, 15, 15, 32)        0         D)                                                              conv2d_1 (Conv2D)           (None, 13, 13, 64)        18496     max_pooling2d_1 (MaxPoolin  (None, 6, 6, 64)          0         g2D)                                                            conv2d_2 (Conv2D)           (None, 4, 4, 64)          36928     flatten (Flatten)           (None, 1024)              0         dense (Dense)               (None, 64)                65600     dense_1 (Dense)             (None, 10)                650       =================================================================
Total params: 122570 (478.79 KB)
Trainable params: 122570 (478.79 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________2024-06-23 22:16:01.054779: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M2
2024-06-23 22:16:01.054802: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2024-06-23 22:16:01.054811: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2024-06-23 22:16:01.054984: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-06-23 22:16:01.055316: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)

三、编译模型

model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

四、训练模型

history = model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))
Epoch 1/102024-06-23 22:16:41.825293: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.1563/1563 [==============================] - ETA: 0s - loss: 1.5781 - accuracy: 0.42422024-06-23 22:16:54.304550: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.1563/1563 [==============================] - 13s 8ms/step - loss: 1.5781 - accuracy: 0.4242 - val_loss: 1.3528 - val_accuracy: 0.5133
Epoch 2/10
1563/1563 [==============================] - 12s 8ms/step - loss: 1.2892 - accuracy: 0.5464 - val_loss: 1.2880 - val_accuracy: 0.5617
Epoch 3/10
1563/1563 [==============================] - 12s 8ms/step - loss: 1.3585 - accuracy: 0.5521 - val_loss: 1.6484 - val_accuracy: 0.5155
Epoch 4/10
1563/1563 [==============================] - 12s 8ms/step - loss: 2.0448 - accuracy: 0.5044 - val_loss: 3.0545 - val_accuracy: 0.4380
Epoch 5/10
1563/1563 [==============================] - 12s 8ms/step - loss: 5.7139 - accuracy: 0.4563 - val_loss: 20.7035 - val_accuracy: 0.2908
Epoch 6/10
1563/1563 [==============================] - 12s 8ms/step - loss: 45.9029 - accuracy: 0.3672 - val_loss: 109.2576 - val_accuracy: 0.3624
Epoch 7/10
1563/1563 [==============================] - 12s 8ms/step - loss: 504.0281 - accuracy: 0.2838 - val_loss: 1375.9681 - val_accuracy: 0.2399
Epoch 8/10
1563/1563 [==============================] - 12s 8ms/step - loss: 3719.2263 - accuracy: 0.2359 - val_loss: 6212.4688 - val_accuracy: 0.2268
Epoch 9/10
1563/1563 [==============================] - 12s 8ms/step - loss: 11472.0957 - accuracy: 0.2238 - val_loss: 20005.8828 - val_accuracy: 0.1773
Epoch 10/10
1563/1563 [==============================] - 12s 8ms/step - loss: 25618.4004 - accuracy: 0.2182 - val_loss: 31095.4336 - val_accuracy: 0.2160

五、预测

通过模型进行预测得到的是每一个类别的概率,数字越大该图片为该类别的可能性越大

plt.imshow(test_images[1])

在这里插入图片描述

输出测试集中第一张图片的预测结果

import numpy as nppre = model.predict(test_images)
print(class_names[np.argmax(pre[1])])
 75/313 [======>.......................] - ETA: 0s2024-06-23 22:20:12.257425: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.313/313 [==============================] - 1s 3ms/step
ship

六、模型评估

import matplotlib.pyplot as pltplt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)

在这里插入图片描述

print(test_acc)
0.6845156432345124

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/32897.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

测试测量-DMM直流精度

测试测量-DMM直流精度 最近去面试&#xff0c;发现了自己许多不足&#xff0c;比如我从未考虑过万用表准或者不准&#xff0c;或者万用表有多准&#xff1f; 在过去的实验室中&#xff0c;常用的DMM有KEYSIGHT 34401A以及 KEITHLEY THD2015&#xff0c;就以这两台为例&#x…

Leetcode 3195. Find the Minimum Area to Cover All Ones I

Leetcode 3195. Find the Minimum Area to Cover All Ones I 1. 解题思路2. 代码实现 题目链接&#xff1a;3195. Find the Minimum Area to Cover All Ones I 1. 解题思路 这一题还是挺简单的&#xff0c;只要找到所有1所在的元素的上下左右4个边界&#xff0c;作为目标矩形…

线程C++

#include <thread> #include <chrono> #include <cmath> #include <mutex> #include <iostream> using namespace std;mutex mtx; void threadCommunicat() {int ans 0;while (ans<3){mtx.lock();//上锁cout << "ans" <…

预制直埋聚氨酯保温管

&#x1f600;宝子们&#xff0c;今天来给大家介绍一下预制聚氨酯直埋保温管&#x1f389;&#xff01; 它可是个好东西&#x1f44d;&#xff0c;具有超强的保温性能&#x1f9ca;&#xff0c;能够有效保持管道内的温度&#x1f60e;。 而且它还很耐用&#x1f4aa;&#xff0…

解析Java中1000个常用类:AbstractSet类,你学会了吗?

推荐一个我自己写的小报童专栏导航网站: http://xbt100.top 收录了生财有术项目精选、AI海外赚钱、纯银的产品分析等专栏,陆续会收录更多的专栏,欢迎体验~复制URL可直达。 以下是正文。 在 Java 集合框架中,AbstractSet 是一个重要的抽象类,为实现自定义的集合(Set)提…

【Python】处理 scikit-learn 中的 SettingWithCopyWarning

那年夏天我和你躲在 这一大片宁静的海 直到后来我们都还在 对这个世界充满期待 今年冬天你已经不在 我的心空出了一块 很高兴遇见你 让我终究明白 回忆比真实精彩 &#x1f3b5; 王心凌《那年夏天宁静的海》 这不是一个错误&#xff0c;而是一个 SettingW…

用 idea 启动多个实例

在学习负载均衡的时候&#xff0c;要模拟多个实例均提供一个服务&#xff0c;我们要如何用 idea 启动多个实例呢&#xff1f; 如下图&#xff0c;我们已经启动了一个 ProductService 服务&#xff0c;现在想再启动两个相同的服务 1. 选中要启动的服务,右键选择 Copy Configura…

用Java获取键盘输入数的个十百位数

这段Java代码是一个简单的程序&#xff0c;用于接收用户输入的一个三位数&#xff0c;并将其分解为个位、十位和百位数字&#xff0c;然后分别打印出来。下面是代码的详细解释&#xff1a; 导入所需类库: import java.util.Scanner;&#xff1a;导入Scanner类&#xff0c;用于从…

opencv学习笔记(3)

绘制直线 line(img, 开始点&#xff0c;结束点&#xff0c;颜色&#xff0c;线宽&#xff0c;线型(默认为8)) import cv2 import numpy as npimg np.zeros((640, 480, 3), np.uint8)#画线&#xff0c;坐标点为(x, y) cv2.line(img, (10, 20), (10, 220), (0, 0, 255), 5, 4)…

【经典算法】LeetCode 22括号生成(Java/C/Python3/Go实现含注释说明,中等)

作者主页&#xff1a; &#x1f517;进朱者赤的博客 精选专栏&#xff1a;&#x1f517;经典算法 作者简介&#xff1a;阿里非典型程序员一枚 &#xff0c;记录在大厂的打怪升级之路。 一起学习Java、大数据、数据结构算法&#xff08;公众号同名&#xff09; ❤️觉得文章还…

urllib3版本与系统openssl版本不兼容

urllib3版本与系统openssl版本不兼容 报错信息 ImportError: urllib3 v2.0 only supports OpenSSL 1.1.1, currently the ssl 解决办法 安装urllib3的1.x.xx版本&#xff0c;如&#xff1a; pip install urllib31.25.11

UsersGUI.java用户界面

完成效果图&#xff1a; 点击阅读按钮&#xff1a; 点击删除按钮&#xff1a; 点击新建按钮&#xff1a; Code /* This GUI application allows users to manage their diaries: ​ Read: Users can read existing diaries. Create: Users can create new diaries. Delete: Us…

ARC学习(3)基本编程模型认识(三)

笔者来介绍arc的编程模型的中断流程和异常流程 1、中断介绍 主要介绍一下中断进入的流程&#xff0c;包括需要配置的寄存器等信息。 中断号&#xff1a;16-255&#xff0c;总共240个中断。触发类型&#xff1a;脉冲或者电平触发中断优先级&#xff1a;16个&#xff0c;0最大&…

【git1】指令,commit,免密

文章目录 1.常用指令&#xff1a;git branch查看本地分支&#xff0c; -r查看远程分支&#xff0c; -a查看本地和远程&#xff0c;-v查看各分支最后一次提交, -D删除分支2.commit规范&#xff1a;git commit进入vi界面&#xff08;进入前要git config core.editor vim设一下vi模…

DVWA-XSS(Stored)-httponly分析

拿DVWA的XSS为例子 httponly是微软对cookie做的扩展。这个主要是解决用户的cookie可能被盗用的问题。 接DVWA的分析&#xff0c;发现其实Impossible的cookie都是设置的httponly1&#xff0c;samesite1. 这两个参数的意思参考Set-Cookie HttpOnly:阻止 JavaScript 通过 Documen…

Java项目:基于SSM框架实现的精品酒销售管理系统分前后台【ssm+B/S架构+源码+数据库+毕业论文】

一、项目简介 本项目是一套基于SSM框架实现的精品酒销售管理系统 包含&#xff1a;项目源码、数据库脚本等&#xff0c;该项目附带全部源码可作为毕设使用。 项目都经过严格调试&#xff0c;eclipse或者idea 确保可以运行&#xff01; 该系统功能完善、界面美观、操作简单、功…

文本三剑客—sed命令

sed命令 一、概念 sed是一种流编辑器&#xff0c;一次处理一行内容。 处理方式&#xff1a;一行一行处理&#xff0c;处理完当前行&#xff0c;才会处理下一行&#xff0c;直到文件末尾。 如果只是展示&#xff0c;会放在缓冲区&#xff08;模式空间&#xff09;&#xff0…

微信公众号 H5授权登录实现(最详细)

一、微信公众号 &#xff08;一&#xff09;基础信息 微信授权类型 自己的网站、APP等第三方&#xff0c;要实现接入微信授权登录&#xff0c;有多种方式&#xff1a;微信公众号&#xff08;网页&#xff09;、微信小程序、微信开放平台&#xff08;APP&#xff09;等等。 【…

面试:关于word2vec的相关知识点Hierarchical Softmax和NegativeSampling

1、为什么需要Hierarchical Softmax和Negative Sampling 从输入层到隐含层需要一个维度为NK的权重矩阵&#xff0c;从隐含层到输出层又需要一个维度为KN的权重矩阵&#xff0c;学习权重可以用反向传播算法实现&#xff0c;每次迭代时将权重沿梯度更优的方向进行一小步更新。但…

100337. 最大化子数组的总成本

Powered by:NEFU AB-IN Link 文章目录 100337. 最大化子数组的总成本题意思路代码 100337. 最大化子数组的总成本 题意 给你一个长度为 n 的整数数组 nums。 子数组 nums[l…r]&#xff08;其中 0 < l < r < n&#xff09;的 成本 定义为&#xff1a; cost(l, r)…