神经网络初始化实例化的维度与调用输入数据的维度

神经网络初始化实例化的维度与调用输入数据的维度

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

#from agents.helpers import SinusoidalPosEmb
class SinusoidalPosEmb(nn.Module):
def init(self, dim=16): #dim为初始化需要设置的参数 比如默认为16 计算后会升维
super().init()
self.dim = dim

def forward(self, x):device = x.devicehalf_dim = self.dim // 2emb = math.log(10000) / (half_dim - 1)emb = torch.exp(torch.arange(half_dim, device=device) * -emb)emb = x[:, None] * emb[None, :]emb = torch.cat((emb.sin(), emb.cos()), dim=-1)return emb

class MLP(nn.Module):
“”"
MLP Model
“”"
def init(self, ##初始化以及参数
state_dim=2, #####初始化实例化定义的维度这个对看懂代码很关键! 与后面调用函数的输入数据的维度一般需要一致
action_dim=2,
device=“cpu”,
t_dim=16):

    super(MLP, self).__init__()self.device = "cpu"self.time_mlp = nn.Sequential(SinusoidalPosEmb(t_dim),  #  这个是初始化为16维度nn.Linear(t_dim, t_dim * 2),nn.Mish(),nn.Linear(t_dim * 2, t_dim),)input_dim = state_dim + action_dim + t_dim ## 2+2+16=20  维度数量self.mid_layer = nn.Sequential(nn.Linear(input_dim, 256),nn.Mish(),nn.Linear(256, 256),nn.Mish(),nn.Linear(256, 256),nn.Mish())self.final_layer = nn.Linear(256, action_dim)  ##输出2维度def forward(self, x, time, state):  ##定义个方法t = self.time_mlp(time)x = torch.cat([x, t, state], dim=1)  ###第二个维度以后要一致x = self.mid_layer(x)return self.final_layer(x)

MLPinstance=MLP()#初始化一个实例
MLPinstance

######################
x = torch.rand(5, 1, 2) # [10, 1, 8] #bath_size,1,2维度
time=5 #标量不对
state=torch.rand(5, 1, 2)
x.shape
torch.tensor(range(5)).unsqueeze(1).unsqueeze(2).shape,torch.rand(5, 1, 2).shape #有什么区别?

这两个 torch.tensor 的操作创建了不同形状和内容的张量。

1. torch.tensor(range(5)).unsqueeze(1).unsqueeze(2)

- 首先,torch.tensor(range(5)) 创建了一个一维张量,内容为 [0, 1, 2, 3, 4]

- 然后,.unsqueeze(1) 在第1个维度(现在是一维张量的唯一维度,等同于插入一个新的列维度)添加一个维度,使得形状变为 (5, 1)

- 接着,.unsqueeze(2) 再次在新维度的后面添加一个维度,最终形状变为 (5, 1, 1)。因此,这个操作的结果是一个形状为 (5, 1, 1) 的张量,每个元素都是从0到4的数字,每一行重复同一个数字,并且在最后两个维度上只有一个单位。

2. torch.rand(5, 1, 2)

- 这个操作直接创建了一个形状为 (5, 1, 2) 的三维张量,其中的所有元素都是从0到1之间的随机数(均匀分布)。这意味着你得到的是一个有5行,每行包含一个大小为1的子列表,每个子列表内有2个随机数的张量。

总结:

- 形状不同:前者形状为 (5, 1, 1),主要由连续的整数构成;后者形状为 (5, 1, 2),由随机浮点数构成。

- 内容不同:前者的内容是确定的,是0到4的整数,每个数字沿最后一个维度重复;后者的内容是随机的,范围在0到1之间。

- 数据类型不同:默认情况下,前者(基于 range)会是整数类型(除非显式转换),而后者明确是浮点数类型,因为使用了 torch.rand

torch.tensor(range(5)).unsqueeze(1).unsqueeze(1).shape在那个维度后面升维度,torch.tensor(range(5)).unsqueeze(1).unsqueeze(2).shape

######################################

x = torch.rand(2, 1) # [10, 1, 8]
time=torch.tensor([5]) ##一维度的可以
state=torch.rand(2, 1)
x.shape

实例化SinusoidalPosEmb类

pos_emb = SinusoidalPosEmb(dim=16)

创建一个示例输入张量x

假设我们有一个序列长度为5,维度为16的输入

#x = torch.tensor(range(5)).unsqueeze(1).unsqueeze(2) # 形状为 [5, 1] bath size,以及对应的timestep
time = torch.tensor(time)
print(time.shape)

调用forward方法

positional_embedding = pos_emb(time)

##########################################################################

x = torch.rand(5, 1) # [10, 1, 8]
time= torch.tensor(range(5)).unsqueeze(1)#没必要这样 经过embedding后会增加一个维度
state=torch.rand(5, 1)
x.shape

#################################################################
x = torch.rand(5, 1)
time= torch.tensor([5])
state=torch.rand(5, 1)
x.shape,time.shape
#############################################成功调用的####################
x = torch.rand(5, 1) ####一般定义了2维度 所以一般输入的数据就是2个维度的多个元素 当然少于2维的运算后要能够升维度 或者运算后能够降低倒定义初始化中需要的的维度!!!!!!!!
time= torch.tensor(range(5))
state=torch.rand(5, 1)
x.shape,time.shape

torch.tensor([5]) 和torch.tensor(range(5))的维度有什么区别

torch.tensor([5]) 创建的是一个形状为 torch.Size([1]) 的张量,表示它是一个包含单个元素的一维张量。

torch.tensor(range(5)) 创建的是一个形状为 torch.Size([5]) 的张量,表示它是一个包含5个元素的一维张量。

MLPinstance(x,time,state) #成功调用

将位置索引 x 转换为形状 (5, 1),频率向量转换为形状 (1, 8) 是怎么理解请举例
当我们谈论将位置索引 x 转换为形状 (5, 1) 和频率向量转换为形状 (1, 8),我们实际上是在讨论在进行矩阵运算之前对张量(在PyTorch中,张量是多维数组)的形状调整,以便它们能够进行有效的点乘操作。这个过程通常称为“广播”(broadcasting),它允许不同形状的张量进行数学运算,只要它们在没有明确指定的维度上大小为1或者完全匹配。

例子说明:

位置索引 x

原始的位置索引 x 是一个一维张量,表示5个不同的位置:

x = torch.tensor([0, 1, 2, 3, 4])

形状是 (5,),表示有5个元素。

为了使其能与频率向量正确点乘,我们需要将其扩展成一个二维张量,形状变为 (5, 1)。这意味着每个位置现在被视为一个单独的行,每一行只有一个元素:

x_expanded = x.unsqueeze(1)
# x_expanded 的形状现在是 (5, 1),内容为:
# tensor([[0],
#         [1],
#         [2],
#         [3],
#         [4]])

通过 unsqueeze(1) 操作,我们在索引1的位置增加了一个新的维度,使得每个位置值变成一个单独的列向量。

频率向量

频率向量是基于一半的维度数(如果 dim=16,则 half_dim=8)创建的,假定我们已经计算了对应的频率值并存储在一个一维张量中:

freqs = torch.tensor([...])  # 假设这里填入了计算得到的频率值序列

其原始形状可能是 (8,),代表了8个不同频率的值。

为了点乘,我们需要将其调整为 (1, 8) 形状,意味着它变成了一个只有一行的矩阵:

freqs_expanded = freqs.unsqueeze(0)
# freqs_expanded 的形状现在是 (1, 8),内容保持原频率值不变,但结构上作为单行多列

通过 unsqueeze(0),我们在索引0的位置增加了一个新的维度,使频率值构成了一行向量。

结果

完成上述操作后,x_expandedfreqs_expanded 可以直接相乘,因为它们在第二个维度(列)上是匹配的(一个是1,另一个是8,但因为是点乘,一维会自动复制以匹配另一方的维度)。这样,每个位置的索引值会与所有的频率值进行一一对应乘法运算,从而生成最终的编码向量的一部分。

a为(5,3,1)的张量
b为(5,16)的张量 能否把b变为(5,3,16)的维度?
要将形状为 (5, 16) 的张量 b 转换成形状为 (5, 3, 16),以便它可以与形状为 (5, 3, 1) 的张量 a 在某些维度上进行运算,我们需要沿着一个新维度重复 b 中的每一行数据三次(因为 a 的第二个维度是 3)。这可以通过使用 torch.repeat_interleave 函数来完成,具体是沿着第1维度(在PyTorch中,维度索引从0开始计数,所以这是第二个维度)进行重复。

下面是操作的示例代码:

import torch# 假设 b 是形状为 (5, 16) 的张量
b = torch.randn(5, 16)# 将 b 调整为形状为 (5, 3, 16)
b_expanded = b.unsqueeze(1).repeat(1, 3, 1)print("Original shape of b:", b.shape)
print("Expanded shape of b:", b_expanded.shape)

这段代码首先使用 unsqueeze(1)b 添加一个新的维度(变成 (5, 1, 16)),然后使用 repeat(1, 3, 1) 沿着新添加的维度复制数据三次,最终得到形状为 (5, 3, 16) 的张量。

这样调整后,b_expanded 就可以和形状为 (5, 3, 1) 的张量 a 在相应的维度上进行操作了,比如逐元素相乘等运算。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/25906.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

人脸识别和 ArcFace:用于深度人脸识别的附加角边际损失

在本文中,您将发现一种 ArcFace 方法,该方法可获得用于人脸识别的高分辨特征。阅读本文后,你将了解: 人脸识别任务如何工作。如何计算人脸匹配。SoftMax 和 ArcFace 的直观区别。ArcFace 的几何解释。ArcFace 背后的数学原理本文假定您已经熟悉用于多类分类、检测和 SoftMax…

电子设计类论文降低重复率的方法

1、知网学位论文检测为整篇上传,格式对检测结果可能会造成影响,需要将最终交稿格式提交检测,将影响降到最小,此影响为几十字的小段可能检测不出。对于3万字符以上文字较多的论文是可以忽略的。 2、上传论文后,系统会自…

2024.6.11总结

今天做个面试小结吧。面试了十几家android开发岗,在技术面都挂了。一家软件测试岗,二面挂了。 可见,一般面试有三轮,分别为hr面,专业面(也称技术面)和主管面。 hr面:之前广州电信的hr面挂了&a…

Vuepress 2从0-1保姆级进阶教程——标准化流程(Tailwindcss+autoprefixer+commitizen)

Vuepress 2 专栏目录【已完结】 1. 入门阶段 Vuepress 2从0-1保姆级入门教程——环境配置篇Vuepress 2从0-1保姆级入门教程——安装流程篇Vuepress 2从0-1保姆级入门教程——文档配置篇Vuepress 2从0-1保姆级入门教程——主题与部署 2.进阶阶段 Vuepress 2从0-1保姆级进阶教程—…

电影推荐系统的设计

管理员账户功能包括:系统首页,个人中心,管理员管理,用户管理,免费电影管理,付费电影管理,电影论坛管理 前台账户功能包括:系统首页,个人中心,付费电影&#x…

element-plus的el-space标签的使用

el-space标签可以很方便的设置标签间距和分隔符&#xff0c;对齐方式&#xff0c;是否拆行等属性。 <script setup lang"ts"> import { onMounted, ref } from vue;const sizeref(30)</script><template><el-space wrap :size"size"…

【最新鸿蒙应用开发】——类Web开发范式2——前端语法

兼容JS的类Web开发范式 JS FA应用的JS模块(entry/src/main/js/module)的典型开发目录结构如下&#xff1a; 1. 项目基本结构 1.1. 目录结构 1.2. 项目文件分类如下&#xff1a; .hml结尾的HML模板文件&#xff0c;这个文件用来描述当前页面的文件布局结构。 .css结尾的CSS样…

MIPI A-PHY协议学习

一、说明 A-PHY是一种高带宽串行传输技术,主要为了减少传输线并实现长距离传输的目的,比较适用于汽车。同时,A-PHY兼容摄像头的CSI协议和显示的DSI协议。其主要特征: 长距离传输,高达15m和4个线内连接器; 高速率,支持2Gbps~16Gbps; 支持多种车载线缆(同轴线、屏蔽差分…

Spark RDD算子

Spark RDD算子 转换算子&#xff08;Transformation Operators&#xff09; 类别算子名称简要介绍映射类算子map对RDD中的每个元素进行操作&#xff0c;返回一个新的RDDflatMap类似于map&#xff0c;但每个输入元素可映射到0或多个输出元素mapPartitions对RDD的每个分区中的元…

在VMware虚拟机上安装win10 跳过 通过microsoft登录

在VMware虚拟机上安装win10 跳过 “通过microsoft登录” 配置虚拟机&#xff0c;将网卡断开&#xff0c; 具体操作&#xff1a; 虚拟机/设置/硬件/网络适配器/设备状态&#xff0c;取消已连接和启动时连接的两个对号&#xff0c; 再把虚拟机重启&#xff0c;然后就可以跳过这个…

通过技术优化财务规划报告,重塑企业体验

财务报告使企业的管理层能够及时、准确、清晰且一致地了解整个企业的财务业绩和风险机遇。它促进了企业内部利益相关者之间的沟通&#xff0c;从而支持基于数据驱动的洞察力提升和战略决策。但财务报告往往需要占用大量的时间来运行和准备&#xff0c;且可能使最终结论偏离核心…

什么是PV操作

PV操作是一种在操作系统中用于同步和互斥的机制,它基于信号量(Semaphore)的概念。在并发编程中,多个进程或线程可能会同时访问共享资源,PV操作可以用来确保这些访问是同步的,以防止竞态条件和数据不一致的问题。 PV操作包括两个原子操作: P操作(Proberen,测试):这…

使用 C# 学习面向对象编程:第 4 部分

C# 构造函数 第 1 部分仅介绍了类构造函数的基础知识。 在本课中&#xff0c;我们将详细讨论各种类型的构造函数。 属性类型 默认构造函数构造函数重载私有构造函数构造函数链静态构造函数析构函数 请注意构造函数的一些基本概念&#xff0c;并确保你的理解非常清楚&#x…

从入门到精通:进程间通信

引言 在现代操作系统中&#xff0c;进程是程序运行的基本单位。为了实现复杂的功能&#xff0c;多个进程常常需要进行通信。进程间通信&#xff08;Inter-Process Communication, IPC&#xff09;是指多个进程之间进行数据交换的一种机制。IPC的主要目的包括数据传输、资源共享…

WDF驱动开发-电源策略(三)

多组件设备的 KMDF 驱动程序只能将请求发送到处于活动状态的组件。 通常&#xff0c;驱动程序将 I/O 队列分配给组件或组件集。 首先考虑分配给单个组件的队列。 驱动程序在组件变为活动状态时启动队列&#xff0c;并在组件空闲时停止队列。 因此&#xff0c;当 KMDF 调用队列…

生成式人工智能重置:从初期热潮到战略扩展

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

PyTorch学习8:多分类问题

文章目录 前言一、说明二、示例1.步骤2.示例代码 总结 前言 介绍如何利用PyTorch中Softmax 分类器实现多分类问题。 一、说明 1.多分类问题的输出是一个分布&#xff0c;满足和为1. 2.Softmax 分类器 3.损失函数&#xff1a;交叉熵损失 torch.nn.CrossEntropyLoss() 二、…

运维开发详解:DevOps 理念下的高效运维实践

目录 前言 1、 运维开发的核心概念 2、 运维开发的技术栈 3、运维开发的实践案例 4、 运维开发的挑战与机遇 5、 运维开发的未来发展趋势 6、运维开发概念 7、运维开发的角色 8、成为一名优秀的运维开发工程师 9、总结 前言 随着互联网业务的快速发展&#xff0c;传…

虚拟化 之一 详解 jailhouse 架构及原理、软硬件要求、源码文件、基本组件

Jailhouse 是一个基于 Linux 实现的针对创建工业级应用程序的小型 Hypervisor&#xff0c;是由西门子公司的 Jan Kiszka 于 2013 年开发的&#xff0c;并得到了官方 Linux 内核的支持&#xff0c;在开源社区中获得了知名度和吸引力。 Jailhouse Jailhouse 是一种轻量级的虚拟化…

微软如何打造数字零售力航母系列科普13 - Prime Focus Technologies在NAB 2024上推出CLEAR®对话人工智能联合试点

Prime Focus Technologies在NAB 2024上推出CLEAR对话人工智能联合试点 彻底改变您与内容的互动方式&#xff0c;从内容的创建到分发 洛杉矶&#xff0c;2024年4月9日/PRNewswire/-媒体和娱乐&#xff08;M&E&#xff09;行业人工智能技术解决方案的先驱Prime Focus Techn…