Pytorch学习笔记——在GPU上进行训练

文章目录

      • 1. 环境准备
      • 2. 导入必要库
      • 3. 加载数据集
      • 4. 定义简单的神经网络模型
      • 5. 检查和设置GPU设备
      • 6. 定义损失函数和优化器
      • 7. 训练模型
      • 8. 全部代码展示及运行结果

1. 环境准备

首先,确保PyTorch已经安装,且CUDA(NVIDIA的并行计算平台和编程模型)支持已经正确配置。可以通过以下代码检查CUDA是否可用:

print(torch.cuda.is_available())  # 如果返回True,则CUDA可用

配置PyTorch环境和CUDA支持,可以参考我写的这篇博客
Pytorch学习笔记——环境配置安装

2. 导入必要库

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Linear, Flatten, Sequential
from torch.utils.data import DataLoader
  • torch 是PyTorch的核心库。
  • torchvision 提供了用于计算机视觉任务的工具,包括数据集和变换。
  • nn 包含了构建神经网络所需的各种模块。
  • DataLoader 用于加载数据集并进行批处理。

3. 加载数据集

使用CIFAR-10数据集进行训练,它是一个常用的小型图像数据集。加载数据集并创建数据加载器。

# 加载数据集
dataset = torchvision.datasets.CIFAR10(root="data1", train=False, transform=torchvision.transforms.ToTensor(), download=True)
# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=64)
  • torchvision.datasets.CIFAR10:下载并加载CIFAR-10数据集。
  • DataLoader:将数据集划分为小批次,并进行数据加载。

4. 定义简单的神经网络模型

定义一个简单的卷积神经网络(CNN)模型:

self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),  # 第一次卷积MaxPool2d(2),  # 第一次最大池化Conv2d(32, 32, 5, padding=2),  # 第二次卷积MaxPool2d(2),  # 第二次最大池化Conv2d(32, 64, 5, padding=2),  # 第三次卷积MaxPool2d(2),  # 第三次最大池化Flatten(),    # 展平层Linear(1024, 64),  # 第一个全连接层Linear(64, 10),  # 第二个全连接层)
def forward(self, x):x = self.model1(x)return x
  • Conv2d:二维卷积层,用于提取图像特征。
  • MaxPool2d:最大池化层,用于下采样。
  • Flatten:将多维输入展平为一维,用于全连接层的输入。
  • Linear:全连接层,用于分类任务。

5. 检查和设置GPU设备

需要检查是否有可用的GPU,并将模型和数据移动到GPU上。

# 检查是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 将模型和数据转移到GPU上
mynn = NN().to(device)
print(mynn)
  • torch.device("cuda"):如果CUDA可用,则使用GPU;否则使用CPU。
  • to(device):将模型转移到指定的设备(CPU或GPU)。

6. 定义损失函数和优化器

定义损失函数和优化器来训练模型:

# 定义损失函数
loss = nn.CrossEntropyLoss().to(device)# 定义优化器
optim = torch.optim.SGD(mynn.parameters(), lr=0.01)
  • nn.CrossEntropyLoss:适用于分类问题的损失函数。
  • torch.optim.SGD:随机梯度下降优化器。

7. 训练模型

通过多个epoch对模型进行训练。在每个epoch中,进行前向传播、计算损失、反向传播和参数更新:

# 多轮学习 0 - 20 20轮
for epoch in range(20):running_loss = 0.0for data in dataloader:# 确保数据也转移到GPU上imgs, targets = data[0].to(device), data[1].to(device)optim.zero_grad()  # 清零梯度缓存outputs = mynn(imgs)  # 前向传播loss_value = loss(outputs, targets)  # 计算损失loss_value.backward()  # 反向传播,计算梯度optim.step()  # 根据梯度更新权重running_loss += loss_value.item()  # 累加损失值print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}")print("------------------------------")
  • optim.zero_grad():清零之前计算的梯度。
  • outputs = mynn(imgs):进行前向传播。
  • loss_value.backward():进行反向传播,计算梯度。
  • optim.step():根据计算的梯度更新权重。
  • running_loss:累加损失值以计算平均损失。

8. 全部代码展示及运行结果

# -*- coding: utf-8 -*-
# @Author: kk
import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Linear, Flatten, Sequential
from torch.utils.data import DataLoader# 加载数据集
dataset = torchvision.datasets.CIFAR10(root="data1", train=False, transform=torchvision.transforms.ToTensor(), download=True)
# loader加载
dataloader = DataLoader(dataset, batch_size=64)# 网络
class NN(nn.Module):def __init__(self):super(NN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2),  # 第一次卷积MaxPool2d(2),  # 第一次最大池化Conv2d(32, 32, 5, padding=2),  # 第二次卷积MaxPool2d(2),  # 第二次最大池化Conv2d(32, 64, 5, padding=2),  # 第三次卷积MaxPool2d(2),  # 第三次最大池化Flatten(),    # 展平层Linear(1024, 64),  # 第一个全连接层Linear(64, 10),  # 第二个全连接层)def forward(self, x):x = self.model1(x)return x# 检查是否有可用的GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 将模型和数据转移到GPU上
mynn = NN().to(device)
print(mynn)
loss = nn.CrossEntropyLoss().to(device)# 优化器
optim = torch.optim.SGD(mynn.parameters(), lr=0.01)# 多轮学习  0 - 20  20轮
for epoch in range(20):running_loss = 0.0for data in dataloader:# 确保数据也转移到GPU上imgs, targets = data[0].to(device), data[1].to(device)optim.zero_grad()  # 清零梯度缓存outputs = mynn(imgs)  # 前向传播loss_value = loss(outputs, targets)  # 计算损失loss_value.backward()  # 反向传播,计算梯度optim.step()  # 根据梯度更新权重running_loss += loss_value.item()  # 累加损失值print(f"Epoch {epoch + 1}, Loss: {running_loss / len(dataloader)}")print("------------------------------")

运行结果如下,发现在每一轮过后,Loss在逐渐减小:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/48412.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

org.springframework.context.annotation.ImportSelector的作用是什么?

org.springframework.context.annotation.ImportSelector 是 Spring 框架中一个非常有用的接口,它允许你根据条件动态地向 Spring 应用上下文中导入配置类。这对于开发模块化、可扩展且可配置的 Spring 应用来说非常有用。 如何使用 ImportSelector 创建 ImportSel…

Leetcode热题100 Day2

六、三数之和 这一题最关键的想法是把第二层嵌套和第三层嵌套合并为同一层嵌套,合并后即可使用两指针法。但是即使这样我在写的时候还是花了很多时间,一个是边界条件的处理(尤其是连续有相同值的处理)以及我发现了leetcode的编译…

用PyTorch从零开始编写DeepSeek-V2

DeepSeek-V2是一个强大的开源混合专家(MoE)语言模型,通过创新的Transformer架构实现了经济高效的训练和推理。该模型总共拥有2360亿参数,其中每个令牌激活21亿参数,支持最大128K令牌的上下文长度。 在开源模型中&…

java-双亲委派机制

Java虚拟机(JVM)中的类加载器(Class Loader)负责将类(.class文件)加载到JVM中,以便Java程序能够使用这些类。在JVM中,类加载器被组织成一种层次结构关系,这种层次结构关系…

vue3前端开发-小兔鲜项目-一些额外提醒的内容

vue3前端开发-小兔鲜项目-一些额外提醒的内容!今天这一篇文章,是提醒大家,如果你正在学习小兔鲜这个前端项目,有些地方需要提醒大家,额外注意的地方。 第一个:就是大家在进入二级页面后,有一个分…

深度学习-7-使用DCGAN生成动漫头像(实战)

参考什么是GAN生成对抗网络,使用DCGAN生成动漫头像 1 什么是生成对抗网络 生成对抗网络,英文是Generative Adversarial Networks,简称GAN。 GAN是一种无监督的深度学习模型,于2014年首次被提出。该算法通过竞争学习的方式生成新的、且与原始数据集相似的数据。 这些生成…

公司培训总结:拒绝倦怠,探索工作中的自驱力

拒绝倦怠,探索工作中的自驱力 在快节奏、高压力的现代职场环境中,感到迷茫和缺乏动力是许多人的共同体验。工作中的倦怠感可能会导致生产力下降、职业发展受阻,甚至影响到个人的心理健康。然而,挖掘并激发我们的内在驱动力&#…

昇思25天学习打卡营第19天|生成式-DCGAN生成漫画头像

打卡 目录 打卡 GAN基础原理 DCGAN原理 案例说明 数据集操作 数据准备 数据处理和增强 部分训练数据的展示 构造网络 生成器 生成器代码 ​编辑 判别器 判别器代码 模型训练 训练代码 结果展示(3 epoch) 模型推理 GAN基础原理 原理介…

C#实战 | 天行健、上下而求索

本文介绍C#开发入门案例。 01、项目一:创建控制台应用“天行健,君子以自强不息” 项目说明: 奋斗是中华民族的底色,见山开山,遇水架桥,正是因为自强不息的奋斗,才有了辉煌灿烂的中华民族。今…

xmind--如何快速将Excel表中多列数据,复制到XMind分成多级主题

每次要将表格中的数据分成多级时,只能复制粘贴吗 快来试试这个简易的方法吧 这个是原始的表格,分成了4级 步骤: 1、我们可以先按照这个层级设置下空列(后买你会用到这个空列) 二级不用加、三级前面加一列、四级前面加…

#和private有什么区别?

先上代码: class Person {#salary: numberconstructor(salary: number, private name: string) {this.#salary salaryconsole.log(this.#salary) // 可以访问私有属性console.log(this.name) // 可以访问公共属性} }const person new Person(1000, 张三); // co…

MAT使用

概念 Shallow heap & Retained Heap Shallow Heap就是对象本身占用内存的大小。 Retained Heap就是当前对象被GC后,从Heap上总共能释放掉的内存(表示如果一个对象被释放掉,那会因为该对象的释放而减少引用进而被释放的所有的对象(包括…

leetcode位运算(1684. 统计一致字符串的数目)

前言 经过前期的基础训练以及部分实战练习,粗略掌握了各种题型的解题思路。后续开始专项练习。 描述 给你一个由不同字符组成的字符串 allowed 和一个字符串数组 words 。如果一个字符串的每一个字符都在 allowed 中,就称这个字符串是 一致字符串 。 请…

Python-for-Android:把你的Python应用打包为APK文件

Python-for-Android(简称p4a)是一个开发工具,它可以将Python应用打包成可以在Android设备上运行的二进制文件。它是基于开源框架Kivy开发的,旨在为开发者提供在移动设备上轻松运行Python应用的解决方案。 什么是Python-for-Androi…

MyBatis Plus 实现中文排序的两种有效策略

前言 在MyBatis Plus项目开发中,针对中文数据的排序需求是一个常见的挑战,尤其是在需要按照拼音或特定语言逻辑排序时。本文整合了两种有效的方法,旨在帮助开发者克服MyBatis Plus在处理中文排序时遇到的障碍,确保数据能够按照预…

【React】JSX 实现列表渲染

文章目录 一、基础语法1. 使用 map() 方法2. key 属性的使用 二、常见错误和注意事项1. 忘记使用 key 属性2. key 属性的选择 三、列表渲染的高级用法1. 渲染嵌套列表2. 条件渲染列表项3. 动态生成组件 四、最佳实践 在 React 开发中,列表渲染是一个非常常见的需求。…

【多模态】CLIP-KD: An Empirical Study of CLIP Model Distillation

论文:CLIP-KD: An Empirical Study of CLIP Model Distillation 链接:https://arxiv.org/pdf/2307.12732 CVPR 2024 Introduction Motivation:使用大的Teacher CLIP模型有监督蒸馏小CLIP模型,出发点基于在资源受限的应用中&…

【WPF开发】控件介绍-button(按钮)

基本介绍 按钮(button)控件的作用 按钮控件(Button)是用户界面(UI)设计中最基本的元素之一,其主要作用包括: 触发操作:用户通过点击按钮来执行一个命令或触发一个事件&…

【网络】tcp_socket

tcp_socket 一、tcp_server与udp_server一样的部分二、listen接口(监听)三、accept接收套接字1、为什么还要多一个套接字(明明已经有了个socket套接字文件了,为什么要多一个accept套接字文件?)2、底层拿到新…

从R-CNN到Faster-R-CNN的简单介绍

1、R-CNN RCNN算法4个步骤 1、一张图像生成1K~2K个候选区域(使用Selective Search方法) 2、对每个候选区域,使用深度网络提取特征 3、特征送入每一类的SVM分类器,判别是否属于该类 4、使用回归器精细修正候选框位置 R-CNN 缺陷 : 1.训练…