pytorch:to()、device()、cuda()将Tensor或模型移动到指定的设备上

将Tensor或模型移动到指定的设备上:tensor.to(‘cuda:0’)

  • 最开始读取数据时的tensor变量copy一份到device所指定的GPU上去,之后的运算都在GPU上进行
  • 在做高维特征运算的时候,采用GPU无疑是比用CPU效率更高,如果两个数据中一个加了.cuda()或者.to(device),而另外一个没有加,就会造成类型不匹配而报错。

1. Tensor.to(device)

功能:将Tensor移动到指定的设备上。

以下代码将Tensor移动到GPU上:
device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

1.1 修改dtype

a = tensor.to(torch.float64).

  • tensor.dtype : torch.float32
  • a.dtype : torch.float64

1.2 改变device:用字符串形式给出

a = tensor.to('cuda:0').

  • tensor.device : device(type=‘cpu’)
  • a.device : device(type=‘cuda’, index=0)

1.3 改变device:用torch.device给出

cuda0 = torch.device('cuda:0') .
b = tensor.to(cuda0) .

  • tensor.device : device(type=‘cpu’)
  • b.device : device(type=‘cuda’, index=0)

1.4 同时改变device和dtype

c = tensor.to('cuda:0',torch.float64) .
other = torch.randn((), dtype=torch.float64, device=cuda0) .
d = tensor.to(other, non_blocking=True) .

  • tensor.device:device(type=‘cpu’)
  • d :tensor([], device=‘cuda:0’, dtype=torch.float64))

2. model.to(device)

功能:将模型移动到指定的设备上。

使用以下代码将模型移动到GPU上:

import torch
import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return xmodel = Net()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

如果有多个GPU,使用以下方法:

if torch.cuda.device_count() > 1:model = nn.DataParallel(model,device_ids=[0,1,2])model.to(device)

将由GPU保存的模型加载到GPU上。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由GPU保存的模型加载到CPU上。
torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
model.to(device)

将由CPU保存的模型加载到GPU上。
torch.load()函数中的map_location参数设置为torch.device('cuda')

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

参考:PyTorch之Tensor.to(device)和model.to(device)

3. .to(device) 和.cuda()的区别

  • .to(device) 可以指定CPU 或者GPU
  • .cuda() 只能指定GPU

图参考:pytorch中.to(device) 和.cuda()的区别
在这里插入图片描述

官方文档:CUDA SEMANTICS

with torch.cuda.device(1):# allocates a tensor on GPU 1a = torch.tensor([1., 2.], device=cuda)# transfers a tensor from CPU to GPU 1b = torch.tensor([1., 2.]).cuda()# a.device and b.device are device(type='cuda', index=1)# You can also use ``Tensor.to`` to transfer a tensor:b2 = torch.tensor([1., 2.]).to(device=cuda)# b.device and b2.device are device(type='cuda', index=1)
  • 两个方法都可以达到同样的效果,在pytorch中,即使是有GPU的机器,它也不会自动使用GPU,而是需要在程序中显示指定。
  • 调用model.cuda(),可以将模型加载到GPU上去。这种方法不被提倡,而建议使用model.to(device)的方式,这样可以显示指定需要使用的计算资源,特别是有多个GPU的情况下。

4. CUDA相关信息查询

import torch
print('CUDA版本:',torch.version.cuda)
print('Pytorch版本:',torch.__version__)
print('显卡是否可用:','可用' if(torch.cuda.is_available()) else '不可用')
print('显卡数量:',torch.cuda.device_count())
print('当前显卡的CUDA算力:',torch.cuda.get_device_capability(0))
print('当前显卡型号:',torch.cuda.get_device_name(0))
>>>
CUDA版本: 11.7
Pytorch版本: 1.13.1
显卡是否可用: 可用
显卡数量: 1
当前显卡的CUDA算力: (8, 6)
当前显卡型号: NVIDIA GeForce RTX 3060 Laptop GPU

参考:https://blog.csdn.net/weixin_43845386/article/details/131723010

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/222610.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

word四级目录序号不随上级目录序号变化问题解决方法

一、word中的几个元素简介 1、word中的列表 如下图所示,代表word的列表: 2、word中的标题 如下图所示,代表word的标题: 3、word中的编号/序号 如下图所示,代表word的编号/序号: 4、word中的目录 如下图…

Stable diffusion 简介

Stable diffusion 是 CompVis、Stability AI、LAION、Runway 等公司研发的一个文生图模型,将 AI 图像生成提高到了全新高度,其效果和影响不亚于 Open AI 发布 ChatGPT。Stable diffusion 没有单独发布论文,而是基于 CVPR 2022 Oral —— 潜扩…

面试经典150题(20)

leetcode 150道题 计划花两个月时候刷完,今天(第八天)完成了1道(20)150: 这个题花了我快 2个小时。。。 20:(6. N 字形变换)题目描述: 将一个给定字符串 s 根据给定的行数 numRow…

Linux 之 性能优化

uptime $ uptime -p up 1 week, 1 day, 21 hours, 27 minutes$ uptime12:04:11 up 8 days, 21:27, 1 user, load average: 0.54, 0.32, 0.23“12:04:11” 表示当前时间“up 8 days, 21:27,” 表示运行了多长时间“load average: 0.54, 0.32, 0.23”“1 user” 表示 正在登录…

cfa一级考生复习经验分享系列(二)

本人学历背景:毕业于普通二本学校,本科专业为金融学,所以还算是有一定的基础,今年19年6月份参加了cfa一级的考试,最终通过了一级的考试,虽然成绩一般,但是已经知足了,因为自己复习的…

在接口实现类中,加不加@Override的区别

最近的软件构造实验经常需要设计接口,我们知道Override注解是告诉编译器,下面的方法是重写父类的方法,那么单纯实现接口的方法需不需要加Override呢? 定义一个类实现接口,使用idea时,声明implements之后会…

cfa一级考生复习经验分享系列(三)

从总成绩可以看出,位于90%水平之上,且置信区间全体均高于90%线。 从各科目成绩可以看出,所有科目均位于90%线上或高于90%线,其中,另类与衍生、公司金额、经济学、权益投资、固定收益、财报分析表现较好,目测…

芯知识 | 什么是可重复擦写(Flash型)语音芯片?

什么是可重复擦写(Flash型)语音芯片? 可重复擦写(Flash型)语音芯片是一种嵌入式语音存储解决方案,采用了Flash存储技术,使得语音内容能够被多次擦写、更新,为各种嵌入式系统提供了灵…

QEMU源码全解析 —— virtio(1)

接前一篇文章: 本文内容参考: 《趣谈Linux操作系统》 —— 刘超,极客时间 《QEMU/KVM》源码解析与应用 —— 李强,机械工业出版社 特此致谢! virtio简介 对于一台虚拟机而言,除了要虚拟化CPU和内存&…

.360勒索病毒解密方法|勒索病毒解决|勒索病毒恢复|数据库修复

导言: 在数字化时代,.360勒索病毒如影随形,威胁个人和组织的数据安全。本文将深入介绍.360病毒的特征、威胁,以及如何有效地恢复被加密的数据文件,同时提供预防措施,助您更好地保护数字资产。如不幸感染这…

原生小程序中对特定数据进行计算(wxml中wxs的使用)

背景&#xff1a;商品详情页对好评数进行统计&#xff0c;但是现在只有商品数据 使用wxs编写方法&#xff0c;module.exports导出&#xff0c;wxml中使用module名进行获取{{goodsRate.getRate(goodsInfoList)}} <wxs module"goodsRate">module.exports {get…

一站式查询热门小程序排名,助力小程序运营决策

如今小程序数量激增,竞争日益激烈,如何能在众多同类小程序中脱颖而出,提高曝光度与下载量,是每一个小程序运营者都极为关心的问题。对此,及时准确地查询自己小程序的热门排名,分析强劲对手,找出自己的短板,都是提高小程序竞争力的重要一环。那我们该如何方便快捷地查询到这些关…

Java通过documents4j和libreoffice把word转为pdf

文章目录 word转pdf的相关第三方jar说明Linux系统安装LibreOffice在线安装离线安装word转pdf验证 Java工具类代码 word转pdf的相关第三方jar说明 docx4j 免费开源、稍微复杂点的word&#xff0c;样式完全乱了&#xff0c;且xalan升级为2.7.3后会报错。poi 免费开源、官方文档少…

DNS:从域名解析到网络连接

目录 解密 DNS&#xff1a;从域名解析到网络连接的不可或缺 1. DNS的基本工作原理 1.1 本地解析器查询 1.2 递归查询 1.3 迭代查询 1.4 TLD 查询 1.5 权威 DNS 查询 2. DNS的重要性与作用 2.1 地址解析与负载均衡 2.2 网络故障处理与容错 2.3 安全性与防护 3. DNS的…

Flink 流处理流程 API详解

流处理API的衍变 Storm&#xff1a;TopologyBuilder构建图的工具&#xff0c;然后往图中添加节点&#xff0c;指定节点与节点之间的有向边是什么。构建完成后就可以将这个图提交到远程的集群或者本地的集群运行。 Flink&#xff1a;不同之处是面向数据本身的&#xff0c;会把D…

PyTorch 的 10 条内部用法

欢迎阅读这份有关 PyTorch 原理的简明指南[1]。无论您是初学者还是有一定经验&#xff0c;了解这些原则都可以让您的旅程更加顺利。让我们开始吧&#xff01; 1. 张量&#xff1a;构建模块 PyTorch 中的张量是多维数组。它们与 NumPy 的 ndarray 类似&#xff0c;但可以在 GPU …

【springboot】【easyexcel】excel文件读取

目录 pom.xmlExcelVo逐行读取并处理全部读取并处理向ExcelListener 传参 pom.xml <dependency><groupId>com.alibaba</groupId><artifactId>easyexcel</artifactId><version>3.1.1</version> </dependency>ExcelVo 字段映射…

基于 Webpack5 Module Federation 的业务解耦实践

前言 本文中会提到很多目前数栈中使用的特定名词&#xff0c;统一做下解释描述 dt-common&#xff1a;每个子产品都会引入的公共包(类似 NPM 包) AppMenus&#xff1a;在子产品中快速进入到其他子产品的导航栏&#xff0c;统一维护在 dt-common 中&#xff0c;子产品从 dt-com…

c/c++ | 少量内存溢出不会影响程序的正常运行

常见 通过指针 获得 字符串“aaaaaaaaaaaaaaaaaaaaaaaaaaa” 数据&#xff0c;然后通过strcpy 拷贝到对象 char str1[5] 中&#xff0c;事实是造成了内存溢出&#xff0c;但是 &#xff0c;最后 str1 打印输出的结果 是上面的 “aaaaaaaaaaaaaaaaaaaaaaaaaaa” 这是不正常的&am…

Draw.io or diagrams.net 使用方法

0 Preface/Foreword 在工作中&#xff0c;经常需要用到框图&#xff0c;流程图&#xff0c;时序图&#xff0c;等等&#xff0c;draw.io可以完成以上工作。 official website:draw.io 1 Usage 1.1 VS code插件 draw.io可以扩展到VS code工具中。