pytorch:to()、device()、cuda()将Tensor或模型移动到指定的设备上

将Tensor或模型移动到指定的设备上:tensor.to(‘cuda:0’)

  • 最开始读取数据时的tensor变量copy一份到device所指定的GPU上去,之后的运算都在GPU上进行
  • 在做高维特征运算的时候,采用GPU无疑是比用CPU效率更高,如果两个数据中一个加了.cuda()或者.to(device),而另外一个没有加,就会造成类型不匹配而报错。

1. Tensor.to(device)

功能:将Tensor移动到指定的设备上。

以下代码将Tensor移动到GPU上:
device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)

1.1 修改dtype

a = tensor.to(torch.float64).

  • tensor.dtype : torch.float32
  • a.dtype : torch.float64

1.2 改变device:用字符串形式给出

a = tensor.to('cuda:0').

  • tensor.device : device(type=‘cpu’)
  • a.device : device(type=‘cuda’, index=0)

1.3 改变device:用torch.device给出

cuda0 = torch.device('cuda:0') .
b = tensor.to(cuda0) .

  • tensor.device : device(type=‘cpu’)
  • b.device : device(type=‘cuda’, index=0)

1.4 同时改变device和dtype

c = tensor.to('cuda:0',torch.float64) .
other = torch.randn((), dtype=torch.float64, device=cuda0) .
d = tensor.to(other, non_blocking=True) .

  • tensor.device:device(type=‘cpu’)
  • d :tensor([], device=‘cuda:0’, dtype=torch.float64))

2. model.to(device)

功能:将模型移动到指定的设备上。

使用以下代码将模型移动到GPU上:

import torch
import torch.nn as nnclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(10, 5)self.fc2 = nn.Linear(5, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return xmodel = Net()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

如果有多个GPU,使用以下方法:

if torch.cuda.device_count() > 1:model = nn.DataParallel(model,device_ids=[0,1,2])model.to(device)

将由GPU保存的模型加载到GPU上。

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)

将由GPU保存的模型加载到CPU上。
torch.load()函数中的map_location参数设置为torch.device('cpu')

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
model.to(device)

将由CPU保存的模型加载到GPU上。
torch.load()函数中的map_location参数设置为torch.device('cuda')

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)

参考:PyTorch之Tensor.to(device)和model.to(device)

3. .to(device) 和.cuda()的区别

  • .to(device) 可以指定CPU 或者GPU
  • .cuda() 只能指定GPU

图参考:pytorch中.to(device) 和.cuda()的区别
在这里插入图片描述

官方文档:CUDA SEMANTICS

with torch.cuda.device(1):# allocates a tensor on GPU 1a = torch.tensor([1., 2.], device=cuda)# transfers a tensor from CPU to GPU 1b = torch.tensor([1., 2.]).cuda()# a.device and b.device are device(type='cuda', index=1)# You can also use ``Tensor.to`` to transfer a tensor:b2 = torch.tensor([1., 2.]).to(device=cuda)# b.device and b2.device are device(type='cuda', index=1)
  • 两个方法都可以达到同样的效果,在pytorch中,即使是有GPU的机器,它也不会自动使用GPU,而是需要在程序中显示指定。
  • 调用model.cuda(),可以将模型加载到GPU上去。这种方法不被提倡,而建议使用model.to(device)的方式,这样可以显示指定需要使用的计算资源,特别是有多个GPU的情况下。

4. CUDA相关信息查询

import torch
print('CUDA版本:',torch.version.cuda)
print('Pytorch版本:',torch.__version__)
print('显卡是否可用:','可用' if(torch.cuda.is_available()) else '不可用')
print('显卡数量:',torch.cuda.device_count())
print('当前显卡的CUDA算力:',torch.cuda.get_device_capability(0))
print('当前显卡型号:',torch.cuda.get_device_name(0))
>>>
CUDA版本: 11.7
Pytorch版本: 1.13.1
显卡是否可用: 可用
显卡数量: 1
当前显卡的CUDA算力: (8, 6)
当前显卡型号: NVIDIA GeForce RTX 3060 Laptop GPU

参考:https://blog.csdn.net/weixin_43845386/article/details/131723010

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/222610.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

word四级目录序号不随上级目录序号变化问题解决方法

一、word中的几个元素简介 1、word中的列表 如下图所示,代表word的列表: 2、word中的标题 如下图所示,代表word的标题: 3、word中的编号/序号 如下图所示,代表word的编号/序号: 4、word中的目录 如下图…

Stable diffusion 简介

Stable diffusion 是 CompVis、Stability AI、LAION、Runway 等公司研发的一个文生图模型,将 AI 图像生成提高到了全新高度,其效果和影响不亚于 Open AI 发布 ChatGPT。Stable diffusion 没有单独发布论文,而是基于 CVPR 2022 Oral —— 潜扩…

在接口实现类中,加不加@Override的区别

最近的软件构造实验经常需要设计接口,我们知道Override注解是告诉编译器,下面的方法是重写父类的方法,那么单纯实现接口的方法需不需要加Override呢? 定义一个类实现接口,使用idea时,声明implements之后会…

cfa一级考生复习经验分享系列(三)

从总成绩可以看出,位于90%水平之上,且置信区间全体均高于90%线。 从各科目成绩可以看出,所有科目均位于90%线上或高于90%线,其中,另类与衍生、公司金额、经济学、权益投资、固定收益、财报分析表现较好,目测…

QEMU源码全解析 —— virtio(1)

接前一篇文章: 本文内容参考: 《趣谈Linux操作系统》 —— 刘超,极客时间 《QEMU/KVM》源码解析与应用 —— 李强,机械工业出版社 特此致谢! virtio简介 对于一台虚拟机而言,除了要虚拟化CPU和内存&…

一站式查询热门小程序排名,助力小程序运营决策

如今小程序数量激增,竞争日益激烈,如何能在众多同类小程序中脱颖而出,提高曝光度与下载量,是每一个小程序运营者都极为关心的问题。对此,及时准确地查询自己小程序的热门排名,分析强劲对手,找出自己的短板,都是提高小程序竞争力的重要一环。那我们该如何方便快捷地查询到这些关…

DNS:从域名解析到网络连接

目录 解密 DNS:从域名解析到网络连接的不可或缺 1. DNS的基本工作原理 1.1 本地解析器查询 1.2 递归查询 1.3 迭代查询 1.4 TLD 查询 1.5 权威 DNS 查询 2. DNS的重要性与作用 2.1 地址解析与负载均衡 2.2 网络故障处理与容错 2.3 安全性与防护 3. DNS的…

Flink 流处理流程 API详解

流处理API的衍变 Storm:TopologyBuilder构建图的工具,然后往图中添加节点,指定节点与节点之间的有向边是什么。构建完成后就可以将这个图提交到远程的集群或者本地的集群运行。 Flink:不同之处是面向数据本身的,会把D…

PyTorch 的 10 条内部用法

欢迎阅读这份有关 PyTorch 原理的简明指南[1]。无论您是初学者还是有一定经验,了解这些原则都可以让您的旅程更加顺利。让我们开始吧! 1. 张量:构建模块 PyTorch 中的张量是多维数组。它们与 NumPy 的 ndarray 类似,但可以在 GPU …

基于 Webpack5 Module Federation 的业务解耦实践

前言 本文中会提到很多目前数栈中使用的特定名词,统一做下解释描述 dt-common:每个子产品都会引入的公共包(类似 NPM 包) AppMenus:在子产品中快速进入到其他子产品的导航栏,统一维护在 dt-common 中,子产品从 dt-com…

c/c++ | 少量内存溢出不会影响程序的正常运行

常见 通过指针 获得 字符串“aaaaaaaaaaaaaaaaaaaaaaaaaaa” 数据,然后通过strcpy 拷贝到对象 char str1[5] 中,事实是造成了内存溢出,但是 ,最后 str1 打印输出的结果 是上面的 “aaaaaaaaaaaaaaaaaaaaaaaaaaa” 这是不正常的&am…

Draw.io or diagrams.net 使用方法

0 Preface/Foreword 在工作中,经常需要用到框图,流程图,时序图,等等,draw.io可以完成以上工作。 official website:draw.io 1 Usage 1.1 VS code插件 draw.io可以扩展到VS code工具中。

百度搜索展现服务重构:进步与优化

作者 | 瞭东 导读 本文将简单介绍搜索展现服务发展过程,以及当前其面临的三大挑战:研发难度高、架构能力欠缺、可复用性低,最后提出核心解决思路和具体落地方案,期望大家能有所收货和借鉴。 全文4736字,预计阅读时间12…

产品入门第四讲:Axure动态面板

📚📚 🏅我是默,一个在CSDN分享笔记的博主。📚📚 ​​​​​ 🌟在这里,我要推荐给大家我的专栏《Axure》。🎯🎯 🚀无论你是编程小白,还…

17.分割有效信息【2023.12.9】

1.问题描述 有时候我们需要截取字符串以获取有用的信息,比如对于字符串 “日期:2010-10-29”,我们需要截取后面的 10 个字符来获取日期,以便进行进一步分析。编写一个程序,输入一个字符串,然后输出截取后的…

【Spark精讲】Spark Shuffle详解

目录 Shuffle概述 Shuffle执行流程 总体流程 中间文件 ShuffledRDD生成 Stage划分 Task划分 Map端写入(Shuffle Write) Reduce端读取(Shuffle Read) Spark Shuffle演变 SortShuffleManager运行机制 普通运行机制 bypass 运行机制 Tungsten Sort Shuffle 运行机制…

Spark环境搭建和使用方法

目录 一、安装Spark (一)基础环境 (二)安装Python3版本 (三)下载安装Spark (四)配置相关文件 二、在pyspark中运行代码 (一)pyspark命令 &#xff08…

AR眼镜光学方案_AR眼镜整机硬件定制

增强现实(Augmented Reality,AR)技术通过将计算机生成的虚拟物体或其他信息叠加到真实世界中,实现对现实的增强。AR眼镜作为实现AR技术的重要设备,具备虚实结合、实时交互的特点。为了实现透视效果,AR眼镜需要同时显示真实的外部世…

基于vue实现的疫情数据可视化分析及预测系统-计算机毕业设计推荐django

目 录 摘 要 I ABSTRACT II 目 录 II 第1章 绪论 1 1.1背景及意义 1 1.2 国内外研究概况 1 1.3 研究的内容 1 第2章 相关技术 3 2.1 nodejs简介 4 2.2 express框架介绍 6 2.4 MySQL数据库 4 第3章 系统分析 5 3.1 需求分析 5 3.2 系统可行性分析 5 3.2.1技术可行性:…

mac python安装grpcio以及xcode升级权限问题记录

问题1: ERROR: Could not build wheels fol grpcio, which is required to install pyproject.toml-based projects pip3 install --no-cache-dir --force-reinstall -Iv grpcio1.41.0 # (我这里是降级安装的) 问题2: fatal error: ‘stdio.h’ file not found 25 | #include …