pytorch将数据与模型都放到GPU上训练

默认是CPU,如果想要用GPU需要:

  1. 安装配置cuda,然后更新/下载支持gpu版本的pytorch,可以参考:https://blog.csdn.net/weixin_35757704/article/details/124315569
  2. 设置device:
    device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
    
    然后将数据与模型后面都额外加上.to(device)即可

示例程序

import torch
import torch.nn as nn# 一个简单的模型
class LinearRegressionModel(nn.Module):def __init__(self, input_shape, output_shape):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_shape, output_shape)def forward(self, x):out = self.linear(x)return outdef main():x_train = torch.randn(100, 4)  # 生成训练特征y_train = torch.randn(100, 1)  # 生成labelmodel = LinearRegressionModel(x_train.shape[1], 1)optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 优化函数criterion = nn.MSELoss()  # 损失函数for epoch in range(100):optimizer.zero_grad()outputs = model(x_train)loss = criterion(outputs, y_train)loss.backward()optimizer.step()if __name__ == '__main__':main()

修改为GPU版本:

import torch
import torch.nn as nn# 一个简单的模型
class LinearRegressionModel(nn.Module):def __init__(self, input_shape, output_shape):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_shape, output_shape)def forward(self, x):out = self.linear(x)return outdef main():# 1. 设置devicedevice = torch.device('cuda' if torch.cuda.is_available else 'cpu')# 2. 数据与模型后都加 .to(device) 即可x_train = torch.randn(100, 4).to(device)  # 生成训练特征y_train = torch.randn(100, 1).to(device)  # 生成labelmodel = LinearRegressionModel(x_train.shape[1], 1).to(device)  # next(transformer.parameters()).deviceoptimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 优化函数criterion = nn.MSELoss()  # 损失函数for epoch in range(100):optimizer.zero_grad()outputs = model(x_train)loss = criterion(outputs, y_train)loss.backward()optimizer.step()if __name__ == '__main__':main()

修改后:

  1. 查看变量的位置:可以使用x_train.device查看tensor变量的位置
  2. 查看模型的位置:可以使用next(model.parameters()).device查看模型的位置

注意:不在同一个位置上的变量之间无法计算,模型无法使用不在同一个位置的数据

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/65358.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CPT203 Software Engineering 软件工程 Pt.1 概论和软件过程(中英双语)

文章目录 1.Introduction1.1 What software engineering is and why it is important(什么是软件工程,为什么它很重要)1.1 We can’t run the modern world without software(我们的世界离不开软件)1.1.1 What is Soft…

从 Coding (Jenkinsfile) 到 Docker:全流程自动化部署 Spring Boot 实战指南(简化篇)

前言 本文记录使用 Coding (以 Jenkinsfile 为核心) 和 Docker 部署 Springboot 项目的过程,分享设置细节和一些注意问题。 1. 配置服务器环境 在实施此过程前,确保服务器已配置好 Docker、MySQL 和 Redis,可参考下列链接进行操作&#xff1…

[WASAPI]音频API:从Qt MultipleMedia走到WASAPI,相似与不同

[WASAPI] 从Qt MultipleMedia 来看WASAPI 最近在学习有关Windows上的音频驱动相关的知识,在正式开始说WASAPI之前,我想先说一说Qt的Multiple Media,为什么呢?因为Qt的MultipleMedia实际上是WASAPI的一层封装,它在是线…

绝美的数据处理图-三坐标轴-散点图-堆叠图-数据可视化图

clc clear close all %% 读取数据 load(MyColor.mat) %读取颜色包for iloop 1:25 %提取工作表数据data0(iloop) {readtable(data.xlsx,sheet,iloop)}; end%% 解析数据 countzeros(23,14); for iloop 1:25index(iloop) { cell2mat(table2array(data0{1,iloop}(1,1)))};data(i…

第三百四十六节 JavaFX教程 - JavaFX绑定

JavaFX教程 - JavaFX绑定 JavaFX绑定同步两个值:当因变量更改时,其他变量更改。 要将属性绑定到另一个属性,请调用bind()方法,该方法在一个方向绑定值。例如,当属性A绑定到属性B时,属性B的更改将更新属性A…

详解VHDL如何编写Testbench

1.概述 仿真测试平台文件(Testbench)是可以用来验证所设计的硬件模型正确性的 VHDL模型,它为所测试的元件提供了激励信号,可以以波形的方式显示仿真结果或把测试结果存储到文件中。这里所说的激励信号可以直接集成在测试平台文件中,也可以从…

RNA-Seq 数据集、比对和标准化

RNA-Seq 数据集、比对和标准化|玉米中的元基因调控网络突出了功能上相关的调控相互作用。 RNA-Seq 表达分析代码和数据 该仓库是一个公开可用 RNA-Seq 数据集的集合(主要是玉米数据),提供了系统分析这些数据的代码/流程,以及质量…

学技术学英文:Spring AOP和 AspectJ 的关系

AspectJ是AOP领域的江湖一哥, Spring AOP 只是一个小弟 Spring AOP is implemented in pure Java. There is no need for a special compilation process. Spring AOP does not need to control the class loader hierarchy and is thus suitable for use in a ser…

JVM学习-内存结构(二)

一、堆 1.定义 2.堆内存溢出问题 1.演示 -Xmx设置堆大小 3.堆内存的诊断 3.1介绍 1,2都是命令行工具(可直接在ideal运行时,在底下打开终端,输入命令) 1可以拿到Java进程的进程ID,2 jmap只能查询某一个时…

Browser Use:AI智能体自动化操作浏览器的开源工具

Browser Use:AI智能体自动化操作浏览器的开源工具 Browser Use 简介1. 安装所需依赖2. 生成openai密钥3. 编写代码4. 运行代码5. 部署与优化5.1 部署AI代理5.2 优化与扩展总结Browser Use 简介 browser-use是一个Python库,它能够帮助我们将AI代理与浏览器自动化操作结合起来;…

Spring Cloud——注册中心

介绍 什么是注册中心? 主要负责服务的注册与发现,确保服务之间的通信顺畅,具体来说,注册中心有以下主要功能:‌服务注册、服务发现、服务健康检查。 服务注册: 服务提供者在启动时会向注册中心注册自身服务…

CSS基础入门【2】

目录 一、知识复习 二、权重问题深入 2.1 同一个标签,携带了多个类名,有冲突: 2.2 !important标记 2.3 权重计算的总结 三、盒模型 3.1 盒子中的区域 3.2 认识width、height 3.3 认识padding 3.4 border 作业: 一、知识…

捋一捋相关性运算,以及DTD和NLP中的应用

捋一捋相关性运算,以及DTD和NLP中的应用 相关性和相干性,有木有傻傻分不清相关性数字信号的相关运算同维度信号的相关理解 相关--互相关--相干 回声消除过程如何套用这些知识相关性/相干性检测在DT中的应用时域的标量与向量结合的互相关方法适合block处理的频域相干…

Elasticsearch:normalizer

一、概述 ‌Elastic normalizer‌是Elasticsearch中用于处理keyword类型字段的一种工具,主要用于对字段进行规范化处理,确保在索引和查询时保持一致性。 Normalizer与analyzer类似,都是对字段进行处理,但normalizer不会对字段进…

go语言的成神之路-筑基篇-对文件的操作

目录 一、对文件的读写 Reader?接口 ?Writer接口 copy接口 bufio的使用 ioutil库? 二、cat命令 三、包 1. 包的声明 2. 导入包 3. 包的可见性 4. 包的初始化 5. 标准库包 6. 第三方包 ?7. 包的组织 8. 包的别名 9. 包的路径 10. 包的版本管理 四、go mo…

【入门】拐角III

描述 输入整数N&#xff0c;输出相应方阵。 输入描述 一个整数N。&#xff08; 0 < n < 10 ) 输出描述 一个方阵&#xff0c;每个数字的场宽为3。 用例输入 1 5 用例输出 1 5 5 5 5 55 4 4 4 45 4 3 3 35 4 3 2 25 4 3 2 1 来源 二维数组…

攻防世界 ics-06

开启场景 可以交互的按钮不是很多&#xff0c;没有什么有用信息&#xff0c;查看页面源代码找到了index.php &#xff0c;后面跟着“报表中心” 传参访问 /index.php 看到了参数 id1&#xff0c;用 burp 抓包爆破&#xff08;这里应该不是 sql 注入&#xff09; 2333 的长度与众…

VMware虚拟机安装银河麒麟操作系统KylinOS教程(超详细)

目录 引言1. 下载2. 安装 VMware2. 安装银河麒麟操作系统2.1 新建虚拟机2.2 安装操作系统2.3 网络配置 3. 安装VMTools 创作不易&#xff0c;禁止转载抄袭&#xff01;&#xff01;&#xff01;违者必究&#xff01;&#xff01;&#xff01; 创作不易&#xff0c;禁止转载抄袭…

ByConity BSP 解锁数据仓库新未来

文章目录 前言BSP 模式简介基于 TPC-DS 的 ELT 活动测试环境登录 ECS数据查询配置 执行 02.sqlsql解释&#xff1a;1. 第一步&#xff1a;创建 wscs 临时表2. 第二步&#xff1a;创建 wswscs 临时表3. 第三步&#xff1a;对比 2001 年和 2002 年的数据子查询 1&#xff1a;提取…

论文解读 | EMNLP2024 一种用于大语言模型版本更新的学习率路径切换训练范式

点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入&#xff01; 点击 阅读原文 观看作者讲解回放&#xff01; 作者简介 王志豪&#xff0c;厦门大学博士生 刘诗雨&#xff0c;厦门大学硕士生 内容简介 新数据的不断涌现使版本更新成为大型语言模型&#xff08;LLMs&#xff…