ytorch深度学习完整GPU图像分类代码

1. CPU与GPU不同

1.输入数据
2.网络模型
3.损失函数
.cuda()

  1. 说明:下面代码中GPU版本中取消下划线的即为CPU版本

2.完成的分类代码(GPU)

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter# from model import *
# 准备数据集
from torch import nn
from torch.utils.data import DataLoader# 定义训练的设备
~~device = torch.device("cuda")~~ train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),download=True)# length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
# 如果train_data_size=10, 训练数据集的长度为:10
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用 DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 创建网络模型
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x
tudui = Tudui()
~~tudui = tudui.to(device)~~ # 损失函数
loss_fn = nn.CrossEntropyLoss()
~~loss_fn = loss_fn.to(device)~~ 
# 优化器
# learning_rate = 0.01
# 1e-2=1 x (10)^(-2) = 1 /100 = 0.01
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10# 添加tensorboard
writer = SummaryWriter("../logs_train")for i in range(epoch):print("-------第 {} 轮训练开始-------".format(i+1))# 训练步骤开始tudui.train()for data in train_dataloader:imgs, targets = data~~imgs = imgs.to(device)targets = targets.to(device)~~ outputs = tudui(imgs)loss = loss_fn(outputs, targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1if total_train_step % 100 == 0:print("训练次数:{}, Loss: {}".format(total_train_step, loss.item()))writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试步骤开始tudui.eval()total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, targets = data~~imgs = imgs.to(device)targets = targets.to(device)~~ outputs = tudui(imgs)loss = loss_fn(outputs, targets)total_test_loss = total_test_loss + loss.item()accuracy = (outputs.argmax(1) == targets).sum()total_accuracy = total_accuracy + accuracyprint("整体测试集上的Loss: {}".format(total_test_loss))print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)total_test_step = total_test_step + 1torch.save(tudui, "tudui_{}.pth".format(i))print("模型已保存")writer.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/813462.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot3整合Mybatis plus

Java版本:17 Spring Boot版本:3.1.10 Mybatis plus版本:3.5.5 源码地址:Gitee仓库 01 创建我们的项目工程 首先,我们创建一个maven工程spring-boot3-demo,pom文件配置如下。 这里我们将spring-boot-start…

【python】基于librosa库提取音频特征

一、源码 import librosa audio_path ./audio.mp3 audio, sr librosa.load(audio_path) # 提取音频信号的时域特征 amplitude librosa.amplitude_to_db(librosa.stft(audio), refnp.max) # 提取音频信号的频域特征 mfccs librosa.feature.mfcc(audio, srsr) # 提取音频信号…

视频号小店好做吗?普通人没有货源,也可以做吗?

大家好,我是电商糖果 视频号小店作为2022年才出来的电商黑马项目,吸引了不少正在找创业项目的朋友。 这里面也有很多没有接触过电商,没有货源的普通人。 于是不少朋友就问糖果,如果普通人没有货源可以做吗?小店好做…

JCYZ H3CNE-RS+

JCYZ H3CNE-RS 20240413 20240413 https://www.h3c.com/cn/ 支持–软件下载–其他产品–模拟器官方下载 人才研学中心—技术认证—电子资料 按范围划分:局域网 城域网 广域网 按拓扑结构划分:总线型 环型 星型 树型 全网状 部分网状(优缺点&a…

简单好用的SaaS知识库工具都在这了,看完赶紧收藏!

在信息飞速发展的今天,企业如何有效地管理海量的信息和知识成为了提高工作效率的关键。SaaS知识库工具正成为企业寻求的解决方案,它们不仅能够帮助团队组织文档,而且优化知识分享流程。现在就让我们来看看市场上几款简单又好用的SaaS知识库工…

佛山分公司迎来重要指导蒋书记一行及杭州区域分公司领导共襄盛举

近日,佛山分公司迎来了一场重要的指导活动。蒋书记携夫人,以及助理黄显文和公司工作人员施晓燕等一行领导莅临佛山分公司,为公司的未来发展提供了宝贵的指导意见。同时,江浙福地区的杭州区域分公司负责人白棋元总和朱建江总也亲临…

宝藏免费音乐软件LX music

欢迎来到我的博客,代码的世界里,每一行都是一个故事 宝藏免费音乐软件LX music 前言LX Music的特色功能:音乐播放的新境界安装与配置:在不同平台上使用LX Music下载页面 主题定制 本文将深入研究LX Music,一款备受欢迎…

socat神器解密:网络数据传输的利器

欢迎来到我的博客,代码的世界里,每一行都是一个故事 socat神器解密:网络数据传输的利器 前言socat简介基本用法常见功能常见功能:1. 端口转发和数据重定向:2. 加密和解密数据流: 高级功能1. 代理服务器和隧…

力扣 | 160. 相交链表

import ListNodeInfo.ListNode;import java.util.HashSet; import java.util.Set;public class Problem_160_IntersectionOfTwoLinkedList {//双指针方法 public ListNode getIntersectionListNode(ListNode headA, ListNode headB){if(headA null || headB null) return nul…

MemberPress配置和使用会员登录页面

目录 隐藏 创建会员登录页面 编辑登录页面 设计您的登录页面 链接到您的登录页面 创建会员登录页面 要创建MemberPress会员登录页面,您需要做的就是导航到 MemberPress > 设置 > 页面选项卡,然后在页面顶部附近的“MemberPress 登录页面”…

【VUE】使用Vue和CSS动画创建滚动列表

使用Vue和CSS动画创建滚动列表 在这篇文章中,我们将探讨如何使用Vue.js和CSS动画创建一个动态且视觉上吸引人的滚动列表。这个列表将自动滚动显示项目,类似于轮播图的方式,非常适合用于仪表盘、排行榜或任何需要在有限空间内展示项目列表的应…

【Python使用】python高级进阶知识md总结第8篇:TCP 网络应用程序开发流程,1. TCP 网络应用程序开发流程的介绍【附代码文档】

python高级进阶全知识知识笔记总结完整教程(附代码资料)主要内容讲述:操作系统,虚拟机软件。ls命令选项,mkdir和rm命令选项。压缩和解压缩命令,文件权限命令。编辑器 vim,软件安装。获取进程编号…

docker安装es和kibana

1.创建网络 docker network create es-net 2.下载镜像 docker pull elasticsearch:7.12.1 docker pull kibana:7.12.1 docker pull mobz/elasticsearch-head:5 3.运行容器 docker run -d \ --restartalways --name es7 \ -e "ES_JAVA_OPTS-Xms512m -Xmx512m" …

B站大数据平台元数据业务分享

背景介绍 元数据是数据平台的衍生数据,比如调度任务信息,离线hive表,实时topic,字段信息,存储信息,质量信息,热度信息等。在数据平台建设初期,这类数据主要散落于各种平台子系统的数…

Ubuntu22.04安装Opencv + opencv_contrib(v4.9.0)

需下载两个文件: opencv-4.9.0.tar.gzopencv_contrib-4.9.0.tar.gz 将上述文件上传到如下目录 rootf5b3d2a6bf04:/opencv# pwd /opencv rootf5b3d2a6bf04:/opencv# ll total 149036 drwxrwxr-x 2 1000 1000 4096 Apr 8 10:07 ./ drwxr-xr-x 1 root root …

【智能算法应用】哈里斯鹰算法(HHO)在WSN覆盖中的应用

目录 1.算法原理2.数学模型3.结果展示4.参考文献 1.算法原理 【智能算法】哈里斯鹰算法(HHO)原理及实现 【智能算法应用】猎人猎物优化算法(HPO)在WSN覆盖中的应用 2.数学模型 3.结果展示 HPO设置区域边长为20,节点数为35&…

redis的三种工作模式

Redis 简介 Redis(Remote Dictionary Server)是一个开源的高性能键值对(key-value)存储系统。它支持多种数据结构,如字符串、列表、集合、有序集合和哈希,并且提供了丰富的功能,包括数据持久化…

C语言 08 类型转换

一种类型的数据转换为另一种类型的数据&#xff0c;这种操作称为类型转换。 类型转换分为自动类型转换和强制类型转换。 自动类型转换 比如现在希望将一个 short 类型的数据转换为 int 类型的数据&#xff1a; #include <stdio.h>int main(){short s 10;// 直接将s的…

2024洗地机名牌排行榜:细数最值得买的4大热门款

传统的清洁地面方式往往费时费力&#xff0c;容易导致腰酸背痛等不适&#xff0c;给人们带来一系列家务问题。然而&#xff0c;随着洗地机、扫地机器人、吸尘器等电动清洁工具的出现&#xff0c;清洁变得更加轻松便捷&#xff0c;受到了广大用户的欢迎。身为一名有着多年家居经…

Javaweb过滤器(Filter)

一、概念 Filter表示过滤器&#xff0c;是Javaweb三大组件&#xff08;Servlet、Filter、Listener&#xff09;之一。 过滤器可以把对资源的请求拦截下来&#xff0c;从而实现一些特殊的功能。 过滤器一般完成一些通用的操作&#xff0c;比如权限控制、统一编码处理、敏感字符处…