PyTorch入门学习(九):神经网络-最大池化使用

目录

一、数据准备

二、创建神经网络模型

三、可视化最大池化效果


一、数据准备

首先,需要准备一个数据集来演示最大池化层的应用。在本例中,使用了CIFAR-10数据集,这是一个包含10个不同类别图像的数据集,用于分类任务。我们使用PyTorch的torchvision库来加载CIFAR-10数据集并进行必要的数据转换。

import torch
import torchvision
from torch.utils.data import DataLoader# 数据集准备
dataset = torchvision.datasets.CIFAR10("D:\\Python_Project\\pytorch\\dataset2", train=False, transform=torchvision.transforms.ToTensor(), download=True)# 使用DataLoader加载数据集,每批次包含64张图像
dataLoader = DataLoader(dataset, batch_size=64)

二、创建神经网络模型

接下来,创建一个简单的神经网络模型,其中包含一个卷积层和一个最大池化层。这个模型将帮助演示最大池化层的效果。首先定义一个Tudui类,该类继承了nn.Module,并在初始化方法中创建了一个卷积层和一个最大池化层。

import torch.nn as nn
from torch.nn import Conv2d
from torch.nn.functional import max_pool2dclass Tudui(nn.Module):def __init__(self):super(Tudui, self).__init()# 卷积层self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)# 最大池化层self.pool = nn.MaxPool2d(kernel_size=2, stride=2)def forward(self, x):x = self.conv1(x)x = self.pool(x)return xtudui = Tudui()
print(tudui)

上述代码中,定义了Tudui类,包括了一个卷积层和一个最大池化层。在forward方法中,数据首先经过卷积层,然后通过最大池化层,以减小图像的维度。

三、可视化最大池化效果

最大池化层有助于减小图像的维度,提取图像中的主要特征。接下来将使用TensorBoard来可视化最大池化的效果,以更好地理解它。首先,导入SummaryWriter类并创建一个SummaryWriter对象。

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("logs")

然后,遍历数据集,对每个批次的图像应用卷积和最大池化操作,并将卷积前后的图像写入TensorBoard。

step = 0
for data in dataLoader:imgs, targets = data# 卷积和最大池化操作output = tudui(imgs)# 将输入图像写入TensorBoardwriter.add_images("input", imgs, step)# 由于TensorBoard不能直接显示多通道图像,我们需要重定义输出图像的大小output = torch.reshape(output, (-1, 6, 15, 15))# 将卷积和最大池化后的图像写入TensorBoardwriter.add_images("output", output, step)step += 1writer.close()

在上述代码中,使用writer.add_images将输入和输出的图像写入TensorBoard,并使用torch.reshape来重定义输出图像的大小,以适应TensorBoard的显示要求。

运行上述代码后,将在TensorBoard中看到卷积和最大池化的效果。最大池化层有助于提取图像中的关键信息,减小图像维度,并提高模型的计算效率。

完整代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
#数据集准备
dataset = torchvision.datasets.CIFAR10("D:\\Python_Project\\pytorch\\dataset2",train=False,transform=torchvision.transforms.ToTensor(),download=True)
#使用dataloader加载数据集,批次数为64
dataLoader = DataLoader(dataset,batch_size=64)class Tudui(nn.Module):def __init__(self):super(Tudui,self).__init__()# 该神经网络调用conv2d进行一层卷积,输入通道为3层(彩色图像为3通道),卷积核大小为3*3,输出通道为6,设置步长为1,padding为0,不进行填充。self.conv1 = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)def forward(self,x):x = self.conv1(x)return xtudui = Tudui()
print(tudui)# 生成日志
writer = SummaryWriter("logs")step = 0
# 输出卷积前的图片大小和卷积后的图片大小
for data in dataLoader:imgs,targets = data# 卷积操作output = tudui(imgs)print(imgs.shape)print(output.shape)writer.add_images("input",imgs,step)"""注意:使用tensorboard输出时需要重新定义图片大小对于输入的图片集imgs来说,tensor.size([64,3,32,32]),即一批次为64张,一张图片为三个通道,大小为32*32对于经过卷积后输出的图片集output来说,tensor.size([64,6,30,30]),通道数变成了6,tensorboard不知道怎么显示通道数为6的图片,所以如果直接输出会报错解决方案:使用reshape方法对outputs进行重定义,把通道数改成3,如果不知道批次数大小,可以使用-1代替,程序会自动匹配批次大小。"""#重定义输出图片的大小output = torch.reshape(output,(-1,3,30,30))# 显示输出的图片writer.add_images("output",output,step)step = step + 1
writer.close()

参考资料:

视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/125208.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

前端环境的安装 Node npm yarn

一 node npm 1.下载NodeJS安装包 下载地址:Download | Node.js 2.开始安装 打开安装包后,一直Next即可。当然,建议还是修改一下安装位置,NodeJS默认安装位置为 C:\Program Files 3.验证是否安装成功 打开DOS命令界面&#…

MINIO 对象存储服务

MINIO 官网下载地址: 注:需要下载 MINIO SERVER(服务端) 和 MINIO CLIENT(客户端)两个文件 WINDOWS版本下载地址:https://min.io/download#/windows LINUX版本下载地址:https://mi…

已解决:conda找不到对应版本的cudnn如何解决?

1.解决方法 配置深度学习环境时,打算安装cudatoolkit11.2和cudnn8.1,当使用conda install cudnn8.0时,却搜索不到这个版本的包,解决方法如下: conda search cudnn -c conda-forge然后就可以使用如下命令进行安装对应…

Python web自动化测试 —— 文件上传

​文件上传三种方式: (一)查看元素标签,如果是input,则可以参照文本框输入的形式进行文件上传 方法:和用户输入是一样的,使用send_keys 步骤:1、找到定位元素,2&#…

如何解决缓存雪崩?

缓存雪崩是指在缓存中大量的数据同时失效,导致大量请求直接访问数据库,造成数据库负载急剧增加的情况。为了解决缓存雪崩问题,可以采取以下一些策略和方法: 合理设置缓存的过期时间 分散缓存的过期时间,避免在同一时间…

【大数据基础平台】星环TDH社区开发版单机部署

🍁 博主 "开着拖拉机回家"带您 Go to New World.✨🍁 🦄 个人主页——🎐开着拖拉机回家_大数据运维-CSDN博客 🎐✨🍁 🪁🍁🪁🍁🪁&#…

Vue之CSS基础

CSS:层叠样式表 1、选择器 从模板template中选择某元素进行样式设置 需要注意的是作用域到底是当前模板还是整个html文档 1.1 基础(单一)选择器 标签、类、 id、通配符 标签、直接使用标签名,比如div,span… 优点:全选 模板中的名{。。。}…

反射率检测仪如何检测后视镜

后视镜反射率检测是评估后视镜质量的重要步骤,可以反映后视镜的反射效果是否满足设计要求。一般来说,后视镜的反射率越高,驾驶员观察车后的道路状况就越清晰,从而能够更好地判断与后方车辆的距离和速度差。 后视镜反射率检测的原理…

你被骗了吗?别拿低价诱骗机器视觉小白,4000元机器视觉系统怎么来的?机器视觉工程师自己组装一个2000元不到,还带深度学习

淘宝闲鱼,大家搜搜铺价格,特别是机器视觉小白。 机架:(新的)200元以下。(看需求,自己简单打光,买个50元的。如果复杂,就拿给供应商免费打光) 相机,镜头:&am…

Spring MVC的常用注解(设置响应篇)

目录 1.返回静态页面 2.返回数据 3.返回HTML代码片段 4.返回json 5.设置状态码 6.设置Header (1).设置 Content-Type (2).设置其他Header 推荐先看前篇博客Spring MVC的常用注解(接收请求数据篇) 接收…

【数据结构】时间复杂度与空间复杂度

算法在编写成可执行程序后,运行时需要耗费时间资源和空间(内存)资源 。因此衡量一个算法的好坏,一般是从时间和空间两个维度来衡量的,即时间复杂度和空间复杂度。 时间复杂度: 主要衡量一个算法的运行快慢 空间复杂度:…

C#学习相关系列之多线程---ConfigureAwait的用法

一、ConfigureAwait的作用 ConfigureAwait方法是Task类中的一个实例方法,它用于配置任务的运行上下文。运行上下文指的是任务在执行期间所处的环境,包括线程、同步上下文等。ConfigureAwait方法接受一个布尔值参数,用于决定是否捕获上下文。当…

磁盘的命令

目录 1- 磁盘空间命令1.1 df1.2 du只想查看目录的权限 加 -d 参数 1- 磁盘空间命令 1.1 df 全称 disk free 快速获取磁盘被占用了多少空间, 目前还剩下所少空间 常用命令df -hdf 是从总体上统计系统各磁盘的占用情况,不能统计具体的文件夹或文件的大小 1.2 du 全称 disk u…

【JAVA学习笔记】53 - 集合-List类及其子类Collection、ArrayList、LinkedList类

项目代码 https://github.com/yinhai1114/Java_Learning_Code/tree/main/IDEA_Chapter14/src/com/yinhai/collection_ https://github.com/yinhai1114/Java_Learning_Code/tree/main/IDEA_Chapter14/src/com/yinhai/list_ 集合 一、集合引入 前面我们保存多个数据使用的是数组…

Elasticsearch下载安装,IK分词器、Kibana下载安装使用,elasticsearch使用演示

首先给出自己使用版本的网盘链接:自己的版本7.17.14 链接:https://pan.baidu.com/s/1FSlI9jNf1KRP-OmZlCkEZw 提取码:1234 一般情况下 Elastic Search(ES) 并不单独使用,例如主流的技术组合 ELK&#xff08…

MyBatis Plus之wrapper用法

一、条件构造器关系 条件构造器关系介绍: 绿色框:抽象类 abstract 蓝色框:正常 class 类,可 new 对象 黄色箭头:父子类关系,箭头指向为父类 wrapper介绍: Wrapper :条件构造抽象类…

mac m1下navicat执行mongorestore 到mongodb

首先,下载https://www.mongodb.com/try/download/mongocli 解压缩后 有可执行文件使用navicat打开 加载后再重新点击 选择 要恢复的文件即可

动态规划23(Leetcode354俄罗斯套娃信封问题)

和题解里方法一几乎一样但是超时 class Solution {public int maxEnvelopes(int[][] envelopes) {int n envelopes.length;Arrays.sort(envelopes,new Comparator<int[]>(){public int compare(int[] arr1,int[] arr2){if(arr1[0]!arr2[0]){return arr1[0]-arr2[0];}el…

CDN加速技术海外与大陆优劣势对比

内容分发网络&#xff08;CDN&#xff09;是一项广泛应用于网络领域的技术&#xff0c;旨在提高网站和应用程序的性能、可用性和安全性。CDN是一种通过将内容分发到全球各地的服务器来加速数据传输的服务。本文将探讨使用CDN的优势以及国内CDN和海外CDN之间的不同优势和劣势。 …

【C语言】函数指针存疑调试及回调函数编写(结构体内的Callback回调函数传参和虚伪的回调函数__weak声明)

【C语言】函数指针存疑调试及回调函数编写&#xff08;结构体内的Callback回调函数传参和虚伪的回调函数__weak声明&#xff09; 文章目录 函数指针存疑调试函数指针函数调用 回调函数编写结构体内的回调函数虚伪的回调函数 附录&#xff1a;压缩字符串、大小端格式转换压缩字符…