经典卷积神经网络 LeNet

一、实例图片

#我们传入的是28*28,所以加了padding
net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.Sigmoid(),nn.Linear(120, 84), nn.Sigmoid(),nn.Linear(84, 10))

二、总结

1、LeNet是早期成功的神经网络

2、先使用卷积层来学习图片空间信息

3、然后使用全连接层来转换到类别空间

三、代码

1、评估模型,将参数放到GPU中

def evaluate_accuracy_gpu(net, data_iter, device=None): #@save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):#eval是将模型设置为评估模式,评估模式就不会改变模型参数了可以用来预测结果;eval就是关闭模型中的dropout功能,调到评价模式;与之相对的是train()net.eval()if not device:#如果未提供设备参数,则使用模型第一个参数的设备作为默认设备。这确保了数据和模型在            同一设备上运行device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)#确保在计算精度时不会计算梯度,从而节省显存和提高计算效率with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

2、训练模型

#@save
def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""#对于net里面的所有parameter,都去run一下那个初始化权重的函数。就是说在整个net中的所有层上面都使用init__weights函数来初始化所有现行层和卷积层的权重def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:#xavier能够根据输入输出的大小,使得初始化随机权重能使,输入和输出的方差差别不会很大,保证在模型最开始的时候,结果不会指数爆炸或者消失nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)#SGD是随机梯度下降算法optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():#l是当前批次的平均损失。X.shape[0]是当前批次的样本数。l * X.shape[0]计算的是当前批次的总损失。metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]#这里是可视化#(i + 1) % (num_batches // 5)每训练到一个阶段时(5次中的每一次)会更新可视化数据#i == num_batches - 1:当训练到最后一个批次时,条件也会满足if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/866148.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux Swap机制关键点分析

1. page被swap出去之后,再次缺页是怎么找到找个换出的页面? 正常内存的页面是通过pte映射找到page的,swap出去的page有其特殊的方式:swap的页面page->private字段保存的是:swap_entry_t通过swap_entry_t就能找到该页面的扇区号sector_t,拿到扇区号就可以从块设备中读…

Werkzeug库介绍:Python WSGI工具集

Werkzeug库介绍:Python WSGI工具集 1. 什么是Werkzeug?2. 基本概念3. 安装Werkzeug4. 基本用法示例4.1 创建一个简单的WSGI应用4.2 路由和URL构建4.3 处理表单数据 5. 高级特性5.1 中间件5.2 Sessions5.3 文件上传 6. 性能考虑7. 注意事项8. 结语 1. 什么是Werkzeug? Werkze…

day04-组织架构

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 1.组织架构-树组件应用树形组件-用层级结构展示信息,可展开或折叠。 2.组织架构-树组件自定义结构3.组织架构-获取组织架构数据4.组织架构-递归转化树形…

算力共享所面临的痛点问题和现有解决办法,怎样和人工智能相结合

目录 算力共享所面临的痛点问题和现有解决办法,怎样和人工智能相结合 算力共享所面临的痛点问题 现有解决办法 与人工智能的结合 怎样利用分布式计算技术将算力下沉到更接近用户的地方,减少延迟和提高可用性。 一、分布式计算技术的应用 二、算力下沉的策略 算力共享所…

Redis 典型应用——分布式锁

一、什么是分布式锁 在一个分布式的系统中,也会涉及到多个节点访问同一个公共资源的情况,此时就需要通过锁来做互斥控制,避免出现类似于 "线程安全" 的问题; 而 Java 中的 synchronized,只能在当前进程中生…

day60---面试专题(微服务面试题-参考回答)

微服务面试题 **面试官:**Spring Cloud 5大组件有哪些? 候选人: 早期我们一般认为的Spring Cloud五大组件是 Eureka : 注册中心Ribbon : 负载均衡Feign : 远程调用Hystrix : 服务熔断Zuul/Gateway : 网关 随着SpringCloudAlibba在国内兴起 , …

HOW - React Router Feature 实践(react-router-dom)

目录 基本特性ranked routes matchingactive linksNavLinkuseMatch relative links1. 相对路径的使用2. 嵌套路由的增强行为3. 优势和注意事项4. . 和 ..5. 总结 data loading 基本特性 client side routingnested routesdynamic segments 比较好理解,这里不赘述。…

【C语言】控制台扫雷(C语言实现)

目录 博文目的实现思路项目创建文件解释 具体实现判断玩家进行游戏还是退出扫雷棋盘的确定地图初始化埋雷玩家扫雷的实现雷判断函数 源码game.cgame.h扫雷.c 博文目的 相信不少人都学习了c语言的函数,循环,分支那我们就可以写一个控制台的扫雷小游戏来检…

面向对象-封装

一.包 1.简介 当我们把所有的java类都写src下的第一层级,如果是项目中,也许会有几百个java文件。 src下的文件会很多,开发的时候不方便查找,也不方便维护如果较多的文件中有同名的,十分麻烦 模块1中有一个叫test.ja…

android应用的持续构建CI(二)-- jenkins集成

一、背景 接着上一篇文章,本文我们将使用jenkins把所有的流程串起来。 略去了对android应用的加固流程,重点是jenkins的job该如何配置。 二、配置jenkins job 0、新建job 选择一个自由风格的软件项目 1、参数赋值 你可以增加许多参数,这…

Games101学习笔记 Lecture16 Ray Tracing 4 (Monte Carlo Path Tracing)

Lecture16 Ray Tracing 4 (Monte Carlo Path Tracing 一、蒙特卡洛积分 Monte Carlo Integration二、路径追踪 Path tracing1.Whitted-Style Ray Tracings Problems2.只考虑直接光照时3.考虑全局光照①考虑物体的反射光②俄罗斯轮盘赌 RR (得到正确shade函数&#x…

全球投资中如何规避国别风险

不管邓普顿在全球投资中是自上而下的选股,还是自下而上的选股,他都不得不面临在不在某一个国家大规模投资的问题。尽管我暂时不会考虑跨国投资,不过还是可以学习一下。那么,他是怎么规避国别风险的呢?劳伦在《逆向投资…

Linux-Kafka 3.7.0 Kraft+SASL认证模式 集群安装与部署超详细

1.集群规划 一般模式下,元数据在 zookeeper 中,运行时动态选举 controller,由controller 进行 Kafka 集群管理。kraft 模式架构(实验性)下,不再依赖 zookeeper 集群,而是用三台 controller 节点…

嵌入式底层系统了解

当裸机功能不复杂的时候,即类似与点亮一个LED灯,驱动LCD和OLED这样的模块,以及各位大学生的搭积木式的毕业设计(狗头保命),此时可以简单地分为硬件和软件层(应用层),以及以中间层作为中间联系。 当需要实现…

深入Kafka:如何保证数据一致性与可靠性?

我是小米,一个喜欢分享技术的29岁程序员。如果你喜欢我的文章,欢迎关注我的微信公众号“软件求生”,获取更多技术干货! Hello, 大家好!我是小米,今天我们来聊一聊Kafka的一致性问题。Kafka作为一个高性能的分布式流处理平台,一直以来都备受关注。今天,我将深入探讨Kaf…

C++(第四天----拷贝函数、类的组合、类的继承)

一、拷贝构造函数(复制构造函数) 1、概念 拷贝构造函数,它只有一个参数,参数类型是本类的引用。如果类的设计者不写拷贝构造函数,编译器就会自动生成拷贝构造函数。大多数情况下,其作用是实现从源对象到目…

Python获取QQ音乐歌单歌曲

准备工作 歌单分享的url地址 比如: https://i.y.qq.com/n2/m/share/details/taoge.html?hosteuin=oKvzoK4l7evk7n**&id=9102222552&appversion=130605&ADTAG=wxfshare&appshare=iphone_wx 代码实现 def mu(share_url):share_url = share_url.split(id=)[1…

目标检测入门:3.目标检测损失函数(IOU、GIOU、GIOU)

目录 一、IOU 二、GIOU 三、DIOU 四、DIOU_Loss实战 在前面两章里面训练模型时,损失函数都是选择L1Loss(平均绝对值误差(MAE))损失函数,L1Loss损失函数公式如下: 由公式可知,L1Loss损失函数…

为PPT加密:如何设置和管理“打开密码”?

在保护演示文稿的内容时,给PPT文件设置“打开密码”是一个简单而有效的方法。今天一起来看看如何设置和管理PPT文件的“打开密码”吧! 一、设置PPT“打开密码” 首先,打开需要加密的PPT文件,点击左上角的“文件”选项卡&#x…

oracle如何判定数据库的时区并进行时间的时区转换

在Oracle数据库中,判断和设置时区以及进行时区的转换是很重要的功能。以下是一些基本的步骤和方法: 1. 判定数据库的时区 要查看Oracle数据库的时区,你可以查询DBTIMEZONE。例如: sql SELECT DBTIMEZONE FROM DUAL; 这将返回…