PyTorch使用教程(11)-cuda的使用方法

1. 基本概念

CUDA(Compute Unified Device Architecture)是NVIDIA开发的一种并行计算平台和编程模型,专为图形处理器(GPU)设计,旨在加速科学计算、工程计算和机器学习等领域的高性能计算任务。CUDA允许开发人员使用GPU进行通用计算(也称为GPGPU,General-Purpose computing on Graphics Processing Units)。
在这里插入图片描述

2.Torch与CUDA

Torch是一个流行的深度学习库,由PyTorch开发团队创建,主要用于Python编程环境。当Torch结合CUDA时,它可以显著提升训练深度神经网络的速度。通过将数据和计算转移到GPU上,利用GPU的大量并行核心处理大量矩阵运算,实现对大规模数据集的高效处理。

3. 核心功能

(1)、torch.cuda.device
torch.cuda.device是一个上下文管理器,用于更改所选设备。它允许你在代码块内指定张量或模型应在哪个GPU上创建或执行。

(2)、 torch.cuda.is_available
torch.cuda.is_available()函数用于检查CUDA是否可用。如果系统中安装了NVIDIA的显卡驱动和CUDA工具包,并且PyTorch版本支持CUDA,那么该函数将返回True。

(3)、torch.device
torch.device是一个对象,表示张量可以存放的设备。它可以是CPU或某个GPU。通过指定torch.device(“cuda”),你告诉PyTorch你希望在一个支持CUDA的NVIDIA GPU上执行张量运算。如果有多个GPU,可以通过指定GPU的索引来选择其中一个,例如torch.device(“cuda:0”)表示第一个GPU,torch.device(“cuda:1”)表示第二个GPU,依此类推。

(4)、张量移动
在PyTorch中,你可以使用.to(‘cuda’)或.cuda()函数将张量(Tensor)从CPU移动到GPU。同样,你也可以使用这些方法将模型参数和优化器移动到GPU上。

4.功能示例

(1)、检查CUDA是否可用

import torchif torch.cuda.is_available():print("CUDA is available. Number of GPUs:", torch.cuda.device_count())
else:print("CUDA is not available.")

(2)、创建张量并移动到GPU

import torch# 在CPU上创建一个张量
x = torch.randn(3, 3)# 检查CUDA是否可用
if torch.cuda.is_available():# 将张量移动到GPUdevice = torch.device("cuda")x_gpu = x.to(device)print(x_gpu)  # 这将显示张量的设备为 "cuda:0"# 直接在GPU上创建另一个张量y = torch.randn(3, 3, device=device)z = x_gpu + y  # 这个加法操作在GPU上执行print(z)

(3)、在不同GPU上创建和操作张量

import torch# 在默认GPU上创建一个张量
x = torch.cuda.FloatTensor(1)
print("x.get_device() ==", x.get_device())  # 输出 0# 在GPU 1上创建一个张量
with torch.cuda.device(1):a = torch.cuda.FloatTensor(1)print("a.get_device() ==", a.get_device())  # 输出 1# 将CPU张量转移到GPU 1b = torch.FloatTensor(1).cuda()print("b.get_device() ==", b.get_device())  # 输出 1c = a + bprint("c.get_device() ==", c.get_device())  # 输出 1# 在GPU 0上的张量操作
z = x + x  # 仍然在GPU 0上
print("z.get_device() ==", z.get_device())  # 输出 0# 在特定GPU上创建张量
d = torch.randn(2).cuda(2)
print("d.get_device() ==", d.get_device())  # 输出 2

(4)、将模型和优化器移动到GPU

import torch
import torch.nn as nn
import torch.optim as optim# 创建一个简单的神经网络模型
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.fc1 = nn.Linear(3, 2)self.fc2 = nn.Linear(2, 1)def forward(self, x):x = self.fc1(x)x = self.fc2(x)return xnet = Net()# 检查CUDA是否可用
if torch.cuda.is_available():# 将模型参数和优化器移动到GPUdevice = torch.device("cuda")net = net.to(device)print(net)optimizer = optim.SGD(net.parameters(), lr=0.01)optimizer = optimizer.to(device)  # 注意:优化器通常不需要显式移动到GPU# 创建一些假数据并移动到GPU
inputs = torch.randn(20, 3).to(device)
targets = torch.randint(0, 2, (20,)).to(device)# 定义损失函数
criterion = nn.CrossEntropyLoss()# 训练模型
net.train()
for epoch in range(5):optimizer.zero_grad()outputs = net(inputs)loss = criterion(outputs, targets)loss.backward()optimizer.step()print(f'Epoch {epoch+1}, Loss: {loss.item()}')

5. 使用注意事项

(1)、GPU内存限制
显卡的内存是有限的,如果模型或数据过大,可能会导致内存不足的问题。可以通过减小批量大小、使用更小的模型或者使用分布式训练等方式来解决。

(2)、数据类型匹配
在使用CUDA加速时,需要确保模型和数据的数据类型匹配。通常情况下,模型和数据都应该使用torch.cuda.FloatTensor类型。

(3)、CUDA版本和驱动兼容性
确保安装了适用于CUDA的PyTorch版本以及相应版本的NVIDIA显卡驱动。不同版本的CUDA和PyTorch之间可能存在兼容性问题。

(4)、避免跨GPU操作
默认情况下,PyTorch不支持跨GPU操作。如果需要对分布在不同设备上的张量进行操作,需要显式地进行数据传输,这可能会引入额外的开销。

(5)、异步数据传输
为了将数据传输与计算重叠,可以使用异步的GPU副本。只需在调用cuda()时传递一个额外的async=True参数。此外,通过将pin_memory=True传递给DataLoader的构造函数,可以使DataLoader将batch返回到固定内存中,从而加快主机到GPU的复制速度。

(6)、多GPU训练
对于多GPU训练,PyTorch提供了nn.DataParallel等工具和函数来简化这一过程。然而,在使用多进程进行CUDA模型训练时需要注意线程安全和资源竞争等问题。

6、小结

torch.cuda是PyTorch中用于在NVIDIA GPU上进行加速计算的重要模块。通过合理利用CUDA的并行计算能力,可以显著提升深度学习模型的训练和推理速度。然而,在使用CUDA时也需要注意一些细节和限制,以确保程序的正确性和性能。通过本文的介绍和示例代码,希望读者能够更好地理解和使用torch.cuda进行深度学习开发。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/bicheng/68617.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

金融项目实战 07|Python实现接口自动化——连接数据库和数据清洗、测试报告、持续集成

目录 一、投资模块(投资接口投资业务) 二、连接数据库封装 和 清洗数据 1、连接数据库 2、数据清洗 4、调用 三、批量执行测试用例 并 生成测试报告 四、持续集成 1、代码上传gitee 2、Jenkin持续集成 一、投资模块(投资接口投资业务…

Ubuntu22.04安装paddle GPU版本

文章目录 确立版本安装CUDA与CUDNN安装paddle 确立版本 查看官网信息,确立服务版本:https://www.paddlepaddle.org.cn/documentation/docs/zh/2.6/install/pip/linux-pip.html 安装CUDA与CUDNN 通过nvidia-smi查看当前显卡驱动版本: 通过…

网络编程-UDP套接字

文章目录 UDP/TCP协议简介两种协议的联系与区别Socket是什么 UDP的SocketAPIDatagramSocketDatagramPacket 使用UDP模拟通信服务器端客户端测试 完整测试代码 UDP/TCP协议简介 两种协议的联系与区别 TCP和UDP其实是传输层的两个协议的内容, 差别非常大, 对于我们的Java来说, …

Unity补充 -- 协程相关

1.协程。 协程并不是线程。线程是主线程之外的另一条 代码按照逻辑执行通道。协程则是在代码在按照逻辑执行的同时,是否需要执行额外的语句块。 2.协程的作用。 在update执行的时候,是按照帧来进行刷新的,也是按照帧执行代码的。但是又不想…

IoTDB 常见问题 QA 第四期

关于 IoTDB 的 Q & A IoTDB Q&A 第四期来啦!我们将定期汇总我们将定期汇总社区讨论频繁的问题,并展开进行详细回答,通过积累常见问题“小百科”,方便大家使用 IoTDB。 Q1:Java 中如何使用 SSL 连接 IoTDB 问题…

Json转换类型报错问题:java.lang.Integer cannot be cast to java.math.BigDecimal

Json转换类型报错问题:java.lang.Integer cannot be cast to java.math.BigDecimal 小坑规避指南 小坑规避指南 项目中遇到json格式转换成Map,已经定义了Map的key和value的类型,但是在遍历Map取值的时候出现了类型转换的报错问题&#xff08…

数据结构——队列和栈(介绍、类型、Java手搓实现循环队列)

我是一个计算机专业研0的学生卡蒙Camel🐫🐫🐫(刚保研) 记录每天学习过程(主要学习Java、python、人工智能),总结知识点(内容来自:自我总结网上借鉴&#xff0…

python http server运行Angular 单页面路由时重定向,解决404问题

问题 当Angular在本地ng server运行时候,可以顺利访问各级路由。 但是运行ng build后,在dist 路径下的打包好的额index.html 必须要在服务器下运行才能加载。 在服务器下我们第一次访问路由页面时是没有问题的,但是尝试刷新页面或手动输入路…

SQL表间关联查询详解

简介 本文主要讲解SQL语句中常用的表间关联查询方式,包括:左连接(left join)、右连接(right join)、全连接(full join)、内连接(inner join)、交叉连接&…

Android Jni(一) 快速使用

文章目录 Android Jni(一) 快速使用1、 环境配置下载 NDK2、右键 add c to module3、创建一个 native 方法,并更具提示,自动创建对应的 JNI 实现4、实现对应 Jni 方法5、static loadLibrary6、调用执行 遇到的问题1、[CXX1300] CM…

【HarmonyOS之旅】基于ArkTS开发(二) -> UI开发之常见布局

目录 1 -> 自适应布局 1.1 -> 线性布局 1.1.1 -> 线性布局的排列 1.1.2 -> 自适应拉伸 1.1.3 -> 自适应缩放 1.1.4 -> 定位能力 1.1.5 -> 自适应延伸 1.2 -> 层叠布局 1.2.1 -> 对齐方式 1.2.2 -> Z序控制 1.3 -> 弹性布局 1.3.1…

物联网网关Web服务器--Boa服务器移植与测试

1、Boa服务器介绍 BOA 服务器是一个小巧高效的web服务器,是一个运行于unix或linux下的,支持CGI的、适合于嵌入式系统的单任务的http服务器,源代码开放、性能高。 Boa 嵌入式 web 服务器的官方网站是http://www.boa.org/。 特点 轻量级&#x…

tomcat状态一直是Exited (1)

docker run -di -p 80:8080 --nametomcat001 你的仓库地址/tomcat:9执行此命令后tomcat一直是Exited(1)状态 解决办法: 用以下命令创建运行 docker run -it --name tomcat001 -p 80:8080 -d 你的仓库地址/tomcat:9 /bin/bash最终结果 tomcat成功启动

三天急速通关Java基础知识:Day1 基本语法

三天急速通关JAVA基础知识:Day1 基本语法 0 文章说明1 关键字 Keywords2 注释 Comments2.1 单行注释2.2 多行注释2.3 文档注释 3 数据类型 Data Types3.1 基本数据类型3.2 引用数据类型 4 变量与常量 Variables and Constant5 运算符 Operators6 字符串 String7 输入…

表单中在不设置required的情况下在label前加*必填标识

参考:https://blog.csdn.net/qq_55798464/article/details/136233718 需求:在发票类型前面添加*必填标识 我最开始直接给发票类型这个表单类型添加了验证规则required:true,问题来了,这个发票类型它是有默认值的,所以我点击保存…

2025寒假备战蓝桥杯01---朴素二分查找的学习

文章目录 1.暴力方法的引入2.暴力解法的思考 与改进3.朴素二分查找的引入4.朴素二分查找的流程5.朴素二分查找的细节6.朴素二分查找的题目 1.暴力方法的引入 对于下面的这个有序的数据元素的组合,我们的暴力解法就是挨个进行遍历操作,一直找到和我们的这…

ROS机器人学习和研究的势-道-术-转型和变革的长期主义习惯

知易行难。说说容易做到难。 例如,不受成败评价影响,坚持做一件事情10年以上,专注事情本身。 机器人专业不合格且失败讲师如何让内心保持充盈的正能量(节选)-CSDN博客 时间积累 注册20年。 创作历程10年。 创作10年…

渗透测试之XEE[外部实体注入]漏洞 原理 攻击手法 xml语言结构 防御手法

目录 原理 XML语言解释 什么是xml语言: 以PHP举例xml外部实体注入 XML语言结构 面试题目 如何寻找xxe漏洞 XEE漏洞修复域防御 提高版本 代码修复 php java python 手动黑名单过滤(不推荐) 一篇文章带你深入理解漏洞之 XXE 漏洞 - 先知社区 原理 XXE&…

BUUCTF_Web([GYCTF2020]Ezsqli)

1.输入1 ,正常回显。 2.输入1 ,报错false,为字符型注入,单引号闭合。 原因: https://mp.csdn.net/mp_blog/creation/editor/145170456 3.尝试查询字段,回显位置,数据库,都是这个。…

react使用react-redux状态管理

1、安装 npm install react-redux2、创建store.js import { createStore } from redux;// 定义初始状态 const initialState {counter: 888 };// 定义 reducer 函数,根据 action 类型更新状态 function reducer(state initialState, action) {switch (action.ty…