【深度学习】(5)--搭建卷积神经网络

文章目录

  • 搭建卷积神经网络
    • 一、数据预处理
      • 1. 下载数据集
      • 2. 创建DataLoader(数据加载器)
    • 二、搭建神经网络
    • 三、训练数据
    • 四、优化模型
  • 总结

搭建卷积神经网络

一、数据预处理

1. 下载数据集

在PyTorch中,有许多封装了很多与图像相关的模型、数据集,那么如何获取数据集呢?

导入datasets模块:

from torchvision import datasets #封装了很多与图像相关的模型,数据集

以datasets模块中的MNIST数据集为例,包含70000张手写数字图像:60000张用于训练,10000张用于测试。图像是灰度的,28*28像素,并且居中的,以减少预处理和加快运行。

在这里插入图片描述

from torch.utils.data import DataLoader #数据包管理工具,打包数据
from torchvision import datasets #封装了很多与图像相关的模型,数据集
from torchvision.transforms import ToTensor # 数据转换,张量,将其他类型数据转换为tensor张量
"""-----下载训练集数据集-----"""
training_data = datasets.MNIST(root="data",train=True,# 取训练集download=True,transform=ToTensor(),# 张量,图片是不能直接传入神经网络模型的
) # 对于pytorch库能够识别的数据,一般是tensor张量"""-----下载测试集数据集-----"""
test_data = datasets.MNIST(root="data",train=False,download=True,transform=ToTensor(),
)# numpy数组只能在CPU上运行,Tensor可以在GPU上运行,这在深度学习中可以显著提高计算速度

在这里插入图片描述

下载完成之后可在project栏查看。

2. 创建DataLoader(数据加载器)

在PyTorch中,创建DataLoader的主要作用是将数据集(Dataset)加载到模型中,以便进行训练或推理。DataLoader通过封装数据集,提供了一个高效、灵活的方式来处理数据。

DataLoader通过batch_size参数将数据集自动划分为多个小批次(batch),每一批次的放入模型训练,减少内存的使用,提高训练速度。

import torch
from torch.utils.data import DataLoader
"""
创建数据DataLoader(数据加载器)
batch_size:将数据集分成多份,每一份为batch_size(指定数值)个数据。
优点:减少内存的使用,提高训练速度
"""
train_dataloder = DataLoader(training_data,batch_size=64)# 64张图片为一个包
test_datalodar = DataLoader(test_data,batch_size=64)
# 查看打包好的数据
for x,y in test_datalodar: #x是表示打包好的每一个数据包print(f"Shape of x [N, C, H, W]:{x.shape}")print(f"Shape of y:{y.shape} {y.dtype}")break
-----------------------
Shape of x [N, C, H, W]:torch.Size([64, 1, 28, 28])
Shape of y:torch.Size([64]) torch.int64

二、搭建神经网络

在这里插入图片描述

注意:同普通的神经网络不同,卷积神经网络在传入图片时不需要将其展开,因为对图片进行卷积就是在原图上进行内积,不能展开。

卷积神经网络是由输入层、卷积层、激活层、池化层、全连接层、输出层组成。所以在结构上我们也同这样式的来,但是可以搭建多层卷积哦!

"""---判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU"""
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu""""-----定义神经网络-----"""
class CNN(nn.Module):def __init__(self): # 输入大小(1,28,28)super(CNN,self).__init__()self.conv1 = nn.Sequential( # 将多个层组合在一起nn.Conv2d(         # 2d一般用于图像,3d用于视频数据(多一个时间维度),1d一般用于结构化的序列数据in_channels=1, # 图像通道个数,1表示灰度图(确定卷积核 组中的个数)out_channels=16, # 要得到多少特征图,卷积核的个数kernel_size=5,  # 卷积核大小stride=1,   # 步长padding=2   # 边界填充大小), # 输出的特征图为(16,28,28)-->16个大小28*28的图像nn.ReLU(), # relu层,不会改变特征图的大小nn.MaxPool2d(kernel_size=2) # 进行池化操作(2*2区域),输出结果为(16,14,14))self.conv2 = nn.Sequential( # 输入(16,14,14)nn.Conv2d(16,32,5,1,2), # 输出(32*14*14)nn.ReLU(),nn.Conv2d(32,32,5,1,2), # 输出(32*14*14)nn.ReLU(),nn.MaxPool2d(2) # 输出(32,7,7))self.conv3 = nn.Sequential( # 输入(32,7,7)nn.Conv2d(32,64,5,1,2), # 输出(64,7,7)nn.ReLU())self.out = nn.Linear(64*7*7,10)def forward(self,x):x = self.conv1(x)x = self.conv2(x)x = self.conv3(x) # 输出(64,7,7)x = x.view(x.size(0),-1) # flatten 操作,结果为:(batch_size,64*7*7)output = self.out(x)return outputmodel = CNN().to(device)

三、训练数据

  • optimizer优化器
optimizer = torch.optim.Adam(model.parameters(),lr=0.001)
  • loss_fn损失函数

在PyTorch中,**nn.CrossEntropyLoss()**是一个常用的损失函数,它结合了 nn.LogSoftmax()nn.NLLLoss()(负对数似然损失)在一个单独的类中。

loss_fn = nn.CrossEntropyLoss()
  • 训练集
from torch import nn #导入神经网络模块
def train(dataloader,model,loss_fn,optimizer):model.train()# 设置模型为训练模式batch_size_num =1# 迭代次数 for x,y in dataloader:x,y = x.to(device),y.to(device)  # 将数据和标签发送到指定设备  pred = model.forward(x)  # 前向传播  loss = loss_fn(pred,y)  # 计算损失  optimizer.zero_grad()  # 清除之前的梯度  loss.backward()  # 反向传播  optimizer.step()  # 更新模型参数  loss_value = loss.item()  # 获取损失值if batch_size_num %200 == 0:  # 每200次迭代打印一次损失  print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1
train(train_dataloder,model,loss_fn,optimizer)
------------------------
loss:0.158841 [number:200]
loss:0.242431 [number:400]
loss:0.173504 [number:600]
loss:0.020542 [number:800]
  • 测试集
"""-----测试集-----"""
def test(dataloader,model,loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss,correct = 0,0with torch.no_grad():for x,y in dataloader:x,y = x.to(device),y.to(device)pred = model.forward(x)test_loss += loss_fn(pred,y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= sizecorrect = round(correct, 4)print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")test(test_datalodar,model,loss_fn)
--------------------
Test result: Accuracy:98.11999999999999%,Avg loss:0.05511626677004996

四、优化模型

通过多次迭代,神经网络不断调整其内部参数(如权重和偏置),以最小化预测值与实际值之间的误差。这种优化过程使得神经网络能够更准确地处理输入数据,提高分类、回归等任务的性能。

epochs = 5
for t in range(epochs):print(f"Epoch {t+1} \n-------------------------")train(train_dataloder,model,loss_fn,optimizer)
print("Done!")
test(test_datalodar,model,loss_fn)

输出结果:

Epoch 1 
-------------------------
loss:0.182339 [number:200]
loss:0.229839 [number:400]
loss:0.210450 [number:600]
loss:0.028532 [number:800]
Epoch 2 
-------------------------
loss:0.066216 [number:200]
loss:0.149762 [number:400]
loss:0.084482 [number:600]
loss:0.003749 [number:800]
…………
Done!
Test result: Accuracy:98.99%,Avg loss:0.03138259953491878

总结

本篇介绍了如何搭建卷积神经网络,其主要的构造部分为卷积层、激活层以及池化层,可以搭建多层该部分对数据进行多次卷积、池化。
注意:同普通的神经网络不同,卷积神经网络在传入图片时不需要将其展开,因为对图片进行卷积就是在原图上进行内积,不能展开。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/54916.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue3 通过 axios + jsonp 实现根据公网 ip, 查询天气信息

前提 安装 axios 的 jsonp 适配器。 pnpm install pingtou/axios-jsonp 简单使用说明:当与后端约定的请求 callback 参数名称不为为 callback 时,可修改。一般无需添加。 1. 获取当前电脑 ip 和城市信息 请求地址: https://whois.pconl…

Linux之我不会

一、常用命令 1.系统管理 1.1 systemctl start | stop | restart | status 服务名 案例实操 1 查看防火墙状态 systemctl status firewalld2 停止防火墙服务 systemctl stop firewalld3 启动防火墙服务 systemctl start firewalld4 重启防火墙服务 systemctl restart f…

【Canvas与诗词】秋夕.杜牧(银烛秋光冷画屏......)

【成图】 【代码】 <!DOCTYPE html> <html lang"utf-8"> <meta http-equiv"Content-Type" content"text/html; charsetutf-8"/> <head><title>金六边形外圈绿色底录杜牧秋夕诗</title><style type"…

【电商搜索】现代工业级电商搜索技术-Facebook语义搜索技术QueSearch

【电商搜索】现代工业级电商搜索技术-Facebook语义搜索技术Que2Search 目录 文章目录 【电商搜索】现代工业级电商搜索技术-Facebook语义搜索技术Que2Search目录0. 论文信息1. 研究背景&#xff1a;2. 技术背景和发展历史&#xff1a;3. 算法建模3.1 模型架构3.1.1 双塔与分类 …

NLP:BERT的介绍

1. BERT 1.1 Transformer Transformer架构是一种基于自注意力机制(self-attention)的神经网络架构&#xff0c;它代替了以前流行的循环神经网络和长短期记忆网络&#xff0c;已经应用到多个自然语言处理方向。   Transformer架构由两个主要部分组成&#xff1a;编码器(Encod…

【HarmonyOS】应用引用media中的字符串资源如何拼接字符串

【HarmonyOS】应用引用media中的字符串资源如何拼接字符串 一、问题背景&#xff1a; 鸿蒙应用中使用字符串资源加载&#xff0c;一般文本放置在resoutces-base-element-string.json字符串配置文件中。便于国际化的处理。当然小项目一般直接引用字符串&#xff0c;不需要加载s…

python爬虫:从12306网站获取火车站信息

代码逻辑 初始化 (init 方法)&#xff1a; 设置请求头信息。设置车站版本号。 同步车站信息 (synchronization 方法)&#xff1a; 发送GET请求获取车站信息。返回服务器响应的文本。 提取信息 (extract 方法)&#xff1a; 从服务器响应中提取车站信息字符串。去掉字符串末尾的…

如何通过Dockfile更改docker中ubuntu的apt源

首先明确我们有一个宿主机和一个docker环境&#xff0c;接下来的步骤是基于他们两个完成的 1.在宿主机上创建Dockerfile 随便将后面创建的Dockerfile放在一个位置,我这里选择的是 /Desktop 使用vim前默认你已经安装好了vim 2.在输入命令“vim Dockerfile”之后&#xff0c;…

知识付费APP开发指南:基于在线教育系统源码的技术详解

本篇文章&#xff0c;我们将探讨基于在线教育系统源码的知识付费APP开发的技术细节&#xff0c;帮助开发者和企业快速入门。 一、选择合适的在线教育系统源码 选择合适的在线教育系统源码是开发的关键一步。市场上有许多开源和商业化的在线教育系统源码&#xff0c;开发者需要…

花都狮岭寄宿自闭症学校:开启孩子的生命之门

在花都狮岭这片充满温情的土地上&#xff0c;有一所特别的学校&#xff0c;它像一把钥匙&#xff0c;轻轻旋转&#xff0c;为自闭症儿童们开启了一扇通往无限可能的生命之门——这就是广州星贝育园自闭症儿童寄宿制学校。这所学校不仅是知识的摇篮&#xff0c;更是孩子们心灵成…

React 启动时webpack版本冲突报错

报错信息&#xff1a; 解决办法&#xff1a; 找到全局webpack的安装路径并cmd 删除全局webpack 安装所需要的版本

Python(六)-拆包,交换变量名,lambda

目录 拆包 交换变量值 引用 lambda函数 lambda实例 字典的lambda 推导式 列表推导式 列表推导式if条件判断 for循环嵌套列表推导式 字典推导式 集合推导式 拆包 看一下在Python程序中的拆包&#xff1a;把组合形成的元组形式的数据&#xff0c;拆分出单个元素内容…

影响上证50股指期货价格的因素有哪些?

上证50股指期货&#xff0c;作为反映上海证券交易所最具代表性50只股票整体表现的期货合约&#xff0c;其价格同样受到一系列复杂因素的驱动。以下是对影响上证50股指期货价格的主要因素进行的详细分析。 因素一、期货合约的供求关系 股指期货市场是一个由多头和空头双方共同…

具身智能综述:鹏城实验室中大调研近400篇文献,深度解析具身智能

具身智能是实现通用人工智能的必经之路&#xff0c;其核心是通过智能体与数字空间和物理世界的交互来完成复杂任务。近年来&#xff0c;多模态大模型和机器人技术得到了长足发展&#xff0c;具身智能成为全球科技和产业竞争的新焦点。然而&#xff0c;目前缺少一篇能够全面解析…

面试遇到的质量体系10个问题(深度思考)

在某大型公司的招聘面试中关于质量体系本身及建设实践方面的10个问题&#xff0c;这些问题都是偏理论性强一些&#xff0c;但是可以通过这些问题来了解大型公司对质量体系的一些想法和预期的内容&#xff0c;本期先抛出来这10个问题&#xff0c;不附答案&#xff0c;目的就是让…

AI绘画:Stable Diffusion 终极炼丹宝典:从入门到精通

前言 我是Lison&#xff0c;以浅显易懂的方式&#xff0c;与大家分享那些实实在在可行之宝藏。 历经耗时数十个小时&#xff0c;总算将这份Stable Diffusion的使用教程整理妥当。 从最初的安装与配置&#xff0c;细至界面功能的详解&#xff0c;再至实战案例的制作&#xff…

使用mendeley生成APA格式参考文献

mendeley 是一款文献管理工具&#xff0c;可以在word中方便的插入引用文献。 效果对比&#xff1a; 注&#xff1a;小绿鲸有三种导出格式&#xff0c;分别为复制、导出为Bibtex和导出为Endnote三种。 mendeley 下载与安装 Download Mendeley Reference Manager For Desktop mac…

98问答网是一个怎样的平台?它主要提供哪些服务?

98问答网是一个集知识分享、问题解答与社区交流为一体的综合性在线问答平台。该平台旨在通过汇聚来自各行各业的专家、学者以及广大网友的智慧&#xff0c;为用户提供一个快速获取准确信息、解决生活工作中遇到的各种问题的渠道。 主要服务包括&#xff1a; 问题提问与解答&am…

10.C++程序中的循环语句

C中提供了三种循环语句&#xff08;for循环&#xff0c;while循环以及do-while循环)来使程序员可以更方便地对数据进行迭代操作。 if语句 for语句的格式为&#xff1a; for(初始化语句&#xff1b;循环条件&#xff1b;迭代语句&#xff09; &#xff5b; 代码块 &#x…

【中级通信工程师】终端与业务(十一):市场营销计划、实施与控制

【零基础3天通关中级通信工程师】 终端与业务(十一)&#xff1a;市场营销计划、实施与控制 本文是中级通信工程师考试《终端与业务》科目第十一章《市场营销计划、实施与控制》的复习资料和真题汇总。终端与业务是通信考试里最简单的科目&#xff0c;有效复习通过率可达90%以上…