深度学习基础实战使用MNIST数据集对图片分类

本文代码完全借鉴pytorch中文手册

'''我们找到数据集,对数据做预处理,定义我们的模型,调整超参数,测试训练,再通过训练结果对超参数进行调整或者对模型进行调整。'''
import torch
import torch.nn as nn
import torch.nn.functional as F 
import torch.optim as optim  #实现各种优化算法的库
from torchvision import datasets,transformsBATCH_SIZE=512  #大概需要2G的显存
EPOCHS=20       #总共训练20次
DEVICE=torch.device("cuda" if torch.cuda.is_available() else "cpu") #让torch判断是否使用GPU#对数据进行预处理
transforms=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))])
#准备数据集,路径是关于py文件的相对路径
trainset=datasets.MNIST(root='./MNIST_data',train=True,download=False,transform=transforms)#加载数据集
train_loader=torch.utils.data.DataLoader(trainset,batch_size=BATCH_SIZE,shuffle=True)#准备测试集
testset=datasets.MNIST(root='./MNIST_data',train=True,download=False,transform=transforms)#加载测试集
test_loader=torch.utils.data.DataLoader(testset,batch_size=BATCH_SIZE,shuffle=True)#定义卷积神经网络
class ConvNet(nn.Module):def __init__(self):super().__init__()#batch*1*28*28(每次会送入batch个样本,输入通道数1(黑白图像)),图像分辨率是28*28)#下面的卷积层Conv2d的第一个参数指输入通道数,第二个参数指输出通道数,第三个参数指卷积核的大小self.conv1=nn.Conv2d(1,10,5)self.conv2=nn.Conv2d(10,20,3)#下面的全连接层Linear的第一个参数指输入通道数,第二个参数指输出通道数self.fc1=nn.Linear(20*10*10,500) #输入通道数是2000,输出通道数是500self.fc2=nn.Linear(500,10)  #输入通道数是500,输出通道数是10,即10分类def forward(self,x):in_size=x.size(0)  #在本例中in_size=512,也就是BATCH_SIZE的值。输入的x可以看成是512*1*28*28的张量out=self.conv1(x)  #batch*1*28*28 -> batch*10*24*24(28×28的图像经过一次核为5×5的卷积,输出变为24×24)out=F.relu(out)    #batch*10*24*24(激活函数ReLU不改变形状)out=F.max_pool2d(out,2,2)#batch*10*24*24 -> batch*10*12*12(2×2的池化层会减半)out=self.conv2(out) #batch*10*12*12 -> batch*20*10*10(再卷积一次,核的大小是3)out=F.relu(out)out=out.view(in_size,-1) #batch*20*10*10 -> batch*2000(out的第二维是-1,说明是自动推算,本例中第二维是20*10*10)out=self.fc1(out)        #batch*2000 -> batch*500out=F.relu(out)         out=self.fc2(out)		 #batch*500 -> batch*10out=F.log_softmax(out,dim=1) #计算log(softmax(x)),用log是为了防止数过大。return out#我们实例化一个网络,实例化后使用.to方法将网络移动到GPU
model=ConvNet().to(DEVICE)#优化器我们也直接选择简单暴力的Adam
optimizer=optim.Adam(model.parameters())#定义一个训练函数
def train(model,device,train_loader,optimizer,epoch):model.train() #启用BatchNormalization和Dropout,将BatchNormalization和Dropout置为Truefor batch_idex,(data,target) in enumerate(train_loader): ##将迭代器的数据组成一个索引系列,并输出索引和值,batch_diex是序号,后者是数据data,target=data.to(device),target.to(device)  #在gpu上跑optimizer.zero_grad() #梯度清零output=model(data)  #将数据放入模型loss=F.nll_loss(output,target)  #计算损失函数loss.backward()    #计算梯度optimizer.step()if (batch_idex+1)%30 == 0:  #每训练30个打印一次print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch,batch_idex*len(data),len(train_loader.dataset),100. * batch_idex/len(train_loader.dataset),loss.item()))#定义一个测试函数
def test(model,device,test_loader):model.eval()  #不启用BatchNormalization和Dropout,将BatchNormalization和Dropout置为Falsetest_loss=0correct=0with torch.no_grad():for data,target in test_loader:data,target=data.to(device),target.to(device)  #在gpu上跑output=model(data)test_loss+=F.nll_loss(output,target,reduction='sum').item()  #将一批的损失相加pred=output.max(1,keepdim=True)[1] #找到概率最大的下标correct+=pred.eq(target.view_as(pred)).sum().item()test_loss/=len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss,correct,len(test_loader.dataset),100. * correct/len(test_loader.dataset)))#下面开始训练,这里就体现出封装起来的好处了,只要写两行就可以了
for epoch in range(1,EPOCHS+1):train(model,DEVICE,train_loader,optimizer,epoch)test(model,DEVICE,test_loader)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/333970.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux 命令之 curl -- 文件传输工具/下载工具/网络接口调试

文章目录 一、命令介绍二、常用选项三、wget 与 curl 对比四、命令示例(一)以 post 方式提交数据/以 post 方式传递请求参数(二)查看网页的源码内容(三)保存访问的网页源码内容(四)将服务器的回应保存成文件/将输出保存成文件(五)显示 http response 头信息,打印出服…

python cookie使用_Python使用cookielib模块操作cookie的实例教程

cookielib是一个自动处理cookies的模块,如果我们在使用爬虫等技术的时候需要保存cookie,那么cookielib会让你事半功倍!他最常见的搭档模块就是python下的urllib和request。核心类1.Cookie该类实现了Netscape and RFC 2965 cookies定义的cooki…

pytorch中unsqueeze()函数理解

unsqueeze()函数起升维的作用,参数表示在哪个地方加一个维度。 在第一个维度(中括号)的每个元素加中括号 0表示在张量最外层加一个中括号变成第一维。 直接看例子: import torch inputtorch.arange(0,6) print(input) print(input.shape) 结果: tensor([0, 1, 2, 3…

Linux 命令之 ifconfig -- 配置和显示网卡的网络参数

文章目录一、命令介绍二、常用选项三、参考示例(一)显示网络设备信息(激活状态的)(二)启动关闭指定网卡(三)显示所有配置的网络接口,不论其是否激活(四&#…

版本交付_连续交付友好的Maven版本

版本交付持续交付管道需要可预测的软件和依赖版本。 Maven软件项目中常见的快照版本与“持续交付”背后的动机背道而驰。 为了将快照版本更新为发行版本,开发人员通常手动或通过诸如maven-release-plugin来编辑pom.xml文件。 但是,Maven还提供了将版本号…

shell开启飞行模式_今天才知道,原来手机的飞行模式用处那么多,看完涨知识了...

想必大家都知道,手机里有个飞行模式,是在乘坐飞机时使用的。其实除了这个功能之外,飞行模式还有很多其他的妙用,下面笔者就为大家一一进行介绍。一、 加快充电速度有些特殊情况你想加快手机的充电速度时,可以试着开启飞…

Anaconda安装库

有时候pip安装库特别慢,就采用conda别的方法装 conda install -c conda-forge 库名#如pydicom,gdcm

Linux 如何安装 SRPM 包(源代码 rpm 软件包,以 .src.rpm 为后缀名)/rpm 格式的源码软件包/源码包

文章目录一、SRPM 介绍二、SRPM 命名格式三、SRPM 的安装(一)直接使用命令 rpmbuild(二)利用 *.spec 文件编译(三)使用命令 make 编译和安装四、写在最后一、SRPM 介绍 SRPM 包,比 RPM 包多了一…

payara 创建 集群_在Payara Server和GlassFish中配置密码

payara 创建 集群回答Stackoverflow问题可以为我发现我最喜欢的开源工具的正式文档中的空白提供很好的反馈。 我在这里回答的问题之一是如何在docker容器中更改Payara Server主密码 。 显然,在标准服务器安装中,这很简单–只需使用asadmin change-master…

axure怎么做5秒倒计时_五个月宝宝早教,5个月婴儿早教怎么做

五个月宝宝早教,5个月婴儿早教怎么做,5个月宝宝是需要开始有意识的进行精细动作的家庭训练了5个月宝宝的一般特点:到了5个月时,能用眼睛观察周围的物体了,而且对什么都感到新奇好玩,能在眼睛的支配下抓住东…

python中enumerate()的理解

enumerate()函数的作用是通过迭代来遍历一个字符串、列表或字典等,并且为其增加索引,返回值为enumerate类。 代码举例如下: list[1,2,3,4,5,6] for i,j in enumerate(list):print(i,j) #结果: 0 1 1 2 2 3 3 4 4 5 5 6namesaber…

jpa执行sql脚本_JPA persistence.xml SQL脚本定义

jpa执行sql脚本您可以在将在运行时执行的JPA持久性上下文定义中定义并链接到SQL脚本。 有标准化的属性来定义脚本&#xff0c;以分别说明如何创建模式&#xff0c;批量加载数据和删除模式&#xff1a; <persistence version"2.1" xmlns"http://xmlns.jcp.or…

RPM 软件包默认的安装路径

通常情况下&#xff0c;RPM 包采用系统默认的安装路径&#xff0c;所有安装文件会按照类别分散安装到表 1 所示的目录中。 表 1 RPM 包默认安装路径安装路径含义/etc/配置文件安装目录/usr/bin/可执行的命令文件安装目录/usr/lib/程序所使用的函数库保存位置/usr/share/doc/基本…

图像融合亮度一致_重磅干货低光图像处理方案

点击上方“AIWalker”&#xff0c;选择加“星标”或“置顶” 重磅干货&#xff0c;第一时间送达Tips&#xff1a;一点点提示&#xff0c;因内容较多建议先关注&#xff0c;再置顶&#xff0c;最后端杯茶来精心浏览。背景低光图像是夜晚拍照时极为常见的一种现象。不充分的光…

修改本地文件的名字

将名字叫做megumi的文本文件改成名字叫做asuna的文本文件。主要用到os库的rename方法。 代码如下: import os folder"C:/Users/13451/Desktop" oldos.path.join(folder,megumi.txt) #或者oldfolder/megumi.txt newos.path.join(folder,asuna.txt) os.rename(old,n…

Adobe PhotoShop(PS) for Mac 如何隐藏切片框?

如何取消显示如下图所示的切片框&#xff1a; 打开『视图』➟ 『显示』&#xff0c;把『切片』前面的勾去掉&#xff0c;如下图所示&#xff1a;

groovy grails_在Grails战争中添加一个“精简”的Groovy Web控制台

groovy grails假设您已将Grails应用程序部署到服务器上–如何查找应用程序的配置方式&#xff1f; 如果您有来源&#xff0c;则可以查看Config.groovy &#xff0c; BuildConfig.groovy等&#xff08;在这种情况下&#xff0c;我正在谈论Grails 2应用程序&#xff0c;但是这些想…

邮宝打印面单尺寸调整_如何打印身份证的实际尺寸?怎样用照片打印身份证复印件...

点击上面 蓝色 文字关注我们&#xff0c;了解选购百科知识&#xff0c;快乐健康不停&#xff01;怎样打印身份证复印件&#xff1f;可以把身份证的照片导入电脑或者扫描件打印黑白的即可。如何打印身份证的实际尺寸&#xff1f;二代身份证的实际尺寸是&#xff1a;85.6MM X 54M…

将一个文件夹的文件复制到另一个文件夹

将桌面的文件复制到F盘anime文件夹下&#xff0c;主要用到shutil库下的copy方法。 from shutil import copy import os from_pathC:/Users/13451/Desktop/asuna.txt #asuna.txt是文件 to_pathF:/anime #anime是个文件夹 copy(from_path, to_path)

取模和求余运算

文章目录背景探究总结被除数 dividend 用 a 表示&#xff1b; 除数 divisor 用 b 表示&#xff1b; 商 quotient 用 q 表示&#xff1b; 余 remainder 用 rem 表示&#xff1b; 模 modulo 用 mod 表示。 背景 最近在一道 Java 习题中&#xff0c;看到这样的一道题&#xff1a;…