基于Keras的手写数字识别(附源码)

目录

引言

为什么要创建虚拟环境,好处在哪里?

源码 

我修改的部分

调用本地数据

修改第二层卷积层


引言

本文是博主为了记录一个好的开源代码而写,下面是代码出处!强烈建议收藏!【深度学习实战—1】:基于Keras的手写数字识别(非常详细、代码开源)

写的非常好,但是复现这篇博客却让我吃了很多苦头, 大家要先下载Anaconda3然后创建一个虚拟环境,在虚拟环境里面主要下载以下三个东西版本号只要对应好,肯定能运行,其他的库少什么安装什么!如果用显卡跑模型,原博客有提及配置!

版本号
Python版本3.7.3
Keras版本2.4.3
tensorflow版本2.4.0

为什么要创建虚拟环境,好处在哪里?

在进行机器学习项目时,我们经常会遇到需要为不同的模型安装不同版本的Python或相关库的情况。这是因为每个模型可能依赖于特定版本的库,这些版本之间可能存在兼容性差异。如果不使用虚拟环境,而是在主环境中直接安装这些库,可能会遇到以下问题:

首先,当你为新的模型安装特定版本的库时,可能会覆盖掉主环境中已经存在的其他模型所需的库版本,导致之前的模型无法正常运行。

其次,不同的Python版本之间也可能存在兼容性问题。如果你直接在主环境中升级或降级Python版本,可能会影响到依赖于特定Python版本的其他项目。

为了避免这些问题,使用虚拟环境变得尤为重要。虚拟环境是一个隔离的Python环境,其中可以安装特定版本的Python和库,而不会影响到主环境或其他虚拟环境。这样,你可以为每个机器学习模型创建一个独立的虚拟环境,并在其中安装所需的Python版本和库版本,从而确保每个模型都能在其特定的环境中稳定运行。

通过这种方法,你可以轻松地管理多个项目,而无需担心库版本冲突或Python版本不兼容的问题。希望这样的解释能帮助大家更好地理解虚拟环境在机器学习项目中的重要性。

源码 

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from keras.datasets import mnist
from sklearn.metrics import confusion_matrix
import seaborn as sns
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
from keras.utils import np_utils
import tensorflow as tfconfig = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)# 设定随机数种子,使得每个网络层的权重初始化一致
# np.random.seed(10)# x_train_original和y_train_original代表训练集的图像与标签, x_test_original与y_test_original代表测试集的图像与标签
(x_train_original, y_train_original), (x_test_original, y_test_original) = mnist.load_data()
# 假设你已经知道mnist.npz文件的路径
# file_path = 'mnist.npz'  # 替换为你的mnist.npz文件的实际路径
#
# # 加载npz文件
# with np.load(file_path, allow_pickle=True) as f:
#     x_train_original = f['x_train']
#     y_train_original = f['y_train']
#     x_test_original = f['x_test']
#     y_test_original = f['y_test']"""
数据可视化
"""# 单张图像可视化
def mnist_visualize_single(mode, idx):if mode == 0:plt.imshow(x_train_original[idx], cmap=plt.get_cmap('gray'))title = 'label=' + str(y_train_original[idx])plt.title(title)plt.xticks([])  # 不显示x轴plt.yticks([])  # 不显示y轴plt.show()else:plt.imshow(x_test_original[idx], cmap=plt.get_cmap('gray'))title = 'label=' + str(y_test_original[idx])plt.title(title)plt.xticks([])  # 不显示x轴plt.yticks([])  # 不显示y轴plt.show()# 多张图像可视化
def mnist_visualize_multiple(mode, start, end, length, width):if mode == 0:for i in range(start, end):plt.subplot(length, width, 1 + i)plt.imshow(x_train_original[i], cmap=plt.get_cmap('gray'))title = 'label=' + str(y_train_original[i])plt.title(title)plt.xticks([])plt.yticks([])plt.show()else:for i in range(start, end):plt.subplot(length, width, 1 + i)plt.imshow(x_test_original[i], cmap=plt.get_cmap('gray'))title = 'label=' + str(y_test_original[i])plt.title(title)plt.xticks([])plt.yticks([])plt.show()mnist_visualize_multiple(mode=0, start=0, end=4, length=2, width=2)
# 原始数据量可视化
print('训练集图像的尺寸:', x_train_original.shape)
print('训练集标签的尺寸:', y_train_original.shape)
print('测试集图像的尺寸:', x_test_original.shape)
print('测试集标签的尺寸:', y_test_original.shape)"""
数据预处理
"""
#
# 从训练集中分配验证集
x_val = x_train_original[50000:]
y_val = y_train_original[50000:]
x_train = x_train_original[:50000]
y_train = y_train_original[:50000]
print('======================')
# 打印验证集数据量
print('验证集图像的尺寸:', x_val.shape)
print('验证集标签的尺寸:', y_val.shape)
print('======================')
# 将图像转换为四维矩阵(nums,rows,cols,channels), 这里把数据从unint类型转化为float32类型, 提高训练精度。
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
x_val = x_val.reshape(x_val.shape[0], 28, 28, 1).astype('float32')
x_test = x_test_original.reshape(x_test_original.shape[0], 28, 28, 1).astype('float32')
#
# 原始图像的像素灰度值为0-255,为了提高模型的训练精度,通常将数值归一化映射到0-1。
x_train = x_train / 255
x_val = x_val / 255
x_test = x_test / 255
#
print('训练集传入网络的图像尺寸:', x_train.shape)
print('验证集传入网络的图像尺寸:', x_val.shape)
print('测试集传入网络的图像尺寸:', x_test.shape)
# #
# 图像标签一共有10个类别即0-9,这里将其转化为独热编码(One-hot)向量
y_train = np_utils.to_categorical(y_train)
print(y_train[0])y_val = np_utils.to_categorical(y_val)
y_test = np_utils.to_categorical(y_test_original)#
# """
# 定义网络模型
# """
#
#
def CNN_model():model = Sequential()model.add(Conv2D(filters=16, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))  # 卷积层model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))  # 池化层# model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))  # 卷积层model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu'))  # 卷积层model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))  # 池化层model.add(Flatten())  # 平铺层model.add(Dense(100, activation='relu'))  # 全连接层model.add(Dense(10, activation='softmax'))  # 全连接层print(model.summary())return model#
#
# """
# 训练网络
# """
#
model = CNN_model()
# #
# 编译网络(定义损失函数、优化器、评估指标)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])# 开始网络训练(定义训练数据与验证数据、定义训练代数,定义训练批大小) 原来20
train_history = model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=20, batch_size=32, verbose=2)# 模型保存
model.save('handwritten_numeral_recognition.h5')#
#
# #
# #
# 定义训练过程可视化函数(训练集损失、验证集损失、训练集精度、验证集精度)
def show_train_history(train_history, train, validation):plt.plot(train_history.history[train])plt.plot(train_history.history[validation])plt.title('Train History')plt.ylabel(train)plt.xlabel('Epoch')plt.legend(['train', 'validation'], loc='best')plt.show()show_train_history(train_history, 'accuracy', 'val_accuracy')
show_train_history(train_history, 'loss', 'val_loss')# 输出网络在测试集上的损失与精度
score = model.evaluate(x_test, y_test)
print('Test loss:', score[0])
print('Test accuracy:', score[1])# 测试集结果预测
predictions = model.predict(x_test)
predictions = np.argmax(predictions, axis=1)
print('前9张图片预测结果:', predictions[:9])# 预测结果图像可视化
def mnist_visualize_multiple_predict(start, end, length, width):for i in range(start, end):plt.subplot(length, width, 1 + i)plt.imshow(x_test_original[i], cmap=plt.get_cmap('gray'))title_true = 'true=' + str(y_test_original[i])title_prediction = ',' + 'prediction' + str(model.predict_classes(np.expand_dims(x_test[i], axis=0)))title = title_true + title_predictionplt.title(title)plt.xticks([])plt.yticks([])plt.show()mnist_visualize_multiple_predict(start=0, end=9, length=3, width=3)# 混淆矩阵
cm = confusion_matrix(y_test_original, predictions)
cm = pd.DataFrame(cm)
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']def plot_confusion_matrix(cm):plt.figure(figsize=(10, 10))sns.heatmap(cm, cmap='Oranges', linecolor='black', linewidth=1, annot=True, fmt='', xticklabels=class_names,yticklabels=class_names)plt.xlabel("Predicted")plt.ylabel("Actual")plt.title("Confusion Matrix")plt.show()plot_confusion_matrix(cm)

我修改的部分

调用本地数据

# x_train_original和y_train_original代表训练集的图像与标签, x_test_original与y_test_original代表测试集的图像与标签
# (x_train_original, y_train_original), (x_test_original, y_test_original) = mnist.load_data()
# 假设你已经知道mnist.npz文件的路径
file_path = 'mnist.npz'  # 替换为你的mnist.npz文件的实际路径# 加载npz文件
with np.load(file_path, allow_pickle=True) as f:x_train_original = f['x_train']y_train_original = f['y_train']x_test_original = f['x_test']y_test_original = f['y_test']

因为原来的代码是每次运行都请求下载网上的在线数据,这是没必要的,当你运行了一次,可以把数据存在本地,然后以后本地调用

修改第二层卷积层

def CNN_model():model = Sequential()model.add(Conv2D(filters=16, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))  # 卷积层model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))  # 池化层# model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu', input_shape=(28, 28, 1)))  # 卷积层model.add(Conv2D(filters=32, kernel_size=(5, 5), activation='relu'))  # 卷积层model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))  # 池化层model.add(Flatten())  # 平铺层model.add(Dense(100, activation='relu'))  # 全连接层model.add(Dense(10, activation='softmax'))  # 全连接层print(model.summary())return model

 原文中的第二层卷积层的输入是规定为(28,28,1),但是这是有问题的,应该是不设置参数,这样子的话,会自动将第一个池化层的输出当作输入

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/840893.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【spring】@ControllerAdvice注解学习

ControllerAdvice介绍 ControllerAdvice 是 Spring 框架提供的一个注解,用于定义一个全局的异常处理类或者说是控制器增强类(controller advice class)。这个特性特别适用于那些你想应用于整个应用程序中多个控制器的共有行为,比…

ctfhub中的SSRF的相关例题(下)

目录 URL Bypass 知识点 相关例题 数字IP Bypass 相关例题 方法一:使用数字IP 方法二:转16进制 方法三:用localhost代替 方法四:特殊地址 302跳转 Bypass ​编辑 关于localhost原理: DNS重绑定 Bypass 知识点&…

ant design pro 6.0搭建教程

一、搭建 环境: Node.js 18.16.1 ant design pro 6.0 注意:选择umi3时,使用node.js 18版本的会报错,可以实践一下,这里就不再进行实践了。 umi3需要版本是低于node.js 18的 node下载地址: https://nodejs.…

可重构柔性装配产线,为智能制造领域带来了新的革命性变革

随着科技的飞速发展,个性化需求逐渐成为市场的主导。在这个充满变革的时代,制造业正面临着前所未有的挑战和机遇。如何快速响应市场需求、提高生产效率、保证产品质量,成为每一家制造企业必须思考的问题。 在这样的背景下,富唯智…

免费插件集-illustrator插件-Ai插件-文本对象和文本段落互转

文章目录 1.介绍2.安装3.通过窗口>扩展>知了插件4.功能解释5.总结 1.介绍 本文介绍一款免费插件,加强illustrator使用人员工作效率,进行文本对象和文本段落互转。首先从下载网址下载这款插件 https://download.csdn.net/download/m0_67316550/878…

00.OpenLayers快速开始

00OpenLayers快速开始 官方文档: 快速开始:https://openlayers.org/doc/quickstart.html 需要node环境 一、设置新项目 npm create ol-app my-app cd my-app npm start第一个命令将创建一个名为 my-app​ 的目录(如果您愿意,…

Java——简易图书管理系统

本文使用 Java 实现一个简易图书管理系统 一、思路 简易图书管理系统说白了其实就是 用户 与 图书 这两个对象之间的交互 书的属性有 书名 作者 类型 价格 借阅状态 而用户可以分为 普通用户 管理员 使用数组将书统一管理起来 用户对这个数组进行操作 普通用户可以进…

有趣的css - 圆形背景动效多选框

大家好,我是 Just,这里是「设计师工作日常」,今天分享的是用 css 实现一个圆形背景动效多选框,适用提醒用户勾选场景,突出多选框选项,可以有效增加用户识别度。 最新文章通过公众号「设计师工作日常」发布…

VBA批量合并带有图片、表格与文本框的Word

本文介绍基于VBA语言,对大量含有图片、文本框与表格的Word文档加以批量自动合并,并在每一次合并时添加分页符的方法。 在我们之前的文章基于Python中docx与docxcompose批量合并多个Word文档文件并逐一添加分页符(https://blog.csdn.net/zhebu…

helloworld 可执行程序得到的过程

// -E 预处理 开发过程中可以确定某个宏 // -c 把预处理 编译 汇编 都做了,但是不链接 // -o 指定输出文件 // -I 指定头文件目录 // -L 指定链接库文件目录 // -l 指定链接哪一个库文件 #include <stdio.h> #include <stdlib.h> #include <string.h>int mai…

【微积分】CH16 integrals and vector fields听课笔记

【托马斯微积分学习日记】13.1-线积分_哔哩哔哩_bilibili 概述 16.1line integrals of scalar functions [中英双语]可视化多元微积分 - 线积分介绍_哔哩哔哩_bilibili 16.2vector fields and line integrals&#xff1a; work circulation and flux 向量场差不多也是描述某种…

gpt-4o继续迭代考场安排程序 一键生成考场清单

接上两篇gpt-4o考场安排-CSDN博客&#xff0c;考场分层次安排&#xff0c;最终exe版-CSDN博客 当然你也可以只看这一篇。 今天又添加了以下功能&#xff0c;程序见后。 1、自动分页&#xff0c;每个考场打印一页 2、添加了打印试场单页眉 3、添加了页脚 第X页&#xff0c;…

Leetcode刷题笔记1:数组基础1

导语 leetcode刷题笔记记录&#xff0c;本篇博客记录数组基础1部分的题目&#xff0c;主要题目包括&#xff1a; Leetcode 704 二分查找Leetcode 27 移除元素 知识点 二分查找 原理 二分查找的适用对象为有序数组且数组中无重复元素&#xff0c;其主要原理是每次都从有序…

AI视频教程下载:全面掌握ChatGPT和LangChain开发AI应用(附源代码)

这是一门深入的课程&#xff0c;涉及ChatGPT、LangChain和Python。打造专注于现实世界AI集成的AI应用&#xff0c;课件附有每一节涉及到的源代码。 **你将学到什么&#xff1a;** - 将ChatGPT集成到LangChain的生产风格应用中 - 使用LangChain组件构建复杂的文本生成管道 - …

推荐五个线上兼职,在家也能轻松日入百元,适合上班族和全职宝妈

在这个瞬息万变的时代&#xff0c;你是否也曾考虑过在繁忙的工作之外&#xff0c;寻找一份兼职副业来补贴家用&#xff0c;同时保持生活的多样性&#xff1f;别急&#xff0c;现在就让我为你揭秘五个可靠的日结线上兼职岗位&#xff0c;助你轻松迈向财务自由之路&#xff01; 一…

云WAF与传统WAF:网络安全的双重防线

在网络安全领域&#xff0c;Web应用防火墙&#xff08;WAF&#xff09;是守护企业网络安全的重要盾牌。随着云计算技术的迅猛发展&#xff0c;云WAF作为一种新型的安全服务模式&#xff0c;正逐渐成为企业网络安全防护的新宠。本文将深入探讨云WAF与传统WAF的区别&#xff0c;分…

使用 Flask 和 Celery 构建异步任务处理应用

文章目录 什么是 Flask&#xff1f;什么是 Celery&#xff1f;如何在 Flask 中使用 Celery&#xff1f;步骤 1&#xff1a;安装 Flask 和 Celery步骤 2&#xff1a;创建 Flask 应用程序步骤 3&#xff1a;运行 Celery Worker步骤 4&#xff1a;启动 Flask 应用程序 结论 在构建…

高校网站群及融媒体中心建设方案

一、项目背景 随着信息技术的飞速发展&#xff0c;互联网已成为高校展示形象、传播信息、服务师生、沟通社会的重要渠道。然而&#xff0c;目前许多高校在网站建设和媒体传播方面存在以下问题&#xff1a; 网站分散、缺乏统一规划&#xff1a;各高校内部往往存在多个部门或学院…

零拷贝(Zero-Copy)

1.背景 现在有这样一个场景&#xff0c;我们需要在本地选择一个文件后&#xff0c;然后上传到网络上。 我们再看看文件的内容数据的具体搬运过程&#xff1a; 你会发现&#xff0c;在整个文件搬运的过程中&#xff0c;发生了多次的数据拷贝和上下文转换。 4次数据拷贝&#…