torch中张量与数据类型的介绍

PyTorch张量的定义介绍

PyTorch最基本的操作对象是张量,它表示一个多维数组,类似NumPy的数组,但是前者可以在GPU上加速计算

初始化张量
t=torch.tensor([1,2])  # 创建一个张量
print(t)       
t.dtype       #打印t的数据类型为torch.int64

如果直接从Python数据创建张量,无须指定类型,PyTorch会自动推荐其类型,可通过张量的dtype属性查看其数据的类型,与numpy中的语法很相似

#创建float类型
t=torch.FloatTensor([1,2])
print(t)
print(t.dtype)

# 创建int类型的数据
t=torch.LongTensor([1,2])
print(t)
print(t.dtype)

也可以使用torch.from_numpy()方法从NumPy数组ndarray创建张量。

np_array=np.array([[1,2],[3,4]])
t_np=torch.from_numpy(np_array.reshape(1,4)).type(torch.FloatTensor)
print(t_np)
print(t_np.dtype)

首先我们能够看得到np_array依然是numpy的对象数据,依旧可以使用reshape等相关的方法
我们可以使用torch.from_numpy()方法在numpy数组上创建张量

t=torch.tensor([1,2],dtype=torch.int64)
print(t)
print(t.dtype)

t=torch.tensor([1,2],dtype=torch.float32)
print(t)
print(t.dtype)

这里用的最多的两种类型为int64,float32,这两种类型也尝尝被表示为torch.long和torch.float,其实这两种了类型对应了两种创建张量的方法,分别为
torch.FloatTensor()和torch.LongTensor()

#三种方法创建int64类型的数据
#1.
t1=torch.LongTensor([1,2])
print(t1)
print(t1.dtype)#2.
t2=torch.tensor([1,2],dtype=torch.long)
print(t2)
print(t2.dtype)#3.
t3=torch.tensor([1,2],dtype=torch.int64)
print(t3)
print(t3.dtype)

#三种方法创建float32类型的数据
#1.
t1=torch.FloatTensor([1,2])
print(t1)
print(t1.dtype)#2.
t2=torch.tensor([1,2],dtype=torch.float)
print(t2)
print(t2.dtype)#3.
t3=torch.tensor([1,2],dtype=torch.float32)
print(t3)
print(t3.dtype)

那他们确实是相等的吗,我们用代码来实现一下

print(torch.float==torch.float32)#判断是否相等,结果放回为true
print(torch.long==torch.int64)

PyTorch中张量的类型可以使用type()方法进行转换

t=torch.tensor([1,2],dtype=torch.float)
print(t.dtype)
#使用type()方法进行数据类型转化
t=t.type(torch.int64)
print(t.dtype)

torch框架中提供了两个快捷的实力转换方法

t=torch.tensor([1,2],dtype=torch.float)
print(t.dtype)
t=t.long()#使用long()将float32数据转换类int64的数据类型
print(t.dtype)
t=t.float()
print(t.dtype)

创建随机值张量

t=torch.rand(2,3)#创建一个2列3行的0-1均匀分布的随机数
print(t)t=torch.randn(2,3)#创建一个2列3行的标准正态分布随机数
print(t)t=torch.zeros(3,3)#创建一个3行3列的全是0的张量
print(t)t=torch.ones(3,2)#创建一个3行2列的全是1的张量
print(t)

t=torch.ones(3,3,3)#创建三维的数组,就只需要传入3个数据进入就行
print(t)

x=torch.zeros_like(t)  #类似的方法还有torch.ones_like()
print(x)x=torch.rand_like(t)
print(x)

张量的属性

t=torch.ones(2,3,dtype=torch.float32)
print(t.shape)
print(t.size())
print(t.size(0))
print(t.dtype)
print(t.device)

将张量移动到显存

张量可以在cpu上运行,也可以在GPU上运行,在GPU上的运算速度通常高于CPU,默认是在CPU上创建张量,如下可以使用tensor.to()方法将张量移动到GPU上

# 如果GPU可用,将张量移动到显存
if torch.cuda.is_available():t=t.to('cuda')print('true')
print(t.device)

#一般使用如下代码获取当前可用设备
device="cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))
t=t.to(device)
print(t.device)

张量的运算

t1=torch.randn(2,3)
t2=torch.ones(2,3)print(t1)
print(t2)
print(t1+3)   # t1中每一个元素都加3
print(t1+t2)  #t1与t2中每一个相同位置的元素相加 #不能将不同维数的张量相加,张量当中也有一个广播机制存在

张量加法运算
t3=torch.ones(1,3)
print(t3)
print(t1+t3)

t4=torch.randn(2,3)
t5=torch.ones(2,3)
print(t4)
print(t5)
t4.add_(t5)
print(t4)

如果一个运算方法后面加上下划线,代表就地改变原值,即上方中的t1.add_(t2)会直接将运算结果保存为t1,这样做可以节省内存,但是缺点就是会直接改变t1原值,在使用此方法的时候一定要谨慎使用。

矩阵乘法
print(t5.T)
print(t4.matmul(t5.T)) # matmul中没有下划线的方法
print(t4 @ (t5.T)) #都是对t1和t2的转置进行矩阵乘法

tensor.item
#输出一个python浮点数
print(t3)
print(t3.item())
#首先将张量中的所有的元素求和,得到只有一个元素的张量,然后使用tensor.item()方法将其转换为标量
#这种转换在我们希望打印模型正确率和损失值的时候很常见

与numpy数据类型的转换

a=np.random.randn(2,3)  #创建一个形状为(2,3)的ndarray对象
print("ndarray为:\n",a)                
t=torch.from_numpy(a)   #使用torch.from_numpy创建一个ndarray创建一个张量
print("tensor为:\n",t)       
print("numpy所对应的ndarray为:\n",t.numpy())         #使用torch.numpy()方法获得张量对应的ndarray

张量的变形

tensor.size()和tensor.shape属性可以返回张量的形状,方需要改变张量的形状时,可以通过tensor.view()方法,这个方法相当于NumPy中的reshape方法,用于改变张量的形状,但是在转换的过程中,一定要确保元素数量一致。

t=torch.randn(4,6,dtype=torch.float32)
print(t)
print(t.shape)t1=t.view(3,8)
print(t1)
print(t1.shape)

t1=t.view(-1,1)
print(t1)

如果想展成横着的一维,就默认让view(1,-1)中设置1

t1=t.view(1,-1)
print(t1)

# 也可以使用view增加维度,当然元素个数是不变的
t1=t.view(1,4,6)
print(t1)
print(t1.shape)

对于维度长度为1的张量,可以使用torch.squeeze()方法去掉长度为1的维度,相应的也有一个增加长度为1维度的方法,即torch.unsqueeze()

print(t1.shape)
print(t1)
t2=torch.squeeze(t1)
print(t2.shape)
print(t2)
t3=torch.unsqueeze(t2,0)
print(t3.shape)
print(t3)

张量的自动微分

在PyTorch中,张量有一个requires_grad属性可以在创建张量时指定此属性为True,如果requires_grad属性被设置为True,PyTorch将开始跟踪对此张量的所有计算,完成计算,可以对计算记过调用backward()方法,PyTorch将自动计算所有梯度。该张量的梯度将累加到张量的grad属性中,张量的grad_fn属性则指向运算生成此张量的方法。
张量的requires_grad属性用来明确是否跟踪张量的梯度,grad属性表示计算得到的梯度,grad_fn属性表示运算得到生成此张量的方法。

t=torch.ones(2,2,requires_grad=True)  # 这里将requires_grad设为True
print(t)
print(t.requires_grad)  # 输出是否跟踪计算张量梯度,输出True
print(t.grad)  #输出tensor.grad输出张量的梯度,输出为None,表示目前t没有梯度

接下来进行张量运算,得到y

y=t+5
print(y)
print(y.grad_fn)
print(y.grad)

# 进行其他运算
z=y*2
out=z.mean()
print(out)

上面的代码中,首先创建了张量,并指定requires_grad属性为True,目前其grad和grad_fn属性均为空。然后经过加法、乘法和取均运算,我们得到了out这个最终结果,注意,现在out只有单个元素,它是一个标量。下面在out上执行自动微分运算,并且输出t的梯度(d(out)/d(x))

out.backward()
print(t.grad)

print(t.requires_grad)
print((t+2).requires_grad)

# 将代码块包装在with torch.no_grad():上下文中
with torch.no_grad():print((t+2).requires_grad) # PyTorch没有继续跟踪此张量的运算

也可以使用tensor.detach()方法获得具有相同内容但是不需要跟踪运算的新张量,可以认为是获得张量的值

print(out.requires_grad)
s=out.detach()  # 获得out的值,也可以使用out.data()方法
print(s.requires_grad)

可以使用requires.grad_()方法就地改变张量的这个属性,当我们希望模型的参数不在随着训练变化时,我们可以使用此方法

print(t.requires_grad)
t.requires_grad_(False)
print(t.requires_grad)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/228403.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

尺度函数与小波函数

尺度函数与小波函数 尺度函数 设存在函数 φ j , k ( x ) 2 j / 2 φ ( 2 j x − k ) \varphi_{j,k}(x)2^{j/2}\varphi(2^{j}x-k) φj,k​(x)2j/2φ(2jx−k) 对所有的 j j j, k ∈ Z k{\in}\mathbb{Z} k∈Z 和 φ ( x ) ∈ L 2 ( R ) \varphi(x){\in}L^2(R) φ(x)∈L2(R)…

为什么Apache Doris适合做大数据的复杂计算,MySQL不适合?

为什么Apache Doris适合做大数据的复杂计算,MySQL不适合? 一、背景说明二、DB架构差异三、数据结构差异四、存储结构差异五、总结 一、背景说明 经常有小伙伴发出这类直击灵魂的疑问: Q:“为什么Apache Doris适合做大数据的复杂计…

大数据与深度挖掘:如何在数字营销中与研究互动

数字营销最吸引人的部分之一是对数据的内在关注。 如果一种策略往往有积极的数据,那么它就更容易采用。同样,如果一种策略尚未得到证实,则很难获得支持进行测试。 数字营销人员建立数据信心的主要方式是通过研究。这些研究通常分为两类&…

【教3妹学编程-算法题】找出峰值

3妹:2哥2哥,你有没有看到新闻:北京地铁事故中102人骨折! 2哥 : 看到了,没想到坐个地铁还出事故了。 3妹:事故原因为雪天轨滑导致前车信号降级,紧急制动停车,后车因所在区段位于下坡地…

【️Java是值传递还是引用传递?】

✅Java是值传递还是引用传递? ✅Java是值传递还是引用传递?✅典型理解 ✅增加知识仓✅Java的求值策略✅Java中的对象传递✅值传递和共享对象传递的现象冲突吗? ✅总结 ✅Java是值传递还是引用传递? ✅典型理解 编程语言中需要进行方法间的…

kafka学习笔记--Kafka副本

本文内容来自尚硅谷B站公开教学视频,仅做个人总结、学习、复习使用,任何对此文章的引用,应当说明源出处为尚硅谷,不得用于商业用途。 如有侵权、联系速删 视频教程链接:【尚硅谷】Kafka3.x教程(从入门到调优…

比特币即自由

号外:教链内参12.15《疯狂的铭文》 文 | Ross Ulbricht. 原文标题:Bitcoin Equals Freedom. 2019.9.25 在中本聪发明比特币后的头一年左右,发生了一些特别的事情,不仅没有人预料到,甚至很多人认为不可能。试着想象一下…

昇腾Profiling性能分析工具使用问题案例

昇腾Profiling性能分析工具用于采集和分析运行在昇腾硬件上的AI任务各个运行阶段的关键性能指标, 用户可根据输出的性能数据,快速定位软、硬件性能瓶颈,提升AI任务性能分析的效率。具体使用方法请参考: 本期分享几个关于Profiling性能分析工具…

【CMU 15-445】Lecture 11: Joins Algorithms 学习笔记

Joins Algorithms Nested Loop JoinNaive Nested Loop JoinBLock Nested Loop JoinIndex Nested Loop Join Sort-Merge JoinHash JoinBasic Hash JoinPartitioned Hash Join Conclusion 本节课主要介绍的是数据库系统中的一些Join算法 Nested Loop Join Naive Nested Loop Joi…

高压脉冲发生器的各种电路图

高压脉冲发生器电路图一: 高压脉冲发生器的主放电回路的等效电路。其中,S是可控开关,C1是电容器组电容,R1是高压变压器输入端的损耗电阻,L1,L2分别是高压变压器初次级电感,K为耦合系数&#xff…

架构设计系列之基础设施能力建设

周末聊两句: 今天将的基础设施能力建设部分,一般的架构书籍中都不存在的部分,这是我在实践过程中的经验和能力总结部分,希望和大家有一个很好的交流自从在 WeChat 中开了订阅号的两周半的时间,非常感谢大家的支持&…

K - 近邻算法

1、算法介绍 KNN(K Near Neighbor):k个最近的邻居,即每个样本都可以用它最接近的k个邻居来代表。KNN算法属于监督学习方式的分类算法,我的理解就是计算某给点到每个点的距离作为相似度的反馈。 简单来讲,KN…

代码随想录算法训练营第十八天 | 前中后序构造二叉树

目录 力扣题目 力扣题目记录 513.找树左下角的值 递归 迭代法 总结 112. 路径总和 106.从中序与后序遍历序列构造二叉树 总结 力扣题目 用时:2h 1、513.找树左下角的值 2、112. 路径总和 3、106.从中序与后序遍历序列构造二叉树 力扣题目记录 513.找树…

持续集成交付CICD:基于 GitLabCI 与 JenkinsCD 实现后端项目发布

目录 一、实验 1. GitLabCI环境设置 2.优化GitLabCI共享库代码 3.JenkinsCD 发布后端项目 4.再次优化GitLabCI共享库代码 5.JenkinsCD 再次发布后端项目 一、实验 1. GitLabCI环境设置 (1)GitLab给后端项目添加CI配置路径 (2&#xf…

算法通关村第十二关—字符串冲刺题(黄金)

字符串冲刺题 一、最长公共前缀 LeetCode14 编写一个函数来查找字符串数组中的最长公共前缀。如果不存在公共前缀,返回空字符串"" 示例1: 输入:strs["flower","fLow","flight"] 输出:&…

机器学习算法---时间序列

类别内容导航机器学习机器学习算法应用场景与评价指标机器学习算法—分类机器学习算法—回归机器学习算法—聚类机器学习算法—异常检测机器学习算法—时间序列数据可视化数据可视化—折线图数据可视化—箱线图数据可视化—柱状图数据可视化—饼图、环形图、雷达图统计学检验箱…

SVPWM马鞍波形仿真(python)

SVPWM波的原理不再过多介绍。 最近在学习SVPWM,仿真了一下马鞍波。 python源码贡献出来。 import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as anim############################################# # 我们的目的是根据机械角度&…

12.16_黑马数据结构与算法笔记Java

目录 167 B树 remove 168 B树 remove 搭架子 169 B树 remove case1-4 170 B树 remove case5-6分析 171 B树 remove case5 旋转 172 B树 remove case5 合并 173 B树 remove case6 174 B树 remove 演示1 175 B树 remove 演示2 176 哈希表 概述 177 哈希表 hash码映射索…

XXE漏洞 [NCTF2019]Fake XML cookbook1

打开题目 查看源代码 发现我们post传入的数据都被放到了doLogin.php下面 访问一下看看 提示加载外部xml实体 bp抓包一下看看 得到flag 或者这样 但是很明显这样是不行的,因为资源是在admin上,也就是用户名那里 PHP引用外部实体,常见的利用…

(2)Linux 操作系统||基本创建与操作

本章将浅谈一下 "操作系统是什么" 的问题,随后通过讲解一些 Linux 下的基本指令,显示目录内容、跳转操作和文件的创建与删除。在讲解的同时我会穿插一些知识点,比如 Linux 隐藏文件、路径等基础知识。 了解操作系统 什么是操作系统…