pytorch-构建卷积神经网络

构建卷积神经网络

  • 卷积网络中的输入和层与传统神经网络有些区别,需重新设计,训练模块基本一致
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    from torchvision import datasets,transforms 
    import matplotlib.pyplot as plt
    import numpy as np
    %matplotlib inline

    首先读取数据

  • 分别构建训练集和测试集(验证集)
  • DataLoader来迭代取数据
    # 定义超参数 
    input_size = 28  #图像的总尺寸28*28
    num_classes = 10  #标签的种类数
    num_epochs = 3  #训练的总循环周期
    batch_size = 64  #一个撮(批次)的大小,64张图片# 训练集
    train_dataset = datasets.MNIST(root='./data',  train=True,   transform=transforms.ToTensor(),  download=True) # 测试集
    test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

    卷积网络模块构建

  • 一般卷积层,relu层,池化层可以写成一个套餐
  • 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务
    class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1,              # 灰度图out_channels=16,            # 要得到几多少个特征图kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),                      # relu层nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.conv3 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(32, 64, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),             # 输出 (32, 7, 7))self.out = nn.Linear(64 * 7 * 7, 10)   # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)output = self.out(x)return output

    准确率作为评估标准

    def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels) 

    训练网络模型

    # 实例化
    net = CNN() 
    #损失函数
    criterion = nn.CrossEntropyLoss() 
    #优化器
    optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
    for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = [] for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环net.train()                             output = net(data) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() right = accuracy(output, target) train_rights.append(right) if batch_idx % 100 == 0: net.eval() val_rights = [] for (data, target) in test_loader:output = net(data) right = accuracy(output, target) val_rights.append(right)#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1], 100. * val_r[0].numpy() / val_r[1]))
    当前epoch: 0 [0/60000 (0%)]	损失: 2.300918	训练集准确率: 10.94%	测试集正确率: 10.10%
    当前epoch: 0 [6400/60000 (11%)]	损失: 0.204191	训练集准确率: 78.06%	测试集正确率: 93.31%
    当前epoch: 0 [12800/60000 (21%)]	损失: 0.039503	训练集准确率: 86.51%	测试集正确率: 96.69%
    当前epoch: 0 [19200/60000 (32%)]	损失: 0.057866	训练集准确率: 89.93%	测试集正确率: 97.54%
    当前epoch: 0 [25600/60000 (43%)]	损失: 0.069566	训练集准确率: 91.68%	测试集正确率: 97.68%
    当前epoch: 0 [32000/60000 (53%)]	损失: 0.228793	训练集准确率: 92.85%	测试集正确率: 98.18%
    当前epoch: 0 [38400/60000 (64%)]	损失: 0.111003	训练集准确率: 93.72%	测试集正确率: 98.16%
    当前epoch: 0 [44800/60000 (75%)]	损失: 0.110226	训练集准确率: 94.28%	测试集正确率: 98.44%
    当前epoch: 0 [51200/60000 (85%)]	损失: 0.014538	训练集准确率: 94.78%	测试集正确率: 98.60%
    当前epoch: 0 [57600/60000 (96%)]	损失: 0.051019	训练集准确率: 95.14%	测试集正确率: 98.45%
    当前epoch: 1 [0/60000 (0%)]	损失: 0.036383	训练集准确率: 98.44%	测试集正确率: 98.68%
    当前epoch: 1 [6400/60000 (11%)]	损失: 0.088116	训练集准确率: 98.50%	测试集正确率: 98.37%
    当前epoch: 1 [12800/60000 (21%)]	损失: 0.120306	训练集准确率: 98.59%	测试集正确率: 98.97%
    当前epoch: 1 [19200/60000 (32%)]	损失: 0.030676	训练集准确率: 98.63%	测试集正确率: 98.83%
    当前epoch: 1 [25600/60000 (43%)]	损失: 0.068475	训练集准确率: 98.59%	测试集正确率: 98.87%
    当前epoch: 1 [32000/60000 (53%)]	损失: 0.033244	训练集准确率: 98.62%	测试集正确率: 99.03%
    当前epoch: 1 [38400/60000 (64%)]	损失: 0.024162	训练集准确率: 98.67%	测试集正确率: 98.81%
    当前epoch: 1 [44800/60000 (75%)]	损失: 0.006713	训练集准确率: 98.69%	测试集正确率: 98.17%
    当前epoch: 1 [51200/60000 (85%)]	损失: 0.009284	训练集准确率: 98.69%	测试集正确率: 98.97%
    当前epoch: 1 [57600/60000 (96%)]	损失: 0.036536	训练集准确率: 98.68%	测试集正确率: 98.97%
    当前epoch: 2 [0/60000 (0%)]	损失: 0.125235	训练集准确率: 98.44%	测试集正确率: 98.73%
    当前epoch: 2 [6400/60000 (11%)]	损失: 0.028075	训练集准确率: 99.13%	测试集正确率: 99.17%
    当前epoch: 2 [12800/60000 (21%)]	损失: 0.029663	训练集准确率: 99.26%	测试集正确率: 98.39%
    当前epoch: 2 [19200/60000 (32%)]	损失: 0.073855	训练集准确率: 99.20%	测试集正确率: 98.81%
    当前epoch: 2 [25600/60000 (43%)]	损失: 0.018130	训练集准确率: 99.16%	测试集正确率: 99.09%
    当前epoch: 2 [32000/60000 (53%)]	损失: 0.006968	训练集准确率: 99.15%	测试集正确率: 99.11%
    

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/70589.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Agisoft Metashape相机标定笔记

Lens Calibration(镜头标定) 使用Metashape进行自动相机标定是可能的。Metashape使用LCD显示屏作为标定目标(可选:使用打印的棋盘格图案,但需保证它是平坦的且单元格是正方形)。 相机标定步骤支持全相机标定矩阵的估计&#xff…

pg 配置 -- chatGPT

问:pg 配置不生成 log gpt: 如果你想在 PostgreSQL 中禁用日志记录(不生成日志),你可以采取以下步骤: **1. 编辑 PostgreSQL 配置文件:** 打开 PostgreSQL 的配置文件,通常位于 /etc/postgr…

蓝桥杯打卡Day3

文章目录 吃糖果递推数列 一、吃糖果IO链接 本题思路:本题题意就是斐波那契数列&#xff01; #include <bits/stdc.h>typedef uint64_t i64;i64 f(i64 n) {if(n1) return 1;if(n2) return 2;return f(n-1)f(n-2); }signed main() {std::ios::sync_with_stdio(false);s…

kubernetes——ingress

简介 ingress: 是k8s内部的一个资源对象ingress controller -> ingress控制器&#xff1a; 是k8s里启动的一个pod&#xff0c;运行的是nginx的镜像&#xff0c;实现k8s内部的service&#xff08;ClusterIP类型&#xff09;的负载均衡 ingress 和ingress controller 的关…

docker容器运行成功但无法访问,原因分析及对应解决方案(最新,以Tomcat为例,亲测有效)

原因分析&#xff1a; 是否能访问当运行docker容器虚拟机&#xff08;主机&#xff09;地址 虚拟机对应的端口号是否开启或者防墙是否关闭 端口映射是否正确&#xff08;这个是我遇到的&#xff09; tomcat下载的是最新版&#xff0c;docker运行后里面是没有东西的&am…

大厂密集背后,折叠屏市场“暗战”已起

在经过长达10余年的高速增长之后&#xff0c;智能手机行业在过去两年出现了显著的“颓势”&#xff0c;高中低端各个不同市场的分化也越发明显&#xff0c;但新兴的折叠屏市场却滑出了一条“向上”的曲线&#xff0c;引发市场关注。它从形态上颠覆了用户的认知&#xff0c;给市…

单片机-蜂鸣器

简介 蜂鸣器是一种一体化结构的电子讯响器&#xff0c;采用直流电压供电 蜂鸣器主要分为 压电式蜂鸣器 和 电磁式蜂鸣器 两 种类型。 压电式蜂鸣器 主要由多谐振荡器、压电蜂鸣片、阻抗匹配器及共鸣箱、外壳等组成。多谐振荡器由晶体管或集成电路构成&#xff0c;当接通电源后&…

Java ArrayList

简介 ArrayList类示一个可以动态修改的数组&#xff0c;与普通数组的区别是它没有固定大小的限制&#xff0c;可以添加和删除元素。 适用情况&#xff1a; 频繁的访问列表中的某一元素只需要在列表末尾进行添加和删除某些元素 实例 ArrayList 是一个数组队列&#xff0c;提…

PostgreSQL本地化

本地化的概念 本地化的目的是支持不同国家、地区的语言特性、规则。比如拥有本地化支持后&#xff0c;可以使用支持汉语、法语、日语等等的字符集。除了字符集以外&#xff0c;还有字符排序规则和其他语言相关规则的支持&#xff0c;例如我们知道(‘a’,‘b’)该如何排序&…

移动安全测试框架-MobSF WINDOWS 环境搭建

安装python python-3.11.5-amd64.exe 安装Win64OpenSSL-3_1_2.exe 安装VisualStudioSetup.exe github下载安装包 https://github.com/MobSF/Mobile-Security-Framework-MobSF/archive/refs/heads/master.zip GitHub - MobSF/Mobile-Security-Framework-MobSF: Mobile Secur…

亲测可用的1688商品详情API和关键字搜索API,需要自取!

引言 1688是中国最大的B2B电商平台之一&#xff0c;拥有众多的供应商和商品。为了方便用户查找需要的商品&#xff0c;1688提供了关键词搜索商品的API接口。通过该接口开发者可以在自己的系统中调用平台上的数据和功能&#xff0c;例如获取商品信息、下单等操作1。 1688平台提…

mysql 数据库面试题整理

Mysql 中 MyISAM 和 InnoDB 的区别 1、InnoDB 支持事务MyISAM 不支持 2、InnoDB 支持外键MyISAM 不支持 3、InnoDB 是聚集索引&#xff0c;MyISAM 是非聚集索引 4、InnoDB 不保存表的具体行数 5、InnoDB 最小的锁粒度是行锁&#xff0c;MyISAM是表锁 mysql中有就更新&#xf…

gismo程序示例:边长为 8 16 32 的长方体 受均布载荷

文章目录 前言一、一、8*32面 受均布载荷 二、最小的面&#xff08;8*16&#xff09;受均布载荷三、最大的面受均布载荷 前言 只是为方便学习&#xff0c;不做其他用途&#xff0c; 一、 一、8*32面 受均布载荷 /// This is an example of using the linear elasticity solver…

【进阶篇】Redis缓存击穿, 穿透, 雪崩, 污染详解

【进阶篇】Redis缓存穿击, 穿透, 雪崩, 污染详解 文章目录 【进阶篇】Redis缓存穿击, 穿透, 雪崩, 污染详解0. 前言大纲缓存穿击缓存穿透缓存雪崩缓存污染 1. 什么是缓存穿透&#xff1f;1.1. 发生原因1.2. 导致问题1.3. 解决方案 2. 什么是缓存击穿3.1. 发生原因3.1. 解决方案…

Spring-Cloud-Openfeign如何传递用户信息?

用户信息传递 微服务系统中&#xff0c;前端会携带登录生成的token访问后端接口&#xff0c;请求会首先到达网关&#xff0c;网关一般会做token解析&#xff0c;然后把解析出来的用户ID放到http的请求头中继续传递给后端的微服务&#xff0c;微服务中会有拦截器来做用户信息的…

PY32F003F18端口复用功能映射

PY32F003F18端口复用功能映射&#xff0c;GPIO引脚可配置为"输入&#xff0c;输出,模拟或复用功能。 一、端口A复用功能映射 端口A复用功能映射表里&#xff0c;每个引脚都有AF0~AF15&#xff0c;修改AF0~AF15的值&#xff0c;就可以将对应复用用能引脚映射到CPU引脚上。…

leetcode 925. 长按键入

2023.9.7 我的基本思路是两数组字符逐一对比&#xff0c;遇到不同的字符&#xff0c;判断一下typed与上一字符是否相同&#xff0c;不相同返回false&#xff0c;相同则继续对比。 最后要分别判断name和typed分别先遍历完时的情况。直接看代码&#xff1a; class Solution { p…

【业务功能篇99】微服务-springcloud-springboot-电商订单模块-生成订单服务-锁定库存

八、生成订单 一个是需要生成订单信息一个是需要生成订单项信息。具体的核心代码为 /*** 创建订单的方法* param vo* return*/private OrderCreateTO createOrder(OrderSubmitVO vo) {OrderCreateTO createTO new OrderCreateTO();// 创建订单OrderEntity orderEntity build…

[HNCTF 2022] web 刷题记录

文章目录 [HNCTF 2022 Week1]easy_html[HNCTF 2022 Week1]easy_upload[HNCTF 2022 Week1]Interesting_http[HNCTF 2022 WEEK2]ez_SSTI[HNCTF 2022 WEEK2]ez_ssrf[HNCTF 2022 WEEK2]Canyource [HNCTF 2022 Week1]easy_html 打开题目提示cookie有线索 访问一下url 发现要求我们…

C语言中计算函数执行时间

要计算C语言程序中某个函数从开始到结束执行的时间&#xff0c;可以使用C标准库中的clock()函数。clock()函数返回程序执行开始到当前时刻的时钟数&#xff0c;除以某个常数&#xff08;CLOCKS_PER_SEC&#xff09;&#xff0c;就可以得到程序执行的秒数。 下面是一个示例代码…