atm取款机的简单程序代码_LeNet:一个简单的卷积神经网络PyTorch实现

前两篇文章分别介绍了卷积层和池化层,卷积和池化是卷积神经网络必备的两大基础。本文我们将介绍一个早期用来识别手写数字图像的卷积神经网络:LeNet[1]。LeNet名字来源于论文的第一作者Yann LeCun。1989年,LeNet使用卷积神经网络和梯度下降法,使得手写数字识别达到当时领先水平。这个奠基性的工作第一次将卷积神经网络推上历史舞台,为世人所知。由于LeNet的出色表现,在很多ATM取款机上,LeNet被用来识别数字字符。

本文基于PyTorch和TensorFlow 2的代码已经放在了我的GitHub上:https://github.com/luweizheng/machine-learning-notes/tree/master/neural-network/cnn。

网络模型结构

LeNet的网络结构如下图所示。

c66712173dfb61d76cb0c32572553374.png

LeNet分为卷积层块和全连接层块两个部分。

卷积层块里的基本单位是卷积层后接最大池化层:卷积层用来识别图像里的空间模式,如线条和物体局部,之后的最大池化层则用来降低卷积层对位置的敏感性。卷积层块由卷积层加池化层两个这样的基本单位重复堆叠构成。在卷积层块中,每个卷积层都使用5×5的窗口,并在输出上使用Sigmoid激活函数。整个模型的输入是1维的黑白图像,图像尺寸为28×28。第一个卷积层输出通道数为6,第二个卷积层输出通道数则增加到16。这是因为第二个卷积层比第一个卷积层的输入的高和宽要小,所以增加输出通道使两个卷积层的参数尺寸类似。卷积层块的两个最大池化层的窗口形状均为2×2,且步幅为2。由于池化窗口与步幅形状相同,池化窗口在输入上每次滑动所覆盖的区域互不重叠。

我们通过PyTorch的Sequential类来实现LeNet模型。

class LeNet(nn.Module):    def __init__(self):        super(LeNet, self).__init__()                # 输入 1 * 28 * 28        self.conv = nn.Sequential(            # 卷积层1            # 在输入基础上增加了padding,28 * 28 -> 32 * 32            # 1 * 32 * 32 -> 6 * 28 * 28            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2), nn.Sigmoid(),            # 6 * 28 * 28 -> 6 * 14 * 14            nn.MaxPool2d(kernel_size=2, stride=2), # kernel_size, stride            # 卷积层2            # 6 * 14 * 14 -> 16 * 10 * 10             nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), nn.Sigmoid(),            # 16 * 10 * 10 -> 16 * 5 * 5            nn.MaxPool2d(kernel_size=2, stride=2)        )        self.fc = nn.Sequential(            # 全连接层1            nn.Linear(in_features=16 * 5 * 5, out_features=120), nn.Sigmoid(),            # 全连接层2            nn.Linear(in_features=120, out_features=84), nn.Sigmoid(),            nn.Linear(in_features=84, out_features=10)        )    def forward(self, img):        feature = self.conv(img)        output = self.fc(feature.view(img.shape[0], -1))        return output复制代码

我们有必要梳理一下模型各层的参数。输入形状为通道数为1的图像(1维黑白图像),尺寸为28×28,经过第一个5×5的卷积层,卷积时上下左右都使用了2个元素作为填充,输出形状为:(28 - 5 + 4 + 1) × (28 - 5 + 4 + 1) = 28 × 28。第一个卷积层输出共6个通道,输出形状为:6 × 28 × 28。最大池化层核大小2×2,步幅为2,高和宽都被折半,形状为:6 × 14 × 14。第二个卷积层的卷积核也为5 × 5,但是没有填充,所以输出形状为:(14 - 5 + 1) × (14 - 5 + 1) = 10 × 10。第二个卷积核的输出为16个通道,所以变成了 16 × 10 × 10。经过最大池化层后,高和宽折半,最终为:16 × 5 × 5。

卷积层块的输出形状为(batch_size, output_channels, height, width),在本例中是(batch_size, 16, 5, 5),其中,batch_size是可以调整大小的。当卷积层块的输出传入全连接层块时,全连接层块会将一个batch中每个样本变平(flatten)。原来是形状是:(通道数 × 高 × 宽),现在直接变成一个长向量,向量长度为通道数 × 高 × 宽。在本例中,展平后的向量长度为:16 × 5 × 5 = 400。全连接层块含3个全连接层。它们的输出个数分别是120、84和10,其中10为输出的类别个数。

训练模型

基于上面的网络,我们开始训练模型。我们使用Fashion-MNIST作为训练数据集,很多框架,比如PyTorc提供了Fashion-MNIST数据读取的模块,我做了一个简单的封装:

def load_data_fashion_mnist(batch_size, resize=None, root='~/Datasets/FashionMNIST'):    """Use torchvision.datasets module to download the fashion mnist dataset and then load into memory."""    trans = []    if resize:        trans.append(torchvision.transforms.Resize(size=resize))    trans.append(torchvision.transforms.ToTensor())        transform = torchvision.transforms.Compose(trans)    mnist_train = torchvision.datasets.FashionMNIST(root=root, train=True, download=True, transform=transform)    mnist_test = torchvision.datasets.FashionMNIST(root=root, train=False, download=True, transform=transform)    if sys.platform.startswith('win'):        num_workers = 0      else:        num_workers = 4    train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)    test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)    return train_iter, test_iter复制代码

load_data_fashion_mnist()方法返回训练集和测试集。

在训练过程中,我们希望看到每一轮迭代的准确度,构造一个evaluate_accuracy方法,计算当前一轮迭代的准确度(模型预测值与真实值之间的误差大小):

def evaluate_accuracy(data_iter, net, device=None):    if device is None and isinstance(net, torch.nn.Module):        device = list(net.parameters())[0].device    acc_sum, n = 0.0, 0    with torch.no_grad():        for X, y in data_iter:            if isinstance(net, torch.nn.Module):                # set the model to evaluation mode (disable dropout)                net.eval()                 # get the acc of this batch                acc_sum += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()                # change back to train mode                net.train()             n += y.shape[0]    return acc_sum / n复制代码

接着,我们可以构建一个train()方法,用来训练神经网络:

def try_gpu(i=0):    if torch.cuda.device_count() >= i + 1:        return torch.device(f'cuda:{i}')    return torch.device('cpu')def train(net, train_iter, test_iter, batch_size, optimizer, num_epochs, device=try_gpu()):    net = net.to(device)    print("training on", device)    loss = torch.nn.CrossEntropyLoss()    batch_count = 0    for epoch in range(num_epochs):        train_l_sum, train_acc_sum, n = 0.0, 0.0, 0        for X, y in train_iter:            X = X.to(device)            y = y.to(device)            y_hat = net(X)            l = loss(y_hat, y)            optimizer.zero_grad()            l.backward()            optimizer.step()            train_l_sum += l.cpu().item()            train_acc_sum += (y_hat.argmax(dim=1) == y).sum().cpu().item()            n += y.shape[0]            batch_count += 1        test_acc = evaluate_accuracy(test_iter, net)        if epoch % 10 == 0:            print(f'epoch {epoch + 1} : loss {train_l_sum / batch_count:.3f}, train acc {train_acc_sum / n:.3f}, test acc {test_acc:.3f}')复制代码

在整个程序的主逻辑中,设置必要的参数,读入训练和测试数据并开始训练:

def main():    batch_size = 256    lr, num_epochs = 0.9, 100    net = LeNet()    optimizer = torch.optim.SGD(net.parameters(), lr=lr)        # load data    train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)    # train    train(net, train_iter, test_iter, batch_size, optimizer, num_epochs)复制代码

小结

  1. LeNet是一个最简单的卷积神经网络,卷积神经网络包含卷积块部分和全连接层部分。
  2. 卷积块包括一个卷积层和一个池化层。

参考文献

  1. LeCun, Y., Bottou, L., Bengio, Y., & Haffner, P. (1998). Gradient-based learning applied to document recognition. Proceedings of the IEEE, 86(11), 2278-2324.
  2. http://d2l.ai/chapter_convolutional-neural-networks/lenet.html
  3. https://tangshusen.me/Dive-into-DL-PyTorch/#/chapter05_CNN/5.5_lenet

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/469081.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【数据集转换】VOC数据集转COCO数据集·代码实现+操作步骤

在自己的数据集上实验时,往往需要将VOC数据集转化为coco数据集,因为这种需求所以才记录这篇文章,代码出处未知,感谢开源。 在远程服务器上测试目标检测算法需要用到测试集,最常用的是coco2014/2017和voc07/12数据集。 …

idea spring tomcat启动失败_技术篇 | 实用IDEA插件和工具系列

前 言本章主要分享一些工作中常用的IDEA插件(Maven Helper、Lombok、Mybatis Log Plugin、RestfulToolkit、JRebel And XRebel)和实用工具arthas。01Maven Helper作用:能清晰的查看当项目的Maven依赖版本、依赖关系、依赖冲突等情况。使用步骤:①安装后,…

【数据集可视化】VOC数据集标注可视化+代码实现

二、VOC可视化数据集 1、作用 在做目标检测时,首先要检查标注数据。一方面是要了解标注的情况,另一方面是检查数据集的标注和格式是否正确,只有正确的情况下才能进行下一步的训练。 2、代码实现 import os # import sys import cv2 import…

串口UART串行总线协议

串口UART 串行端口是异步的(不传输时钟相关数据),两个设备在使用串口通信时,必须先约定一个数据传输速率,并且这两个设备各自的时钟频率必须与这个速率保持相近,某一方的时钟频率相差很大都会导致数据传输…

基于Springboot外卖系统01:技术构成+功能模块介绍

外卖系统是专门为餐饮企业(餐厅、饭店)定制的一款软件产品,包括 系统管理后台 和 移动端应用 两部分。其中系统管理后台主要提供给餐饮企业内部员工使用,可以对餐厅的分类、菜品、套餐、订单、员工等进行管理维护。移动端应用主要…

HTML5本地图片裁剪并上传

最近做了一个项目,这个项目中需要实现的一个功能是:用户自定义头像(用户在本地选择一张图片,在本地将图片裁剪成满足系统要求尺寸的大小)。这个功能的需求是:头像最初剪切为一个正方形。如果选择的图片小于…

嵌入式就应该这样学!!

嵌入式就应该这样学!! 1、Linux内核 Linux 内核定时器 Linux进程上下文和中断上下文内核空间和用户空间 Linux内核链表 Linux 内核模块编译 Linux内核使用Gdb调试 Linux动态打印kernel日志 Linux的中断可以嵌套吗 Linux内核定时器 Linux 驱动之Ioctl Lin…

基于Springboot外卖系统02:数据库搭建+Maven仓库搭建

1 数据库环境搭建 1.1 创建数据库 可以通过以下两种方式中的任意一种, 来创建项目的数据库: 1).图形界面 注意: 本项目数据库的字符串, 选择 utf8mb4 2).命令行 1.2 数据库表导入 项目的数据库创建好了之后, 可以直接将 资料/数据模型/db_reggie.sql 直接导入到数据库中, …

margin 负边距应用

margin-right:负值&#xff0c;在没有设置DOM元素宽度的前提下&#xff0c;DOM元素宽度变宽。 1 <!DOCTYPE html>2 <html lang"zh-CN">3 4 <head>5 <meta charset"UTF-8">6 <meta http-equiv"X-UA-Co…

基于Springboot外卖系统03:pom.xml导入依赖+数据库配置文件+Boot启动类+静态资源映射

1).在pom.xml中导入依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache…

写给过得很辛苦很迷茫的你~一定要看啊

#前面的话 我是一个农村的孩子&#xff0c;我家很穷&#xff0c;小时候过得非常苦&#xff0c;每次开学是我最害怕的时候&#xff0c;我害怕我爸妈拿不出学费&#xff0c;我害怕我爸妈会让我辍学在家帮忙干活&#xff0c;每次跟我妈吵架的时候&#xff0c;当我妈跟我说不让我读…

flatpickr功能强大的日期时间选择器插件

flatpickr日期时间选择器支持移动手机&#xff0c;提供多种内置的主题效果&#xff0c;并且提供对中文的支持。它的特点还有&#xff1a; 使用SVG作为界面的图标。 兼容jQuery。 支持对各种日期格式的解析。 轻量级&#xff0c;高性能&#xff0c;压缩后的版本仅6K大小。 对…

基于Springboot外卖系统04:后台系统用户登录+登出功能

登录业务流程 ① 在登录页面输入用户名和密码 ② 调用后台接口进行验证 ③ 通过验证之后&#xff0c;根据后台的响应状态跳转到项目主页 2. 登录业务的相关技术点 http 是无状态的通过 cookie 在客户端记录状态通过 session 在服务器端记录状态通过 token 方式维持状态如果前端…

排序算法时间复杂度、空间复杂度、稳定性比较

排序算法分类 排序算法比较表格填空 排序算法平均时间复杂度最坏时间复杂度空间复杂度是否稳定冒泡排序:————-::—–::—–::—–:选择排序:————-::—–::—–::—–:直接插入排序:————-::—–::—–::—–:归并排序:————-::—–::—–::—–:快速排序:———…

基于Springboot外卖系统05:用户非登陆状态的页面拦截器实现

1. 完善登录功能 1.1 问题分析 用户访问接口验证&#xff0c;如果用户没有登录&#xff0c;则不让他访问除登录外的任何接口。 1.前端登录&#xff0c;后端创建session&#xff0c;返给前端 2.前端访问其他接口&#xff0c;失效或不存在&#xff0c;则返回失效提示&#xff…

python删除指定行_关于csv:删除python中的特定行和对应文件

我想删除90%的"转向"值等于0的行。这三个图像都有一个对应的图像文件&#xff0c;中间&#xff0c;左边和右边。我也要删除它们。csv文件如下&#xff1a;我编写了以下代码&#xff0c;以至少获取转向值为0的文件。我所需要的就是随机获取90%的文件并删除它们的代码。…

I2C总线传输协议

简介 I2C&#xff08;Inter-integrated Circuit&#xff09;总线支持设备之间的短距离通信&#xff0c;用于处理器和一些外围设备之间的接口&#xff0c;它只需要两根信号线来完成信息交换。I2C最早是飞利浦在1982年开发设计并用于自己的芯片上&#xff0c;一开始只允许100kHz…

基于Springboot外卖系统06: 新增员工功能+全局异常处理器

2. 新增员工 2.1 需求分析 后台系统中可以管理员工信息&#xff0c;通过新增员工来添加后台系统用户。点击[添加员工]按钮跳转到新增页面&#xff0c;如下 当填写完表单信息, 点击"保存"按钮后, 会提交该表单的数据到服务端, 在服务端中需要接受数据, 然后将数据保…

spring aop实现原理_Spring 异步实现原理与实战分享

最近因为全链路压测项目需要对用户自定义线程池 Bean 进行适配工作&#xff0c;我们知道全链路压测的核心思想是对流量压测进行标记&#xff0c;因此我们需要给压测的流量请求进行打标&#xff0c;并在链路中进行传递&#xff0c;那么问题来了&#xff0c;如果项目中使用了多线…

基于Springboot外卖系统07:员工分页查询+ 分页插件配置+分页代码实现

1. 员工分页查询 1.1 需求分析 在分页查询页面中, 以分页的方式来展示列表数据&#xff0c;以及查询条件 "员工姓名"。 请求参数 搜索条件&#xff1a; 员工姓名(模糊查询) 分页条件&#xff1a; 每页展示条数 &#xff0c; 页码 响应数据 总记录数 结果列表 1…