常用torch.nn

目录

  • 一、torch.nn和torch.nn.functional
  • 二、nn.Linear
  • 三、nn.Embedding
  • 四、nn.Identity
  • 五、Pytorch非线性激活函数
  • 六、nn.Conv2d
  • 七、nn.Sequential
  • 八、nn.ModuleList
  • 九、torch.outer torch.cat

一、torch.nn和torch.nn.functional

Pytorch中torch.nn和torch.nn.functional的区别及实例详解
Pytorch中,nn与nn.functional有哪些区别?
相同之处:

两者都继承于nn.Module
nn.x与nn.functional.x的实际功能相同,比如nn.Conv3d和nn.functional.conv3d都是进行3d卷积
运行效率几乎相同
不同之处:
nn.x是nn.functional.x的类封装,nn.functional.x是具体的函数接口
nn.x除了具有nn.functional.x功能之外,还具有nn.Module相关的属性和方法,比如:train(),eval()等
nn.functional.x直接传入参数调用,nn.x需要先实例化再传参调用
nn.x能很好的与nn.Sequential结合使用,而nn.functional.x无法与nn.Sequential结合使用
nn.x不需要自定义和管理weight,而nn.functional.x需自定义weight,作为传入的参数

二、nn.Linear

torch.nn.functional 中的 Linear 函数
torch.nn.functional.linear 是 PyTorch 框架中的一个功能模块,主要用于实现线性变换。这个函数对于构建神经网络中的全连接层(或称为线性层)至关重要。它能够将输入数据通过一个线性公式(y = xA^T + b)转换为输出数据,其中 A 是权重,b 是偏置项。

用途
神经网络构建:在构建神经网络时,linear 函数用于添加线性层。
特征变换:在数据预处理和特征工程中,使用它进行线性特征变换。

基本用法如下:

output = torch.nn.functional.linear(input, weight, bias=None)
input:输入数据
weight:权重
bias:偏置(可选)
class torch.nn.Linear(in_features, out_features, bias=True)
参数:
in_features - 每个输入样本的大小
out_features - 每个输出样本的大小
bias - 若设置为False,这层不会学习偏置。默认值:True

对输入数据做线性变换:y=Ax+b
形状:
输入: (N,in_features)
输出: (N,out_features)
变量:
weight -形状为(out_features x in_features)的模块中可学习的权值
bias -形状为(out_features)的模块中可学习的偏置
例子:

>>> m = nn.Linear(20, 30)
>>> input = autograd.Variable(torch.randn(128, 20))
>>> output = m(input)
>>> print(output.size())

三、nn.Embedding

pytorch nn.Embedding详解
nn.Embedding作用
nn.Embedding是PyTorch中的一个常用模块,其主要作用是将输入的整数序列转换为密集向量表示。在自然语言处理(NLP)任务中,可以将每个单词表示成一个向量,从而方便进行下一步的计算和处理。
nn.Embedding词向量转化
nn.Embedding是将输入向量化,定义如下:

torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)
参数说明:
num_embeddings :字典中词的个数
embedding_dim:embedding的维度
padding_idx(索引指定填充):如果给定,则遇到padding_idx中的索引,则将其位置填00是默认值,事实上随便填充什么值都可以)。

在PyTorch中,nn.Embedding用来实现词与词向量的映射。nn.Embedding具有一个权重(.weight),形状是(num_words, embedding_dim)。例如一共有100个词,每个词用16维向量表征,对应的权重就是一个100×16的矩阵。
Embedding的输入形状N×W,N是batch size,W是序列的长度,输出的形状是N×W×embedding_dim。
Embedding输入必须是LongTensor,FloatTensor需通过tensor.long()方法转成LongTensor。
Embedding的权重是可以训练的,既可以采用随机初始化,也可以采用预训练好的词向量初始化。

四、nn.Identity

nn.Identity() 是 PyTorch 中的一个层(layer)。它实际上是一个恒等映射,不对输入进行任何变换或操作,只是简单地将输入返回作为输出。

五、Pytorch非线性激活函数

torch.nn.functional非线性激活函数

六、nn.Conv2d

torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1)

对几个输入特征图组成的输入信号应用2D卷积。
参数:
input – 输入张量 (minibatch x in_channels x iH x iW)
weight – 过滤器张量 (out_channels, in_channels/groups, kH, kW)
bias – 可选偏置张量 (out_channels) -
stride – 卷积核的步长,可以是单个数字或一个元组 (sh x sw)。默认为1
padding – 输入上隐含零填充。可以是单个数字或元组。 默认值:0
groups – 将输入分成组,in_channels应该被组数除尽

例子:

>>> # With square kernels and equal stride
>>> filters = autograd.Variable(torch.randn(8,4,3,3))
>>> inputs = autograd.Variable(torch.randn(1,4,5,5))
>>> F.conv2d(inputs, filters, padding=1)

七、nn.Sequential

一个序列容器,用于搭建神经网络的模块被按照被传入构造器的顺序添加到nn.Sequential()容器中。除此之外,一个包含神经网络模块的OrderedDict也可以被传入nn.Sequential()容器中。利用nn.Sequential()搭建好模型架构,模型前向传播时调用forward()方法,模型接收的输入首先被传入nn.Sequential()包含的第一个网络模块中。然后,第一个网络模块的输出传入第二个网络模块作为输入,按照顺序依次计算并传播,直到nn.Sequential()里的最后一个模块输出结果。
nn.Sequential()和torch.nn.ModuleList的区别在于:torch.nn.ModuleList只是一个储存网络模块的list,其中的网络模块之间没有连接关系和顺序关系。而nn.Sequential()内的网络模块之间是按照添加的顺序级联的。

八、nn.ModuleList

nn.ModuleList() 是 PyTorch 中的一个类,用于管理神经网络模型中的子模块列表。它允许我们将多个子模块组织在一起,并将它们作为整个模型的一部分进行管理和操作。

在神经网络模型的开发过程中,通常需要定义和使用多个子模块,例如不同的层、块或者其他组件。nn.ModuleList() 提供了一种方便的方式来管理这些子模块,并确保它们被正确地注册为模型的一部分。

九、torch.outer torch.cat

torch.outer: 实现张量的点积,张量都需要是一维向量

import torch
v1 = torch.arange(1., 5.)
print(v1)
v2 = torch.arange(1., 4.)
print(v2)
print(torch.outer(v1, v2))tensor([1., 2., 3., 4.])
tensor([1., 2., 3.])
tensor([[ 1.,  2.,  3.],[ 2.,  4.,  6.],[ 3.,  6.,  9.],[ 4.,  8., 12.]])

torch.cat: 函数将两个张量(tensor)按指定维度拼接在一起,注意:除拼接维数dim数值可不同外其余维数数值需相同,方能对齐。torch.cat()函数不会新增维度,而torch.stack()函数会新增一个维度,相同的是两个都是对张量进行拼接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/19449.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue使用axios实现调用后端接口

准备后端接口 首先,我已经写好一个后端接口用来返回我的用户数据,并用Postman测试成功如下: 以我的接口为例,接口地址为:http://localhost:8080/user/selectAll 返回Json为: {"code": "2…

docker制作高版本jdk17镜像踩坑

1、创建目录并下载jdk上传到服务器中 从jdk官网下载jdk17镜像,提示:下载到本地用xftp上传到服务器(速度会快点) jdk官网:https://www.oracle.com/java/technologies/downloads/#graalvmjava21 创建目录,将…

Ubuntu系统编译内核——deb安装 / install安装

摘要 本文简要记录两种编译内核的方法: 打包成deb模块安装(推荐);直接make install安装; 更推荐使用——打包成deb模块安装,因为可以方便的拷贝下次其他机器使用。 1. 编译环境准备 系统:lin…

强化学习——学习笔记3

一、强化学习都有哪些分类? 1、基于模型与不基于模型 根据是否具有环境模型,强化学习算法分为两种:基于模型与不基于模型 基于模型的强化学习(Model-based RL):可以简单的使用动态规划求解,任务可定义为预测和控制&am…

cesium 实现自定义弹窗并跟随场景移动

cesium 添加点位自定义弹窗跟随场景移动 完整代码演示可直接copy使用 1 效果图&#xff1a; 2 深入理解 就是原始点位的数据 id>property 点位真实渲染到球体上的笛卡尔坐标系 id>_polyline 的路径下 可以通过 3 代码示例 <!DOCTYPE html> <html lang"…

【数据分享】2017-2023年全球范围10米精度土地覆盖数据

土地覆盖数据是我们在各项研究中都非常常用的数据&#xff0c;土地覆盖数据的来源也有很多。之前我们分享过欧空局发布的2020年和2021年的10米分辨率的土地覆盖数据,也分享过我国首套1米分辨率的土地覆盖数据&#xff08;均可查看之前的文章获悉详情&#xff09;&#xff01; …

dwc3 DR_MODE 处理初始化 OTG gadget

dwc3控制器是怎么处理otg-CSDN博客 dwc3_probe static int dwc3_probe(struct platform_device *pdev) {struct device *dev &pdev->dev;struct resource *res, dwc_res;struct dwc3 *dwc;int ret;void __iomem …

管道液位传感器可以检测哪些液体?

管道液位传感器是一种专门用于检测流动性比较好的液体的传感器装置。它采用光学感应原理&#xff0c;不涉及任何机械运动&#xff0c;具有长寿命、安装方便和微功耗的特点。相比传统机械式液位传感器&#xff0c;光电管道传感器有效解决了低精度和卡死失效等问题&#xff0c;同…

Django 解决 CSRF 问题

在 Django 出现 CSRF 问题 要解决这个问题&#xff0c;就得在 html 里这么修改 <!DOCTYPE html> <html><head></head><body><form action"/login/" method"post">{% csrf_token %}</form></body> </…

C++基础知识之类和对象

一、类 类是一种用户自定义的数据类型&#xff0c;用于封装数据和方法。它定义了一组属性&#xff08;数据成员&#xff09;和方法&#xff08;成员函数&#xff09;&#xff0c;并且可以被多个对象共享。在面向对象编程中&#xff0c;类是一种用于创建对象的蓝图或模板。它定义…

短视频脚本创作的五个方法 沈阳短视频剪辑培训

说起脚本&#xff0c;我们大概都听过影视剧脚本、剧本&#xff0c;偶尔可能在某些综艺节目里听过台本。其中剧本是影视剧拍摄的大纲&#xff0c;用来指导影视剧剧情的走向和发展&#xff0c;而台本则是综艺节目流程走向的指导大纲。 那么&#xff0c;短视频脚本是什么&#xf…

探析GPT-4o:技术之巅的跃进

如何评价GPT-4o? 简介&#xff1a;最近&#xff0c;GPT-4o横空出世。对GPT-4o这一人工智能技术进行评价&#xff0c;包括版本间的对比分析、GPT-4o的技术能力以及个人感受等。 随着人工智能领域的不断发展&#xff0c;GPT系列模型一直处于行业的前沿。最近&#xff0c;GPT-4…

前端实习记录——git篇(一些问题与相关命令)

1、版本控制 &#xff08;1&#xff09;版本回滚 git log // 查看版本git reset --mixed HEAD^ // 回滚到修改状态&#xff0c;文件内容没有变化git reset --soft HEAD^ // 回滚暂存区&#xff0c;^的个数代表几个版本git reset --hard HEAD^ // 回滚到修改状态&#xff…

生态农业:引领未来农业新篇章

生态农业&#xff0c;正以其独特的魅力和创新理念&#xff0c;引领着未来农业发展的新篇章。在这个充满变革的时代&#xff0c;我们需要更加关注农业的可持续发展&#xff0c;而生态农业正是实现这一目标的重要途径。 生态农业产业的王总说&#xff1a;生态农业强调生态平衡和可…

python基础-内置函数3-类与对象相关内置函数

文章目录 python基础-内置函数3类与对象getattr()hasattr()setattr()delattr()vars()dir()property()super()classmethod()staticmethod()isinstance()issubclass()callable()object()repr()ascii()id()hash()type() python基础-内置函数3 类与对象 getattr() getattr(objec…

以讲师能力提升,优路教育促学员拓宽职业原野

在建设教育强国的过程中&#xff0c;加强教师队伍建设被视为重要的基础工作。当前&#xff0c;我国正大力推进高素质“双师型”职业教育教师队伍建设&#xff0c;以培养更多既具备理论教学能力&#xff0c;又拥有实践教学经验的教师。在这一背景下&#xff0c;优路教育积极响应…

【文档+源码+调试讲解】古典舞在线交流平台的设计与实现

摘 要 随着互联网技术的发展&#xff0c;各类网站应运而生&#xff0c;网站具有新颖、展现全面的特点。因此&#xff0c;为了满足用户古典舞在线交流的需求&#xff0c;特开发了本古典舞在线交流平台。 本古典舞在线交流平台应用Java技术&#xff0c;MYSQL数据库存储数据&…

构建一个简单的情感分析器:使用Python和spaCy

构建一个简单的情感分析器&#xff1a;使用Python和spaCy 引言 情感分析是自然语言处理&#xff08;NLP&#xff09;中的一项重要技术&#xff0c;它可以帮助企业和研究人员理解公众对特定主题或产品的看法。 在本篇文章中&#xff0c;我们将使用Python编程语言和 spaCy 库来构…

FreeRTOS【7】队列使用

1.开发背景 操作系统提供了多线程并行的操作&#xff0c;为了方便代码的维护&#xff0c;各个线程都分配了专用的内存并处理对应的内容。但是线程间也是需要协助操作的&#xff0c;例如一个主线程接收信息&#xff0c;会把接收的信息并发到其他线程&#xff0c;即主线程不阻塞&…

[LitCTF 2023]yafu (中级) (素数分解)

题目&#xff1a; from Crypto.Util.number import * from secret import flagm bytes_to_long(flag) n 1 for i in range(15):n *getPrime(32) e 65537 c pow(m,e,n) print(fn {n}) print(fc {c})n 152412082177688498871800101395902107678314310182046454156816957…