PyTorch框架学习三——张量操作

PyTorch框架学习三——张量操作

  • 一、拼接
    • 1.torch.cat()
    • 2.torch.stack()
  • 二、切分
    • 1.torch.chunk()
    • 2.torch.split()
  • 三、索引
    • 1.torch.index_select()
    • 2.torch.masked_select()
  • 四、变换
    • 1.torch.reshape()
    • 2.torch.transpace()
    • 3.torch.t()
    • 4.torch.squeeze()
    • 5.torch.unsqueeze()

一、拼接

1.torch.cat()

功能:将tensor按照维度dim进行拼接,除了需要拼接的维度外,其余维度尺寸得是相同的。

torch.cat(tensors, dim=0, out=None)

看一下所有的参数:
在这里插入图片描述

  1. tensors:需要被拼接的张量序列。
  2. dim:(int,可选)被拼接的维度,默认为0。
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790,  0.1497],[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790,  0.1497],[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,-1.0969, -0.4614],[-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,-0.5790,  0.1497]])

2.torch.stack()

功能:在新创建的维度dim上进行拼接,所有的张量必须是相同的维度。

torch.stack(tensors, dim=0, out=None)

在这里插入图片描述
注意:stack()会创建一个新的维度。

t = torch.ones((2, 3))
t_stack = torch.stack([t, t, t], dim=2)
print("\nt_stack:{} shape:{}".format(t_stack, t_stack.shape))

在这里插入图片描述
原来t的维度是(2, 3),本来是没有第三维的,但是stack()会构建新的dim=2,就是先构建第三维dim=2,然后在该维度上进行拼接。

二、切分

1.torch.chunk()

功能:将tensor按维度dim进行平均切分。如果不能整除,最后一份tensor在该维度上的长度小于其他tensor。

torch.chunk(input, chunks, dim=0)

在这里插入图片描述

  1. input:要切分的张量。
  2. chunks:要切分的份数。
  3. dim:要切分的维度,默认为0。
a = torch.ones((2, 7))  # 7
list_of_tensors = torch.chunk(a, dim=1, chunks=3)   # 3for idx, t in enumerate(list_of_tensors):print("第{}个张量:{}, shape is {}".format(idx+1, t, t.shape))

在这里插入图片描述

2.torch.split()

功能:将tensor按dim进行切分。

torch.split(tensor, split_size_or_sections, dim=0)

在这里插入图片描述

  1. tensor:要切分的张量。
  2. split_size_or_sections:(int或list(int))为int时,表示每一份的长度,如果不能整除,最后一份的长度要小于其他的张量,为list时,按list元素来切分。
  3. dim:同上。
>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],[2, 3],[4, 5],[6, 7],[8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],[2, 3]]),tensor([[4, 5],[6, 7]]),tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),tensor([[2, 3],[4, 5],[6, 7],[8, 9]]))

三、索引

1.torch.index_select()

功能:在dim上,按照index索引数据,返回一个依据index索引数据拼接的张量。

torch.index_select(input, dim, index, out=None)

在这里插入图片描述

  1. input:要索引的张量。
  2. dim:被索引的维度。
  3. index:一维张量,包括了要索引的数据序号。(long,不能是float)
  4. out:输出张量(可选)。
>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],[-0.4664,  0.2647, -0.1228, -1.1068],[-1.1734, -0.6571,  0.7230, -0.6004]])
>>> indices = torch.tensor([0, 2])
>>> torch.index_select(x, 0, indices)
tensor([[ 0.1427,  0.0231, -0.5414, -1.0009],[-1.1734, -0.6571,  0.7230, -0.6004]])
>>> torch.index_select(x, 1, indices)
tensor([[ 0.1427, -0.5414],[-0.4664, -0.1228],[-1.1734,  0.7230]])

2.torch.masked_select()

功能:按照mask中的True进行索引,返回一个一维张量。

torch.masked_select(input, mask, out=None)

在这里插入图片描述

>>> x = torch.randn(3, 4)
>>> x
tensor([[ 0.3552, -2.3825, -0.8297,  0.3477],[-1.2035,  1.2252,  0.5002,  0.6248],[ 0.1307, -2.0608,  0.1244,  2.0139]])
>>> mask = x.ge(0.5)
>>> mask
tensor([[False, False, False, False],[False, True, True, True],[False, False, False, True]])
>>> torch.masked_select(x, mask)
tensor([ 1.2252,  0.5002,  0.6248,  2.0139])

四、变换

1.torch.reshape()

功能:变换张量的形状。

torch.reshape(input, shape)

在这里插入图片描述

  1. input:输入张量。
  2. shape:新张量的形状。当某个维度为-1时,表示该维度不用关心,可以从别的维度计算得到。
>>> a = torch.arange(4.)
>>> torch.reshape(a, (2, 2))
tensor([[ 0.,  1.],[ 2.,  3.]])
>>> b = torch.tensor([[0, 1], [2, 3]])
>>> torch.reshape(b, (-1,))
tensor([ 0,  1,  2,  3])

2.torch.transpace()

功能:交换tensor的两个维度。

torch.transpose(input, dim0, dim1)

在这里插入图片描述

  1. input:输入张量。
  2. dim0和dim1:要交换的两个维度。
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 1.0028, -0.9893,  0.5809],[-0.1669,  0.7299,  0.4942]])
>>> torch.transpose(x, 0, 1)
tensor([[ 1.0028, -0.1669],[-0.9893,  0.7299],[ 0.5809,  0.4942]])

3.torch.t()

功能:2维tensor转置,对矩阵而言。等价于torch.transpose(input, 0, 1)。

torch.t(input)
>>> x = torch.randn(())
>>> x
tensor(0.1995)
>>> torch.t(x)
tensor(0.1995)
>>> x = torch.randn(3)
>>> x
tensor([ 2.4320, -0.4608,  0.7702])
>>> torch.t(x)
tensor([ 2.4320, -0.4608,  0.7702])
>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.4875,  0.9158, -0.5872],[ 0.3938, -0.6929,  0.6932]])
>>> torch.t(x)
tensor([[ 0.4875,  0.3938],[ 0.9158, -0.6929],[-0.5872,  0.6932]])

注意:只对矩阵会转置,对标量和向量都不会。

4.torch.squeeze()

功能:压缩长度为1的维度(轴)。

torch.squeeze(input, dim=None, out=None)

在这里插入图片描述

  1. input:输入张量。
  2. dim:(可选)若为None,移除所有长度为1的轴,若指定轴,当且仅当该轴长度为1时移除。
  3. out:输出张量。
>>> x = torch.zeros(2, 1, 2, 1, 2)
>>> x.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x)
>>> y.size()
torch.Size([2, 2, 2])
>>> y = torch.squeeze(x, 0)
>>> y.size()
torch.Size([2, 1, 2, 1, 2])
>>> y = torch.squeeze(x, 1)
>>> y.size()
torch.Size([2, 2, 1, 2])

5.torch.unsqueeze()

功能:返回一个新的张量,对输入的指定位置插入维度 1。

torch.unsqueeze(input, dim)
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0)
tensor([[ 1,  2,  3,  4]])
>>> torch.unsqueeze(x, 1)
tensor([[ 1],[ 2],[ 3],[ 4]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/491718.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

'chcp' 不是内部或外部命令,也不是可运行的程序

在cmd窗口中输入activate tensorflow时报错chcp 不是内部或外部命令,也不是可运行的程序 添加两个环境变量即可解决: 将Anaconda的安装地址添加到环境变量“PATH”,如果没有可以新建一个,我的安装地址是“D:\Anaconda”&#xf…

2019年全球企业人工智能发展现状分析报告

来源:199IT互联网数据中心《悬而未决的AI竞赛——全球企业人工智能发展现状》由德勤洞察发布,德勤中国科技、传媒和电信行业编译。为了解全球范围内的企业在应用人工智能技术方面的情况以及所取得的成效,德勤于2018年第三季度针对早期人工智能…

qt调动DLL

void func(void); // dll库中的函数 typedef void (*PFUNC)(void); 方法一&#xff1a; HMODULE g_hAPIDLL NULL; wchar_t tcDLLPath[100] L"D:\\name.dll"; g_hAPIDLL ::LoadLibrary(tcDLLPath); if (NULL g_hAPIDLL) { qDebug() << "load library f…

PyTorch框架学习四——计算图与动态图机制

PyTorch框架学习四——计算图与动态图机制一、计算图二、动态图与静态图三、torch.autograd1.torch.autograd.backward()2.torch.autograd.grad()3.autograd小贴士4.代码演示理解&#xff08;1&#xff09;构建计算图并反向求导&#xff1a;&#xff08;2&#xff09;grad_tens…

ipynb文件转为python(.py)文件

在Anaconda中的jupyter打开该ipynb文件&#xff0c;然后依次点击File—>Download as—>python(.py)

美国准备跳过5G直接到6G 用上万颗卫星包裹全球,靠谱吗?

来源&#xff1a;瞭望智库这项2015年提出的计划&#xff0c;规模极其巨大&#xff0c;总计要在2025年前发射近12000颗卫星。有自媒体认为&#xff0c;该计划表示美国将在太空中建立下一代宽带网络&#xff0c;绕过5G&#xff0c;直接升级到6G&#xff0c;并据此认为“6G并不遥远…

8月读书分享-《执行力是训练出来的》

写在最开头的是&#xff0c;没有拿到这本书之前其实我是很期待的&#xff0c;因为我觉得执行力是我所很需要的东西。但是拿到书之后就有一些失望了&#xff0c;因为我发现他的章节实在是太多了&#xff0c;我总觉得如果章节太多会不会其实是作者的归纳整理能力不太好呢&#xf…

PyTorch框架学习五——图像预处理transforms(一)

PyTorch框架学习五——图像预处理transforms&#xff08;一&#xff09;一、transforms运行机制二、transforms的具体方法1.裁剪&#xff08;1&#xff09;随机裁剪&#xff1a;transforms.RandomCrop()&#xff08;2&#xff09;中心裁剪&#xff1a;transforms.CenterCrop()&…

机器之心 GitHub 项目地址:

机器之心 GitHub 项目地址&#xff1a;https://github.com/jiqizhixin/ML-Tutorial-Experiment

IBM Watson大裁70% 员工,撕掉了国内大批伪AI企业最后一块遮羞布!

来源:新医路Watson 是IBM 的重量级AI 系统&#xff1b;近年IBM 大力发展AI 医疗&#xff0c;在2015 年成立独立的 Watson Health 部门&#xff0c;并收购多家医疗数据公司&#xff0c;前景看好。然而短短三年&#xff0c;这个明星部门就要裁员50% 到70% 的员工&#xff0c;代表…

PyTorch框架学习六——图像预处理transforms(二)

PyTorch框架学习六——图像预处理transforms&#xff08;二&#xff09;&#xff08;续&#xff09;二、transforms的具体方法4.图像变换&#xff08;1&#xff09;尺寸变换&#xff1a;transforms.Resize()&#xff08;2&#xff09;标准化&#xff1a;transforms.Normalize()…

新浪微博学习的知识点

instancetype 默认会识别当前是哪个类或者对象调用,就会转换成对应的类的对象 模型设计思想:Item:就是苹果的模型命名规范 tabBarItem: 决定着 tabBar 上按钮的内容 NSMutableDictionary *att [[NSMutableDictionary alloc] init]; att[NSForegroundColorAttributeName] […

numpy方法读取加载mnist数据集

方法来自机器之心公众号 首先下载mnist数据集&#xff0c;并将里面四个文件夹解压出来&#xff0c;下载方法见前面的博客 import tensorflow as tf import numpy as np import osdataset_path rD:\PycharmProjects\tensorflow\MNIST_data # 这是我存放mnist数据集的位置 is_…

纳米线传感器来了,传感芯片还会远吗

来源&#xff1a;科学网“无旁路电路”纳米线桥接生长方案 黄辉供图微型气体检测仪 黄辉供图人工智能、可穿戴装备、物联网等信息技术迅猛发展&#xff0c;需要海量的传感器提供支持&#xff0c;大数据和云计算等业务也需要各种传感器实时采集数据来支撑。但目前的传感器存在国…

PyTorch框架学习七——自定义transforms方法

PyTorch框架学习七——自定义transforms方法一、自定义transforms注意要素二、自定义transforms步骤三、自定义transforms实例&#xff1a;椒盐噪声虽然前面的笔记介绍了很多PyTorch给出的transforms方法&#xff0c;也非常有用&#xff0c;但是也有可能在具体的问题中需要开发…

Fashion-MNIST下载地址

训练集的图像&#xff1a;60000&#xff0c;http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz 训练集的类别标签&#xff1a;60000&#xff0c;http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.…

dfs算法

一般bfs算法都是使用递归 //下面简单的代码 visited[Max]; dfs(_graph g,int vo){ print(vo); visited[vo]1 for(int i0;i<Max;i){ if(visited[i]0){ dfs(g,i); } } }转载于:https://www.cnblogs.com/dick159/p/4900935.html

美国芯片简史:军方大力扶持下的产物 但一度被日 韩超越

来源&#xff1a;知乎专栏腾讯科技近日发起系列策划&#xff0c;聚焦各个芯片大国的发展历程。第四期&#xff1a;《美国芯片简史》。集成电路是电子信息产业的的基石&#xff0c;电子信息产业对国民经济与社会发展具有重大推动作用。从全球集成电路产业发展历程来看&#xff0…

PyTorch框架学习八——PyTorch数据读取机制(简述)

PyTorch框架学习八——PyTorch数据读取机制&#xff08;简述&#xff09;一、数据二、DataLoader与Dataset1.torch.utils.data.DataLoader2.torch.utils.data.Dataset三、数据读取整体流程琢磨了一段时间&#xff0c;终于对PyTorch的数据读取机制有了一点理解&#xff0c;并自己…

使用feed_dict不一定要用占位符

使用feed_dict一般会伴有占位符&#xff0c;如 x tf.placeholder(tf.float32) 但是没有tf.placeholder也可以使用feed_dict方法&#xff0c;如下面这个例子&#xff1a; import tensorflow as tfinput1 tf.constant([2], dtypetf.float32) input2 tf.constant([3], dtype…