卷积神经网络实战

构建卷积神经网络

  • 卷积网络中的输入和层与传统神经网络有些区别,需重新设计,训练模块基本一致

1.首先读取数据

 - 分别构建训练集和测试集(验证集)
- DataLoader来迭代取数据

# 定义超参数 
input_size = 28  #图像的总尺寸28*28
num_classes = 10  #标签的种类数
num_epochs = 3  #训练的总循环周期
batch_size = 64  #一个撮(批次)的大小,64张图片# 训练集
train_dataset = datasets.MNIST(root='./data',  train=True,   transform=transforms.ToTensor(),  download=True) # 测试集
test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())# 构建batch数据
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

2.卷积网络模块构建

- 一般卷积层,relu层,池化层可以写成一个套餐
- 注意卷积最后结果还是一个特征图,需要把图转换成向量才能做分类或者回归任务

class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Sequential(         # 输入大小 (1, 28, 28)nn.Conv2d(in_channels=1,              # 灰度图out_channels=16,            # 要得到几多少个特征图kernel_size=5,              # 卷积核大小stride=1,                   # 步长padding=2,                  # 如果希望卷积后大小跟原来一样,需要设置padding=(kernel_size-1)/2 if stride=1),                              # 输出的特征图为 (16, 28, 28)nn.ReLU(),                      # relu层nn.MaxPool2d(kernel_size=2),    # 进行池化操作(2x2 区域), 输出结果为: (16, 14, 14))self.conv2 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(16, 32, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),                      # relu层nn.Conv2d(32, 32, 5, 1, 2),nn.ReLU(),nn.MaxPool2d(2),                # 输出 (32, 7, 7))self.conv3 = nn.Sequential(         # 下一个套餐的输入 (16, 14, 14)nn.Conv2d(32, 64, 5, 1, 2),     # 输出 (32, 14, 14)nn.ReLU(),             # 输出 (32, 7, 7))self.out = nn.Linear(64 * 7 * 7, 10)   # 全连接层得到的结果def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x)x = x.view(x.size(0), -1)           # flatten操作,结果为:(batch_size, 32 * 7 * 7)output = self.out(x)return output

 

3.准确率作为评估标准

def accuracy(predictions, labels):pred = torch.max(predictions.data, 1)[1] rights = pred.eq(labels.data.view_as(pred)).sum() return rights, len(labels) 

 

4训练网络模型

# 实例化
net = CNN() 
#损失函数
criterion = nn.CrossEntropyLoss() 
#优化器
optimizer = optim.Adam(net.parameters(), lr=0.001) #定义优化器,普通的随机梯度下降算法#开始训练循环
for epoch in range(num_epochs):#当前epoch的结果保存下来train_rights = [] for batch_idx, (data, target) in enumerate(train_loader):  #针对容器中的每一个批进行循环net.train()                             output = net(data) loss = criterion(output, target) optimizer.zero_grad() loss.backward() optimizer.step() right = accuracy(output, target) train_rights.append(right) if batch_idx % 100 == 0: net.eval() val_rights = [] for (data, target) in test_loader:output = net(data) right = accuracy(output, target) val_rights.append(right)#准确率计算train_r = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))val_r = (sum([tup[0] for tup in val_rights]), sum([tup[1] for tup in val_rights]))print('当前epoch: {} [{}/{} ({:.0f}%)]\t损失: {:.6f}\t训练集准确率: {:.2f}%\t测试集正确率: {:.2f}%'.format(epoch, batch_idx * batch_size, len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.data, 100. * train_r[0].numpy() / train_r[1], 100. * val_r[0].numpy() / val_r[1]))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/797256.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【代码】二分法求最小值

仅适用于以下情况:区间内单调或者最多一个极小值 代码 以[0,pi]内的三角函数为例 clc clear close allx0:pi/1000:pi; ytest(x); figure() plot(x,y,.)cutnum100;x1x(1); x2x(end); error_max10^-1000;%能接受的误差上限 for i1:cutnum%这里cutnum是取值上限num(…

电池二次利用走向可持续大循环周期的潜力和挑战(第一篇)

一、背景 当前,气候变化是全球可持续发展面临的重大挑战。缓解气候变化最具挑战性的目标是在本世纪中期实现碳中和(排放量低到足以被自然系统安全吸收),其中电动汽车(EV)的引入是一项关键举措。电动汽车在…

对代理模式的理解

目录 一、前言二、案例1 代码2 自定义代理类【静态代理】2.1 一个接口多个实现,到底注入哪个依赖呢?2.1.1 Primary注解2.1.2 Resource注解(指定name属性)2.1.3 Qualifier注解 2.2 面向接口编程2.3 如果没接口咋办呢?2.…

阿里巴巴中国站获得1688商品详情 API:如何通过API接口批量获取价格、标题、图片、库存等数据

在数字化时代,数据的重要性不言而喻。对于电商从业者来说,获取商品详情数据是提升业务效率和用户体验的关键。阿里巴巴中国站作为电商行业的巨头,提供了丰富的API接口,方便开发者们批量获取商品信息。本文将详细叙述如何通过阿里巴…

C语言——详解字符函数和字符串函数(二)

Hi,铁子们好呀!之前博主给大家简单地介绍了部分字符和字符串函数,那么这次,博主将会把这些字符串函数给大家依次讲完! 今天讲的具体内容如下: 文章目录 6.strcmp函数的使用及模拟实现6.1 strcmp函数介绍和基本使用6.1.1 strcmp函…

总结:微信小程序中跨组件的通信、状态管理的方案

在微信小程序中实现跨组件通信和状态管理,有以下几种主要方案: 事件机制 通过事件机制可以实现父子组件、兄弟组件的通信。 示例: 父组件向子组件传递数据: 父组件: <child binddata"handleChildData" /> 子组件: Component({..., methods: { handleChildData(…

Linux网卡与IP地址:通往网络世界的通行证

在探索Linux网卡和IP地址的关系之前&#xff0c;我们得先理解Linux网卡是怎么工作的。想象一下&#xff0c;每台计算机都是一个世界&#x1f30e;&#xff0c;而网卡就是连接这些世界的门户&#x1f6aa;。网卡的工作就是接收和发送数据包&#xff0c;就像邮差&#x1f4ec;递送…

RabbitMQ3.13.0起支持MQTT5.0协议及MQTT5.0特性功能列表

RabbitMQ3.13.0起支持MQTT5.0协议及MQTT5.0特性功能列表 文章目录 RabbitMQ3.13.0起支持MQTT5.0协议及MQTT5.0特性功能列表1. MQTT概览2. MQTT 5.0 特性1. 特性概要2. Docker中安装RabbitMQ及启用MQTT5.0协议 3. MQTT 5.0 功能列表1. 消息过期1. 描述2. 举例3. 实现 2. 订阅标识…

洛谷 1126.机器人搬重物

思路&#xff1a;BFS 这道BFS可谓是细节爆炸&#xff0c;对于编程能力和判断条件的能力的考察非常之大。 对于这道题&#xff0c;我们还需要额外考虑一些因素&#xff0c;那就是对于障碍物的考虑和机器人方位的考虑。 首先我们看第一个问题&#xff0c;就是对于障碍物的考虑…

【洛谷】P9236 [蓝桥杯 2023 省 A] 异或和之和

题目链接 P9236 [蓝桥杯 2023 省 A] 异或和之和 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 思路 1. 暴力求解 直接枚举出所有子数组&#xff0c;求每个子数组的异或和&#xff0c;再对所有的异或和求和 枚举所有子数组的时间复杂度为O&#xff08;N^2&#xff09;&…

Qt+OpenGL-part3

1-4EBO画矩形_哔哩哔哩_bilibili 可以绘制两个三角形来组成一个矩形&#xff08;OpenGL主要处理三角形&#xff09; 直接画两个三角形&#xff1a; #include "openglwidget.h" #include <QDebug>unsigned int VBO,VAO; unsigned int shaderProgram;//顶点着…

Leetcode 215. 数组中的第K个最大元素

心路历程&#xff1a; 这道题本质上是排序不完全的过程&#xff0c;而且这道题有bug&#xff0c;直接用python的排序算法其实就能AC。 可以按照快排排到找到k-1个large元素的思维去做&#xff0c;不过这道题需要考虑空间复杂度&#xff0c;所以需要用指针快排。 其实也可以考虑…

序列超图的下一项推荐 笔记

1 Title Next-item Recommendation with Sequential Hypergraphs&#xff08;Jianling Wang、Kaize Ding、Liangjie Hong、Huan Liu、James Caverlee&#xff09;【SIGIR 2020】 2 Conclusion This study explores the dynamic meaning of items in realworld scenarios and p…

RocketMQ的简单使用

这里需要创建2.x版本的springboot项目 导入依赖 <dependencies><dependency><groupId>org.apache.rocketmq</groupId><artifactId>rocketmq-spring-boot-starter</artifactId><version>2.2.3</version></dependency>&…

基于SSM+Jsp+Mysql的人事管理系统

开发语言&#xff1a;Java框架&#xff1a;ssm技术&#xff1a;JSPJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包…

深入理解JVM的内存结构及GC机制(2)

虚拟机栈占用的是操作系统内存&#xff0c;每个线程对应一个虚拟机栈&#xff0c;它是线程私有的&#xff0c;生命周期和线程一样&#xff0c;每个方法被执行时产生一个栈帧&#xff08;Statck Frame&#xff09;&#xff0c;栈帧用于存储局部变量表、动态链接、操作数和方法出…

大语言模型落地的关键技术:RAG

1、什么是RAG&#xff1f; RAG 是检索增强生成&#xff08;Retrieval-Augmented Generation&#xff09;的简称&#xff0c;是当前最火热的大语言模型应用落地的关键技术&#xff0c;主要用于提高语言模型的效果和准确性。它结合了两种主要的NLP方法&#xff1a;检索&#xff…

post请求搜索功能爬虫

<!--爬虫仅支持1.8版本的jdk--> <!-- 爬虫需要的依赖--> <dependency> <groupId>org.apache.httpcomponents</groupId> <artifactId>httpclient</artifactId> <version>4.5.2</version> </dependency>…

2023年下半年网络工程师上午真题及答案解析

1.当计算机突然断电时&#xff0c;( )中存储的信息会丢失。 A.光盘 B.ROM C.RAM D.硬盘 2.进程的状态有就绪态、运行态、阻塞态&#xff0c;其中( )的变化是不可能直接发生的。 A.就绪态到运行态 B.阻塞态到就绪态 C.运行态到阻塞态 D.阻塞态到运行态 3.分…

老板们注意了,AI可能在悄悄威胁你的工作

前天,科技新闻大佬The Register发了一篇文章,说的是AI在科研领域的管理角色越来越大,可能会让管理岗位变得过时,听起来是不是有点儿疯狂? ESMT Berlin的研究小伙伴们发现,AI能够以更大的规模和效率来管理研究项目,比如审查科学文献和预测创新化合物等等,而不是取代人类…