pytorch常用的模块函数汇总(1)

目录

torch:核心库,包含张量操作、数学函数等基本功能

torch.nn:神经网络模块,包括各种层、损失函数和优化器等

torch.optim:优化算法模块,提供了各种优化器,如随机梯度下降 (SGD)、Adam、RMSprop 等。

torch.autograd:自动求导模块,用于计算张量的梯度


torch:核心库,包含张量操作、数学函数等基本功能

  1. torch.tensor(): 创建张量
  2. torch.zeros()torch.ones(): 创建全零或全一张量
  3. torch.rand(): 创建随机张量
  4. torch.from_numpy(): 从 NumPy 数组创建张量
  5. torch.add()torch.sub()torch.mul()torch.div(): 加法、减法、乘法、除法

  6. torch.mm()torch.matmul(): 矩阵乘法

  7. torch.exp()torch.log()torch.sin()torch.cos(): 指数、对数、正弦、余弦等数学函数

  8. torch.index_select(): 按索引选取张量的子集

  9. torch.masked_select(): 根据掩码选取张量的子集

    切片操作:类似 Python 中的列表切片操作,如 tensor[2:5]
  10. torch.view()torch.reshape(): 改变张量的形状

  11. torch.squeeze()torch.unsqueeze(): 压缩或扩展张量的维度

  12. torch.mean()torch.sum()torch.max()torch.min(): 计算张量均值、和、最大值、最小值等

  13. torch.broadcast_tensors(): 对张量进行广播操作

  14. torch.cat(): 拼接张量

  15. torch.stack(): 堆叠张量

  16. torch.split(): 分割张量

torch.nn:神经网络模块,包括各种层、损失函数和优化器等

  • 神经网络层

    • torch.nn.Linear(in_features, out_features): 全连接层,进行线性变换。
    • torch.nn.Conv2d(in_channels, out_channels, kernel_size): 2D卷积层。
    • torch.nn.MaxPool2d(kernel_size): 2D 最大池化层。
    • torch.nn.ReLU(): ReLU 激活函数。
    • torch.nn.Sigmoid(): Sigmoid 激活函数。
    • torch.nn.Dropout(p): Dropout 层,用于防止过拟合。

备注:Sigmoid 激活函数是一种常用的非线性激活函数,其作用可以总结如下:

将输入映射到 (0, 1) 范围内:输出范围在 0 到 1 之间,可以将任意实数输入映射到 0 到 1 之间。这种特性在某些情况下很有用,比如对于二分类任务,Sigmoid 函数的输出可以被解释为样本属于正类的概率。

引入非线性变换: Sigmoid 函数是一种非线性函数,可以引入神经网络的非线性变换能力,使得神经网络可以学习更加复杂的模式和关系。在深度神经网络中,非线性激活函数的使用可以帮助神经网络学习非线性模式,提高网络的表达能力。

输出平滑且连续: Sigmoid 函数具有平滑的 S 形曲线,在定义域内都是可导的,这使得在反向传播算法中计算梯度变得相对容易。这一点对于神经网络的训练至关重要。

  • 损失函数

torch.nn.CrossEntropyLoss(): 交叉熵损失函数,常用于多分类问题。

交叉熵损失函数用于衡量两个概率分布之间的差异,通常用于多分类任务中。在神经网络的多分类任务中,输入模型的输出是一个概率分布,表示每个类别的预测概率,而交叉熵损失函数则用于比较这个预测概率分布与实际标签的分布之间的差异。

torch.nn.CrossEntropyLoss() 来计算交叉熵损失函数,它会自动将模型的输出通过 Softmax 函数转换为概率分布,并计算交叉熵损失。

torch.nn.MSELoss(): 均方误差损失函数,常用于回归问题。

均方误差损失函数用于衡量模型输出与实际目标之间的差异,通常在回归任务中使用。该损失函数计算预测值与真实值之间的平方差,并将所有样本的平方差求平均得到最终的损失值。

  • 优化器

    • torch.optim.SGD(model.parameters(), lr=learning_rate): 随机梯度下降优化器。
    • torch.optim.Adam(model.parameters(), lr=learning_rate): Adam 优化器。
  • 模型定义相关

    • torch.nn.Module: 所有神经网络模型的基类,需要继承这个类。
    • model.forward(input_tensor): 定义前向传播。
  • 数据处理相关

    • torch.utils.data.Dataset: PyTorch 数据集的基类,需要自定义数据集时使用。
    • torch.utils.data.DataLoader(dataset, batch_size, shuffle): 数据加载器,用于批量加载数据。
  • torch.optim:优化算法模块,提供了各种优化器,如随机梯度下降 (SGD)、Adam、RMSprop 等。

  • 优化器(Optimizer)类

    • torch.optim.SGD(params, lr=0.01, momentum=0, weight_decay=0):随机梯度下降优化器,实现了带动量的随机梯度下降
    • torch.optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):Adam 优化器,结合了动量方法和 RMSProp 方法,通常在深度学习中表现良好。
    • torch.optim.Adagrad(params, lr=0.01, lr_decay=0, weight_decay=0):Adagrad 优化器,自适应地为参数分配学习率。它根据参数的历史梯度信息对每个参数的学习率进行调整。这意味着对于不同的参数,Adagrad可以为其分配不同的学习率,从而更好地适应参数的更新需求
    • torch.optim.RMSprop(params, lr=0.01, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):RMSprop 优化器,有效地解决了 Adagrad 学习率下降较快的问题(RMSprop对梯度平方项进行指数加权平均)。

备注:

1.在优化算法中,momentum(动量)是一种用于加速模型训练的技巧。动量项的引入旨在解决标准随机梯度下降在训练过程中可能遇到的震荡和收敛速度慢的问题。

动量项的引入可以帮助优化算法在参数更新时更好地利用之前的更新方向,从而在一定程度上减少参数更新的波动,加快收敛速度,并有助于跳出局部极小值。具体来说,动量项在参数更新时会考虑之前的更新方向,并对当前的更新方向进行一定程度的调整。

在 PyTorch 的 torch.optim.SGD 中,动量可以通过设置 momentum 参数来控制。通常情况下,动量的取值范围在 0 到 1 之间,常见的默认取值为 0.9。当动量设为0.9时,每次迭代,都会保留上一次速度的 90%,并使用当前梯度微调最终的更新方向。

总结来说,动量项的引入可以提高随机梯度下降的稳定性和收敛速度,有助于在训练神经网络时更快地找到较优的解。

2. 

在优化算法中,weight decay(权重衰减)是一种用于控制模型参数更新的正则化技术。权重衰减通过在优化过程中对参数进行惩罚,防止其取值过大,从而有助于降低过拟合的风险。

具体来说,在SGD中的weight_decay参数是对模型的权重进行L2正则化,即在计算梯度时额外增加一个关于参数的惩罚项。这个惩罚项会使得优化算法更倾向于选择较小的权重值,从而降低模型的复杂度,减少过拟合的风险。

在PyTorch中,torch.optim.SGD中的weight_decay参数用于控制权重衰减的程度。通常情况下,weight_decay的取值为一个小的正数,比如 0.001 或 0.0001。设置了weight_decay之后,在计算梯度时会额外考虑到权重的惩罚项,从而影响参数的更新方式。

总结来说,权重衰减是一种正则化技术,通过对模型参数的惩罚来控制模型的复杂度,减少过拟合的风险,提高模型的泛化能力。

3.

  • L1正则化:L1正则化会给模型的损失函数添加一个关于权重绝对值的惩罚项,即L1范数(权重的绝对值之和)。在梯度下降过程中,L1正则化会导致部分权重直接变为0,因此可以实现稀疏性,有特征选择的效果。L1正则化倾向于产生稀疏的权重矩阵,可以用于特征选择和降维。

L1正则化的惩罚项是模型权重的L1范数,即权重的绝对值之和。在优化过程中,为了最小化损失函数并减少正则化项的影响,优化算法会尝试将权重调整到较小的值。由于L1正则化的几何形状在坐标轴上拐角处就会与坐标轴相交,这就导致了在坐标轴上许多点都是对称的,因此在这些点上的梯度不唯一。这意味着在这些对称点上,优化算法更有可能将权重调整为0,从而导致稀疏性。

  • L2正则化:L2正则化会给模型的损失函数添加一个关于权重平方的惩罚项,即L2范数的平方(权重的平方和)。在梯度下降过程中,L2正则化会使得权重都变得比较小,但不会直接导致稀疏性。L2正则化对异常值比较敏感,因为它会平方每个权重,使得异常值对损失函数的影响更大。

  • 总的来说,L1正则化和L2正则化都是常用的正则化技术,它们在模型训练过程中都有助于控制模型的复杂度,减少过拟合的风险。选择使用哪种正则化方法通常取决于具体的问题和数据特点,以及对模型稀疏性的需求。在实际应用中,有时也会将L1和L2正则化结合起来,形成弹性网络正则化(Elastic Net regularization),以兼顾两种正则化方法的优势。

4. 

betas 是 Adam 算法中的两个超参数之一,它控制了梯度的一阶矩估计和二阶矩估计的指数衰减率。betas 是一个长度为2的元组,通常形式为 (beta1, beta2)。在 Adam 算法中,beta1 控制了一阶矩估计(梯度的均值)的衰减率,beta2 控制了二阶矩估计(梯度的平方的均值)的衰减率。

通常情况下,beta1 的默认值为 0.9,beta2 的默认值为 0.999。这意味着在每次迭代中,一阶矩估计将保留当前梯度的 90%,而二阶矩估计将保留当前梯度的平方的 99.9%。这些衰减率的选择使得 Adam 算法能够在训练过程中自适应地调整学习率,并对梯度的变化做出快速或缓慢的响应,从而更有效地更新模型参数。

总之,betas 参数在 Adam 算法中起着调节梯度一阶和二阶矩估计衰减率的作用,通过合理设置 betas 可以影响算法的收敛性和稳定性。

  • 调整学习率的函数

    • torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)根据给定的函数 lr_lambda 调整学习率。
    • torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1):每个 step_size 个 epoch 将学习率降低为原来的 gamma 倍。
    • torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1):在指定的里程碑上将学习率降低为原来的 gamma 倍。
  • 其他常用函数

    • zero_grad():用于将模型参数的梯度清零,通常在每个 batch 后调用。
    • step(closure):用于执行单步优化器的更新,需要传入一个闭包函数 closure
    • state_dict() 和 load_state_dict():用于保存和加载优化器的状态字典,方便恢复训练。
  • torch.autograd:自动求导模块,用于计算张量的梯度

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/782550.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

手机投屏到windows11电脑

1 安装无线投影组件 2 电脑端打开允许其他设备投影的开关 3 手机找到投屏选项 4 手机搜索可用设备连接即可 这里的官方文档给的不太好,给了一些让人眼花撩乱的信息,以下是经过整合的有效信息

每日一题 --- 四数之和[力扣][Go]

四数之和 题目:18. 四数之和 给你一个由 n 个整数组成的数组 nums ,和一个目标值 target 。请你找出并返回满足下述全部条件且不重复的四元组 [nums[a], nums[b], nums[c], nums[d]] (若两个四元组元素一一对应,则认为两个四元组…

FL Studio21.2.3中文版软件新功能介绍及下载安装步骤教程

FL Studio21.2中文版的适用人群非常广泛,主要包括以下几类: FL Studio 21 Win-安装包下载如下: https://wm.makeding.com/iclk/?zoneid55981 FL Studio 21 Mac-安装包下载如下: https://wm.makeding.com/iclk/?zoneid55982 音乐制作人&#xff1a…

开发指南020-banner

<dependency><groupId>org.qlm</groupId><artifactId>qlm-common</artifactId><version>1.0-SNAPSHOT</version> </dependency> 以上组件封装了平台的banner&#xff0c;不做任何配置的话&#xff0c;将输出平台的banner 想修…

二维码门楼牌管理应用平台建设:三维白模数据建设的意义

文章目录 前言一、三维白模数据建设的意义二、二维码门楼牌管理系统的构建三、二维码门楼牌管理系统的优势四、面临的挑战与未来展望 前言 随着城市管理的精细化和智能化需求日益增强&#xff0c;二维码门楼牌管理应用平台的建设成为推动城市管理现代化的重要手段。本文将探讨…

预处理、编译、汇编、链接过程

预处理、编译、汇编、链接过程 预处理 引入头文件 #include 展开宏定义 #define 处理条件编译指令 #ifdef 删除注释 添加行号 在Linux下可以使用gcc -E命令把hello.c文件预处理成hello.i文件。windows这些操作都集成在编译器visual studio这些里面了。 编译 进行语法分…

第几个幸运数字(蓝桥杯)

文章目录 第几个幸运数字题目描述答案&#xff1a;1905生成法C代码代码详细注释代码思路解释 第几个幸运数字 题目描述 本题为填空题&#xff0c;只需要算出结果后&#xff0c;在代码中使用输出语句将所填结果输出即可。 到x星球旅行的游客都被发给一个整数&#xff0c;作为…

Opencv C++和Python教程

1、何为Opencv? OpenCV是一个开源的计算机视觉和机器学习库,它提供了丰富的图像处理和计算机视觉算法,如图像处理、目标检测、人脸识别、物体跟踪等。OpenCV最初由英特尔公司发起,现在是由社区维护和开发。OpenCV支持多种编程语言,如C++、Python、Java等,可以在不同的操…

软考高级架构师:安全模型概念和例题

作者&#xff1a;明明如月学长&#xff0c; CSDN 博客专家&#xff0c;大厂高级 Java 工程师&#xff0c;《性能优化方法论》作者、《解锁大厂思维&#xff1a;剖析《阿里巴巴Java开发手册》》、《再学经典&#xff1a;《Effective Java》独家解析》专栏作者。 热门文章推荐&am…

软考高级架构师:信息安全保护等级

作者&#xff1a;明明如月学长&#xff0c; CSDN 博客专家&#xff0c;大厂高级 Java 工程师&#xff0c;《性能优化方法论》作者、《解锁大厂思维&#xff1a;剖析《阿里巴巴Java开发手册》》、《再学经典&#xff1a;《Effective Java》独家解析》专栏作者。 热门文章推荐&am…

Java接口实战:模拟咖啡制作、订购与消费完整流程(day14)

定义接口&#xff1a; // 咖啡制作接口 interface CoffeeMaker { Coffee makeCoffee(String type); } // 咖啡店接口 interface CoffeeShop { void orderCoffee(String type, CoffeeConsumer consumer); } // 咖啡消费者接口 interface CoffeeConsumer { void …

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《考虑新能源发电商租赁共享储能的电力市场博弈分析》

本专栏栏目提供文章与程序复现思路&#xff0c;具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

解决前后端通信跨域问题

因为浏览器具有同源策略的效应。 同源策略是一个重要的网络安全机制&#xff0c;用于Web浏览器中&#xff0c;以防止一个网页文档或脚本来自一个源&#xff08;域、协议和端口&#xff09;&#xff0c;获取另一个源的数据。同源策略的目的是保护用户的隐私和安全&#xff0c;防…

maven pom relativePath属性的作用

maven pom relativePath属性的作用 文章目录 maven pom relativePath属性的作用一、relativePath出现的地方二、relativePath默认值三、四、<relativePath>一个pom路径 一、relativePath出现的地方 搭建maven项目&#xff0c;子模块指定父模块试&#xff0c;经常会在par…

专升本-信息技术介绍

信息技术是什么&#xff1f; 用于管理和处理信息所采用的各种技术的总称 以电子计算机和现代通信为主要手段 位于信息科学体系的技术应用层次 新一代信息技术有哪些&#xff1a; 代表性&#xff1a;人工智能&#xff0c;量子信息&#xff0c;移动通信&#xff0c;物联网&a…

Ubuntu中文输入法设置指南:轻松上手,畅享输入体验

在Linux的世界里,Ubuntu以其强大的功能和优美的界面设计赢得了众多用户的喜爱。然而,对于许多中文用户来说,如何在Ubuntu上设置中文输入法却是一个不小的挑战。今天,就让我来为大家详细介绍一下如何在Ubuntu上轻松设置中文输入法,让您的输入体验更加流畅自如。 首先,我们…

【使用python读取多类型文件夹中的文档内容】

突发奇想&#xff0c;想使用python读取多类型文件夹中的文档内容&#xff0c;在Python中&#xff0c;读取多类型文件夹中的文档内容通常涉及几个步骤&#xff1a; 遍历文件夹以获取文件列表。根据文件扩展名判断文件类型。使用适当的库或方法来读取每种文件类型的内容。 以下…

java数组与集合框架(三)--Map,Hashtable,HashMap,LinkedHashMap,TreeMap

Map集合&#xff1a; Map接口: 基于 键&#xff08;key&#xff09;/值&#xff08;value&#xff09;映射 Map接口概述 Map与Collection并列存在。用于保存具有映射关系的数据:key-value Map 中的key 和value 都可以是任何引用类型的数据Map 中的key 用Set来存放&#xff0…

stitcher类实现多图自动拼接

效果展示 第一组&#xff1a; 第二组&#xff1a; 第三组&#xff1a; 第四组&#xff1a; 运行代码 import os import sys import cv2 import numpy as npdef Stitch(imgs,savePath): stitcher cv2.Stitcher.create(cv2.Stitcher_PANORAMA)(result, pano) stitcher.st…

【每日跟读】常用英语500句(400~500)

【每日跟读】常用英语500句 Where can I buy a ticket? 在哪里能买到票&#xff1f; When is the next train? 下趟火车什么时候到&#xff1f; Thank you so much for helping me move yesterday. 非常感谢你昨天帮我搬家 I’m feeling a little under the weather toda…