训练 CNN 对 CIFAR-10 数据中的图像进行分类-keras实现

1. 加载 CIFAR-10 数据库

import keras
from keras.datasets import cifar10# 加载预先处理的训练数据和测试数据
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

2. 可视化前 24 个训练图像

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inlinefig = plt.figure(figsize=(20,5))
for i in range(36):ax = fig.add_subplot(3, 12, i + 1, xticks=[], yticks=[])ax.imshow(np.squeeze(x_train[i]))

3. 通过将每幅图像中的每个像素除以 255 来调整图像比例

事实上,代价函数的形状是一个碗,但如果特征的比例非常不同,它也可能是一个拉长的碗。下图显示了梯度下降法在特征 1 和特征 2 比例相同的训练集上的应用(左图),以及在特征 1 的值远小于特征 2 的训练集上的应用(右图)。

Tips : 使用梯度下降法时,应确保所有特征的比例相似,以加快训练速度,否则收敛时间会更长。

# rescale [0,255] --> [0,1]
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255

 4. 将数据集分为训练集、测试集和验证集

from keras.utils import to_categorical# 对标签进行一次热编码
num_classes = len(np.unique(y_train))
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)# 将训练集分为训练集和验证集
(x_train, x_valid) = x_train[5000:], x_train[:5000]
(y_train, y_valid) = y_train[5000:], y_train[:5000]# 打印训练集的形状
print('x_train shape:', x_train.shape)# 打印训练、验证和测试图像的数量
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(x_valid.shape[0], 'validation samples')

5. 定义模型架构 

from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropoutmodel = Sequential()
model.add(Conv2D(filters=16, kernel_size=2, padding='same', activation='relu', input_shape=(32, 32, 3)))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=64, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Dropout(0.3))
model.add(Flatten())
model.add(Dense(500, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(10, activation='softmax'))model.summary()

6. 编译模型 

# compile the model
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

7. 训练模型

from keras.callbacks import ModelCheckpoint   # 训练模型
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, save_best_only=True)hist = model.fit(x_train, y_train, batch_size=32, epochs=100,validation_data=(x_valid, y_valid), callbacks=[checkpointer], verbose=2, shuffle=True)

8. 加载验证精度最高的模型

# 加载验证精度最高的权重
model.load_weights('model.weights.best.hdf5')

 9. 计算测试集的分类精度

# 评估和打印测试精度
score = model.evaluate(x_test, y_test, verbose=0)
print('\n', 'Test accuracy:', score[1])

10. 可视化一些预测

这可能会让你对网络错误分类某些对象的原因有所了解。

# 在测试集上得到预测
y_hat = model.predict(x_test)# 定义文本标签 (source: https://www.cs.toronto.edu/~kriz/cifar.html)
cifar10_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 绘制测试图像的随机样本、预测标签和基本真实图像
fig = plt.figure(figsize=(20, 8))
for i, idx in enumerate(np.random.choice(x_test.shape[0], size=32, replace=False)):ax = fig.add_subplot(4, 8, i + 1, xticks=[], yticks=[])ax.imshow(np.squeeze(x_test[idx]))pred_idx = np.argmax(y_hat[idx])true_idx = np.argmax(y_test[idx])ax.set_title("{} ({})".format(cifar10_labels[pred_idx], cifar10_labels[true_idx]),color=("green" if pred_idx == true_idx else "red"))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/187916.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

csapp-linklab之第4阶段“输出学号”实验报告(switch跳转表)

实验内容 修改phase4.o相应节中的内容,使其与main.o链接后运行能够输出自己的学号: $ gcc -o linkbomb main.o phase4.o $ ./linkbomb $学号 实验提示 掌握switch语句的机器语言表示及其跳转表的实现。 找出跳转表 反汇编phase4.o,看看里…

el-table实现动态表头

1.1el-table渲染 <el-tableref"refreshTable":data"tableData"highlight-current-row><el-table-columnfixedwidth"170px"label"测点"align"center"prop"测站名称"/><el-table-column label"…

浅谈安科瑞可编程电测仪表在老挝某项目的应用

摘要&#xff1a;本文介绍了安科瑞多功能电能表在老挝某项目的应用。AMC系列交流多功能仪表是一款专门为电力系统、工矿企业、公用事业和智能建筑用于电力监控而设计的智能电表。 Abstract&#xff1a;This article introduces the application of the multi-function energy …

深度学习今年来经典模型优缺点总结,包括卷积、循环卷积、Transformer、LSTM、GANs等

文章目录 1、卷积神经网络&#xff08;Convolutional Neural Networks&#xff0c;CNN&#xff09;1.1 优点1.2 缺点1.3 应用场景1.4 网络图 2、循环神经网络&#xff08;Recurrent Neural Networks&#xff0c;RNNs&#xff09;2.1 优点2.2 缺点2.3 应用场景2.4 网络图 3、长短…

L1-010:比较大小

题目描述 本题要求将输入的任意3个整数从小到大输出。 输入格式: 输入在一行中给出3个整数&#xff0c;其间以空格分隔。 输出格式: 在一行中将3个整数从小到大输出&#xff0c;其间以“->”相连。 输入样例: 4 2 8输出样例: 2->4->8 程序代码 #include<stdio.h&…

基于YOLOv8深度学习的安全帽目标检测系统【python源码+Pyqt5界面+数据集+训练代码】目标检测、深度学习实战

《博主简介》 小伙伴们好&#xff0c;我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源&#xff0c;可关注公-仲-hao:【阿旭算法与机器学习】&#xff0c;共同学习交流~ &#x1f44d;感谢小伙伴们点赞、关注&#xff01; 《------往期经典推…

Git——使用Git进行程序开发

主要介绍个人开发提交记录的主要流程&#xff0c;包括以下内容&#xff1a; 索引- 提交的暂存区。查看工作的状态和内部变更。如何读取用于描述变更的已扩展统一diff格式。支持查询和交互的提交&#xff0c;修改提交。创建、显示和选择&#xff08;切换&#xff09;分支。切换…

婴儿专用洗衣机有必要买吗?宝宝洗衣机洗衣服

我们都知道刚出生的宝宝抵抗力较弱&#xff0c;很容易因为细菌感染然后生病&#xff0c;宝宝接触最多的就是衣服&#xff0c;我们在手洗的过程很难把衣服上的细菌清洗掉&#xff0c;而使用我们传统的洗衣机很容易造成细菌的第二次感染&#xff0c;很容易将宝宝的抵抗力弄得越来…

如何通过linux调用企业微信发送告警消息

一、前期准备 1、企业微信具备管理企业权限。 2、服务器有公网IP或者可以将本机端口通过net映射到公网。 二、通过脚本向企业微信发送消息 1、创建sh脚本用来发送消息。 vim 2.sh 注意&#xff1a;脚本中xxxx信息需要在企业微信管理后台获取。 #!/bin/bash # 设置企业…

2023年计网408

第33题 33.在下图所示的分组交换网络中&#xff0c;主机H1和H2通过路由器互连&#xff0c;2段链路的带宽均为100Mbps、 时延带宽积(即单向传播时延带宽)均为1000bits。若 H1向 H2发送1个大小为 1MB的文件&#xff0c;分组长度为1000B&#xff0c;则从H1开始发送时刻起到H2收到…

代码随想录刷题题Day2

刷题的第二天&#xff0c;希望自己能够不断坚持下去&#xff0c;迎来蜕变。&#x1f600;&#x1f600;&#x1f600; 刷题语言&#xff1a;C / Python Day2 任务 977.有序数组的平方 209.长度最小的子数组 59.螺旋矩阵 II 1 有序数组的平方&#xff08;重点&#xff1a;双指针…

将项目放到gitee上

参考 将IDEA中的项目上传到Gitee仓库中_哔哩哔哩_bilibili 如果cmd运行ssh不行的话&#xff0c;要换成git bash 如果初始化后的命令用不了&#xff0c;直接用idea项放右键&#xff0c;用git工具操作

XXL-Job详解(二):安装部署

目录 前言环境下载项目调度中心部署执行器部署 前言 看该文章之前&#xff0c;最好看一下之前的文章&#xff0c;比较方便我们理解 XXL-Job详解&#xff08;一&#xff09;&#xff1a;组件架构 环境 Maven3 Jdk1.8 Mysql5.7 下载项目 源码仓库地址链接: https://github.…

前端对浏览器的理解

浏览器的主要构成 用户界面 &#xff0d; 包括地址栏、后退/前进按钮、书签目录等&#xff0c;也就是你所看到的除了用来显示你所请求页面的主窗口之外的其他部分。 浏览器引擎 &#xff0d; 用来查询及操作渲染引擎的接口。 渲染引擎 &#xff0d; 用来显示请求的内容&#…

某60区块链安全之薅羊毛攻击实战一学习记录

区块链安全 文章目录 区块链安全薅羊毛攻击实战一实验目的实验环境实验工具实验原理实验内容薅羊毛攻击实战一 实验步骤EXP利用 薅羊毛攻击实战一 实验目的 学会使用python3的web3模块 学会分析以太坊智能合约薅羊毛攻击漏洞 找到合约漏洞进行分析并形成利用 实验环境 Ubun…

JVM类加载与运行时数据区

目录 一、类加载器 jvm类的加载过程 第一阶段&#xff1a;加载 第二阶段&#xff1a;链接阶段 第三阶段&#xff1a;初始化阶段&#xff1a; 双亲委派机制 沙箱安全机制 运行时数据区 栈-Xss1m 堆 TLAB 逃逸分析 方法区 常量池中有什么 StringTable为什么要调整位…

VS Code C++可视化调试配置Natvis,查看Qt、STL变量内容

VS Code C可视化调试配置Natvis 使用GlobalVisualizersDirectory Windows下 C:\Users\YourName\.vscode\extensions\ms-vscode.cpptools-1.18.5-win32-x64\debugAdapters\vsdbg\bin\Visualizers\Linux下 ~\.vscode\extensions\ms-vscode.cpptools-1.18.5-win32-x64\debugAd…

Spring Cloud 原理(第一节)

一、百度百科 Spring Cloud是一系列框架的有序集合。它利用Spring Boot的开发便利性巧妙地简化了分布式系统基础设施的开发&#xff0c;如服务发现注册、配置中心、消息总线、负载均衡、断路器、数据监控等&#xff0c;都可以用Spring Boot的开发风格做到一键启动和部署。Spri…

实验五 C语言函数程序设计习题 (使用函数计算两点间的距离,请编写函数fun,使用函数输出字符矩阵,使用函数求最大公约数和最小公倍数)

1. 使用函数计算两点间的距离&#xff1a;给定平面任意两点坐标(x1,y1)和(x2,y2)&#xff0c;求这两点之间的距离(保留2位)小数。要求定义和调用dist(x1,y1,x2,y2)计算两点间的距离。坐标中两点坐标之间的距离公式如下&#xff1a; #include <stdio.h> #include <math…

1.ORB-SLAM3中如何保存多地图、关键帧、地图点到二进制文件中

1 保存多地图 1.1 为什么保存(视觉)地图 因为我们要去做导航&#xff0c;导航需要先验地图。因此需要保存地图供导航使用&#xff0c;下面来为大家讲解如何保存多地图。 1.2 保存多地图的主函数SaveAtlas 2051 mStrSaveAtlasToFile是配置文件中传递的参数&#xff1a; 这里我们…