PyTorch入门学习(八):神经网络-卷积层

目录

一、数据准备

二、创建卷积神经网络模型

三、可视化卷积前后的图像


一、数据准备

首先,需要准备一个数据集来演示卷积层的应用。在这个示例中,使用了CIFAR-10数据集,该数据集包含了10个不同类别的图像数据,用于分类任务。使用PyTorch的torchvision库来加载CIFAR-10数据集,并进行必要的数据转换。

import torch
import torchvision
from torch.utils.data import DataLoader# 数据集准备
dataset = torchvision.datasets.CIFAR10(root="D:\\Python_Project\\pytorch\\dataset2", train=False, transform=torchvision.transforms.ToTensor(), download=True)# 使用DataLoader加载数据集,每批次包含64张图像
dataLoader = DataLoader(dataset, batch_size=64)

二、创建卷积神经网络模型

接下来,创建一个简单的卷积神经网络模型,以演示卷积层的使用。这个模型包含一个卷积层,其中设置了输入通道数为3(因为CIFAR-10中的图像是彩色的,有3个通道),卷积核大小为3x3,输出通道数为6,步长为1,填充为0。

import torch.nn as nn
from torch.nn import Conv2dclass Tudui(nn.Module):def __init__(self):super(Tudui, self).__init__()# 卷积层self.conv1 = Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)def forward(self, x):x = self.conv1(x)return xtudui = Tudui()
print(tudui)

上述代码定义了一个Tudui类,该类继承了nn.Module,并在初始化方法中创建了一个卷积层。forward方法定义了数据在模型中的前向传播过程。

三、可视化卷积前后的图像

卷积层通常会改变图像的维度和特征。使用TensorBoard来可视化卷积前后的图像以更好地理解卷积操作。首先,导入SummaryWriter类,并创建一个SummaryWriter对象用于记录日志。

from torch.utils.tensorboard import SummaryWriterwriter = SummaryWriter("logs")

然后,使用DataLoader遍历数据集,对每个批次的图像应用卷积操作,并将卷积前后的图像以及输入的图像写入TensorBoard。

step = 0
for data in dataLoader:imgs, targets = data# 卷积操作output = tudui(imgs)# 将输入图像写入TensorBoardwriter.add_images("input", imgs, step)# 由于TensorBoard不能直接显示具有多个通道的图像,我们需要重定义输出图像的大小output = torch.reshape(output, (-1, 3, 30, 30))# 将卷积后的图像写入TensorBoardwriter.add_images("output", output, step)step += 1writer.close()

在上述代码中,使用writer.add_images将输入和输出的图像写入TensorBoard,并使用torch.reshape来重定义输出图像的大小,以满足TensorBoard的显示要求。

运行上述代码后,将在TensorBoard中看到卷积前后的图像,有助于理解卷积操作对图像的影响。

完整代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
#数据集准备
dataset = torchvision.datasets.CIFAR10("D:\\Python_Project\\pytorch\\dataset2",train=False,transform=torchvision.transforms.ToTensor(),download=True)
#使用dataloader加载数据集,批次数为64
dataLoader = DataLoader(dataset,batch_size=64)class Tudui(nn.Module):def __init__(self):super(Tudui,self).__init__()# 该神经网络调用conv2d进行一层卷积,输入通道为3层(彩色图像为3通道),卷积核大小为3*3,输出通道为6,设置步长为1,padding为0,不进行填充。self.conv1 = Conv2d(in_channels=3,out_channels=6,kernel_size=3,stride=1,padding=0)def forward(self,x):x = self.conv1(x)return xtudui = Tudui()
print(tudui)# 生成日志
writer = SummaryWriter("logs")step = 0
# 输出卷积前的图片大小和卷积后的图片大小
for data in dataLoader:imgs,targets = data# 卷积操作output = tudui(imgs)print(imgs.shape)print(output.shape)writer.add_images("input",imgs,step)"""注意:使用tensorboard输出时需要重新定义图片大小对于输入的图片集imgs来说,tensor.size([64,3,32,32]),即一批次为64张,一张图片为三个通道,大小为32*32对于经过卷积后输出的图片集output来说,tensor.size([64,6,30,30]),通道数变成了6,tensorboard不知道怎么显示通道数为6的图片,所以如果直接输出会报错解决方案:使用reshape方法对outputs进行重定义,把通道数改成3,如果不知道批次数大小,可以使用-1代替,程序会自动匹配批次大小。"""#重定义输出图片的大小output = torch.reshape(output,(-1,3,30,30))# 显示输出的图片writer.add_images("output",output,step)step = step + 1
writer.close()

参考资料:

视频教程:PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/124908.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

文章分类管理接口

目录 前言 新建表 获取文章分类列表接口 初始化路由模块 将路由对象导出并使用 初始化路由对象处理函数 修改路由代码 导入数据库 定义sql语句 调用db.query() 完整的获取文章分类列表处理函数 新增文章分类接口 定义路由和处理函数 验证表单数据 查询分类名称与…

<基础数学> 三个点生成一个圆

三个点生成一个圆 如果给定三个点的坐标,我们可以通过这三个点来确定一个圆。以下是一种求解方法: 假设给定的三个点分别为 A ( x 1 , y 1 ) 、 B ( x 2 , y 2 ) 、 C ( x 3 , y 3 ) A(x_1, y_1)、B(x_2, y_2)、C(x_3, y_3) A(x1​,y1​)、B(x2​,y2​…

UDP网络编程的接受与发送信息

/发送端B>可以接受数据 public class UDPSenderB {public static void main(String[] args) throws IOException {//创建一个DatagramSocket 对象,准备发送和接受数据DatagramSocket socket new DatagramSocket(9998);//将需要发送的数据,封装到Data…

空号检测API如何助力于提高客户关系管理

引言 在现代商业世界中,客户关系管理已经成为企业成功的关键要素之一。CRM不仅涉及到如何吸引新客户,还包括如何维护并与现有客户建立持久而有益的关系。在这个过程中,通信是至关重要的。为了确保您的客户数据库保持最新和准确,空…

navicat15 恢复试用方法

1.运行,输入regedit,打开注册表 2.注册表中搜索 HKEY_CURRENT_USER\Software\PremiumSoft\NavicatPremium,删除下面的Registration15XCS文件夹 3.注册表中再搜索 HKEY_CURRENT_USER\Software\Classes\CLSID 然后拉到文件夹目录的最后&#x…

「永不失联」产品创新与升级系列发布,预约直播“即将发车”

数字化浪潮下,北斗时空智能正成为我国重要的新型基础设施。 通过将卫星定位精度提升至厘米级乃至毫米级,时空智能满足了数字化时代智能驾驶、共享出行、智慧城市等多种智能终端对时空信息的爆发式增长需求,同步印证着测绘地理信息领域的技术应…

k8s 集群的组成和原理

集群 集群是一组节点,这些节点可以是物理服务器也可以是虚拟机,在它们中安装了k8s的环境。 集群的组成 k8s 集群由 worker 节点和 node 节点组成,其中worker节点由Controller Manager(控制管理器)、etcd(键值数据库)、scheduler(调度器)、…

ConcurrentHashMap vs Hashtable

1.ConcurrentHashMap 1.7 ReentrantLock Segment HashEntry。 1.8 CAS synchronized HashEntry 红黑树。 public V put(K key, V value) {return putVal(key, value, false);}final V putVal(K key, V value, boolean onlyIfAbsent) {if (key null || value null) th…

什么是Vue.js中的指令(directive)?举例说明一些常见的指令。

聚沙成塔每天进步一点点 ⭐ 专栏简介 前端入门之旅:探索Web开发的奇妙世界 欢迎来到前端入门之旅!感兴趣的可以订阅本专栏哦!这个专栏是为那些对Web开发感兴趣、刚刚踏入前端领域的朋友们量身打造的。无论你是完全的新手还是有一些基础的开发…

Azure机器学习 - 使用与Azure集成的Visual Studio Code实战教程

本文介绍如何启动远程连接到 Azure 机器学习计算实例的 Visual Studio Code。 借助 Azure 机器学习资源的强大功能,使用 VS Code 作为集成开发环境 (IDE)。 在VS Code中将计算实例设置为远程 Jupyter Notebook 服务器。 关注TechLead,分享AI全维度知识。…

Spring本地jar包依赖项目改为maven依赖

1.简介 我们在做项目的时候,可能会偶尔接手较为古老的项目,这些项目使用了较为老旧的版本管理或依赖管理方法,对于新开发项目来说,这些老旧的依赖管理方式会影响开发效率,所以,一般我们会选择将老项目的依…

JavaEE入门介绍,HTTP协议介绍,常用状态码及含义,服务器介绍(软件服务器、云服务器)

一、JavaEE入门 JavaEE(Java Enterprise Edition),Java企业版,是一个用于企业级web开发(不需要使用控制台)平台。最早由Sun公司定制并发布,后由Oracle负责维护。 JavaEE平台规范了在开发企业级w…

NB-IOT的粮库挡粮门异动监测装置

一种基于NBIOT的粮库挡粮门异动监测装置,包括若干个NBIOT开门监测装置,物联网后台管理系统,NBIOT低功耗广域网络和用户访问终端;各个NBIOT开门监测装置通过NBIOT低功耗广域网络与物联网后台管理系统连接,物联网后台管理系统与用户访问终端连接.NBIOT开门监测装置能够对粮库挡粮…

LeetCode:274. H 指数、275. H 指数 II(C++)

目录 274. H 指数 题目描述: 实现代码与解析: 排序暴力 275. H 指数 II 题目描述: 实现代码与解析: 二分 比较简单,不再写解析,注意二分的时候,r指针为n,含义为个数&#xf…

HarmonyOS开发:基于http开源一个网络请求库

前言 网络封装的目的,在于简洁,使用起来更加的方便,也易于我们进行相关动作的设置,如果,我们不封装,那么每次请求,就会重复大量的代码逻辑,如下代码,是官方给出的案例&am…

Ubuntu安装ddns-go使用阿里ddns解析ipv6

Ubuntu安装ddns-go 1.何为ddns-go2.安装环境3.获取ddns-go安装包4.解压ddns-go5.安装ddns-go6.配置ddns-go 1.何为ddns-go DDNS-GO是简单好用的DDNS,它可以帮助你自动更新域名解析到公网IP。比如你希望在本地部署网站,但是因为公网IP是动态的&#xff0…

CrackRTF

加密。 解密 import hashlib# 选择哈希算法(例如SHA-256) hash_algorithm hashlib.sha1()flag2"DBApp"for i in range(100000,999999):datastr(i)flag2hash_valuehashlib.sha1(data.encode())hex_value hash_value.hexdigest()if "6E…

【Linux】虚拟机项目部署与发布

目录 一、Linux部署单机项目 1.1 优缺点 1.2 将项目共享到虚拟机 1.3 解压后将war包放入tomcat 1.4 数据库导入脚本 1.5 Tomcat启动项目 二、部署前后端分离项目 2.1 准备工作 2.2 部署SPA项目 2.2.1 nginx反向代理 2.2.2 SPA项目宿主机访问 一、Linux部署单机项目…

854数据结构简答题---图

1.(2015期末)已知无环路有向图如图3.1,请在表2、表3中填写出各事件的最早发生时间、最迟发生时间、活动的最早、最迟开始时间,给出关键活动及关键路径。 从源点到汇点的有向路径可能有多条,所有路径中,具有最大路径长…

CodeWhisperer 初体验

文章作者:1颗 orange 最近用了一个叫 CodeWhisperer 的插件,这个软件对于来说开发人员,插件有好多实用的功能,编码更高效,代码质量也提升了很多。 CodeWhisperer 简介 CodeWhisperer 是亚⻢逊出品的一款基于机器学习…