Pytorch(2)-tensor常用操作

tensor常用数学操作

  • 1. 随机数
    • 1.1 torch.rand() - 均匀分布数字
    • 1.2 torch.randn() - 正态分布数字
  • 2. 求和
    • 2.1 torch.sum(data, dim)
    • 2.2 numpy.sum(data, axis)
  • 3. 求积
    • 3.1 点乘--对应位置相乘
    • 3.2 矩阵乘法
  • 4. 均值、方差
    • 4.1 torch tensor.mean() .std()
    • 4.2 numpy array.mean() .std()
    • 4.3 numpy 与torch .std()计算公式差别
  • 5 求幂运算--torch.pow()
  • 6. tensor 取值
    • 6.1 scaler.item()
    • 6.2 tensor.tolist()
  • 7. 降升维
    • 7.1 torch.squeeze() 降维
    • 7.2 torch.unsqueeze() 升维
  • 8. 最大/最小/非零 值索引
    • 8.1 tensor.argmax()
    • 8.2 tensor.argmin()
    • 8.3 tensor.nonzero()
  • 9. 矩阵拼接
    • 9.1 torch.cat((a, b, c), dim )
    • 9.2 numpy.concatenate((a,b), axis)
  • 10. 矩阵拉伸
    • 10.1 torch.flatten()
    • 10.2 numpy.matrix.flatten

import torch

1. 随机数

1.1 torch.rand() - 均匀分布数字

产生大小指定的,[0,1)之间的均匀分布的样本.

> torch.rand(2,3)
>>tensor([[0.0270, 0.9856, 0.6599],[0.2237, 0.3888, 0.4566]])
> torch.rand(3,1)
>>tensor([[0.1268],[0.3370],[0.5097]])

1.2 torch.randn() - 正态分布数字

产生大小为指定的,正态分布的采样点,数据类型是tensor

> torch.randn(4)
>>tensor([-2.1436,  0.9966,  2.3426, -0.6366])>torch.randn(2, 3)
>>tensor([[ 1.5954,  2.8929, -1.0923],[ 1.1719, -0.4709, -0.1996]])

2. 求和

2.1 torch.sum(data, dim)

>>> a=torch.ones(2,3)
>>> a
tensor([[1., 1., 1.],[1., 1., 1.]])
# 按行求和
>>> b=torch.sum(a,1)		# 每列叠加,按行求和
>>> b
tensor([3., 3.])
>>> b.size()
torch.Size([2])
# 按列求和
>>> d=torch.sum(a,0)		# 每行叠加,按列求和
>>> d
tensor([2., 2., 2.])
>>> d.size()				
torch.Size([3])

2.2 numpy.sum(data, axis)

>>> import numpy as np
>>> np.sum([[0, 1], [0, 5]], axis=0)  #每行叠加,按列求和
array([0, 6])
>>> np.sum([[0, 1], [0, 5]], axis=1)  
array([1, 5])

3. 求积

3.1 点乘–对应位置相乘

数组和矩阵对应位置相乘,输出与相乘数组/矩阵的大小一致

np.multiply()
torch直接用* 就能实现

3.2 矩阵乘法

矩阵乘法:两个矩阵需要满足一定的行列关系

torch.matmul(tensor1, tensor2)
numpy.matmul(array1, array2)

4. 均值、方差

4.1 torch tensor.mean() .std()

>a.mean()  # a为Tensor型变量
>a.std()>torch.mean(a)  # a为Tensor型变量
>torch.std(a)>>> torch.Tensor([1,2,3,4,5])
tensor([1., 2., 3., 4., 5.])
>>> a=torch.Tensor([1,2,3,4,5])
>>> a.mean()
tensor(3.)
>>> torch.mean(a)
tensor(3.)
>>> a.std()
tensor(1.5811)
>>>> torch.std(a)
tensor(1.5811)       # 注意和numpy求解的区别

torch.mean(input) 输出input 各个元素的的均值,不指定任何参数就是所有元素的算术平均值,指定参数可以计算每一行或者 每一列的算术平均数

> a = torch.randn(4, 4)
>>tensor([[-0.3841,  0.6320,  0.4254, -0.7384],[-0.9644,  1.0131, -0.6549, -1.4279],[-0.2951, -1.3350, -0.7694,  0.5600],[ 1.0842, -0.9580,  0.3623,  0.2343]])
# 每一行的平均值
> torch.mean(a, 1, True)  #dim=true,计算每一行的平均数,输出与输入有相同的维度:两维度(4,1)
>>tensor([[-0.0163],[-0.5085],[-0.4599],[ 0.1807]])> torch.mean(a, 1)   # 不设置dim,默认计算每一行的平均数,内嵌了一个torch.squeeze(),将数值为1的维度压缩(4,)
>>tensor([-0.0163, -0.5085, -0.4599,  0.1807])

4.2 numpy array.mean() .std()

>a.mean()  # a为np array型变量
>a.std()>numpy.mean(a)   # a为np array型变量
>numpy.std(a)
>>> import numpy
>>> c=numpy.array([1,2,3,4,5])
>>> c.mean()
3.0
>>> numpy.mean(c)
3.0
>>> c.std()
1.4142135623730951
>>> numpy.std(c)
1.4142135623730951>>> d=numpy.array([1,1,1,1])
>>> d.mean()
1.0
>>> d.std()
0.0

4.3 numpy 与torch .std()计算公式差别

numpy:
std=1N∑i=1N(xi−x‾)2std=\sqrt{\frac{1}{N}\sum_{i=1}^N(x_i-\overline{x})^2}std=N1i=1N(xix)2

torch:
std=1N−1∑i=1N(xi−x‾)2std=\sqrt{\frac{1}{N-1}\sum_{i=1}^N(x_i-\overline{x})^2}std=N11i=1N(xix)2

5 求幂运算–torch.pow()

对输入的每分量求幂次运算

>>> a = torch.randn(4)
>>> a
tensor([ 0.4331,  1.2475,  0.6834, -0.2791])
>>> torch.pow(a, 2)
tensor([ 0.1875,  1.5561,  0.4670,  0.0779])>>> exp = torch.arange(1., 5.)
>>> a = torch.arange(1., 5.)
>>> a
tensor([ 1.,  2.,  3.,  4.])
>>> exp
tensor([ 1.,  2.,  3.,  4.])
>>> torch.pow(a, exp)
tensor([   1.,    4.,   27.,  256.])

和numpy中的numpy.power()作用类似。

6. tensor 取值

6.1 scaler.item()

一个Tensor调用.item()方法就可以返回这个Tensor 对应的标准python 类型的数据.注意事项,只针对一个元素的标量.若是向量,可以调用tolist()方法.
在这里插入图片描述在低版本的torch中tensor没有.item()属性,那么直接用[0]访问其中的数据.

>>> import torch
>>>c=torch.tensor([2.5555555])
>>> c2.5556
[torch.FloatTensor of size 1]
>>> c[0]
2.555555582046509
>>> round(c[0],3)
2.556

6.2 tensor.tolist()

(略)

7. 降升维

7.1 torch.squeeze() 降维

将输入所有为1的维度去除(2,1)-> (2,)(以行向量的形式存在)

torch.squeeze(input, dim=None, out=None)

> x = torch.zeros(2, 1, 2, 1, 2)
> x.size()
>>torch.Size([2, 1, 2, 1, 2])
>
> y = torch.squeeze(x)
> y.size()
>>torch.Size([2, 2, 2])

7.2 torch.unsqueeze() 升维

Returns a new tensor with a dimension of size one inserted at the specified position)
给数据增加一维度,常看到数据的维度为.size([])懵逼了,在后续计算的时候会造成问题。所以需要给数据升维度。

> x = torch.tensor([1, 2, 3, 4])
>torch.unsqueeze(x, 0)
>>tensor([[ 1,  2,  3,  4]])
>
>torch.unsqueeze(x, 1)
>>tensor([[ 1],[ 2],[ 3],[ 4]])

8. 最大/最小/非零 值索引

8.1 tensor.argmax()

tensor.argmax(dim)

返回tensor最大的值对应的index。dim 不设置-全部元素的最大值,dim = 0 每列的最大值,dim = 1每行的最大值。

>>> a = torch.randn(3,4)
>>> a
tensor([[ 1.1360, -0.5890,  1.8444,  0.6960],[ 0.3462, -1.1812, -1.5536,  0.4504],[-0.4464, -0.5600, -0.1655,  0.3914]])
>>> a.argmax()
tensor(2)               # 按行拉成一维向量,对应的小次奥
>>> a.argmax(dim = 0)
tensor([0, 2, 0, 0])    # 每一列最大值的索引
>>> a.argmax(dim = 1)
tensor([2, 3, 3])       # 每一行最大值索引

8.2 tensor.argmin()

与tensor.argmax() 用法相同,在多分类问题求准确率时会用到。

output_labels = outputs.argmax(dim = 1)
train_acc = (output_labels == labels).float().mean()

8.3 tensor.nonzero()

返回非零元素对应的下标

>>> a = torch.tensor([[1,0],[0,3]])
>>> a.nonzero()
tensor([[0, 0],[1, 1]])
>>> a= torch.tensor([1,2,3,0,4,5,0])
>>> a.nonzero()
tensor([[0],[1],[2],[4],[5]])

9. 矩阵拼接

9.1 torch.cat((a, b, c), dim )

>>> x = torch.randn(2, 3)
>>> x
tensor([[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 0)
tensor([[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790,  0.1497],[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790,  0.1497],[ 0.6580, -1.0969, -0.4614],[-0.1034, -0.5790,  0.1497]])
>>> torch.cat((x, x, x), 1)
tensor([[ 0.6580, -1.0969, -0.4614,  0.6580, -1.0969, -0.4614,  0.6580,-1.0969, -0.4614],[-0.1034, -0.5790,  0.1497, -0.1034, -0.5790,  0.1497, -0.1034,-0.5790,  0.1497]])

9.2 numpy.concatenate((a,b), axis)

>>> a = np.array([[1, 2], [3, 4]])
>>> b = np.array([[5, 6]])
>>> np.concatenate((a, b), axis=0)
array([[1, 2],[3, 4],[5, 6]])
>>> np.concatenate((a, b.T), axis=1)
array([[1, 2, 5],[3, 4, 6]])
>>> np.concatenate((a, b), axis=None)
array([1, 2, 3, 4, 5, 6])

10. 矩阵拉伸

10.1 torch.flatten()

矩阵按行展开:

>>> t = torch.tensor([[[1, 2],[3, 4]],[[5, 6],[7, 8]]])
>>> torch.flatten(t)
tensor([1, 2, 3, 4, 5, 6, 7, 8])
>>> torch.flatten(t, start_dim=1)
tensor([[1, 2, 3, 4],[5, 6, 7, 8]])

10.2 numpy.matrix.flatten

>>> m = np.matrix([[1,2], [3,4]])
>>> m.flatten()
matrix([[1, 2, 3, 4]])
>>> m.flatten('F')
matrix([[1, 3, 2, 4]])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/445302.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深度学习(07)-- 经典CNN网络结构(Inception (v1-v4))

文章目录目录1.Inception介绍1.1 Inception结构1.2 Inception V1(GoogleNet)1.3 Inception V2(Batch Norm)1.4 Inception V3(Factorization)1.5 Inception V4(ResNet)1.5 Inception v1~v4 总结1.6 Inception进阶2.Inception实现目…

Python(13)-函数,lambda语句

函数1 函数定义2 函数调用3 函数注释文档4 函数参数4.1 参数列表,默认参数,任意参数4.1.1 无缺省值参数4.1.2(部分)缺省值参数4.1.3 数量不定形参数4.2 可变对象和不可变对象4.3 作用域4.3.1 globals()函数4.3.2 global 声明变量为全局变量5 函数返回值5…

深度学习(08)-- Residual Network (ResNet)

文章目录目录1.残差网络基础1.1基本概念1.2VGG19、ResNet34结构图1.3 梯度弥散和网络退化1.4 残差块变体1.5 ResNet模型变体1.6 Residual Network补充1.7 1*1卷积核(补充)2.残差网络介绍(何凯明)3.ResNet-50(Ng)3.1 非常深的神经网…

redis——命令请求的执行过程

发送命令请求 当用户在客户端中键入一个命令请求时, 客户端会将这个命令请求转换成协议格式, 然后通过连接到服务器的套接字, 将协议格式的命令请求发送给服务器。 读取命令请求 当客户端与服务器之间的连接套接字因为客户端的写入而变得可…

深度学习(09)-- DenseNet

文章目录目录1.DenseNet网络结构2.稠密连接及其优点3.代码实现4.补充说明目录 1.DenseNet网络结构 2.稠密连接及其优点 每层以之前层的输出为输入,对于有L层的传统网络,一共有L个连接,对于DenseNet,则有L*(L1)/2。 这篇论文主要…

redis——缓存击穿/穿透/雪崩

缓存穿透 一般的缓存系统,都是按照key去缓存查询,如果不存在对应的value,就去后端系统查找(比如DB)。 一些恶意的请求会故意查询不存在的key,请求量很大,就会对后端系统造成很大的压力。这就叫做缓存穿透…

python(15)-window7配置iPython

前提:安装了Pythonanaconda anaconda安装参考:https://www.zhihu.com/question/58033789 在window系统下可以使用两种方法来实现类似与于Linux终端命令运行程序的方法(推荐方式2): 1.cmd:自己没有操作过,可以参考下面…

深度学习(10)-- Capsules Networks(CapsNet)

版权声明&#xff1a;本文为博主原创文章&#xff0c;未经博主允许不得转载。 https://blog.csdn.net/malele4th/article/details/79430464 </div><div id"content_views" class"markdown_views"><!-- flowchart 箭头图标 勿删 --&g…

手把手maven的功能/安装/使用/idea集成

看这篇文章不用着急安装&#xff0c;跟着步骤一定会成功&#xff0c;要理解maven是什么&#xff0c;如何使用。 介绍 maven官网 对于一个小白来说&#xff0c;官网有用的信息就是这些 不管如何介绍maven&#xff0c;作为使用者来说&#xff0c;主要感觉两个方面有帮助&#x…

python(16)-列表list,for循环

高级数据类型--列表1列表定义2列表中取值3列表的增&#xff0c;删&#xff0c;查&#xff0c;改3.1修改指定位置的数据3.2确定指定元素的索引3.3增加操作3.4删除操作3.5 元素是否存在与列表中 in3.6在指定索引位置插入元素4列表的数据统计5列表排序6列表的循环遍历-for7多维度l…

深度学习(11)-- GAN

TensorFlow &#xff08;GAN&#xff09; 目录 TensorFlow &#xff08;GAN&#xff09;目录1、GAN1.1 常见神经网络形式1.2 生成网络1.3 新手画家 & 新手鉴赏家1.4 GAN网络1.5 例子 1、GAN 今天我们会来说说现在最流行的一种生成网络, 叫做 GAN, 又称生成对抗网络, 也…

redis——数据结构和对象的使用介绍

redis官网 微软写的windows下的redis 我们下载第一个 额案后基本一路默认就行了 安装后&#xff0c;服务自动启动&#xff0c;以后也不用自动启动。 出现这个表示我们连接上了。 redis命令参考链接 String 字符串结构 struct sdshdr{//记录buf数组中已使用字节的数量int …

Python模块(1)-Argparse 简易使用教程

argparse 简易使用教程1.概况2. action3. argparse 使用demo3.1 argparse 实现加法器3.2 D-Model parser1.概况 argparse是Python中用于解析命令行参数的一个模块&#xff0c;可以自动生成help和usage信息&#xff1b;当从终端输入的参数无效时&#xff0c;模块会输出提示信息…

redis——NOSQL及redis概述

NoSql入门概述 单机Mysql的美好时代 瓶颈&#xff1a; 数据库总大小一台机器硬盘内存放不下数据的索引&#xff08;B tree&#xff09;一个机器的运行内存放不下访问量&#xff08;读写混合&#xff09;一个实例不能承受Memcached&#xff08;缓存&#xff09; MySql 垂直拆…

Python(17)-元组tuple

高级数据类型--元组1.元组的定义2.元组基本操作3.元组的循环遍历4.元组的应用场景5.元组与格式化字符串6.元组与列表之间的转换元组的最大特征就是可访问不可改&#xff0c;可作为字典的键值&#xff0c;因为键值必须是唯一的。字符串也是不可边类型&#xff0c;因此也适合做字…

深度学习(莫烦 神经网络 lecture 3) Keras

神经网络 & Keras 目录 神经网络 & Keras目录1、Keras简介1.1 科普: 人工神经网络 VS 生物神经网络1.2 什么是神经网络 (Neural Network)1.3 神经网络 梯度下降1.4 科普: 神经网络的黑盒不黑1.5 Why Keras?1.6 兼容 backend 2、如何搭建各种神经网络2.1 Regressor回归…

阿里Java编程规约(集合)

【强制】关于 hashCode 和 equals 的处理&#xff0c;遵循如下规则&#xff1a; 1&#xff09; 只要覆写 equals&#xff0c;就必须覆写 hashCode。 2&#xff09; 因为 Set 存储的是不重复的对象&#xff0c;依据 hashCode 和 equals 进行判断&#xff0c;所以 Set 存储的对…

Pytorch(3)-数据载入接口:Dataloader、datasets

pytorch数据载入1.数据载入概况Dataloader 是啥2.支持的三类数据集2.1 torchvision.datasets.xxx2.2 torchvision.datasets.ImageFolder2.3 写自己的数据类&#xff0c;读入定制化数据2.3.1 数据类的编写map-style范式iterable-style 范式2.3.2 DataLoader 导入数据类1.数据载入…

大数据学习(5)-- NoSQL数据库

文章目录目录1.NoSQL的介绍2.NoSQL产生的原因2.1 web2.02.2 NoSQL兴起原因3.NoSQL和关系数据库的区别4.NoSQL的四大类型4.1 键值数据库4.2 列族数据库4.3 文档数据库4.4 图形数据库4.5 不同类型的NoSQL数据库进行比较5.NoSQL的三大基石5.1 CAP5.2 base5.3 最终一致性6.从NoSQL到…

经典算法重点总结

文章目录排序算法冒泡排序直接插入排序希尔排序直接选择排序快速排序堆排序归并排序总结查找算法顺序查找二分查找插值查找斐波那契查找树表查找分块查找哈希查找总结排序算法 冒泡排序 void bubbleSort(int a[] , int n){for(int i n-1 ; i > 0 ; i--){for(int j 0 ; j …