PyTorch|构建自己的卷积神经网络

图片

如何搭建网络,这在深度学习中非常重要。简单来讲,我们是要实现一个类,这个类中有属性和方法,能够进行计算

一般来讲,使用PyTorch创建神经网络需要三步:

  1. 继承基类:nn.Module

  2. 定义层属性

  3. 实现前向传播方法

如果你对于python面向对象编程非常熟练,那么这里也就非常简单,就是定义一些属性,实现一些方法。

开始建立一个网络,就像这样:

import torchimport torch.nn as nnimport torch.nn.functional as Fclass Network(nn.Module):    def __init__(self):        super(Network, self).__init__()        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)        self.fc2 = nn.Linear(in_features=120,out_features=60)        self.out = nn.Linear(in_features=60, out_features=10)    def forward(self,t):        # (1) input layer        t = t        # (2) hidden conv layer1        t = self.conv1(t)        t = F.relu(t)        t = F.max_pool2d(t, kernel_size=2, stride=2)        # (3) hidden conv layer2        t = self.conv2(t)        t = F.relu(t)        t = F.max_pool2d(t, kernel_size=2, stride=2)        # relu 和 max pooling 都没有权重;激活层和池化层的本质都是对传入的数据按照一定的算法变换。        #(4)hidden linear layer2        t = t.reshape(-1, 12*4*4)        t = self.fc1(t)        t = F.relu(t)        # (5) hidden linear layer2        t = self.fc2(t)        t = F.relu(t)        # (6) output layer        t = self.out(t)

对于上述代码,就是定义了一个新的类,叫做Network,这个类继承自nn.Module,同时,我们又新定义了一些属性,比如conv1,conv2,fc1,fc2,out,并实现了一个方法,叫做:forawrd。

好吧,接下来我们进行初始化

>>>network=Network()

访问对象network的一些属性

>>> network.conv1Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))>>> network.conv2Conv2d(6, 12, kernel_size=(5, 5), stride=(1, 1))>>> network.outLinear(in_features=60, out_features=10, bias=True)

更进一步,我们也可获得每一层的权重,形状

​​​​​​​

>>> network.conv2.weightParameter containing:tensor([[[[-8.0741e-02, -7.1281e-02,  7.6540e-02,  6.2786e-02,  1.1018e-03],          [ 6.5041e-02, -3.5665e-02,  7.8475e-02, -1.1228e-02, -2.9447e-02],          [-8.0508e-02,  7.0457e-02,  7.7877e-02,  7.2872e-02,  4.5671e-02],          [ 3.9757e-03,  7.7676e-02,  3.3951e-02,  6.3745e-02,  7.0577e-02],          [-2.0165e-02,  2.2356e-02,  2.9137e-02,  8.0388e-02,  5.9048e-02]]............>>> network.conv2.weight.shapetorch.Size([12, 6, 5, 5])

吧,上述操作比较繁琐,我们不想这样做,但我们还是非常想了解一个神经网络,那么应该怎么办呢?其实,可以这样。

好吧,接下来我们进行初始化​​​​​​​

>>>network=Network()>>> networkNetwork(  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))  (conv2): Conv2d(6, 12, kernel_size=(5, 5), stride=(1, 1))  (fc1): Linear(in_features=192, out_features=120, bias=True)  (fc2): Linear(in_features=120, out_features=60, bias=True)  (out): Linear(in_features=60, out_features=10, bias=True))
>>> for name, param in network.named_parameters():    print(name,'\t\t', param.shape)    conv1.weight      torch.Size([6, 1, 5, 5])conv1.bias      torch.Size([6])conv2.weight      torch.Size([12, 6, 5, 5])conv2.bias      torch.Size([12])fc1.weight      torch.Size([120, 192])fc1.bias      torch.Size([120])fc2.weight      torch.Size([60, 120])fc2.bias      torch.Size([60])out.weight      torch.Size([10, 60])out.bias      torch.Size([10])

我们实现的新的类是基于基类nn.Module实现的,nn.Linear(in_features=120, out_features=60)是一个线性层,是PyTorch已经实现好的。主要功能就是我们输入一个数据,这个层对数据进行一个线性变换,最后输出一个数据。

还记得之前我们获得了神经网络各层的权重尺寸:

>>> network.fc2.weight.shapetorch.Size([60, 120])

没错,是一个60x120的矩阵。

一个1x120的数据,进入线性层,经过线性变换,最后变成形状为1x60的矩阵。当然,这里解释不是非常正确,Linear层复杂得多,这里仅仅是为了便于理解。

如果你还不理解,下面这个例子可能更加简单:​​​​​​​

# 1. 张量的乘法in_features = torch.tensor([1,2,3,4], dtype=torch.float32)weight_matrix = torch.tensor([    [1,2,3,4],    [2,3,4,5],    [3,4,5,6]], dtype = torch.float32)result1=weight_matrix.matmul(in_features)# 可将上述的权重矩阵看作是一个线性映射.
# 在parameter类中包装一个权重矩阵,以使得输出结果与1中一样fc = nn.Linear(in_features=4, out_features=3)fc.weight= nn.Parameter(weight_matrix)result2=fc(in_features)# 此时的结果接近1中的结果却不精确,是因为由bias的存在
>>> print(result1)tensor([30., 40., 50.])>>> print(result2)tensor([29.5786, 39.9564, 49.7925], grad_fn=<AddBackward0>)

上文实现的网络除了线性层(全连接层),还有卷积层,池化,激活函数,等等,这些内容是卷积神经网络的核心。当然,上述各层都有许多参数,如何确定每一层的参数需要一定的计算。

图片

到了这里,我们对于如何构建神经网络,以及访问神经网络的一些属性,以及线性层大致做了什么有了一个大致的理解。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/601384.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

动态规划(不同路径1,不同路径2,整数拆分)

62.不同路径 力扣题目链接(opens new window) 一个机器人位于一个 m x n 网格的左上角 &#xff08;起始点在下图中标记为 “Start” &#xff09;。 机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角&#xff08;在下图中标记为 “Finish” &#xff09;。…

手机上连网络转接app,电脑连接手机,共用网络转接app的办法

方法一&#xff0c;&#xff08;不推荐&#xff09; 因为太简单了所以写一下 电脑安装MuMu模拟器&#xff0c;之后安装网络转接app&#xff0c;这个模拟器设置了从电脑上安装app和&#xff0c;安卓与电脑同步文件夹功能&#xff0c;实现文件共享。所以直接用就可以了。 方法二…

启动 Mac 时显示闪烁的问号

启动 Mac 时显示闪烁的问号 如果启动时在 Mac 屏幕上看到闪烁的问号&#xff0c;这意味着你的 Mac 无法找到自身的系统软件。 如果 Mac 启动时出现闪烁的问号且无法继续启动&#xff0c;请尝试以下步骤。 1.通过按住其电源按钮几秒钟来关闭 Mac。 2.按一下电源按钮&#xf…

你珍藏的那个表情包女孩,现在滤镜碎了一地。

♥ 为方便您进行讨论和分享&#xff0c;同时也为能带给您不一样的参与感。请您在阅读本文之前&#xff0c;点击一下“关注”&#xff0c;非常感谢您的支持&#xff01; 文|猴哥聊娱乐 编辑|侯欢庭 七年前&#xff0c;一个年仅三岁的小女孩以其无邪的表情包风靡网络&#xff0…

FindMy技术用于键盘

键盘是我们生活中不可或缺的输入工具&#xff0c;是人与计算机之间沟通的桥梁&#xff0c;无论是编写文档、浏览网页、玩游戏、或是进行复杂的数据分析&#xff0c;键盘都在其中发挥着关键的作用。此外&#xff0c;键盘还是各种软件的快捷键操作的关键。通过熟练地运用快捷键&a…

大学物理-实验篇——测量误差与数据处理(测量分类、误差、有效数字、逐差法)

目录 测量分类 测量次数角度 测量条件角度 误差 误差分类 系统误差 随机误差 异常值 误差描述 精密度&#xff08;Precision&#xff09; 正确度&#xff08;Trueness&#xff09; 准确度/精确度&#xff08;Accuracy&#xff09; 随机误差的处理 直接测量 算术…

Opencv(C++)学习之cv::calcHist 任意bin数量进行直方图计算

**背景&#xff1a;**当前网上常见的直方图使用方法都是默认使用256的范围&#xff0c;而对于使用特定范围的直方图方法讲的不够清楚。仔细研究后总结如下&#xff1a; 1、常见使用方法&#xff0c;直接对灰度图按256个Bin进行计算。 Mat mHistUn; int channels[1] { 0 }; {…

rollup 原理解析

✨专栏介绍 Rollup专栏是一个专门介绍Rollup打包工具的系列文章。Rollup是一个现代化的JavaScript模块打包工具&#xff0c;它可以将多个模块打包成一个或多个文件&#xff0c;以提高应用程序的性能和加载速度。 在Rollup专栏中&#xff0c;您将学习到如何安装和配置Rollup&a…

Ubuntu18 安装chatglm2-6b

记了下Ubuntu18 上安装chatglm2-6遇到的问题。 环境&#xff1a;Ubuntu18.04 V100(显卡) nvcc 11.6 显卡驱动cudacudnnaniconda chatglm6b 的安装 网上有很多&#xff0c; 不记录 了。 chatglm2-6b 我从别的地方拷贝的&#xff0c; 模型也包含了。 遇到的问题&#xf…

vue简单实现滚动条

背景&#xff1a;产品提了一个需求在一个详情页&#xff0c;一个form表单元素太多了&#xff0c;需要滚动到最下面才能点击提交按钮&#xff0c;很不方便。他的方案是&#xff0c;加一个滚动条&#xff0c;这样可以直接拉到最下面。 优化&#xff1a;1、支持滚动条&#xff0c;…

阿里云2核2G3M配置的云服务器可以搭建几个网站?

阿里云2核2G3M配置的云服务器可以搭建几个网站&#xff1f;对于这个问题&#xff0c;没有一个具体的答案&#xff0c;因为这取决于您的网站的流量和复杂程度。在一般情况下&#xff0c;这个配置可以支持搭建几个中小型网站。若您的网站需要大量的资源或处理高并发请求&#xff…

机器学习股票崩盘预测模型(企业建模_论文科研)AI model for stock crash prediction

对齐颗粒度&#xff0c;打通股票崩盘底层逻辑&#xff0c;形成一套组合拳&#xff0c;形成信用评级机制良性生态圈&#xff0c;重振股市信心&#xff01;--中国股市新展望&#xff01;By Toby&#xff01;2024.1.3 综合介绍 股票崩盘&#xff0c;是指证券市场上由于某种原因&a…

ctfshow——PHP特性

文章目录 web 89web 90web 91web 92web 93web 94web 95web 96web 97web 98web 99web 100——优先级、eval()用法web 101——RefelctionClass反射类web 102——php伪协议、hex2bin()web103web 104——sha1绕过web 105 web 89 使用人工分配 ID 键的数值型数组绕过preg_match. 两个…

Vue2 实现内容拖拽或添加 HTML 到 Tinymce 富文本编辑器的高级功能详解

在 Web 开发中&#xff0c;Tinymce 被广泛应用作为富文本编辑器。除了基础的文本编辑功能&#xff0c;Tinymce 还提供了一系列高级功能&#xff0c;使得文本编辑更加灵活和便捷。本文将介绍如何在 Tinymce 中实现一些高级功能&#xff0c;并深入了解每个工具的使用。 Tinymce …

Python中的eval和exec函数:深度解析两者的区别与使用场景

概要 Python中的eval和exec函数&#xff0c;它们都是非常强大的工具&#xff0c;用于动态执行代码。然而&#xff0c;它们在用途、用法和安全性方面存在显著的区别。在本文中&#xff0c;将深入探讨eval和exec函数的区别、用法以及示例代码&#xff0c;以帮助大家更好地理解和…

用单片机设计PLC电路图

自记&#xff1a; 以下为PMOS推挽输出及集成块光耦&#xff1a;

MediaPipeUnityPlugin Win10环境搭建(22年3月的记录,新版本已完全不同,这里只做记录)

https://github.com/homuler/MediaPipeUnityPlugin You cannot build libraries for Android with the following steps. 1、安装msys2配置系统环境变量Path添加 C:\msys64\usr\bin 执行 pacman -Su 执行 pacman -S git patch unzip 2、安装Python3.9.10 勾选系统环境变量 …

【性能测试入门】详解客户端性能测试和服务器端性能测试!

一&#xff1a;客户端性能测试和服务器端性能测试 客户端性能测试和服务器端性能测试是两个不同但相关的概念: 客户端性能测试: - 测试应用程序客户端(如Web浏览器、移动应用等)的性能,例如加载时间,响应时间等。 - 测试在不同系统配置(CPU、内存、网络等)下客户端的运行性…

QT自定义信号和槽

信号和槽 介绍实现创建文件对teacher的h和cpp文件进行处理对student的h和cpp文件进行处理对widget的h和cpp文件进行处理 介绍 Qt中的信号和槽是一种强大的机制&#xff0c;用于处理对象之间的通信。它们是Qt框架中实现事件驱动编程的核心部分。 信号&#xff08;Signal&#x…

SpringCloud微服务架构,适合接私(附源码)

一个由商业级项目升级优化而来的微服务架构&#xff0c;采用SpringBoot 2.7 、SpringCloud 等核心技术构建&#xff0c;提供基于React和Vue的两个前端框架用于快速搭建企业级的SaaS多租户微服务平台。 架构图 项目介绍 用户权益 仅允许免费用于学习、毕设、公司项目、私活等。…