[pytorch、学习] - 5.5 卷积神经网络(LeNet)

参考

5.5 卷积神经网络(LeNet)

卷积层尝试解决两个问题:

  1. 卷积层保留输入形状,使图像的像素在高和宽两个方向上的相关性均可能被有效识别;
  2. 卷积层通过滑动窗口将同一卷积核和不同位置的输入重复计算,从而避免参数尺寸过大。

在这里插入图片描述

5.5.1 LeNet模型

LeNet分为卷积层块和全连接层块两个部分.

卷积层块的基本单位是卷积层后接最大池化层: 卷积层用来识别图像里的空间模式(线条和物体局部),之后最大池化用来降低卷积层对位置的敏感性。卷积层块由两个这样的基本单位重复堆叠构成。

在卷积层块中,每个卷积层都使用5×5的窗口,并在输出上使用sigmoid激活函数。第一个卷积层输出通道数为6,第二个卷积层输出通道数则增加到16。这是因为第二个卷积层比第一个卷积层的输入的高和宽要小,所以增加输出通道使两个卷积层的参数尺寸类似。卷积层块的两个最大池化层的窗口形状均为2×2,且步幅为2。由于池化窗口与步幅形状相同,池化窗口在输入上每次滑动所覆盖的区域互不重叠。

卷积层块的输出形状为(批量大小, 通道, 高, 宽)。当卷积层块的输出传入全连接层块时,全连接层块会将小批量中每个样本变平(flatten)。也就是说,全连接层的输入形状将变成二维,其中第一维是小批量中的样本,第二维是每个样本变平后的向量表示,且向量长度为通道、高和宽的乘积。全连接层块含3个全连接层。它们的输出个数分别是120、84和10,其中10为输出的类别个数。

下面通过Sequential类来实现LeNet模型

import time
import torch
import torch.nn as nn
import sys
sys.path.append("..")
import d2lzh_pytorch as d2ldevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv = nn.Sequential(nn.Conv2d(1, 6, 5),   # in_channels, out_channels, kernel_size: (1, 1, 28, 28) -> (6, 1, 24, 24)nn.Sigmoid(),nn.MaxPool2d(2, 2),   #  kernel_size, stride: (6, 24, 24) -> (6, 1,12, 12)nn.Conv2d(6, 16, 5),  # (6, 1, 12, 12) -> (16, 1, 8, 8)nn.Sigmoid(),nn.MaxPool2d(2, 2)    # (16, 1, 8, 8) -> (16, 1, 4, 4))self.fc = nn.Sequential(nn.Linear(16*4*4, 120), # (16, 1, 4, 4) -> (256) -> (120)nn.Sigmoid(),nn.Linear(120, 84),  # (120) -> (84)nn.Sigmoid(),nn.Linear(84, 10)  # (84) -> (10))def forward(self, img):# img: 1 * 1 * 28 * 28feature = self.conv(img)  output = self.fc(feature.view(img.shape[0], -1))return output
net = LeNet()
print(net)

在这里插入图片描述
可以看到,在卷积层块中输入的高和宽在逐层减小。卷积层由于使用高和宽均为5的卷积核,从而将高和宽分别减小4,而池化层则将高和宽减半,但通道数则从1增加到16。全连接层则逐层减少输出个数,直到变成图像的类别数10。

5.5.2 获取数据和训练模型

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size = batch_size)
# 使用GPU计算
def evaluate_accuracy(data_iter, net, device=None):if device is None and isinstance(net, torch.nn.Module):# 如果没指定device就使用net的devicedevice = list(net.parameters())[0].deviceacc_sum, n = 0.0, 0with torch.no_grad():for X, y in data_iter:if isinstance(net, torch.nn.Module):net.eval() # 评估模式, 这会关闭dropoutacc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()net.train() # 改回训练模式else: # 自定义的模型, 3.13节之后不会用到, 不考虑GPUif('is_training' in net.__code__.co_varnames): # 如果有is_training这个参数# 将is_training设置成Falseacc_sum += (net(X, is_training=False).argmax(dim=1) == y).float().sum().item() else:acc_sum += (net(X).argmax(dim=1) == y).float().sum().item() n += y.shape[0]return acc_sum / n
def train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs):net = net.to(device)print("training on ", device)loss = torch.nn.CrossEntropyLoss()for epoch in range(num_epochs):train_l_sum, train_acc_sum, n, batch_count, start = 0.0, 0.0, 0, 0, time.time()for X, y in train_iter:X = X.to(device)y = y.to(device)y_hat = net(X)l = loss(y_hat, y)optimizer.zero_grad()l.backward()optimizer.step()train_l_sum += l.cpu().item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()n += y.shape[0]batch_count += 1test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'% (epoch + 1, train_l_sum / batch_count, train_acc_sum / n, test_acc, time.time() - start))
lr, num_epochs = 0.001, 10
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/250132.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android内存管理机制

好文摘录 原作: https://www.cnblogs.com/nathan909/p/5372981.html 1、基于Linux内存管理 Android系统是基于Linux 2.6内核开发的开源操作系统,而linux系统的内存管理有其独特的动态存储管理机制。不过Android系统对Linux的内存管理机制进行了优化&…

【Ruby】Ruby 类案例

阅读目录 Ruby类案例保存并执行代码Ruby类案例 下面将创建一个名为 Customer 的 Ruby 类,声明两个方法: display_details:该方法用于显示客户的详细信息。total_no_of_customers:该方法用于显示在系统中创建的客户总数量。实例 #!…

[pytorch、学习] - 5.6 深度卷积神经网络(AlexNet)

参考 5.6 深度卷积神经网络(AlexNet) 在LeNet提出后的将近20年里,神经网络一度被其他机器学习方法超越,如支持向量机。虽然LeNet可以在早期的小数据集上取得好的成绩,但是在更大的真实数据集上的表现并不尽如人意。一方面,神经网络计算复杂。虽然20世纪…

Springboot---Model,ModelMap,ModelAndView

Model(org.springframework.ui.Model) Model是一个接口,包含addAttribute方法,其实现类是ExtendedModelMap。 ExtendedModelMap继承了ModelMap类,ModelMap类实现了Map接口。 public class ExtendedModelMap extends M…

东南亚支付——柬埔寨行

考察时间:2018.5.28 至 2018.6.6 为了解柬埔寨大概国情和市场,在柬埔寨开展了为期近10天的工作。 观察了交通情况,周边街道的店面与商品,摊贩等,也走访了大学校区,看了永旺商超、本地超市和中国超市&#x…

Puzzle (II) UVA - 519

题目链接: https://vjudge.net/problem/UVA-519 思路: 剪枝回溯 这个题巧妙的是他按照表格的位置开始搜索,也就是说表格是定的,他不断用已有的图片从(0,0)开始拼到(n-1,m-1) 剪枝的地方: 1.由于含F的面只能拼到边上&am…

[pytorch、学习] - 5.7 使用重复元素的网络(VGG)

参考 5.7 使用重复元素的网络(VGG) AlexNet在LeNet的基础上增加了3个卷积层。但AlexNet作者对它们的卷积窗口、输出通道数和构造顺序均做了大量的调整。虽然AlexNet指明了深度卷积神经网络可以取得出色的结果,但并没有提供简单的规则以指导…

springboot---mybits整合

配置 POM文件 <parent> <groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>1.5.6.RELEASE</version><relativePath /> </parent><properties><proj…

使用airdrop进行文件共享

使用airdrop进行文件共享 学习了&#xff1a; https://support.apple.com/zh-cn/HT203106 https://zh.wikihow.com/%E5%9C%A8Mac%E4%B8%8A%E7%94%A8%E8%BF%91%E6%9C%BA%E6%8D%B7%E4%BC%A0%EF%BC%88Airdrop%EF%BC%89%E5%85%B1%E4%BA%AB%E6%96%87%E4%BB%B6 转载于:https://www.cn…

【链表】逆序打印链表

1 public class Main {2 3 // 逆序打印链表4 public void reversePrint(Node node) {5 if (node null){6 return;7 }8 reversePrint(node.next);9 System.out.println(node.data); 10 } 11 12 public Node crea…

[pytorch、学习] - 5.8 网络中的网络(NiN)

参考 5.8 网络中的网络&#xff08;NiN&#xff09; 前几节介绍的LeNet、AlexNet和VGG在设计上的共同之处是&#xff1a;先以由卷积层构成的模块充分抽取空间特征&#xff0c;再以由全连接层构成的模块来输出分类结果。其中&#xff0c;AlexNet和VGG对LeNet的改进主要在于如何…

springboot---集成mybits方法

SpringBoot集成mybatis配置 一个有趣的现象&#xff1a;传统企业大都喜欢使用hibernate,互联网行业通常使用mybatis&#xff1b;之所以出现这个问题感觉与对应的业务有关&#xff0c;比方说&#xff0c;互联网的业务更加的复杂&#xff0c;更加需要进行灵活性的处理&#xff0c…

jQuery源码解读

参考 &#xff1a; https://www.cnblogs.com/yuqingfamily/p/5785593.html 转载于:https://www.cnblogs.com/wfblog/p/9172622.html

info.plist文件里面添加描述 - 配置定位,相册等

<key>NSAppleMusicUsageDescription</key> <string>App需要您的同意,才能访问媒体资料库</string> <key>NSBluetoothPeripheralUsageDescription</key> <string>App需要您的同意,才能访问蓝牙</string> <key>NSCalendar…

[pytorch、学习] - 5.9 含并行连结的网络(GoogLeNet)

参考 5.9 含并行连结的网络&#xff08;GoogLeNet&#xff09; 在2014年的ImageNet图像识别挑战赛中&#xff0c;一个名叫GoogLeNet的网络结构大放异彩。它虽然在名字上向LeNet致敬&#xff0c;但在网络结构上已经很难看到LeNet的影子。GoogLeNet吸收了NiN中网络串联网络的思…

mybits注解详解

一、mybatis 简单注解 关键注解词 &#xff1a; Insert &#xff1a; 插入sql , 和xml insert sql语法完全一样 Select &#xff1a; 查询sql, 和xml select sql语法完全一样 Update &#xff1a; 更新sql, 和xml update sql语法完全一样 Delete &#xff1a; 删除sql, 和xml d…

使用python装饰器计算函数运行时间的实例

使用python装饰器计算函数运行时间的实例 装饰器在python里面有很重要的作用&#xff0c; 如果能够熟练使用&#xff0c;将会大大的提高工作效率 今天就来见识一下 python 装饰器&#xff0c;到底是怎么工作的。 本文主要是利用python装饰器计算函数运行时间 一些需要精确的计算…

SQLServer用存储过程实现插入更新数据

实现 1&#xff09;有同样的数据&#xff0c;直接返回&#xff08;返回值&#xff1a;0&#xff09;。 2&#xff09;有主键同样。可是数据不同的数据。进行更新处理&#xff08;返回值&#xff1a;2&#xff09;&#xff1b; 3&#xff09;没有数据&#xff0c;进行插入数据处…

[pytorch、学习] - 9.1 图像增广

参考 9.1 图像增广 在5.6节(深度卷积神经网络)里我们提过,大规模数据集是成功应用神经网络的前提。图像增广(image augmentation)技术通过对训练图像做一系列随机改变,来产生相似但又不相同的训练样本,从而扩大训练数据集的规模。图像增广的另一种解释是,随机改变训练样本可以…

mysql绿色版安装

导读&#xff1a;MySQL是一款关系型数据库产品&#xff0c;官网给出了两种安装包格式&#xff1a;MSI和ZIP。MSI格式是图形界面安装方式&#xff0c;基本只需下一步即可&#xff0c;这篇文章主要介绍ZIP格式的安装过程。ZIP Archive版是免安装的。只要解压就行了。 一、首先下…