PyTorch学习笔记(十六)——利用GPU训练

 一、方式一

网络模型、损失函数、数据(包括输入、标注)

找到以上三种变量,调用它们的.cuda(),再返回即可

if torch.cuda.is_available():mynn = mynn.cuda()
if torch.cuda.is_available():loss_function = loss_function.cuda()
for data in train_dataloader:imgs,targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()
for data in test_dataloader:imgs,targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()

完整代码:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
# from model import *# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="../datasets",train=True,transform=torchvision.transforms.ToTensor(),download=False)
test_data = torchvision.datasets.CIFAR10(root="../datasets",train=False,transform=torchvision.transforms.ToTensor(),download=False)train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用dataloader来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 创建网络模型
class MyNN(nn.Module):def __init__(self):super(MyNN, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x
mynn = MyNN()
if torch.cuda.is_available():mynn = mynn.cuda()# 损失函数
loss_function = nn.CrossEntropyLoss()
if torch.cuda.is_available():loss_function = loss_function.cuda()
# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(mynn.parameters(), lr=learning_rate)# 设置训练网络的一些参数
total_train_step = 0 # 记录训练次数
total_test_step = 0 # 记录测试次数
epoch = 10 # 训练的轮数# 添加tensorboard
writer = SummaryWriter("../logs_train")start_time = time.time()
for i in range(epoch):print("----------第{}轮训练开始----------".format(i+1))# 训练步骤开始mynn.train()for data in train_dataloader:imgs,targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()outputs = mynn(imgs)loss = loss_function(outputs, targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print("所用时间:{}".format(end_time - start_time))print("训练次数:{},loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_step)# 测试步骤开始mynn.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs,targets = dataif torch.cuda.is_available():imgs = imgs.cuda()targets = targets.cuda()outputs = mynn(imgs)loss = loss_function(outputs, targets)total_test_loss += lossaccuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集上的loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss",total_test_loss,total_test_step)writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)total_test_step += 1torch.save(mynn,"mynn_{}.pth".format(i))# torch.save(mynn.state_dict(),"mynn_{}.pth".format(i))print("模型已保存")writer.close()

 比较CPU和GPU的训练时间:

 查看GPU信息:

在 终端里输入nvidia-smi

 使用Google Colab:Google 为我们提供了一个免费的GPU

修改 ——> 笔记本设置 ——> 硬件加速器选择GPU(每周免费使用30h)

 

 

 二、方式二(更常用)

定义训练设备

device = torch.device("cpu")
# 对于单显卡来说,以下两种方式没有区别
device = torch.device("cuda")
device = torch.device("cuda:0")
# 一种语法的简写,程序在 CPU 或 GPU/cuda 环境下都能运行
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

网络模型、损失函数、数据(包括输入、标注)

找到以上三种变量,.to(device),再返回即可

mynn = MyNN()
mynn = mynn.to(device)
# 这里可以不用再赋值给mynn,直接mynn.to(device) 也可以
loss_function = nn.CrossEntropyLoss()
loss_function = loss_function.to(device)
# 这里可以不用再赋值给loss_function ,直接loss_function .to(device) 也可以
for data in train_dataloader:imgs,targets = dataimgs = imgs.to(device)targets = targets.to(device)# 这里必须赋值
for data in test_dataloader:imgs,targets = dataimgs = imgs.to(device)targets = imgs.to(device)# 这里必须赋值

完整代码:

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import time
# from model import *# 定义训练的设备
# device = torch.device("cpu")
# device = torch.device("cuda")
# device = torch.device("cuda:0")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 准备数据集
train_data = torchvision.datasets.CIFAR10(root="../datasets",train=True,transform=torchvision.transforms.ToTensor(),download=False)
test_data = torchvision.datasets.CIFAR10(root="../datasets",train=False,transform=torchvision.transforms.ToTensor(),download=False)train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用dataloader来加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 创建网络模型
class MyNN(nn.Module):def __init__(self):super(MyNN, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x
mynn = MyNN()
mynn.to(device)# 损失函数
loss_function = nn.CrossEntropyLoss()
loss_function.to(device)
# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(mynn.parameters(), lr=learning_rate)# 设置训练网络的一些参数
total_train_step = 0 # 记录训练次数
total_test_step = 0 # 记录测试次数
epoch = 10 # 训练的轮数# 添加tensorboard
writer = SummaryWriter("../logs_train")start_time = time.time()
for i in range(epoch):print("----------第{}轮训练开始----------".format(i+1))# 训练步骤开始mynn.train()for data in train_dataloader:imgs,targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = mynn(imgs)loss = loss_function(outputs, targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:end_time = time.time()print("所用时间:{}".format(end_time - start_time))print("训练次数:{},loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss",loss.item(),total_train_step)# 测试步骤开始mynn.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs,targets = dataimgs = imgs.to(device)targets = targets.to(device)outputs = mynn(imgs)loss = loss_function(outputs, targets)total_test_loss += lossaccuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集上的loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss",total_test_loss,total_test_step)writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)total_test_step += 1torch.save(mynn,"mynn_{}.pth".format(i))# torch.save(mynn.state_dict(),"mynn_{}.pth".format(i))print("模型已保存")writer.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/45619.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringMVC之@RequestMapping注解

文章目录 前言一、RequestMapping介绍二、详解(末尾附源码,自行测试)1.RequestMapping注解的位置2.RequestMapping注解的value属性3.RequestMapping注解的method属性4.RequestMapping注解的params属性(了解)5.RequestM…

华为ENSP网络设备配置实战6(简单的链路聚合)

题目要求 1、创建聚合组,添加端口成员 2、PC1网段为vlan10,PC2网段为vlan20 3、LSW1为核心网关设备,正确配置PC网关 4、PC1与PC2互通 解题过程 1.1、 按照拓扑图,各个设备起名 sys (进入系统视图) sy…

TCP协议报文结构

TCP是什么 TCP(传输控制协议)是一种面向连接的、可靠的、全双工的传输协议。它使用头部(Header)和数据(Data)来组织数据包,确保数据的可靠传输和按序传递。 TCP协议报文结构 下面详细阐述TCP…

FRP内网穿透,配置本地电脑作为服务器

FRP内网穿透,配置本地电脑作为服务器 下载FRP服务端客户端 参考链接: https://www.it235.com/实用工具/内网穿透/pierce.html https://www.cnblogs.com/007sx/p/17469301.html 由于没有公网ip,所以尝试内网穿透将本地电脑作为服务器&#xff…

【Mariadb高可用MHA】

目录 一、概述 1.概念 2.组成 3.特点 4.工作原理 二、案例介绍 1.192.168.42.3 2.192.168.42.4 3.192.168.42.5 4.192.168.42.6 三、实际构建MHA 1.ssh免密登录 1.1 所有节点配置hosts 1.2 192.168.42.3 1.3 192.168.42.4 1.4 192.168.42.5 1.5 192.168.42.6 …

(二)结构型模式:7、享元模式(Flyweight Pattern)(C++实例)

目录 1、享元模式(Flyweight Pattern)含义 2、享元模式的UML图学习 3、享元模式的应用场景 4、享元模式的优缺点 5、C实现享元模式的简单实例 1、享元模式(Flyweight Pattern)含义 享元模式(Flyweight&#xff09…

OpenCV笔记之solvePnP函数和calibrateCamera函数对比

OpenCV笔记之solvePnP函数和calibrateCamera函数对比 文章目录 OpenCV笔记之solvePnP函数和calibrateCamera函数对比1.cv::solvePnP2.cv::solvePnP函数的用途和工作原理3.cv::solvePnP背后的数学方程式4.cv::SOLVEPNP_ITERATIVE、cv::SOLVEPNP_EPNP、cv::SOLVEPNP_P3P5.一个固定…

为什么需要单元测试?

为什么需要单元测试? 从产品角度而言,常规的功能测试、系统测试都是站在产品局部或全局功能进行测试,能够很好地与用户的需要相结合,但是缺乏了对产品研发细节(特别是代码细节的理解)。 从测试人员角度而言…

Qt应用开发(基础篇)——纯文本编辑窗口 QPlainTextEdit

一、前言 QPlainTextEdit类继承于QAbstractScrollArea,QAbstractScrollArea继承于QFrame,是Qt用来显示和编辑纯文本的窗口。 滚屏区域基类https://blog.csdn.net/u014491932/article/details/132245486?spm1001.2014.3001.5501框架类QFramehttps://blo…

Elasticsearch复合查询之Boosting Query

前言 ES 里面有 5 种复合查询,分别是: Boolean QueryBoosting QueryConstant Score QueryDisjunction Max QueryFunction Score Query Boolean Query在之前已经介绍过了,今天来看一下 Boosting Query 用法,其实也非常简单&…

Chapter 15: Object-Oriented Programming | Python for Everybody 讲义笔记_En

文章目录 Python for Everybody课程简介Object-oriented programmingManaging larger programsGetting startedUsing objectsStarting with programsSubdividing a problemOur first Python objectClasses as typesObject lifecycleMultiple instancesInheritanceSummaryGlossa…

python基础5——正则、数据库操作

文章目录 一、数据库编程1.1 connect()函数1.2 命令参数1.3 常用语句 二、正则表达式2.1 匹配方式2.2 字符匹配2.3 数量匹配2.4 边界匹配2.5 分组匹配2.6 贪婪模式&非贪婪模式2.7 标志位 一、数据库编程 可以使用python脚本对数据库进行操作,比如获取数据库数据…

Docker 搭建 LNMP + Wordpress(详细步骤)

目录 一、项目模拟 1. 项目环境 2. 服务器环境 3.任务需求 二、Linux 系统基础镜像 三、Nginx 1. 建立工作目录 2. 编写 Dockerfile 脚本 3. 准备 nginx.conf 配置文件 4. 生成镜像 5. 创建自定义网络 6. 启动镜像容器 7. 验证 nginx 四、Mysql 1.…

申请部署阿里云SSL免费证书

使用宝塔自动创建的证书有时候会报NET::ERR_CERT_COMMON_NAME_INVALID,并且每次只能三个月,需要点击续期非常麻烦,容易遗忘。 阿里云免费SSL证书 前往阿里云管理控制台【数字证书管理服务】【SSL证书】,每年20个额度,一…

springBoot 配置文件 flyway 插件相关参数说明

在Spring Boot应用中使用Flyway插件进行数据库迁移时,可以在应用的配置文件中配置相关参数。下面是常用的Flyway配置参数及其说明: flyway.enabled: 是否启用Flyway插件,默认为true,表示启用Flyway插件进行数据库迁移。flyway.ur…

基于Pytorch构建DenseNet网络对cifar-10进行分类

DenseNet是指Densely connected convolutional networks(密集卷积网络)。它的优点主要包括有效缓解梯度消失、特征传递更加有效、计算量更小、参数量更小、性能比ResNet更好。它的缺点主要是较大的内存占用。 DenseNet网络与Resnet、GoogleNet类似&#…

QChart:数据可视化(用图像形式显示数据内容)

1、数据可视化的图形有:柱状/线状/条形/面积/饼/点图、仪表盘、走势图,弦图、金字塔、预测曲线图、关系图、数学公式图、行政地图、GIS地图等。 2、在QT Creator的主页面,点击 欢迎》示例》右侧输入框 输入Chart,即可查看到QChar…

go es实例

go es实例 1、下载第三方库 go get github.com/olivere/elastic下载过程中出现如下报错: 解决方案: 2、示例 import package mainimport ("context""encoding/json""fmt""reflect""time""…

LabVIEW模拟化学反应器的工作

LabVIEW模拟化学反应器的工作 近年来,化学反应器在化学和工业过程领域有许多应用。高价值产品是通过混合产品,化学反应,蒸馏和结晶等多种工业过程转换原材料制成的。化学反应器通常用于大型加工行业,例如酿酒厂公司饮料产品的发酵…

提示词4大经典框架;将AI融入动画工作流的案例和实践经验;构建基于LLM的系统和产品的模式;提示工程的艺术 | ShowMeAI日报

👀日报&周刊合集 | 🎡生产力工具与行业应用大全 | 🧡 点赞关注评论拜托啦! 🤖 高效提示词的4大经典框架:ICIO、CRISPE、BROKE、RASCEF ICIO 框架 Intruction (任务) :你希望AI去做的任务&am…