人工智能——深度学习

4. 深度学习

4.1. 概念

深度学习是一种机器学习的分支,旨在通过构建和训练多层神经网络模型来实现数据的高级特征表达和复杂模式识别。与传统机器学习算法相比,深度学习具有以下特点:

  • 多层表示学习:深度学习使用深层神经网络,允许多个层次的特征表达和抽象,从而能够自动发现和提取输入数据中的重要特征。
  • 端到端学习:通过将输入直接映射到输出,深度学习可以实现端到端的学习,无需手工设计特征提取器或预处理步骤。
  • 大规模并行计算:深度学习模型通常需要进行大量的矩阵运算,在现代硬件(如GPU)上可以进行高效的并行计算,加快了训练和推断的速度。
  • 梯度下降优化:深度学习模型通常使用梯度下降等优化算法来最小化损失函数,并通过反向传播算法有效地更新网络参数。
  • 泛化能力强:深度学习模型具有很强的泛化能力,能够在未见过的数据上进行准确的预测和分类。

深度学习在各个领域都取得了重大突破,包括计算机视觉、自然语言处理、语音识别等。它已经应用于图像分类、目标检测、机器翻译、智能助手等众多任务中,并在许多比赛和实际应用中取得优秀的结果。

神经网络功能强大,并且深度学习则是优化了数据分析,建模过程,因此基于神经网络的深度学习可以统一原来的传统机器学习。AlphaGo是深度学习战胜了世界围棋第一人李世石。2016年Google翻译基于深度学习更新,翻译能力得到大幅提升。最新的Google翻译是基于大语言模型。

4.2. 神经网络

4.2.1. 定义

4.2.1.1. 概念

深度学习是一种基于人工神经网络的机器学习方法,其核心思想是通过多层次的神经网络来模拟人脑的神经元之间的连接。深度学习的特点是可以通过大规模的数据来训练模型,并且可以自动学习到数据的特征表示。

上图就是一个神经网络的基本结构图,X1到Xn是输入,O1到Oj是输出,圆圈是神经元(也称感知机),连线带权重参与计算生成下一个神经元。隐层在实际的神经网络中可能会多层,并且都是全连接,所以计算量巨大,所以需要AI CPU、AI GPT等。

4.2.1.2. 感知机

如下图是一个神经元,其有3个输入数据x附加不同的权重w,另外有一个偏置(可以理解为线性函数中的截距)。

h = f(u + b)

我们先假设没有激活函数,来看下神经元的效果。

  • 如下图,单层单个元神经元可以用来作为分类。

  • 如下图,单层多个神经元可以完全更精细的线性分类分类

通过上面的示例可以看出,在没有激活函数的情况下,无论在多少个神经元作用下,其都是使用累加计算的,总是一阶的,总是线性的。线性函数只能处理一些简单的场景,复杂场景多是需要用曲线或曲面来区分的。如下图,用线段无法区分大小写字母。

我们就需要在神经元加一个函数来加强其能力,这就是激活函数的作用,它让神经元具备非线性表达能力。

我们让激活函数为Sigmoid函数,那么直线与Sigmoid函数相乘,就变成了曲线。

那么在3个神经元在3个激活函数的作用下,就可以形成3条曲线。

3条曲线在不同的权重作用下,可以拟合为一条新的形状,可以达到区分大小写字母的能力。

理论上来说,只要神经元足够多,无论多么复杂的分类都可以实现。

激活函数

激活函数可以选择不同的函数,Sigmoid是以前比较受欢迎的激活函数,但是其存在一些问题。当权重很小时,Sigmoid函数的作用也很小,容易导致梯度消失(简单讲是指区别度不大,导致学习的效率不佳)。ReLU系列的激活函数包括ReLU、Leaky ReLU、PReLU、ELU。

不同的激活函数有不同的应用场景,不同的计算量,需要根据经验进行选择调整。

softmax回归

为了结果更清晰,好对比,我们可能需要将结果进行归一化处理(归一化也被称为单位化,即所有结果之和为1)。

经过softmax回归计算之后,输出的结果可能是这样的:

4.2.2. 损失函数

如何评估神经网络的效果,我们就需要用到损失函数。损失函数(Loss Function)用来估量模型的预测值 f(x) 与真实值 y 的偏离程度。因为误差有正有负,所以可以采用平均绝对误差,均方误差(平均平方误差, Mean Squared Error ,MSE),这些多用于回归问题。用于二分类问题(是/否,对与错),多用交叉熵损失函数(CrossEntropy Loss)。多分类问题,可以用softmax函数,如上图的动物分类。

交叉熵损失函数(CrossEntropy Loss)

熵是用来描述物体混乱程度的概念,越混乱熵越大,也可以理解为数据越随机熵就越大。信息熵越大,事物越具不确定性,事物越复杂。

信息熵公式:

交叉熵主要用于度量两个概率分布之间的差异性。交叉熵越小,表示模型输出分布越接近真实值分布。

在机器学习框架中,交叉熵都有直接提供接口,我们只需要知道交叉熵的概念及其应用场景,知道使用即可。

4.2.3. 计算

4.2.3.1. 前向传播

有如下一个神经网络,3个输入,2个输出,单层神经网络有4个神经元。

转换为数学形式:

一步步从前往后进行计算,这就是前向传播计算。x1、x2、x3总是一起参与计算,其总的输出可以用一个矩阵[x1, x2, x3]表示,所以在神经网络的计算是,需要大量的矩阵计算。所以现在有很专用用于神经网络计算的神经网络处理器(Neural network Processing Unit, NPU)。

前向传播主要用于预测结果。

4.2.3.2. 反向传播

在神经网络学习的过程中,我们通过误差函数来求一个最小误差时的权重和截距(神经网络中叫偏置)。我们可以使用最小二乘法,也可以使用梯度下降法。使用最小二乘法效果好,但是计算量非常大,尤其是在大型神经网络中,如果使用最小二乘法计算量巨大,所以一般使用梯度下降法,梯度下降使用学习率(权重的步进值)这个超参数来控制下降的速度,来提升计算速度。

梯度下降是通过误差函数反向往前推的,所以也被称为反向传播。反向传播主要用于学习(训练)。

4.2.2. 分类

深度学习的主要方法包括卷积神经网络(CNN)、循环神经网络(RNN)和生成对抗网络(GAN)等。

1. 卷积神经网络:

卷积神经网络是一种专门用于处理具有网格结构的数据(如图像和语音)的神经网络。卷积神经网络通过卷积层、池化层和全连接层等组件来提取图像的特征表示,从而实现图像分类、目标检测和图像生成等任务。

2. 循环神经网络:

循环神经网络是一种可以处理序列数据(如语言和时间序列)的神经网络。循环神经网络通过循环连接来处理序列数据的时序信息,并且可以自动学习到序列数据的上下文信息。循环神经网络在自然语言处理、语音识别和机器翻译等领域有广泛应用。

传统的循环神经网络是全连接的,并不关注数据的前后顺序(如语言的前后顺序或时间序列等)。RNN中每个神经元的输出,不仅仅有上一层神经元的输出,还可能把数据序列前处理神经元的输出作为输入。

因为CNN增加了输入,计算量增加了。为了优化RNN,引入了LSTM(长短期记忆网络),减少计算量,并优化了前后依赖关系。

3. 生成对抗网络:

生成对抗网络是一种由生成器和判别器组成的对抗性模型。生成器通过学习训练数据的分布,生成与训练数据相似的新样本;判别器则通过学习区分真实样本和生成样本。生成对抗网络在图像生成、图像修复和文本生成等任务中取得了重要的突破。

4.3. 学习过程

4.3.1. 步骤

  1. 数据准备:收集和预处理数据,使其适合神经网络训练。这可能包括清理数据、删除异常值和对数据进行编码。
  2. 网络架构:设计神经网络的架构,包括层数、神经元数和连接方式。
  3. 初始化权重和偏差:为网络中的权重和偏差分配初始值。
  4. 前向传播:将输入数据通过网络,计算每个神经元的输出。
  5. 计算损失:将网络输出与预期输出进行比较,计算损失函数的值。
  6. 反向传播:使用链式法则计算损失函数相对于网络权重和偏差的梯度。
  7. 权重更新:使用梯度下降或其他优化算法更新网络权重和偏差,以减少损失。
  8. 重复步骤 4-7:重复前向传播、计算损失和反向传播的步骤,直到损失函数达到最小值或达到预定义的训练迭代次数。

4.3.2. 超参数

超参数(Hyperparameter)是机器学习模型中需要人为设定的参数,它们不是通过训练数据自动学习得到的,而是需要人工指定的参数。训练深度神经网络涉及调整以下超参数:

  • 学习率:控制权重更新的步长。
  • 批大小:每次前向和反向传播处理的数据样本数。
  • 正则化:防止过拟合的技术,例如权重衰减和 dropout。
  • 激活函数:神经元输出的非线性函数。
  • 优化器:用于更新权重的算法,例如梯度下降和 Adam。每次训练完成需要更新权重参数,直到损失函数达到要求,退出训练。

4.3.3. 挑战

训练深度神经网络的挑战

  • 过拟合:网络在训练数据上表现良好,但在新数据上表现不佳。
  • 欠拟合:网络无法从训练数据中学到足够的模式。
  • 梯度消失和爆炸:在反向传播过程中,梯度可能变得非常小或非常大,这会阻碍训练。
  • 局部最小值:优化算法可能收敛到局部最小值,而不是全局最小值。

4.3.4. 最佳实践

训练深度神经网络的最佳实践

  • 使用交叉验证来防止过拟合。
  • 使用正则化技术来减少过拟合。
  • 仔细调整超参数以获得最佳性能。
  • 使用早期停止来防止过拟合。
  • 使用权重初始化技术来防止梯度消失和爆炸。

4.4. 应用

深度学习理论上可以完全替代传统的机器学习算法,只要神经元足够,训练数据足够。传统机器学习能够达到的效果,深度学习都可以达到,并且可以拟合得更好。大力出奇迹在深度学习中完美体现。

4.4.1. 多元线性回归

如下13个自变量(输入),一个因变量(输出),因为是线性回归,只用一个神经元,且不需要激活函数。训练完成生成模型之后,可以保存模型,下次就直接使用模型来进行预测了。

神经网络识别手写数字

这段代码是使用 PyTorch 实现一个简单的全连接神经网络,用于在 MNIST 数据集上进行手写数字识别。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# Hyper-parameters 
input_size = 784
hidden_size = 500
num_classes = 10
num_epochs = 5
batch_size = 100
learning_rate = 0.001# MNIST dataset download
train_dataset = torchvision.datasets.MNIST(root='../../data', train=True, transform=transforms.ToTensor(),  download=True)test_dataset = torchvision.datasets.MNIST(root='../../data', train=False, transform=transforms.ToTensor())# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)# Fully connected neural network with one hidden layer
class NeuralNet(nn.Module):def __init__(self, input_size, hidden_size, num_classes):super(NeuralNet, self).__init__()self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU()self.fc2 = nn.Linear(hidden_size, num_classes)  def forward(self, x):out = self.fc1(x)out = self.relu(out)out = self.fc2(out)return outmodel = NeuralNet(input_size, hidden_size, num_classes).to(device)# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  # Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):  # Move tensors to the configured deviceimages = images.reshape(-1, 28*28).to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 100 == 0:print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item()))# Test the model
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, 28*28).to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))# Save the model checkpoint(模型检测点,也即模型参数,供后续直接加载使用)
torch.save(model.state_dict(), 'model.ckpt')

执行结果:

4.4.3. 其他

深度神经网络推荐使用Pytorch。

Github上的代码示例:GitHub - yunjey/pytorch-tutorial: PyTorch Tutorial for Deep Learning Researchers

GitHub - zergtant/pytorch-handbook: pytorch handbook是一本开源的书籍,目标是帮助那些希望和使用PyTorch进行深度学习开发和研究的朋友快速入门,其中包含的Pytorch教程全部通过测试保证可以成功运行

GitHub - chenyuntc/pytorch-book: PyTorch tutorials and fun projects including neural talk, neural style, poem writing, anime generation (《深度学习框架PyTorch:入门与实战》)

GPU测试平台,可以利用Google的免费在线虚拟机器:https://colab.research.google.com/

或阿里云魔搭社区虚拟机,GPT免费36小时:魔搭社区

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/804244.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

麒麟KOS删除鼠标右键新建菜单里不需要的选项

原文链接:麒麟KOS删除鼠标右键新建菜单里不需要的选项 Hello,大家好啊!在日常使用麒麟KOS操作系统时,我们可能会发现鼠标右键新建菜单里包含了一些不常用或者不需要的选项。这不仅影响我们的使用效率,也让菜单显得杂乱…

新能源电力行业设备点巡检系统的应用

新能源电力行业正日益成为全球能源结构的重要支柱,其设备点巡检系统作为确保电力设施安全、高效运行的关键环节,正受到业界的广泛关注和应用。 设备点巡检系统是一种集数据采集、实时监控、智能分析于一体的现代化管理工具。在新能源电力行业中&#xff…

Java常见算法_常见的查找算法和排序算法——简介及代码演示

在本文中我将介绍Java中的常见算法,查找算法包括基本查找、二分查找、插值查找和分块查找。排序算法包括冒泡排序、选择排序、插入排序和快速排序 查找算法: 1.基本查找: 代码: public class BasicSearchDemo {public static …

SpringMVC:搭建第一个web项目并配置视图解析器

👉需求:用spring mvc框架搭建web项目,通过配置视图解析器达到jsp页面不得直接访问,实现基本的输出“hello world”功能。👩‍💻👩‍💻👩‍💻 1 创建web项目 1…

如何解决Python包管理问题:ERROR: Could not find a version that satisfies the requirement

如何解决Python包管理问题:“ERROR: Could not find a version that satisfies the requirement” 文章目录 如何解决Python包管理问题:“ERROR: Could not find a version that satisfies the requirement”错误描述问题分析解决方案检查包名确保网络连…

【JVM】面试题汇总

JVM1. 什么是JVM?2. 了解过字节码文件的组成吗?3. 什么是运行时数据区4. 哪些区域会出现内存溢出5. JVM在JDK6-8之间在内存区域上有什么不同 6. 类的生命周期 7. 什么是类加载器?类加载器有哪几种 8. 什么是双亲委派机制?有什么好…

“国字号”荣誉、全国试点,侨乡群众身边的“放心”公证处

日前,我市五邑公证处获评“全国公共法律服务工作先进集体”称号。 走进公证处,首先映入眼帘的是一间宽敞明亮的大厅,办证点内还设置多个独立办证室,工作人员热情地为前来办理业务的市民提供专业、人性化的公证服务。江门市五邑公证…

Windows上面搭建Flutter Android运行环境

Flutter Android环境搭建 电脑上面安装配置JDK电脑上下载安装Android Studio电脑上面下载配置Flutter Sdk (避坑点一)下载SDK配置对应的环境变量 到path 电脑上配置Flutter国内镜像运行 flutter doctor命令检测环境是否配置成功创建运行Flutter项目&…

ARM单片机的GPIO口在控制不同LED、按键时的设置

个人备忘,不喜勿喷。 GPIO口在驱动共阴极、共阳极LED灯时需要不同的初始化设置 对于这一类的led灯: 最好选择推挽、上拉、高速输出,同时IO口初始化时需要拉高。 上面这种需要下拉输入; 上图这种需要上拉输入,这样才…

vue点击上传图片并实现图片预览功能,并实现多张图片放到一个数组中进行后端请求(使用原生input)

一、将 File 对象转成 BASE64 字符串 &#xff08;FileReader&#xff09; <template><div><!-- 用来显示封面的图片 --><!-- <img src"/assets/images/cover.jpg" alt"" class"cover-img" ref"imgRef" />…

html基础(2)(链接、图像、表格、列表、id、块)

1、链接 <a href"https://www.example.com" target"_blank" title"Example Link">Click here</a> 在上示例中&#xff0c;定义了一个链接&#xff0c;在网页中显示为Click here&#xff0c;鼠标悬停指示为Example Link&#xff0c…

Java(多线程)

一、基本概念 进程&#xff1a;一个具有一定独立功能的程序关于某个数据集合的一次运行活动。它是操作系统动态执行的基本单元&#xff0c;在传统的操作系统中&#xff0c;进程既是基本的分配单元&#xff0c;也是基本的执行单元。线程&#xff1a;操作系统中能够进行运算的最…

java Web课程管理系统用eclipse定制开发mysql数据库BS模式java编程jdbc

一、源码特点 JSP 课程管理系统是一套完善的web设计系统&#xff0c;对理解JSP java 编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCAT7.0,eclipse开发&#xff0c;数据库为Mysql5.0&#xff0c;使用ja…

贪心算法|406.根据身高重建队列

力扣题目链接 class Solution { public:static bool cmp(const vector<int>& a, const vector<int>& b) {if (a[0] b[0]) return a[1] < b[1];return a[0] > b[0];}vector<vector<int>> reconstructQueue(vector<vector<int>…

骑砍2霸主MOD开发(2)-基础开发环境搭建

一.骑砍2霸主程序架构 二.骑砍2霸主C#接口层代码查看 1.C#反编译工具dnspy下载: 2.骑砍2霸主游戏引擎接口查看: 例如IMBAgent interface接口: #调用TaleWorlds.Native.dll中的函数 [EngineMethod("get_movement_flags", false)] uint GetMovementFlags(UIntPtr agen…

Visual Studio Code SSH 连接远程服务器

Visual Studio Code通过 SSH 连接远程服务器并实现免密登录&#xff0c;你可以按照以下步骤进行操作&#xff1a; 1. **安装插件**&#xff1a;首先&#xff0c;在 VS Code 中安装 "Remote - SSH" 插件。打开 VS Code&#xff0c;点击左侧的扩展图标&#xff0c;搜索…

微服务学习3

目录 1.微服务保护 1.1.服务保护方案 1.1.1.请求限流 1.1.2.线程隔离 1.1.3.服务熔断 1.2.Sentinel 1.2.1.微服务整合 1.2.2.请求限流 1.3.线程隔离 1.3.1.OpenFeign整合Sentinel 1.3.2.配置线程隔离 1.4.服务熔断 1.4.1.编写降级逻辑 1.4.2服务熔断 2.分布式事…

mp4转flv怎么转?电脑怎么把视频转成flv?

MP4&#xff08;MPEG-4 Part 14&#xff09;是一种多媒体容器格式&#xff0c;广泛用于包含视频、音频、字幕等多种数据流。MP4因其高度灵活性、压缩效率和兼容性成为视频领域的主流格式&#xff0c;支持范围涵盖从在线视频到移动设备的各类应用场景。 FLV文件格式的多个优点 …

scFed:联邦学习用于scRNA-seq分类

scRNA-seq的出现彻底改变了我们对生物组织中细胞异质性和复杂性的理解。然而&#xff0c;大型&#xff0c;稀疏的scRNA-seq数据集的隐私法规对细胞分类提出了挑战。联邦学习提供了一种解决方案&#xff0c;允许高效和私有的数据使用。scFed是一个统一的联邦学习框架&#xff0c…

Spring Validation解决后端表单校验

NotNull&#xff1a;从前台传递过来的参数不能为null,如果为空&#xff0c;会在控制台日志中把message打印出来 Range&#xff1a;范围&#xff0c;最大多少&#xff0c;最小多少 Patten&#xff0c;标注的字段值必须符合定义的正则表达式&#xff08;按照业务规则&#xff0…