【Python深度学习(第二版)(3)】初识神经网络之深度学习hello world

文章目录

  • 一. 训练Keras中的MNIST数据集
  • 二. 工作流程
    • 1. 构建神经网络
    • 2. 准备图像数据
    • 3. 训练模型
    • 4. 利用模型进行预测
    • 5. (新数据上)评估模型精度

本节将首先给出一个神经网络示例,引出如下概念。了解完本节后,可以对神经网络在代码上的实现有一个整体的了解。

本节相关概念:

  • 样本
  • 标签
  • 层(layer)
  • 数据蒸馏
  • 密集连接
  • 10路softmax分类层
  • 编译(compilation)步骤的3个参数
  • 损失值、精度
  • 过拟合

我们来看一个神经网络的具体实例:使用Python的Keras库来学习手写数字分类。

在这个例子中,我们要解决的问题是,将手写数字的灰度图像(28像素×28像素)划分到10个类别中(从0到9)。我们将使用MNIST数据集。你可以将“解决”MNIST问题看作深度学习的“Hello World”,用来验证你的算法正在按预期运行。下图给出了MNIST数据集的一些样本。

在这里插入图片描述

说明

在机器学习中,分类问题中的某个类别叫作类(class),数据点叫作样本(sample)与某个样本对应的类叫作标签(label)(即描述:样本属于哪个类别)。

 

你不需要现在就尝试在计算机上运行这个例子。之后的文章会具体分析。

 

一. 训练Keras中的MNIST数据集

from tensorflow.keras.datasets import mnist (train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images和train_labels组成了训练集,模型将从这些数据中进行学习。然后,我们在测试集(包括test_images和test_labels)上对模型进行测试。

图像被编码为NumPy数组,而标签是一个数字数组,取值范围是0~9。图像和标签一一对应。

 
看一下训练数据:

>>> train_images.shape (60000, 28, 28) 
>>> len(train_labels) 
60000 
>>> train_labels 
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)

 

再来看一下测试数据:

>>> test_images.shape
(10000, 28, 28) 
>>> len(test_labels) 
10000 
>>> test_labels 
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)

 

二. 工作流程

我们的工作流程如下:
首先,将训练数据(train_images和train_labels)输入神经网络;
然后,神经网络学习将图像和标签关联在一起;
最后,神经网络对test_images进行预测,我们来验证这些预测与test_labels中的标签是否匹配。
 


具体的代码我们可以在 deep-learning-with-python-notebooks 中直接运行。


1. 构建神经网络

下面我们来构建神经网络,如下:

from tensorflow import keras 
from tensorflow.keras import layers model = keras.Sequential([ layers.Dense(512, activation="relu"), layers.Dense(10, activation="softmax") ])

 

神经网络的核心组件是层(layer)

具体来说,层从输入数据中提取表示。大多数深度学习工作涉及将简单的层链接起来,从而实现渐进式的数据蒸馏(data distillation)。深度学习模型就像是处理数据的筛子,包含一系列越来越精细的数据过滤器(也就是层)。
 

本例中的模型包含2个Dense层,它们都是密集连接(也叫全连接)的神经层。

第2层是一个10路softmax分类层,它将返回一个由10个概率值(总和为1)组成的数组。每个概率值表示当前数字图像属于10个数字类别中某一个的概率。
 

在训练模型之前,我们还需要指定编译(compilation)步骤的3个参数

  • 优化器(optimizer):模型基于训练数据来自我更新的机制,其目的是提高模型性能。
  • 损失函数(loss function):模型如何衡量在训练数据上的性能,从而引导自己朝着正确的方向前进。
  • 在训练和测试过程中需要监控的指标(metric):本例只关心精度(accuracy),即正确分类的图像所占比例。后面两章会详细介绍损失函数和优化器的确切用途。

如下代码展示了编译步骤。


model.compile(optimizer="rmsprop", loss="sparse_categorical_crossentropy", metrics=["accuracy"])

 

2. 准备图像数据

在开始训练之前,我们先对数据进行预处理,将其变换为模型要求的形状,并缩放到所有值都在[0, 1]区间。前面提到过,训练图像保存在一个uint8类型的数组中,其形状为(60000, 28, 28),取值区间为[0,255]。我们将把它变换为一个float32数组,其形状为(60000, 28 *28),取值范围是[0, 1]。

下面准备图像数据,如代码所示。

train_images = train_images.reshape((60000, 28 * 28)) 
train_images = train_images.astype("float32") / 255 test_images = test_images.reshape((10000, 28 * 28)) 
test_images = test_images.astype("float32") / 255

 

3. 训练模型

在Keras中,通过调用模型的fit方法调用数据,训练模型。

>>> model.fit(train_images, train_labels, epochs=5, batch_size=128) 
Epoch 1/5 
60000/60000 [===========================] - 5s - loss: 0.2524 - acc: 0.9273 Epoch 2/5 
51328/60000 [=====================>.....] - ETA: 1s - loss: 0.1035 - acc: 0.9692

训练过程中显示了两个数字:一个是模型在训练数据上的损失值(loss),另一个是模型在训练数据上的精度(acc)。我们很快就在训练数据上达到了0.989(98.9%)的精度。

现在我们得到了一个训练好的模型,可以利用它来预测新数字图像的类别概率(如下代码)。这些新数字图像不属于训练数据,比如可以是测试集中的数据。

 

4. 利用模型进行预测

>>> test_digits = test_images[0:10] 
>>> predictions = model.predict(test_digits) 
>>> predictions[0] 
array([1.0726176e-10, 1.6918376e-10, 6.1314843e-08, 8.4106023e-06, 2.9967067e-11, 3.0331331e-09, 8.3651971e-14, 9.9999106e-01, 2.6657624e-08, 3.8127661e-07], dtype=float32)

如上代码我们对11个test_images图片进行预测,是什么数字,我们拿到第一个图片预测的概率数组,其中索引为7时,概率最大(0.99999106,几乎等于1),所以根据我们的模型,这个数字一定是7。

>>> predictions[0].argmax() 
7 
>>> predictions[0][7] 
0.99999106

这里我们检查测试标签是否与之一致:

>>> test_labels[0] 7

平均而言,我们的模型对这种前所未见的数字图像进行分类的效果如何?我们来计算在整个测试集上的平均精度,如下代码所示。

 

5. (新数据上)评估模型精度

>>> test_loss, test_acc = model.evaluate(test_images, test_labels) 
>>> print(f"test_acc: {test_acc}") 
test_acc: 0.9785

测试精度约为97.8%,比训练精度(98.9%)低不少。训练精度和测试精度之间的这种差距是过拟合(overfit)造成的。

过拟合是指机器学习模型在新数据上的性能往往比在训练数据上要差。

第一个例子到这里就结束了。你刚刚看到了如何用不到15行Python代码构建和训练一个神经网络,对手写数字进行分类。

 

之后的文章我们将详细描述每一个步骤的原理,并且将学到张量(输入模型的数据存储对象)、张量运算(层的组成要素)与梯度下降(可以让模型从训练示例中进行学习)。
 

参考:
《Python深度学习(第二版)》–弗朗索瓦·肖莱
https://www.redhat.com/zh/topics/digital-transformation/what-is-deep-learning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/834773.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python数据分析——pandas DataFrame基础知识2

参考资料:活用pandas库 1、分组方式 我们可以把分组计算看作“分割-应用-组合”(split-apply-combine)的过程。首先把数据分割成若干部分,然后把选择的函数(或计算)应用于各部分,最后把所有独立…

如何安全高效地进行分公司文件下发?

确保分公司文件下发过程中的保密性和安全性,是企业信息安全管理的重要组成部分。以下是一些关键步骤和最佳实践: 权限管理:确保只有授权的人员可以访问文件。使用权限管理系统来控制谁可以查看、编辑或下载文件。 加密传输:在文…

Linux|进程地址空间

Linux|内存地址空间 现象基本概念理解如何理解地址空间什么是划分区域&#xff1f;地址空间的理解为什么要有地址空间&#xff1f;如何进一步理解页表和写时拷贝如何理解虚拟地址 Linux真正的进程调度方案 现象 #include <stdio.h> #include <string.h> #include …

Go 使用 MongoDB

MongoDB 安装(Docker)安装 MongoDB Go 驱动使用 Go Driver 连接到 MongoDB在 Go 里面使用 BSON 对象CRUD 操作 插入文档更新文档查询文档删除文档 下一步 MongoDB 安装(Docker) 先装个 mongo&#xff0c;为了省事就用 docker 了。 docker 的 daemon.json 加一个国内的源地址…

Milvus基本概念及其应用场景

Milvus是一款云原生向量数据库&#xff0c;具备高可用、高性能、易拓展的特点&#xff0c;主要用于海量向量数据的实时召回。以下是关于Milvus的基本概念解释&#xff1a; 向量数据库&#xff1a;Milvus是一个向量数据库&#xff0c;用于存储、索引和管理通过深度神经网络和机…

用Java爬虫解决问题:探索网络数据的奥秘

网络爬虫是一种用于自动获取互联网信息的程序&#xff0c;常用于搜索引擎、数据挖掘等领域。本文将介绍如何使用Java编写网络爬虫来解决问题&#xff0c;并提供具体的代码实现及测试&#xff0c;帮助读者掌握爬虫技术并应用于实际项目中。 1. 爬虫原理 爬虫通过模拟人类浏览器…

C++容器——set

set容器 是一个关联容器&#xff0c;按一定的顺序存储一组唯一的元素。 set容器中的元素会根据元素的值自动进行排序&#xff0c;并且不允许包含重复的元素&#xff0c;基于二叉树实现的。 特点&#xff1a; 唯一性&#xff1a; set容器中的元素是唯一的&#xff0c;即容器中…

Java 区块链应用 | 割韭菜之假如K线涨跌可随意变动修改的实现

大家好&#xff0c;我是程序员大猩猩。 我一直在想&#xff0c;币圈这个行情时涨时跌&#xff0c;不断的割韭菜&#xff0c;不是由市场决定的&#xff01;而是由交易所直接输入一个数值后点击确定按钮而变化的&#xff0c;那么是不是很恐怖的行为。 为了验证这么一个想法&…

【代码随想录算法训练Day2】LeetCode 977.有序数组的平方、LeetCode 209.长度最小的子数组、LeetCode 59.螺旋矩阵II

Day2 数组、双指针 LeetCode 977.有序数组的平方【排序/双指针】 要将数组的每个元素平方后在按非递减的顺序&#xff0c;最简单的方法就是先将每个数平方&#xff0c;再将结果数组排序。 解法1&#xff1a;排序 class Solution { public:vector<int> sortedSquares(…

Java实现Excel导入和校验

文章目录 效果实现1,添加依赖2,实体类Member.javaMemberVO.java3,校验、监听器ValidationTool.javaExcelReadListener.java4,请求接口参考博文效果 输入:导入测试.xlsx postman调用实例: postman输出结果: 日志输出: 实现 1,添加依赖 easyexcel要去掉poi-ooxm…

maven打包SpringBoot项目报错Perhaps you are running on a JRE rather than a JDK?

maven打包SpringBoot项目报错信息如下 [ERROR] No compiler is provided in this environment. Perhaps you are running on a JRE rather than a JDK?提示很明显&#xff0c;他需要JDK&#xff0c;你只有JRE 解决方法 通过yum搜索jdk可以看到以下两个应用 $ yum search j…

LabVIEW电机测试系统

LabVIEW电机测试系统 开发了一款基于LabVIEW的有限转角力矩电机测试系统&#xff0c;该系统针对现代高性能驱动系统中的关键组成部分——有限转角力矩电机的测试需求&#xff0c;提供了一种全面的测试解决方案。通过集成化的LabVIEW程序开发和高速数据采集硬件的应用&#xff…

SpringBoot:事务和AOP

事务 一组操作的集合,不可分割的工作单位,会被一起提交或撤销 要么同时成功,要么同时失败 实物操作 begin/start transaction 开启事务 commit 提交事务 rollback 回滚事务 eg:当我们需要保证数据的一致性,例如在删除时,删除了部门,却没有删除部门的员工,就会出现数据的…

开源模型应用落地-模型记忆增强-概念篇(一)

一、前言 语言模型的记忆是基于其训练数据。具体而言,对于较长的文本,模型可能会遗忘较早的信息,因为它的记忆是有限的,并且更容易受到最近出现的内容的影响。模型无法跨越其固定的上下文窗口,而是根据当前上下文生成回应。 提升模型记忆能力有多种方法,比如改进模型的结…

Leetcode—295. 数据流的中位数【困难】

2024每日刷题&#xff08;132&#xff09; Leetcode—295. 数据流的中位数 实现代码 class MedianFinder { public:MedianFinder() {}void addNum(int num) {if(maxHeap.empty() || num < maxHeap.top()) {maxHeap.push(num);} else {minHeap.push(num);}if(maxHeap.size(…

未授权访问:Jenkins未授权访问漏洞

目录 1、漏洞原理 2、环境搭建 3、未授权访问 4、利用未授权访问写入webshell 防御手段 今天继续学习各种未授权访问的知识和相关的实操实验&#xff0c;一共有好多篇&#xff0c;内容主要是参考先知社区的一位大佬的关于未授权访问的好文章&#xff0c;还有其他大佬总结好…

基于JSP动漫论坛的设计与实现(二)

目录 3. 系统开发环境及技术介绍 3.1 开发环境 3.2 开发工具 3.2.1 MyEclipse8.5 3.2.2 MySql 3.3 相关技术介绍 3.3.1 JSP技术简介 3.3.2 JDBC技术技术简介 3.3.3 MVC模式与Struts框架技术 4. 总体设计 4.1 系统模块总体设计 4.1.1 普通用户模块设计 4…

常见的获取dom元素的方法

获取 DOM 元素是前端开发中非常常见的操作。以下是几种常用的方法来获取 DOM 元素&#xff0c;以及它们的适用场景和示例&#xff1a; 1. getElementById 用于获取具有指定 id 属性的元素。 示例 let element document.getElementById(myId); 2. getElementsByClassName …

element ui的无法关掉的提示弹框

使用element的$alert组件的属性把X去掉和确定按钮和取消按钮去掉&#xff1b; import { MessageBox } from element-ui; MessageBox.alert(AI功能已到期或暂未开启, 友情提示, {showClose: false,showCancelButton: false,showConfirmButton: false }); 如果在router的路由守…

【QT教程】QT6音视频处理权威指南 QT音视频

QT6音视频处理权威指南 使用AI技术辅助生成 QT界面美化视频课程 QT性能优化视频课程 QT原理与源码分析视频课程 QT QML C扩展开发视频课程 免费QT视频课程 您可以看免费1000个QT技术视频 免费QT视频课程 QT统计图和QT数据可视化视频免费看 免费QT视频课程 QT性能优化视频免费…