【Pytorch神经网络实战案例】02 CIFAR-10数据集:Pytorch使用GPU训练CNN模版-方法②

import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriterfrom torch.utils.data import DataLoader# 取消全局证书验证(当项目对安全性问题不太重视时,推荐使用,可以全局取消证书的验证,简易方便)
import ssl
ssl._create_default_https_context = ssl._create_unverified_context# 定义训练设备===》使用CPU进行训练方法②
# device=torch.device("cpu")
# 定义训练设备===》使用GPU进行训练方法②
# device=torch.device("cuda")
# 定义训练设备===》使用第X张GPU进行训练方法②
# device=torch.device("cuda:0")
# 定义训练设备===》根据机器情况选择能否用GPU进行训练方法②
device=torch.device("cuda" if torch.cuda.is_available() else "cup")print("当前程序正在{}上运行".format(device))# 准备数据集
train_data=torchvision.datasets.CIFAR10("datas-train",train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data=torchvision.datasets.CIFAR10("datas-test",train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 获得数据集的长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print("训练--数据集的长度为:{}".format(train_data_size))
print("测试--数据集的长度为:{}".format(test_data_size))# 利用DataLoader加载数据集
# Batch Size定义:一次训练所选取的样本数。
# Batch Size的大小影响模型的优化程度和速度。同时其直接影响到GPU内存的使用情况,假如你GPU内存不大,该数值最好设置小一点。
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)# 创建网络模型
# 搭建神经网络
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.model = nn.Sequential(# Conv2d中##in_channels:输入的通道数目 【必选】##out_channels: 输出的通道数目 【必选】##kernel_size:卷积核的大小,类型为int 或者元组,当卷积是方形的时候,只需要一个整数边长即可,卷积不是方形,要输入一个元组表示 高和宽。【必选】##stride: 卷积每次滑动的步长为多少,默认是 1 【可选】##padding(手动计算):设置在所有边界增加值为0的边距的大小(也就是在feature map 外围增加几圈 0 ),##                 例如当 padding =1 的时候,如果原来大小为 3 × 3 ,那么之后的大小为 5 × 5 。即在外围加了一圈 0 。【可选】##dilation:控制卷积核之间的间距【可选】nn.Conv2d(3, 32, 5, 1, 2),# MaxPool2d中:# #kernel_size(int or tuple) - max pooling的窗口大小,# # stride(int or tuple, optional) - max pooling的窗口移动的步长。默认值是kernel_size# # padding(int or tuple, optional) - 输入的每一条边补充0的层数# # dilation(int or tuple, optional) – 一个控制窗口中元素步幅的参数# # return_indices - 如果等于True,会返回输出最大值的序号,对于上采样操作会有帮助# # ceil_mode - 如果等于True,计算输出信号大小的时候,会使用向上取整,代替默认的向下取整的操作nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),# nn.Linear()是用于设置网络中的全连接层的,在二维图像处理的任务中,全连接层的输入与输出一般都设置为二维张量,形状通常为[batch_size, size]#              相当于一个输入为[batch_size, in_features]的张量变换成了[batch_size, out_features]的输出张量。nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xtudui=Tudui()# 定义训练设备===》使用GPU进行训练方法②
tudui=tudui.to(device)# 损失函数
# 使用交叉熵==>分类
loss_fn=nn.CrossEntropyLoss()# 定义训练设备===》使用GPU进行训练方法②
loss_fn=loss_fn.to(device)# 优化器
learning_rate=0.01 #学习速率
optimizer=torch.optim.SGD(tudui.parameters(),lr=learning_rate)#设置训练网络的参数
#记录训练的次数
total_train_step=0
# 记录测试的次数
test_train_step=0
# 训练的轮次
epoch=10
# 添加tensorboard
writer=SummaryWriter("firstjuan")for i in range(epoch): #0-9print("-----------第{}轮训练开始-----------".format(i+1))tudui.train()# 训练步骤开始for data in train_dataloader:imgs,targets=data# 定义训练设备===》使用GPU进行训练方法②imgs = imgs.to(device)targets = targets.to(device)outputs=tudui(imgs)#将计算所得的output的数值与真实数值进行对比,即求差loss=loss_fn(outputs,torch.squeeze(targets).long())#优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()# 记录训练train的次数+1total_train_step=total_train_step+1if total_train_step %100==0 : #减少输出 方便查看测试结果print("训练次数:{},损失值loss:{}".format(total_train_step,loss))writer.add_scalar("train_loss",loss.item(),total_train_step)# 测试步骤开始tudui.eval()total_test_loss=0# 正确率total_accuracy=0with torch.no_grad(): #保证网络模型的梯度保持没有,仅需要测试,不需要对梯度进行优化与调整for data in test_dataloader:imgs,targets=data# 定义训练设备===》使用GPU进行训练方法②imgs = imgs.to(device)targets = targets.to(device)outputs=tudui(imgs)loss=loss_fn(outputs,targets)total_test_loss=total_test_loss+loss.item()# 1为横向 0为竖 计算正确率accuracy=(outputs.argmax(1)==targets).sum()total_accuracy=total_accuracy+accuracyprint("整体测试集上的Loss:{}".format(total_test_loss))print("整体测试集上的正确率accuracy:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_accuracy", total_test_loss, test_train_step)writer.add_scalar("test_loss",total_accuracy/test_data_size,test_train_step)# 记录测试test的次数+1test_train_step=test_train_step+1# 保存模型# torch.save(tudui.state_dict(),"tudui_{}".format(i))torch.save(tudui,"tudui_{}.pth".format(i))print("模型已经保存")
writer.close()

import torch
from torch import nn# 搭建神经网络
class Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()self.model = nn.Sequential(# Conv2d中##in_channels:输入的通道数目 【必选】##out_channels: 输出的通道数目 【必选】##kernel_size:卷积核的大小,类型为int 或者元组,当卷积是方形的时候,只需要一个整数边长即可,卷积不是方形,要输入一个元组表示 高和宽。【必选】##stride: 卷积每次滑动的步长为多少,默认是 1 【可选】##padding(手动计算):设置在所有边界增加值为0的边距的大小(也就是在feature map 外围增加几圈 0 ),##                 例如当 padding =1 的时候,如果原来大小为 3 × 3 ,那么之后的大小为 5 × 5 。即在外围加了一圈 0 。【可选】##dilation:控制卷积核之间的间距【可选】nn.Conv2d(3, 32, 5, 1, 2),# MaxPool2d中:# #kernel_size(int or tuple) - max pooling的窗口大小,# # stride(int or tuple, optional) - max pooling的窗口移动的步长。默认值是kernel_size# # padding(int or tuple, optional) - 输入的每一条边补充0的层数# # dilation(int or tuple, optional) – 一个控制窗口中元素步幅的参数# # return_indices - 如果等于True,会返回输出最大值的序号,对于上采样操作会有帮助# # ceil_mode - 如果等于True,计算输出信号大小的时候,会使用向上取整,代替默认的向下取整的操作nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),# nn.Linear()是用于设置网络中的全连接层的,在二维图像处理的任务中,全连接层的输入与输出一般都设置为二维张量,形状通常为[batch_size, size]#              相当于一个输入为[batch_size, in_features]的张量变换成了[batch_size, out_features]的输出张量。nn.Linear(64*4*4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return xif __name__ == '__main__':tudui = Tudui()input = torch.ones((64, 3, 32, 32))output = tudui(input)print(output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/469862.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

三、Unity中的鼠标、键盘的获取

在Unity中,我们经常会处理点击鼠标的事件检测和键盘的事件检测。所以,我觉的应该将这个小知识点进行一个整理。 1.按下键盘的事件检测: 1.GetKey: 当通过名称指定的按键被用户按住时返回true ------ 持续按下,会一直…

三星的S3C2440A 存储器控制器

对存储器的BANK那个东西我一直是一知半解,感觉很模糊、关于ARM的体系结构可能学得不够深入 三星S3C2440A的存储器控制器 s3c2440A 的存储器控制器提供访问外部存储器所需的存储器控制信号。 s3c2440A 的存储器控制器有以下特性: - 大小端(通过软件选择) - 地址空间:每个ba…

【Pytorch神经网络实战案例】01 CIFAR-10数据集:Pytorch使用GPU训练CNN模版-方法①

import torch import torchvision from torch import nn from torch.utils.tensorboard import SummaryWriterfrom torch.utils.data import DataLoader# 取消全局证书验证(当项目对安全性问题不太重视时,推荐使用,可以全局取消证书的验证&am…

html5自动调整布局,html5移动端自适应布局的实现

场景:为适应各种大小的屏幕自适应布局我知道的两种方式1.使用媒体查询,下面制定了几种适应方式,例如第一个表示屏幕宽度在320px-360px之间的,html字体大小适配为13.65pxmedia only screen and (max-width: 360px) and (min-width:…

C/C++ 中判断某一文件或目录是否存在

方法一&#xff1a;C中比较简单的一种办法&#xff08;使用文件流打开文件&#xff09; 1 #include <iostream>2 #include <fstream>3 4 using namespace std;5 6 #define FILENAME "*.dat" // 指定文件名7 8 int main( void )9 { 10 fstream _fi…

建立交叉编译环境

做什么之前没有编译器是不行的、 1、打开终端运行arm-linux-gcc -v 如果提示这个命令不存在,好吧、照着下面的步骤安装一个吧 在Linux 平台下,要为开发板编译内核,图形界面Qtopia ,bootloader,还有其他一 些应用程序,均需要交叉编译工具链。 之前的系统,要使用不同的编…

logistic模型原理与推导过程分析(1)

从线性分类器谈起 给定一些数据集合&#xff0c;他们分别属于两个不同的类别。例如对于广告数据来说&#xff0c;是典型的二分类问题&#xff0c;一般将被点击的数据称为正样本&#xff0c;没被点击的数据称为负样本。现在我们要找到一个线性分类器&#xff0c;将这些数据分为两…

android转流媒体,android 4.4中的流媒体渲染过程

第一次写blog&#xff0c;只是为了记下学习的过程。android中东西很多&#xff0c;架构和流程都很复杂&#xff0c;经常发现以前学习过的很多东西&#xff0c;即使当时看明白没多久就忘记了&#xff0c;只能重新拾起再看。于是想起blog这个东东&#xff0c;写下来总不会忘记&am…

【错误记录】python requests库 Response 判断坑

在requests访问之后, 我直接判断resp的值, 如下&#xff1a; if resp:do something发现当Response 为500的时候没有进入if分支, 检查源码&#xff0c;发现Response重写了__bool__方法, 根据resp.raise_for_status来确定是否为True, 当为500时, 为假, 记录一下转载于:https://ww…

linux Hello World 模块编程

折腾了差不多一个晚上: 1、关于在前面加上TAB,这个是有必要的、 2、Makefile的编写也是有些差异的 3、关于内核的版本可以通过uname -r来查看一下 我

logistic模型原理与推导过程分析(2)

二项逻辑回归模型 既然logistic回归把结果压缩到连续的区间(0,1)&#xff0c;而不是离散的0或者1&#xff0c;然后我们可以取定一个阈值&#xff0c;通常以0.5为阈值&#xff0c;如果计算出来的概率大于0.5&#xff0c;则将结果归为一类&#xff08;1&#xff09;&#xff0c;…

怎么样批量修改html里的内容,批量修改替换多个Word文档中同一内容的方法

批量修改替换多个Word文档中同一内容的方法群里一位朋友问到&#xff0c;如何一次性批量替换多个word文档中的同一内容。其实&#xff0c;实现多个Word文档的字符进行批量替换的方法有多种。第一种方法&#xff0c;可以利用第三方软件&#xff1a;全能字符串批量替换机。在网上…

织梦建站

这是调用对应文章连接的标签 [field:arcurl/]&#xff0c;例如下面这个调用文章标签里面&#xff0c;就调用了文章的连接&#xff1a; {dede:arclist row 10<a href"[field:arcurl/]">[field:title/]</a>dede:arclist/} 1234{dede:arclist flagh typeid…

linux下的字符设备驱动

Linux字符设备驱动程序的一个简单示例一.开发环境&#xff1a; 主 机&#xff1a;VMWare--Fedora 9 开发板&#xff1a;友善之臂mini2440--256MB Nandflash 编译器&#xff1a;arm-linux-gcc-4.3.2 二.驱动源码&#xff1a; 该源码很浅显易懂&#xff0c;非常适合初学者。 me…

logistic模型原理与推导过程分析(3)

附录&#xff1a;迭代公式向量化 θ相关的迭代公式为&#xff1a; ​ 如果按照此公式操作的话&#xff0c;每计算一个θ需要循环m次。为此&#xff0c;我们需要将迭代公式进行向量化。 首先我们将样本矩阵表示如下&#xff1a; 将要求的θ也表示成矩阵的形式&#xff1a; 将x…

计算机表示法是知识 表示法么,计算机三级考试关于IP地址知识点

计算机三级考试关于IP地址知识点IP地址是IP协议提供的一种统一的地址格式&#xff0c;它为互联网上的每一个网络和每一台主机分配一个逻辑地址&#xff0c;以此来屏蔽物理地址的差异&#xff0c;同时也是计算机三级考试的重要内容&#xff0c;小编整理了相关知识点&#xff0c;…

VMware 下Linux无法上网 新增支持WIFI方式 无线连接

试过很管用、如果宿主机可以上网、不管是有线还是无线、只在在连接网络那里是.net的方式就可以正常上网的 ADSL-VMware 共享上网 单机环境&#xff0c;ADSL拨号上网&#xff0c;安装VMware后&#xff0c;客户机如何与宿主机共享上网&#xff1f;网友经常问这个问题&#xff0…

Http学习笔记

在 MIME 扩展中会使用一种称为多部分对象集合&#xff08;Multipart&#xff09;的方法&#xff0c;来容纳多份不同类型的数据。包含的对象如下&#xff1a; form-data在 Web 表单文件上传时使用。byteranges状态码 206&#xff08;Partial Content&#xff0c;部分内容&#x…

监督学习与无监督学习

监督学习 用一个例子介绍什么是监督学习把正式的定义放在后面介绍。 假如说你想预测房价。前阵子&#xff0c;一个学生从波特兰俄勒冈州的研究所收集了一些房价的数据。你把这些数据画出来&#xff0c;看起来是这个样子&#xff1a; 横轴表示房子的面积&#xff0c;单位是平…

html css配色方案,链接css不同的配色方案问题

为什么导航链接采用正常链接的风格&#xff1f;这是一个基本的导航菜单&#xff1a;Home |Autobelettering |Reclame |Prints |Textiel |Ontwerpen |Aanleveren |Contact这是CSS/* Normal links */a {font-size: 12px;color: #DC342F;}a:link {text-decoration: none;color: #D…