pytorch实现自己的深度神经网络(公共数据集)

一、训练文件——train.py

  注意:在运行此代码之前,需要配置好pytorch-GPU版本的环境,具体再次不谈。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms# 检查GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)# 数据预处理的转换
transform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载CIFAR-10训练数据集
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8,shuffle=True, num_workers=0)# 定义神经网络模型
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(128 * 32 * 32, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = self.pool(torch.relu(self.conv3(x)))x = x.view(-1, 128 * 32 * 32)x = torch.relu(self.fc1(x))x = self.fc2(x)return x# 实例化模型,并将其移动到可用设备上
model = CNN().to(device)# 定义损失函数
criterion = nn.CrossEntropyLoss()# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)if __name__ == '__main__':# 训练神经网络for epoch in range(5):running_loss = 0.0for i, data in enumerate(train_loader, 0):inputs, labels = data[0].to(device), data[1].to(device)# 梯度清零optimizer.zero_grad()# 正向传播outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播 + 优化loss.backward()optimizer.step()# 打印统计信息running_loss += loss.item()if i % 200 == 199:print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 200))running_loss = 0.0print('Finished Training')# 保存模型至文件torch.save(model.state_dict(), 'cifar10_cnn_model.pth')

二、测试文件——val.py

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import cv2# 检查GPU是否可用
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)# 数据预处理的转换
transform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载CIFAR-10测试数据集
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)# 创建测试数据加载器
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=8,shuffle=False, num_workers=0)# 加载模型并将其移动到可用设备上
class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(3, 32, 3, padding=1)self.conv2 = nn.Conv2d(32, 64, 3, padding=1)self.conv3 = nn.Conv2d(64, 128, 3, padding=1)self.pool = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(128 * 32 * 32, 512)self.fc2 = nn.Linear(512, 10)def forward(self, x):x = self.pool(torch.relu(self.conv1(x)))x = self.pool(torch.relu(self.conv2(x)))x = self.pool(torch.relu(self.conv3(x)))x = x.view(-1, 128 * 32 * 32)x = torch.relu(self.fc1(x))x = self.fc2(x)return x
# 显示函数
def imshow(img):img = img / 2 + 0.5npimg = img.numpy()# 坐标转换plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()model = CNN().to(device)
model.load_state_dict(torch.load('cifar10_cnn_model.pth'))
model.eval()if __name__ == '__main__':# 在测试集上测试模型correct = 0total = 0with torch.no_grad():for data in test_loader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)# 预测值的最大值以及最大值的类别索引_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy on the test images: %d %%' % (100 * correct / total))# 显示测试集中的一些图片及其预测结果# 生成一个迭代器,从数据加载器中取出数据dataiter = iter(test_loader)# 从迭代器中获取下一个批次的数据images, labels = dataiter.next()# 将获取到的批次数据移动到device上,在这里也就是GPU上images, labels = images.to(device), labels.to(device)dip_flag = Falseif dip_flag == True:# -------------------------------------------# 可以选择 使用opencv显示# -------------------------------------------np_images = images.cpu().numpy()# 循环遍历并显示所有测试集图片for i in range(len(np_images)):# 从归一化中还原图像数据np_image = np.transpose(np_images[i], (1, 2, 0))   # 从CHW转换为HWCnp_image = np_image * 0.5 + 0.5# 将图像数据从float类型转换为unit8类型np_image = (np_image * 255).astype(np.uint8)# 使用opencv显示图像cv2.imshow("Image {}".format(i+1), np_image)cv2.waitKey(0)# 等待用户按下任意键继续显示下一张图像cv2.destroyAllWindows()imshow(torchvision.utils.make_grid(images.cpu()))print('GroundTruth: ', ' '.join('%5s' % test_dataset.classes[labels[j]] for j in range(8)))outputs = model(images)_, predicted = torch.max(outputs, 1)print('Predicted: ', ' '.join('%5s' % test_dataset.classes[predicted[j]]for j in range(8)))


直接运行即可,亲测可以运行

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/822491.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Midjourney 实现角色一致性的新方法

AI 绘画的奇妙之处,实乃令人叹为观止!就像大千世界中,寻不见两片完全相同的树叶一般,AI 绘画亦复如是。同一提示之词,竟能催生出千变万化的图像,使得AI所绘之作,宛如自然之物般独特,…

项目7-音乐播放器2(上传音乐+查询音乐+拦截器)

0.加入拦截器 之后就不用对用户是否登录进行判断了 0.1 定义拦截器 0.2 注册拦截器 生效 1.上传音乐的接口设计 请求: { post, /music/upload {singer,MultipartFile file}, } 响应: { "status": 0, "message&…

Linux 在后台执行 shell 指令的方法

Shell 指令后台运行是指在Linux或Unix操作系统中执行一个Shell命令或脚本时,使其在后台模式下运行,即在不占用当前终端会话的交互性、不影响用户在该终端进行其他操作的情况下持续执行。这种执行方式允许用户在提交命令后立即返回到命令提示符&#xff0…

PostgreSQL 窗口函数汇总

文章目录 前言一、什么是窗口函数?二、常用的4类窗口函数三、PARTITION BY 子句四、窗口函数示例1. 聚合计算1.1 sum() 函数1.2 count() 函数1.3 avg() 函数2. 分组排序2.1 row_number() 函数2.2 rank() 函数2.3 dense_rank() 函数3. 分组查询

ARM LPD-500 和PCK-600介绍

LPD-500是ARM公司提供的一个低功耗分发器(Low Power Distributor),它是一个独立的可配置组件,用于分发Q-Channel接口到多个设备。LPD-500支持多个设备之间的电源管理,允许它们在不同的低功耗状态之间转换,从…

力扣练习题(2024/4/16)

1买卖股票的最佳时机 给定一个数组 prices ,它的第 i 个元素 prices[i] 表示一支给定股票第 i 天的价格。 你只能选择 某一天 买入这只股票,并选择在 未来的某一个不同的日子 卖出该股票。设计一个算法来计算你所能获取的最大利润。 返回你可以从这笔…

单链表经典算法题分析

目录 一、链表的中间节点 1.1 题目 1.2 题解 1.3 收获 二、移除链表元素 2.1 题目 2.2 题解 2.3 收获 2.4递归详解 三、反转链表 3.1 题目 3.2 题解 3.3 解释 四、合并两个有序列表 4.1 题目 4.2 题解 4.3 递归详解 声明:本文所有题目均摘自leetco…

《手把手教你》系列基础篇(九十二)-java+ selenium自动化测试-框架设计基础-POM设计模式简介(详解教程)

1.简介 页面对象模型(Page Object Model)在Selenium Webdriver自动化测试中使用非常流行和受欢迎,作为自动化测试工程师应该至少听说过POM这个概念。本篇介绍POM的简介,接下来宏哥一步一步告诉你如何在你JavaSelenium3自动化测试…

13个Java基础面试题

Hi,大家好,我是王二蛋。 金三银四求职季,特地为大家整理出13个 Java 基础面试题,希望能为正在准备或即将参与面试的小伙伴们提供些许帮助。 后续还会整理关于线程、IO、JUC等Java相关面试题,敬请各位持续关注。 这1…

【ROS2】搭建ROS2-Humble + Vscode开发流程

【ROS2】搭建ROS2-Humble Vscode开发流程 文章目录 【ROS2】搭建ROS2-Humble Vscode开发流程1.基本环境配置2.搭建Vscode开发环境 1.基本环境配置 基本的环境配置包括以下步骤: 安装ROS2-Humble,可以参考这里安装一些基本的工具,可以参考…

nuxt3项目使用swiper11插件实现点击‘’返回顶部按钮‘’返回到第一屏

该案例主要实现点击返回顶部按钮返回至swiper第一个slide。 版本: "nuxt": "^3.10.3", "pinia": "^2.1.7", "swiper": "^11.0.7", 官方说明 swiper.slideTo(index, speed, runCallbacks) Run transit…

浅析MySQL 8忘记密码处理方式

对MySQL有研究的读者,可能会发现MySQL更新很快,在安装方式上,MySQL提供了两种经典安装方式:解压式和一键式,虽然是两种安装方式,但我更提倡选择解压式安装,不仅快,还干净。在操作系统…

【数据结构1-基本概念和术语】

这里写自定义目录标题 0.数据,数据元素,数据项,数据对项,数据结构,逻辑结构,存储结构1.结构1.1逻辑结构1.2存储结构1.2.1 顺序结构1.2.2链式结构 1.3数据结构1.3.1基本数据类型1.3.2抽象数据类型1.3.2.1一个…

Java SpringBoot基于微信小程序的高速公路服务区充电桩在线预定系统,附源码

博主介绍:✌IT徐师兄、7年大厂程序员经历。全网粉丝15W、csdn博客专家、掘金/华为云//InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇&#x1f3…

05节-51单片机-模块化编程

1.两种编程方式的对比 传统方式编程: 所有的函数均放在main.c里,若使用的模块比较多,则一个文件内会有很多的代码,不利于代码的组织和管理,而且很影响编程者的思路 模块化编程: 把各个模块的代码放在不同的…

STM32外设配置以及一些小bug总结

USART RX的DMA配置 这里以UART串口1为例,首先点ADD添加RX和TX配置DMA,然后模式一般会选择是normal,这个模式是当DMA的计数器减到0的时候就不做任何动作了,还有一种循环模式,是计数器减到0之后,计数器自动重…

Echats 引入地图(二) 之中国地图省份高亮

效果图: 代码: series: [{type: map,map: china,zoom: 1.2, // 地图放大aspectScale: 0.8, //地图宽高比例roam: true, //地图缩放、平移// 滚轮缩放的极限控制scaleLimit: {min: 0.5, //缩放最小大小max: 6, //缩放最大大小},itemStyle…

使用Android studio,安卓手机编译安装yolov8部署ncnn,频繁出现编译错误

从编译开始就开始出现错误,解决步骤: 1.降低graddle版本,7.2-bin --->>> 降低为 6.1.1-all #distributionUrlhttps\://services.gradle.org/distributions/gradle-7.2-bin.zip distributionUrlhttps\://services.gradle.org/di…

springboot配置文件详解

springboot配置文件详解 ​ 在之前的项目开发中,我们可以使用xml,properties进行相关的配置,这种配置方式比较简单,但是在应对复杂的商业需求下,多环境和编程化的配置无法得到满足,因此springboot为我们提供了YAML的配…

5.HC-05蓝牙模块

配置蓝牙模块 注意需要将蓝牙模块接5v,实测接3.3v好像不太好使的样子 首先需要把蓝牙模块通过TTL串口模块接到我们的电脑,然后打开我们的串口助手 注意,我们现在是配置蓝牙模块,所以需要进入AT模式,需要按着蓝牙模块上的黑色小按钮再上电,这时候模块上的LED灯以一秒慢闪一次…