卷积神经网络(CNN)多种图片分类的实现

文章目录

  • 前期工作
    • 1. 设置GPU(如果使用的是CPU可以忽略这步)
      • 我的环境:
    • 2. 导入数据
    • 3.归一化
    • 4.可视化
  • 二、构建CNN网络模型
  • 三、编译模型
  • 四、训练模型
  • 五、预测
  • 六、模型评估

前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")if gpus:gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPUtf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用tf.config.set_visible_devices([gpu0],"GPU")

2. 导入数据

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

3.归一化

# 将像素的值标准化至0到1的区间内。
train_images, test_images = train_images / 255.0, test_images / 255.0train_images.shape,test_images.shape,train_labels.shape,test_labels.shape

4.可视化

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']plt.figure(figsize=(20,10))
for i in range(20):plt.subplot(5,10,i+1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i][0]])
plt.show()

在这里插入图片描述

二、构建CNN网络模型

model = models.Sequential([layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), #卷积层1,卷积核3*3layers.MaxPooling2D((2, 2)),                   #池化层1,2*2采样layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层2,卷积核3*3layers.MaxPooling2D((2, 2)),                   #池化层2,2*2采样layers.Conv2D(64, (3, 3), activation='relu'),  #卷积层3,卷积核3*3layers.Flatten(),                      #Flatten层,连接卷积层与全连接层layers.Dense(64, activation='relu'),   #全连接层,特征进一步提取layers.Dense(10)                       #输出层,输出预期结果
])model.summary()  # 打印网络结构
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (None, 30, 30, 32)        896       
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 15, 15, 32)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 13, 13, 64)        18496     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 6, 6, 64)          0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 4, 4, 64)          36928     
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 64)                65600     
_________________________________________________________________
dense_1 (Dense)              (None, 10)                650       
=================================================================
Total params: 122,570
Trainable params: 122,570
Non-trainable params: 0
_________________________________________________________________

三、编译模型

model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])

四、训练模型

history = model.fit(train_images, train_labels, epochs=10, validation_data=(test_images, test_labels))
Epoch 1/10
1563/1563 [==============================] - 9s 4ms/step - loss: 1.7862 - accuracy: 0.3390 - val_loss: 1.2697 - val_accuracy: 0.5406
Epoch 2/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.2270 - accuracy: 0.5595 - val_loss: 1.0731 - val_accuracy: 0.6167
Epoch 3/10
1563/1563 [==============================] - 5s 3ms/step - loss: 1.0355 - accuracy: 0.6337 - val_loss: 0.9678 - val_accuracy: 0.6610
Epoch 4/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.9221 - accuracy: 0.6727 - val_loss: 0.9589 - val_accuracy: 0.6648
Epoch 5/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.8474 - accuracy: 0.7022 - val_loss: 0.8962 - val_accuracy: 0.6853
Epoch 6/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.7814 - accuracy: 0.7292 - val_loss: 0.9124 - val_accuracy: 0.6873
Epoch 7/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.7398 - accuracy: 0.7398 - val_loss: 0.8924 - val_accuracy: 0.6929
Epoch 8/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.7008 - accuracy: 0.7542 - val_loss: 0.9809 - val_accuracy: 0.6854
Epoch 9/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.6474 - accuracy: 0.7732 - val_loss: 0.8549 - val_accuracy: 0.7137
Epoch 10/10
1563/1563 [==============================] - 5s 3ms/step - loss: 0.6041 - accuracy: 0.7889 - val_loss: 0.8909 - val_accuracy: 0.7046

五、预测

通过模型进行预测得到的是每一个类别的概率,数字越大该图片为该类别的可能性越大

plt.imshow(test_images[10])

在这里插入图片描述

输出测试集中第一张图片的预测结果

import numpy as nppre = model.predict(test_images)
print(class_names[np.argmax(pre[10])])
313/313 [==============================] - 1s 3ms/step
airplane

六、模型评估

import matplotlib.pyplot as pltplt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()test_loss, test_acc = model.evaluate(test_images,  test_labels, verbose=2)

在这里插入图片描述

print(test_acc)
0.7166000008583069

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/152327.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从多表连接视图对比人大金仓和Oracle

KING BASE 信息时代,数据是驱动业务决策和创新的核心资源。然而,随着数据量的不断增加,有效地处理和整合数据的过程变得愈发复杂。这时,多表连接视图悄然走进数据库世界,不仅能够将多个表中的数据整合在一起&#xff0…

管家婆订货易在线商城任意文件上传漏洞复现

0x01 产品简介 管家婆订货易,帮助传统企业构建专属的订货平台,PC微信APP小程序h5商城5网合一,无缝对接线下的管家婆ERP系统,让用户订货更高效。支持业务员代客下单,支持多级推客分销,以客带客,拓…

【C/C++】递归算法

信封 某人写了n封信和n个信封&#xff0c;如果所有的信都装错了信封。求所有的信都装错信封共有多少种不同情况 #include <iostream> using namespace std; const int N 30; int n; long f[N];int main() {scanf("%d", &n);f[1] 0, f[2] 1;for (int …

「Verilog学习笔记」ROM的简单实现

专栏前言 本专栏的内容主要是记录本人学习Verilog过程中的一些知识点&#xff0c;刷题网站用的是牛客网 分析 要实现ROM&#xff0c;首先要声明数据的存储空间&#xff0c;例如&#xff1a;[3:0] rom [7:0]&#xff1b;变量名称rom之前的[3:0]表示每个数据具有多少位&#xff0…

HLS基础issue

hls 是一个用C/c 来开发PL &#xff0c;产生rtl的工具 hls是按照rtl code来运行的 &#xff0c; 但是rtl会在不同器件调用不同的源语&#xff1b; 可能产生的ip使用在vivado另外一个器件的话 会存在问题&#xff1b; Hls &#xff1a; vivado ip &#xff0c; vitis kernel 是…

认识前端包常用包管理工具(npm、cnpm、pnpm、nvm、yarn)

随着前端的快速发展,前端的框架越来越趋向于工程化,所以对于包的使用也越来越多,为了优化性能和后期的维护更新,对于前端包的管理也尤为重要,本文主要阐述对node中包管理工具的理解和简单的使用方法。也欢迎各位大佬和同行们多多指教。😁😁😁 👉1. npm 安装npm 通…

自定义类使用ArrayList中的remove

Java中ArrayList对基础类型和字符串类型的删除操作&#xff0c;直接用remove方法即可。但是对于自定义的类来说&#xff0c;用remove方法删除不了&#xff0c;因为没有办法确定是否是要删除的对象。 ArrayList中remove源码是&#xff1a; public boolean remove(Object o) {if…

【linux】补充:高效处理文本的命令学习(tr、uniq、sort、cut)

目录 一、tr——转换、压缩、删除 1、tr -s “分隔符” &#xff08;指定压缩连续的内容&#xff09; 2、tr -d 想要删除的东西 ​编辑 3、tr -t 内容1 内容2 将内容1全部转换为内容2&#xff08;字符数需要一一对应&#xff09; 二、cut——快速剪裁命令 三、uniq——去…

git撤销未git commit的文件

目录 一、问题描述 二、方式1&#xff1a;git命令撤销&#xff08;更专业&#xff09; 1、文件已git add&#xff0c;未git commit 2、本地修改&#xff0c;未git add &#xff08;1&#xff09;撤销处于unstage的文件&#xff0c;即删除已有变动 &#xff08;2&#xff…

Vue3 动态设置 ref

介绍 在一些场景&#xff0c;ref设置是未知的需要根据动态数据来决定&#xff0c;如表格中的input框需要我们主动聚焦&#xff0c;就需要给每一个input设置一个ref&#xff0c;进而进行聚焦操作。 Demo 点击下面截图中的编辑按钮&#xff0c;自动聚焦到相应的输入框中。 &…

【算法基础】动态规划

背包问题 01背包 每个物品只能放一次 2. 01背包问题 - AcWing题库 二维dp #include<bits/stdc.h> const int N1010; int f[N][N]; int v[N],w[N]; signed main() {int n,m;std::cin>>n>>m; for(int i1;i<n;i) std::cin>>v[i]>>w[i];for…

Nginx的核心配置文件

Nginx的核心配置文件 学习Nginx首先需要对它的核心配置文件有一定的认识&#xff0c;这个文件位于Nginx的安装目录/usr/local/nginx/conf目录下&#xff0c;名字为nginx.conf 详细配置&#xff0c;可以参考resources目录下的<<nginx配置中文详解.conf>> Nginx的核…

warning C4251

c - Warning C4251 when building a DLL that exports a class containing an ATL::CString member - Stack Overflow

python中sklearn库在数据预处理中的详细用法,及5个常用的Scikit-learn(通常简称为 sklearn)程序代码示例

文章目录 前言1. 数据清洗&#xff1a;使用 sklearn.preprocessing 中的 StandardScaler 和 MinMaxScaler 进行数据规范化。2. 缺失值处理&#xff1a;使用 sklearn.impute 中的 SimpleImputer 来填充缺失值。3. 数据编码&#xff1a;使用 sklearn.preprocessing 中的 OneHotEn…

Nosql之redis概述及基本操作

关系数据库与非关系型数据库概述 关系型数据库 关系型数据库是一个结构化的数据库&#xff0c;创建在关系模型&#xff08;二维表格模型&#xff09;基础上&#xff0c;一般面向于记录。SQL语句(标准数据查询语言)就是一种基于关系型数据库的语言&#xff0c;用于执行对关系型…

常见负载均衡算法/策略(概念)

目录 1.1. 轮循均衡&#xff08;Round Robin&#xff09; 1.2. 权重轮循均衡&#xff08;Weighted Round Robin&#xff09; 1.3. 随机均衡&#xff08;Random&#xff09; 1.4. 权重随机均衡&#xff08;Weighted Random&#xff09; 1.5. 响应速度均衡&#xff08;R…

聊一聊go的单元测试(goconvey、gomonkey、gomock)

文章目录 概要一、测试框架1.1、testing1.2、stretchr/testify1.3、smartystreets/goconvey1.4、cweill/gotests 二、打桩和mock2.1、打桩2.2、mock2.2.1、mockgen2.2.1、示例 三、基准测试和模糊测试3.1、基准测试3.2、模糊测试 四、总结4.1、小结4.2、其他4.3、参考资料 概要…

六.Linux远程登录

1.说明&#xff1a;公司开发的时候&#xff0c;具体的应用场景是这样的 1.linux服务器是开发小组共享 2.正式上线的项目是运行在公网 3.因此程序员需要远程登录到Linux进行项目管理或者开发 4.画出简单的网络拓扑示意图(帮助理解) 5.远程登录客户端有Xshell6、Xftp6&#xff0…

7年经验之谈 —— 如何高效的开展app的性能测试?

APP性能测试是什么 从网上查了一下&#xff0c;貌似也没什么特别的定义&#xff0c;我这边根据自己的经验给出一个自己的定义&#xff0c;如有巧合纯属雷同。 客户端性能测试就是&#xff0c;从业务和用户的角度出发&#xff0c;设计合理且有效的性能测试场景&#xff0c;制定…

3D建模基础教程:石墨工具介绍

3DMAX的石墨&#xff08;Graphite&#xff09;工具是一个强大的建模工具&#xff0c;可以用来创建和编辑复杂的3D模型。下面是对石墨工具的详细介绍&#xff1a; 石墨工具的界面布局&#xff1a; 石墨工具的界面与3DMAX的主界面相同&#xff0c;包括菜单栏、工具栏、视图区、…