《深度学习》PyTorch框架 优化器、激活函数讲解

目录

一、深度学习核心框架的选择

        1、TensorFlow

                1)概念

                2)优缺点

        2、PyTorch

                1)概念

                2)优缺点

        3、Keras

                1)概念

                2)优缺点

        4、Caffe

                1)概念

                2)优缺点

二、pytorch安装

        1、安装

        2、pytorch分为CPU版本和GPU版本

                1)CPU版本

                2)GPU版本

        3、相关显卡参数

                1)显卡容量

                2)显存频率

                3)显存位宽

                4)如何查看电脑是CPU还是GPU

三、PyTorch框架认识

        1、利用MNIST数据集实现神经网络的图像识别

        2、大致流程

        3、模型的结构

        4、优化器

                1)BGD(Batch Gradient Descent):批量梯度下降法

                2)SGD(Stochastic Gradient Descent):随机梯度下降

                3)Adam(Adaptive Moment Estimation):自适应优化算法

                4)Adagrad(Adaptive Gradient Algorithm):自适应学习率优化算法

                5)RMSprop(Root Mean Square Propagation):自适应学习率优化算法。

                6)小批量梯度下降法(Mini-batch Gradient Descent)

                7)等等多种优化算法

四、激活函数

        1、常见激活函数

                1)Sigmoid

 

                2)ReLU

                3)anh

                4)LeakyReLU

                5)Softmax

        2、梯度消失

        3、梯度爆炸


一、深度学习核心框架的选择

        1、TensorFlow

                1)概念

                     由Google开发并维护的深度学习框架,具有广泛的生态系统和强大的功能。它支持多种硬件平台,包括CPU、GPU和TPU,并且提供易于使用的高级API(如Keras)和灵活的底层API。

                2)优缺点

                      优点:广泛的生态系统和强大的功能、支持跨平台使用......

                      缺点:代码比较冗余,上手有难度......

                     

        2、PyTorch

                1)概念

                        由Facebook开发的深度学习框架,被认为是TensorFlow的竞争者之一。它具有动态计算图的特性,使得模型的定义和训练更加灵活。PyTorch也具有广泛的生态系统,并且在学术界和研究领域非常受欢迎。

                2)优缺点

                      优点:上手极容易,直接套用模板、易于调试和可视化.......

                      缺点:相对较小的生态系统、相对较少的文档和教程资源

        3、Keras

                1)概念

                        一个高级的深度学习框架,在tensorflow基础上做了封装,可以在TensorFlow和Theano等后端上运行。Keras具有简洁的API,使得模型的定义和训练变得简单易用。它适合对深度学习有基本了解的初学者或者快速原型开发。

                2)优缺点

                      优点:简化代码难度、简洁易用的API、多后端支持.....

                      缺点:功能相对有限、性能较差

        4、Caffe

                1)概念

                        一个由贾扬清等开发的深度学习框架,主要面向卷积神经网络(CNN)的应用。Caffe具有高效的C++实现和易于使用的配置文件,是许多计算机视觉任务的首选框架。

                2)优缺点

                      优点:只需要配置文件即可搭建深度神经网络模型

                      缺点:安装麻烦,缺失很多新网络模型,近几年几乎不更新

二、pytorch安装

        1、安装

                安装教程见上一篇博客,连接如下:

https://blog.csdn.net/qq_64603703/article/details/142218264?fromshare=blogdetail&sharetype=blogdetail&sharerId=142218264&sharerefer=PC&sharesource=qq_64603703&sharefrom=from_linkicon-default.png?t=O83Ahttps://blog.csdn.net/qq_64603703/article/details/142218264?fromshare=blogdetail&sharetype=blogdetail&sharerId=142218264&sharerefer=PC&sharesource=qq_64603703&sharefrom=from_link

       

        2、pytorch分为CPU版本和GPU版本

                1)CPU版本

                        CPU又称中央处理器,作为计算机系统的运算控制核心,是信息处理、程序运行的最终执行单元。可以形象地理解为有25%的ALU(运算单元)、有25%的Control(控制单元)、50%的Cache(缓存)单元,如下图所示:

       

                2)GPU版本

                GPU又称图像处理器,是一种专门在个人电脑等一些移动设备上做图像和图形相关运算工作的微处理器。可以形象地理解为90%的ALU(运算单元),5%的Control(控制单元)、5%的Cache(缓存)。

                如上图所示可发现,GPU中的控制单元和缓存的位置在整个模块的左侧一点点,剩下的全部都是运算单元用来计算的,而CPU中控制单元和缓存的位置几乎占了整个模块的一半,大大减少了运算能力,所以pytorch可以安装cuda及相关驱动来调用GPU对模型进行计算,以的到加速运算的目的。

                例如有下列图片,需要对其进行训练,将其传入CPU,可见传入空间几乎占满,而传入GPU却绰绰有余。

        3、相关显卡参数

                1)显卡容量

                        决定着临时存储数据的能力,如 6GB、8GB、24GB、48GB等等

                2)显存频率

                        反应显存的速度,如 1600MHz、1800MHz、3800MHz、5000MHz等

                3)显存位宽

                        一个时钟周期内所能传送数据的位数,如 64、128、192、256、384、448、512。

                4)如何查看电脑是CPU还是GPU

                        右击状态栏打开任务管理器,在性能里即可查看

三、PyTorch框架认识

        1、利用MNIST数据集实现神经网络的图像识别

                代码流程见上节课所学内容

《深度学习》PyTorch 手写数字识别 案例解析及实现 <上>

《深度学习》PyTorch 手写数字识别 案例解析及实现 <下>

       

        2、大致流程

                有如下手写图片,现在想通过训练模型来判断这个手写数字所代表的是什么数字,此时首先使用命令datasets.MNIST下载训练数据集和测试数据集,这两份数据中包含大量的手写数字及其对应的真实数字类型,将这些图片以例如64张图片及其类别打包成一份,然后再在GPU中建立模型,将这些打包好的图片数据信息传入GPU对其进行计算和训练,训练好的模型可以导入上述打包好的测测试集数据进行测试并与真实值对比,然后计算得到准确率。

       

        3、模型的结构

                例如使用神经网络的多层感知器

        4、优化器

                1)BGD(Batch Gradient Descent):批量梯度下降法

                   使用全样本数据计算梯度,例如一个batch_size=64,计算出64个梯度值,好处是收敛次数少。坏处是每次迭代需要用到所有数据,占用内存大耗时大。

                2)SGD(Stochastic Gradient Descent):随机梯度下降

                   从64个样本中随机抽出一组,训练后按梯度更新一次

                   SGD的原理是在每次迭代中,从训练集中随机选择一个样本进行梯度计算,并根据学习率和动量等参数更新模型参数。

                3)Adam(Adaptive Moment Estimation):自适应优化算法

                   结合了动量和RMSprop的思想,Adam使用动量的概念来加速收敛,并根据每个参数的历史梯度自适应地调整学习率。它计算每个参数的自适应学习率,以及每个参数的梯度的指数移动平均方差。

                4)Adagrad(Adaptive Gradient Algorithm):自适应学习率优化算法

                   它为每个参数维护一个学习率,并根据参数的历史梯度调整学习率。Adagrad使用参数的梯度平方和的平方根来缩放学习率,从而对于稀疏参数更加适用。

                5)RMSprop(Root Mean Square Propagation):自适应学习率优化算法。

                   它类似于Adagrad,但引入了一个衰减系数来平衡历史梯度的重要性。RMSprop使用历史梯度的平均值的平方根来调整学习率。

                6)小批量梯度下降法(Mini-batch Gradient Descent)

                   将训练数据集分成小批量用于计算模型误差和更新模型参数。是批量梯度下降法和随机梯度下降法的结合。

                7)等等多种优化算法

四、激活函数

        1、常见激活函数

                1)Sigmoid

                      Sigmoid函数将输入映射到0到1之间的连续值,其将输入转换成概率值,常用于二分类问题。Sigmoid函数的缺点是在输入较大或较小的情况下,梯度接近于0,可能导致梯度消失问题。

       

                2)ReLU

                      ReLU是最常用的激活函数之一。它将输入小于0的值设为0,大于等于0的值保持不变。ReLU的原理是通过引入非线性,使得神经网络能够学习更复杂的函数。ReLU具有简单的计算和导数计算,且能够缓解梯度消失问题。

                3)anh

                      anh函数将输入映射到-1到1之间的连续值。它的原理与Sigmoid函数类似,但输出范围更大。Tanh函数也具有非线性性质,但仍存在梯度消失问题。

                4)LeakyReLU

                      LeakyReLU是ReLU的变体,它在输入小于0时引入小的斜率,使得负数部分也能有一定的激活。LeakyReLU的原理是通过避免ReLU中的“神经元死亡”问题,进一步缓解梯度消失。

                5)Softmax

                      Softmax函数将输入转换为概率分布,用于多分类问题。Softmax的原理是将输入的指数形式归一化,保证输出是一个概率分布,且每个类别的概率和为1。

        2、梯度消失

                指在神经网络的反向传播过程中,梯度逐渐变小并趋近于零的现象。当梯度接近于零时,权重更新的幅度变得非常小,导致网络参数更新缓慢甚至停止更新,从而影响网络的训练效果。

                通常发生在使用一些特定的激活函数和深层神经网络中。当深层网络的激活函数是Sigmoid或Tanh等饱和函数时,这些函数的导数在输入较大或较小的情况下接近于零,导致梯度逐渐缩小。随着反向传播的进行,梯度会传递到浅层网络,导致浅层网络的参数更新缓慢,最终影响整个网络的训练效果。

        

        3、梯度爆炸

                指在神经网络的训练过程中,梯度增长得非常快,导致梯度值变得非常大甚至无限大的现象。当梯度值变得非常大时,权重的更新幅度也会变得非常大,导致网络参数发生剧烈的变化,进而影响网络的稳定性和训练效果。

                梯度爆炸通常发生在使用一些特定的激活函数和深层神经网络中。当深层网络的激活函数是非线性函数时,特别是使用在深层堆叠的神经网络中时,梯度可能会无限制地增大。这是因为在反向传播过程中,梯度会在每个隐藏层传递并相乘,导致梯度指数级地增长。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/54440.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux操作系统:GCC(GNU Compiler Collection)编译器

在 Linux 系统中,gcc(GNU Compiler Collection)是一个非常强大的编译器,主要用于编译 C 语言程序。 除了基本的编译和链接命令外,gcc还提供了许多选项和功能。 以下是一些常用的 gcc命令及其功能: 1. 基本…

Python | Leetcode Python题解之第420题强密码检验器

题目: 题解: class Solution:def strongPasswordChecker(self, password: str) -> int:n len(password)has_lower has_upper has_digit Falsefor ch in password:if ch.islower():has_lower Trueelif ch.isupper():has_upper Trueelif ch.isdi…

基于SpringBoot+Vue的智慧物业管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 精品专栏:Java精选实战项目源码、Python精…

transformer模型进行英译汉,汉译英

上面是在测试集上的表现 下面是在训练集上的表现 上面是在训练集上的评估效果 这是在测试集上的评估效果,模型是transformer模型,模型应该没问题,以上的是一个源序列没加结束符和加了结束符的情况。 transformer源序列做遮挡填充的自注意力,这就让编码器的输出中每个token的语…

寄存器与内存

第三课:寄存器与内存、中央处理器(CPU)、指令和程序及高级 CPU 设计-CSDN博客 锁存器 引入 ABO0(开始状态)001(将A置1)110(将A置0)11 无论怎么做,都没法从1变…

大学生必看!60万人在用的GPT4o大学数学智能体有多牛

❤️作者主页:小虚竹 ❤️作者简介:大家好,我是小虚竹。2022年度博客之星🏆,Java领域优质创作者🏆,CSDN博客专家🏆,华为云享专家🏆,掘金年度人气作者&#x1…

Mamba所需的causal-conv1d 和mamba-ssm库在哪下载?

背景介绍 参照 Mamba [state-spaces/mamba: Mamba SSM architecture (github.com)] github中提到的环境安装[Installation 一栏] [Option] pip install causal-conv1d>1.4.0: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.…

Qt_窗口界面QMainWindow的介绍

目录 1、菜单栏QMenuBar 1.1 使用QMainWindow的准备工作 1.2 在ui文件中设计窗口 1.3 在代码中设计窗口 1.4 实现点击菜单项的反馈 1.5 菜单中设置快捷键 1.6 菜单中添加子菜单 1.7 菜单项中添加分割线和图标 1.8 关于菜单栏创建方式的讨论 2、工具栏QToolBar …

k8s Service 服务

文章目录 一、为什么需要 Service二、Kubernetes 中的服务发现与负载均衡 -- Service三、用例解读1、Service 语法2、创建和查看 Service 四、Headless Service五、集群内访问 Service六、向集群外暴露 Service七、操作示例1、获取集群状态信息2、创建 Service、Deployment3、创…

飞腾计算模块RapidIO性能测试

1、背景介绍 飞腾计算模块采用FT2000 64核处理器,搭配Tsi721 PCIE转RapidIO芯片,实现飞腾平台下的SRIO数据通信。操作系统采用麒麟信安,内核版本4.19.90. 2、驱动加载 驱动加载部分类似之前写过的X86平台下的RapidIO驱动加载,具…

Rsync未授权访问漏洞复现及彻底修复

一、什么是 Rsync? Rsync 是一种广泛使用的文件传输工具,它允许系统管理员和用户通过局域网(LAN)或广域网(WAN)在计算机之间同步文件和目录。Rsync 支持通过本地或远程 shell 访问,也可以作为守…

【Linux】常用指令详解一(ls,-a,-l,-d,cd,pwd,mkdir,touch,rm,clear)

1.前言 读了一些Linux常用指令的博文,很可惜没读到一点点手把手教怎么操作的博文,所以写一篇手把手教适合初学者的Linux常用指令博文 Linux的命令是树状结构 输入这一句命令:yum install -y tree 即可以查看Linux树状目录结构 查看示例&am…

STM32快速复习(十二)FLASH闪存的读写

文章目录 一、FLASH是什么?FLASH的结构?二、使用步骤1.标准库函数2.示例函数 总结 一、FLASH是什么?FLASH的结构? 1、FLASH简介 (1)STM32F1系列的FLASH包含程序存储器、系统存储器和选项字节三个部分&…

pytorch实现RNN网络

目录 1.导包 2. 加载本地文本数据 3.构建循环神经网络层 4.初始化隐藏状态state 5.创建随机的数据,检测一下代码是否能正常运行 6. 构建一个完整的循环神经网络 7.模型训练 8.个人知识点理解 1.导包 import torch from torch import nn from torch.nn imp…

Qt+FFmpeg开发视频播放器笔记(三):音视频流解析封装

音频解析 音频解码是指将压缩的音频数据转换为可以再生的PCM(脉冲编码调制)数据的过程。 FFmpeg音频解码的基本步骤如下: 初始化FFmpeg解码器(4.0版本后可省略): 调用av_register_all()初始化编解码器。 调用avcodec_register_all()注册所有编解码器。 打开输入的音频流:…

pthread_cond_signal 和pthread_cond_wait

0、pthread_join()函数作用: pthread_join() 函数会一直阻塞调用它的线程,直至目标线程执行结束(接收到目标线程的返回值),阻塞状态才会解除。如果 pthread_join() 函数成功等到了目标线程执行结束(成功获取…

运行 xxxxApplication 时出错。命令行过长。 通过 JAR 清单或通过类路径文件缩短命令行,然后重新运行。

一、问题描述 运行 xxxxApplication 时出错。命令行过长。 通过 JAR 清单或通过类路径文件缩短命令行,然后重新运行。 二、问题分析 在idea中,运行一个springboot项目,在使用大量的库和依赖的时候,会出现报错“命令行过长”&…

Java | Leetcode Java题解之第406题根据身高重建队列

题目&#xff1a; 题解&#xff1a; class Solution {public int[][] reconstructQueue(int[][] people) {Arrays.sort(people, new Comparator<int[]>() {public int compare(int[] person1, int[] person2) {if (person1[0] ! person2[0]) {return person2[0] - perso…

Java项目实战II基于Java+Spring Boot+MySQL的车辆管理系统(开发文档+源码+数据库)

目录 一、前言 二、技术介绍 三、系统实现 四、论文参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 "随着…

Arthas jvm(查看当前JVM的信息)

文章目录 二、命令列表2.1 jvm相关命令2.1.3 jvm&#xff08;查看当前JVM的信息&#xff09; 二、命令列表 2.1 jvm相关命令 2.1.3 jvm&#xff08;查看当前JVM的信息&#xff09; 基础语法&#xff1a; jvm [arthas18139]$ jvmRUNTIME …