TensorFlow学习之:深度学习基础

神经网络基础

神经网络是深度学习的核心,它们受人脑的结构和功能启发,能够通过学习大量数据来识别模式和解决复杂问题。神经网络的基本工作原理包括前向传播和反向传播两个阶段。

前向传播(Forward Propagation)

前向传播是神经网络进行预测的过程。数据从输入层开始,经过隐藏层,最后到达输出层。在每一层,数据都会通过节点(神经元)进行处理,并通过激活函数进行非线性转换。

  1. 输入层:接收输入数据。例如,在图像识别任务中,输入层的神经元数量通常等于图像的像素数。
  2. 隐藏层:一个或多个隐藏层对数据进行进一步处理。每个神经元在这些层中都会接收前一层所有神经元的输出作为输入。
  3. 输出层:生成最终的预测结果。输出层的神经元数量通常取决于任务的类型(例如,在分类任务中等于类别数)。

每个神经元的输出计算公式为:

反向传播(Backpropagation)

反向传播是神经网络学习和更新参数的过程。它通过计算损失函数(如均方误差或交叉熵损失)相对于网络参数的梯度,然后使用这些梯度来更新网络的权重和偏置,从而减小预测错误。

  1. 计算损失:首先,计算网络输出和实际值之间的误差(损失)。
  2. 梯度计算:然后,利用链式法则自输出层向输入层逐层计算每个参数(权重和偏置)对损失的影响(梯度)。
  3. 更新参数:最后,使用梯度下降或其他优化算法根据计算出的梯度更新网络的参数。

这个过程会在训练数据上重复多次(训练迭代),每次迭代都会使神经网络的预测更加准确。

激活函数

激活函数在神经网络中非常重要,它引入非线性因素,使得神经网络能够学习和模拟复杂的函数。没有激活函数,即使网络有多个隐藏层,它也只能表示线性关系。

常用的激活函数包括:

  • ReLU(Rectified Linear Unit):最常用的激活函数,对于正输入保持不变,对于负输入则输出0。
  • Sigmoid:将输入映射到0和1之间,常用于二分类任务的输出层。
  • Tanh(Hyperbolic Tangent):将输入映射到-1和1之间,形状和Sigmoid类似但范围更广。
  • Softmax:将输入映射为概率分布,常用于多分类任务的输出层。

神经网络通过前向传播进行预测,通过反向传播进行学习,利用激活函数引入非线性,这些都是实现其强大功能的关键因素。理解这些基本概念是深入学习深度学习的基础。

构建简单的神经网络

使用TensorFlow构建和训练一个简单的全连接神经网络是学习深度学习的基础。全连接网络,也称为密集网络,是最简单的神经网络结构,其中网络中的每个神经元都与前一层的所有神经元相连。这里,我们将通过构建一个用于手写数字识别(MNIST数据集)的全连接网络来演示这个过程。

步骤 1: 导入必要的库

首先,我们需要导入TensorFlow和其他可能需要的库。如果你还没有安装TensorFlow,请先按照官方指南进行安装。

import tensorflow as tf
from tensorflow.keras import layers, models

步骤 2: 加载数据集

MNIST数据集包含了60000个训练样本和10000个测试样本,每个样本是一个28x28的灰度手写数字图像。

mnist = tf.keras.datasets.mnist(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0  # 归一化

步骤 3: 构建模型

接下来,我们使用tf.keras来构建一个简单的全连接网络。这个网络将包含一个输入层,几个隐藏层,和一个输出层。

model = models.Sequential([layers.Flatten(input_shape=(28, 28)),  # 将28x28的图像展平为一个784维的向量layers.Dense(128, activation='relu'),  # 第一个隐藏层,128个节点layers.Dropout(0.2),  # 防止过拟合layers.Dense(10, activation='softmax')  # 输出层,10个节点对应10个类别
])

在这里,Dense是全连接层。第一个Dense层有128个神经元并使用ReLU激活函数。Dropout层随机地将输入单元的一部分设置为0,有助于防止过拟合。最后一个Dense层是输出层,使用Softmax激活函数输出预测的概率分布。

步骤 4: 编译模型

在训练模型之前,我们需要编译它,设置优化器、损失函数和评估指标。

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

这里我们使用adam优化器和sparse_categorical_crossentropy损失函数,这是处理多分类问题常用的设置。评估指标使用准确率。

步骤 5: 训练模型

现在,我们可以训练模型了。

model.fit(x_train, y_train, epochs=5)

这里,epochs=5表示我们将整个数据集迭代5次。

步骤 6: 评估模型

最后,我们评估模型在测试集上的性能。

 
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print('\nTest accuracy:', test_acc)

这个简单的全连接网络在MNIST数据集上就能达到相当不错的准确率。

通过以上步骤,你可以看到使用TensorFlow构建和训练全连接网络是相对直接的。TensorFlow提供了丰富的API和工具,使得从构建到训练再到评估模型的整个过程都变得简单高效。随着深入学习,你将能够构建更复杂的模型来解决更多样化的问题。

损失函数和优化器

在深度学习中,损失函数(Loss Function)和优化器(Optimizer)是训练神经网络的核心组件。损失函数衡量模型预测与实际值之间的差距,优化器则用于根据损失函数的梯度更新模型的权重,以最小化损失。

损失函数

损失函数是模型性能的量化度量,它计算了模型预测值与实际标签值之间的差异。根据任务的不同,选择合适的损失函数是至关重要的。

  • 均方误差(Mean Squared Error, MSE):用于回归问题,计算预测值和真实值之间差值的平方和的均值。公式为

  • 交叉熵损失(Cross-Entropy Loss)

  • Hinge损失:常用于支持向量机(SVM)中,但也可以用于训练深度学习模型,特别是在一些二分类问题上。

优化器

优化器负责模型训练过程中的权重更新。选择合适的优化器可以加速模型的训练并提高模型的性能。

  • 随机梯度下降(Stochastic Gradient Descent, SGD):最基本的优化器,按照梯度下降的方向更新权重,可选地加入动量(Momentum)来加速训练。

  • Adam:结合了AdaGrad和RMSProp优化器的优点,通过计算梯度的一阶矩估计和二阶矩估计来调整学习率,通常在实践中表现良好。

  • RMSprop:适用于非平稳目标的优化器,通过除以一个衰减的平均值来调整学习率。

  • AdaGrad:通过累积过去梯度的平方来调整每个参数的学习率,适合处理稀疏数据。

选择损失函数和优化器

选择哪个损失函数和优化器取决于具体的应用场景和问题类型。例如,对于多类分类问题,常用的损失函数是分类交叉熵损失,而优化器则可以选择Adam,因为它在许多情况下都能提供良好的性能和快速收敛。

总之,损失函数定义了模型的优化目标,而优化器定义了如何达到这个目标。在深度学习中合理选择并调整它们对于训练有效的模型至关重要。

过拟合和正则化

在深度学习中,过拟合是一个常见问题,它发生在模型对训练数据学得太好,以至于损失了对新数据的泛化能力。过拟合的模型在训练集上表现出色,但在验证集或测试集上表现较差。为了缓解这个问题,正则化技术被广泛应用。

过拟合的原因

  • 数据集太小:模型没有足够的数据来学习泛化的特征,而是记住了训练数据的特点。
  • 模型太复杂:模型的参数过多,学习能力太强,导致模型学到了训练数据的噪声。

解决过拟合的方法

数据增强

增加数据的多样性可以有效减少过拟合。对于图像数据,可以通过旋转、缩放、裁剪等方法来增加数据集的大小和多样性。

降低模型复杂度

减少模型的大小(例如,减少网络层或每层的神经元数量)可以减轻过拟合,但这可能会影响模型的学习能力。

正则化技术

正则化是一种限制模型复杂度的方法,以提高模型的泛化能力。常见的正则化技术包括:

  • L1和L2正则化:在损失函数中添加一个正则项,L1正则化倾向于生成稀疏权重矩阵,L2正则化可以防止权重变得太大。

  • Dropout:在训练过程中随机“丢弃”一部分神经元(即将它们的输出设置为0),这可以被看作是从原始网络中采样出大量不同的子网络来共同决定最终结果,从而增加网络的泛化能力。

  • 早停(Early Stopping):在训练过程中,如果验证集的性能在连续多个epoch后不再提高,则提前终止训练。这可以防止模型在训练集上过度拟合。

批归一化(Batch Normalization)

虽然批归一化主要用于加速深度网络的训练,但它也有轻微的正则化效果,因为它为每层的输入引入了噪声。

代码示例:使用Dropout正则化

在TensorFlow中,可以通过添加Dropout层来实现Dropout正则化。

model = models.Sequential([layers.Flatten(input_shape=(28, 28)),layers.Dense(128, activation='relu'),layers.Dropout(0.2),  # Dropout层layers.Dense(10, activation='softmax')
])

小结

处理过拟合问题是构建高性能深度学习模型的关键步骤。通过合理选择数据处理方法、调整模型结构、应用正则化技术等手段,可以有效提高模型在未见数据上的泛化能力。在实践中,通常需要尝试不同的策略组合,以找到最适合特定问题和数据集的解决方案。

项目:手写数字识别(MNIST数据集)

手写数字识别是机器学习中的一个经典问题,通常使用MNIST数据集来解决。MNIST数据集包含了70,000张大小为28x28像素的手写数字图像,分为60,000张训练图像和10,000张测试图像。每张图像都标有相应的数字(0到9)。解决这个问题的一个标准方法是使用深度学习模型,特别是卷积神经网络(CNN),因为CNN非常适合处理图像数据。

以下是使用TensorFlow构建和训练一个简单CNN模型进行手写数字识别的步骤:

步骤 1: 导入必要的库

import tensorflow as tf
from tensorflow.keras import layers, models

步骤 2: 加载和准备MNIST数据集

TensorFlow提供了加载MNIST数据集的函数,这些函数会自动下载数据。

mnist = tf.keras.datasets.mnist(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 归一化图像到0-1范围
train_images, test_images = train_images / 255.0, test_images / 255.0

步骤 3: 构建模型

我们构建一个简单的CNN模型,该模型包括两个卷积层,紧跟着两个池化层,然后是两个全连接层。

model = models.Sequential([layers.Reshape(target_shape=(28, 28, 1), input_shape=(28, 28)),  # 将图像格式从(28, 28)转换为(28, 28, 1)layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),layers.MaxPooling2D((2, 2)),layers.Conv2D(64, (3, 3), activation='relu'),layers.MaxPooling2D((2, 2)),layers.Flatten(),layers.Dense(64, activation='relu'),layers.Dense(10, activation='softmax')
])

步骤 4: 编译模型

在训练模型之前,需要指定损失函数、优化器和评估指标。

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

步骤 5: 训练模型

使用训练数据对模型进行训练。

model.fit(train_images, train_labels, epochs=5)

步骤 6: 评估模型

最后,评估模型在测试集上的表现。

test_loss, test_acc = model.evaluate(test_images, test_labels)print('Test accuracy:', test_acc)

总结

通过这个项目,你可以了解到使用卷积神经网络进行图像分类的基本流程。MNIST手写数字识别是深度学习入门的一个非常好的项目,它不仅可以帮助你熟悉TensorFlow的基本操作,还能让你对深度学习模型的构建和训练有一个直观的理解。随着经验的积累,你可以尝试使用更复杂的模型或技术来进一步提高识别准确率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/805069.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

全国水科技大会 免费征集《水环境治理减污降碳协同增效示范案例》

申报时间截止到2024年4月15日,请各单位抓紧申报,申报条件及申报表请联系:13718793867 围绕水环境治理减污降碳协同增效领域,以资源化、生态化和可持续化为导向,面向生态、流城、城市、农村、工业园区、电力、石化、钢…

在VsCode中写vue的css,代码提示一直不出现或提示错误

在我们vue项目正常写css样式,便会出现一下提示,如: 但有时无提示,那么这种情况有以下几种解决方案 观察Vscode插件是否正常 Vetur和Vue - Official是否安装(Vue - Official的前身就是Volar) 安装了检查是否最新版本 确保你的s…

el-table 设置固定列导致行错位的解决方案

element 官方给出的解决办法是使用doLayout,使用doLayout重新加载一下table就好了。 updated() {// tableRef是表格的ref属性值if (this.$refs.tableRef&& this.$refs.tableRef.doLayout) {this.$refs.tableRef.doLayout();}},调整前后效果对比&#xff1a…

lua学习笔记14(协程的学习)

print("*****************************协程的学习*******************************") --创建1 coroutine.create(function()) 使用1 coroutine.resume(co) -- 创建2 co2coroutine.wrap(fun) 使用2 co2() --协程的挂起函数 coroutine.yield() --协程的状态 --c…

跨学科高手揭秘:仿真技术如何改变工程世界

编辑 / 木子 审核 / 朝阳 “在高铁上睡觉,最大的噪音不是来自车轮与铁轨的摩擦声,也不是汽笛的轰鸣,而是巨大的‘嘶嘶’声——那是我大学时期做实验发出的声音。”12月9日,中国科学技术大学2024届毕业生郭骞在“伟骅科技”公众号…

4.19号驱动

1. ARM裸机开发和Linux系统开发的异同 相同点:都是对硬件进行操作 不同点: 有无操作系统 是否具备多进程多线程开发 是否可以调用库函数 操作地址是否相同,arm操作物理地址,驱动操作虚拟地址 2. Linux操作系统的层次 应用层…

(2022级)成都工业学院数据库原理及应用实验二:CASE工具关系模型建模

写在前面 1、基于2022级软件工程/计算机科学与技术实验指导书 2、代码仅提供参考 3、如果代码不满足你的要求,请寻求其他的途径 运行环境 window11家庭版 PowerDesigner 16.1 实验要求 某医院一个门诊部排班管理子系统涉及如下信息: 若干科室&a…

成都百洲文化传媒有限公司靠谱吗?怎么样?

随着互联网的迅猛发展,电子商务行业迎来了前所未有的发展机遇。在这个变革的浪潮中,成都百洲文化传媒有限公司凭借其深厚的行业经验和创新的服务模式,正逐渐成为电商服务领域的新领军者。 一、创新引领,塑造电商服务新标准 成都百…

Windows下docker-compose部署DolphinScheduler

参照:快速上手 - Docker部署(Docker) - 《Apache DolphinScheduler v3.1.0 使用手册》 - 书栈网 BookStack 下载源文件 地址:https://dolphinscheduler.apache.org/zh-cn/download/3.2.1 解压到指定目录,进入apache-dolphinscheduler-xxx-…

vscode开发小程序项目并在微信开发者工具运行

需求:vscode开发uniapp之后在微信开发者工具运行,更改的时候微信开发者也同步更改 创建微信小程序所需插件,在vscode的插件管理里面安装就可以了 1.微信小程序开发工具 2.vscode weapp api 3.vscode wxml 4.vscode wechat 1.创建小程序命…

2024年武汉中级工程师评审学历、论文、业绩有什么要求?

2024年大部分地区职称申报已经开始,今年因为政策变动,基本上需要全员参加水平能力测试,水测通过之后安排评审,那么对于中级职称评审有什么要求呢?我们一起跟甘建二看看。 一、2024年武汉中级工程师职称评审学历要求&am…

Web前端—属性描述符

属性描述符 假设有一个对象obj var obj {a:1 }观察这个对象,我们如何来描述属性a: 值为1可以重写可以遍历 我们可以通过Object.getOwnPropertyDescriptor得到它的属性描述符 var desc Object.getOwnPropertyDescriptor(obj, a); console.log(desc);我…

安卓逆向 | 某X游戏垂类Web nonce

*本案例仅做分析参考,如有侵权请联系删除 1.逻辑分析 通过XHR断点,然后逐步往上调发现nonce生出处。 在console执行下函数 其中 i,是当前日期和时间的秒级时间戳,并将其向下取整到最接近的整数。 i = ~~(+_.w() / 1e3)w</

设计模式之迭代器模式(上)

迭代器模式 1&#xff09;概述 1.概念 存储多个成员对象&#xff08;元素&#xff09;的类叫聚合类(Aggregate Classes)&#xff0c;对应的对象称为聚合对象。 聚合对象有两个职责&#xff0c;一是存储数据&#xff0c;二是遍历数据。 2.概述 迭代器模式(Iterator Patter…

Go语言不能常量取址!?

题如下图 在软件开发中&#xff0c;常量是一种重要的编程元素&#xff0c;它们在程序中起到固定值的作用被大量使用 Go语言中的常量取址 在 Go 语言中&#xff0c;常量是无法被取址的。这意味着我们不能使用取址操作符 & 来获取常量的地址。例如&#xff1a; const a …

【Java EE】关于Spring MVC 响应

文章目录 &#x1f38d;返回静态页面&#x1f332;RestController 与 Controller 的关联和区别&#x1f334;返回数据 ResponseBody&#x1f38b;返回HTML代码片段&#x1f343;返回JSON&#x1f340;设置状态码&#x1f384;设置Header&#x1f338;设置Content-Type&#x1f…

MySQL高级(索引分类-聚集索引-二级索引)

目录 1、主键索引、唯一索引、常规索引、全文索引 2、 聚集索引、二级索引 3、回表查询 4、通过id查询和通过name查询那个执行效率高&#xff1f; 5、 InnoDB主键索引的 B tree 高度为多高呢&#xff1f; 1、主键索引、唯一索引、常规索引、全文索引 在MySQL数据库&#xff0c…

[【JSON2WEB】 13 基于REST2SQL 和 Amis 的 SQL 查询分析器

【JSON2WEB】01 WEB管理信息系统架构设计 【JSON2WEB】02 JSON2WEB初步UI设计 【JSON2WEB】03 go的模板包html/template的使用 【JSON2WEB】04 amis低代码前端框架介绍 【JSON2WEB】05 前端开发三件套 HTML CSS JavaScript 速成 【JSON2WEB】06 JSON2WEB前端框架搭建 【J…

微信小程序picker设置了系统年度,打开选择年份从1年开始显示

背景&#xff1a;开发微信小程序时&#xff0c;使用了picker组件&#xff0c;设置值为当前系统时间年份&#xff0c;可以正常回显年份。但是打开面板选择年份的时候&#xff0c;默认从一年开始显示的。如下图所示。 原因&#xff1a;因为绑定的年份字段为Number类型。 解决方案…