快速入门Torch构建自己的网络模型

真有用构建自己的网络模型

    • 读前必看
    • 刚学完Alex网络感觉很厉害的样子,我也要搭建一个
    • 可以看着网络结构实现上面的代码你已经很强了,千万不要再想实现VGG等网络!!!90%你能了解到的模型大佬早已实现好,直接调用就OK
    • 下面是源码用nn.Module实现的AlexNet,和我们实现的区别并不大,将模型print出来能看懂就可以
    • 不忘初心,构建自己的网络模型,将AlexNet输入改为单通道图片:
    • Tips

读前必看

  • 如何用框架复现论文中的模型不重要,重要的是明白网络模型原理!!!
  • 如何用框架复现论文中的模型不重要,重要的是明白网络模型原理!!!
  • 如何用框架复现论文中的模型不重要,重要的是明白网络模型原理!!!

刚学完Alex网络感觉很厉害的样子,我也要搭建一个

在这里插入图片描述

回想一下torch构建网络的几种方法

  • nn.Sequential直接顺序实现
  • nn.Module继承基类构建自定义模型
feature = nn.Sequential(nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(64, 192, kernel_size=5, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(192, 384, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),
)

现在需要计算卷积后图像的维度,根据公式 image_shape = (image_shape - kernel_size + 2 * padding) / stride + 1计算

in_shape= 224
conv_size = [11, 5, 3, 3, 3]
padding_size = [2, 2, 1, 1, 1]
stride_size = [4, 1, 1, 1, 1]
# image_shape = (image_shape - kernel_size + 2 * padding) / stride + 1
for i in range(len(conv_size)):in_shape = (in_shape - conv_size[i] + 2 * padding_size[i]) / stride_size[i] + 1in_shape = math.floor(in_shape)if i in [0, 1, 4]:in_shape = (in_shape - 3 + 2 * 0) / 2 + 1in_shape = math.floor(in_shape)
print(in_shape)

计算结果是6,输出通道是256,所以特征有25666个,将下面代码添加到Sequential中完成自定义AlexNet构建

nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes)

可以看着网络结构实现上面的代码你已经很强了,千万不要再想实现VGG等网络!!!90%你能了解到的模型大佬早已实现好,直接调用就OK

下面是源码用nn.Module实现的AlexNet,和我们实现的区别并不大,将模型print出来能看懂就可以

class AlexNet(nn.Module):def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:super().__init__()# _log_api_usage_once(self)self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(64, 192, kernel_size=5, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(192, 384, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)self.avgpool = nn.AdaptiveAvgPool2d((6, 6))self.classifier = nn.Sequential(nn.Dropout(p=dropout),nn.Linear(256 * 6 * 6, 4096),nn.ReLU(inplace=True),nn.Dropout(p=dropout),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes),)def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

不忘初心,构建自己的网络模型,将AlexNet输入改为单通道图片:

model = AlexNet()
model.features[0] = nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2)
print(model)

Tips

Q1: padding是卷积之后还是卷积之前还是卷积之后实现的?
padding是在卷积之前补0,如果愿意的话,可以通过使用torch.nn.Functional.pad来补非0的内容。

Q2:padding补0的默认策略是什么?
四周都补!如果pad输入是一个tuple的话,则第一个参数表示高度上面的padding,第2个参数表示宽度上面的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/629068.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

MFC CAsyncSocket类作为客户端示例

之前写过CAsyncSocket类使用的博客;进一步看一下; VS新建一个MFC 对话框工程; 添加一个类,从CAsyncSocket继承,起个自己的名字; 对话框添加几个编辑框,按钮,静态控件; 为自己的CxxxAsyncSocket类添加重写的虚函数,OnConnect、OnReceive、OnSend; 自己的CAsyncSoc…

pytorch一致数据增强—独用增强

前作 [1] 介绍了一种用 pytorch 模仿 MONAI 实现多幅图(如:image 与 label)同用 random seed 保证一致变换的写法,核心是 MultiCompose 类和 to_multi 包装函数。不过 [1] 没考虑各图用不同 augmentation 的情况,如&am…

生物制药厂污水处理需要哪些工艺设备

生物制药厂是一种特殊的工业场所,由于其生产过程中涉及的有机物较多,导致废水中含有高浓度的有机物和微生物等污染物,因此需要采用一些特殊的工艺设备来进行污水处理。本文将介绍生物制药厂污水处理中常用的工艺设备。 首先,对于生…

【面试合集】说说微信小程序的支付流程?

面试官:说说微信小程序的支付流程? 一、前言 微信小程序为电商类小程序,提供了非常完善、优秀、安全的支付功能 在小程序内可调用微信的API完成支付功能,方便、快捷 场景如下图所示: 用户通过分享或扫描二维码进入商…

2024年华数杯国际赛赛题浅析

21号完赛,28号出成绩的华数杯国际赛,作为美赛最合适的练手赛正式开赛。为了让大家更好地比赛,首先为大家带来本次竞赛两道题目的浅要解析。主要分析两道题目适合的群体,未来大家求解过程中可能遇到的问题。方便大家快速完成选题。…

农业四情监测管理系统的特点优势

TH-Q2农业四情监测管理系统是一种利用现代信息技术,对农业生产中的墒情、苗情、虫情、灾情进行实时监测和管理的智能化系统。那么,这个系统到底有哪些特点和优势呢?让我们一起来了解一下。 1.实时监测:通过传感器、摄像头等设备&a…

美国智库发布《用人工智能展望网络未来》的解析

文章目录 前言一、人工智能未来可能改善网络安全的方式二、人工智能可能损害网络安全的方式三、人工智能使用的七条建议四、人工智能的应用和有效使用AI五、安全有效地使用人工智能制定具体建议六、展望网络未来的人工智能(一)提高防御者的效率&#xff…

DWM1000 MAC层

DWM1000 MAC层 MAC层 概述 MAC层,即媒体访问控制层,是数据通信协议栈中的一个重要部分,位于链路层的下半部分,紧邻物理层。在OSI模型中,它属于第二层,即数据链路层的一部分。MAC层的主要职责是控制如何在…

框架基础-Maven+SpringBoot入门

框架基础 Maven基础 Maven概述 Maven是为Java项目提供项目构建和依赖管理的工具 Maven三大功能 - 项目构建构建:是一个将代码从开发阶段到生产阶段的一个过程:清理,编译,测试,打包,安装,部署…

运筹说 第46期 | 目标规划-数学模型

经过前几期的学习,想必大家已经对线性规划问题有了详细的了解,但线性规划作为一种决策工具,在解决实际问题时,存在着一定的局限性:(1)线性规划只能处理一个目标,而现实问题往往存在多个目标;(2)…

GEE时序——利用sentinel-2(哨兵-2)数据进行地表物候学分析(时间序列平滑法估算和非平滑算法代码)

简介 哨兵-2A/B 串联卫星的空间分辨率高、重访时间长,有可能改进对陆地表面物候的检索。不过,生物群落和区域特征在很大程度上限制了陆表物候学算法的设计。在北极地区,这种生物群落特有的特征包括长期积雪、持续云层覆盖和生长季节短暂。在此,我们评估了哨兵-2 获取北极高…

【jupyter添加虚拟环境内核(pytorch、tensorflow)- 实操可行】

jupyter添加虚拟环境内核(pytorch、tensorflow)- 实操可行 1、查看当前状态(winR,cmd进入之后)2、激活虚拟环境并进入3、安装ipykernel5、完整步骤代码总结6、进入jupyter 添加pytorch、tensorflow内核操作相同,以下内容默认已经安…

Klocwork—符合功能安全要求的自动化静态测试工具

产品概述 Klocwork是Perforce公司产品,主要用于C、C、C#、Java、 python和Kotlin代码的自动化静态分析工作,可以提供编码规则检查、代码质量度量、测试结果管理等功能。Klocwork可以扩展到大多数规模的项目,与大型复杂环境、各种开发工具集成…

低代码自动化测试的实践

何为低代码测试 传统上,功能、 UI、端到端等测试自动化的实现都涉及编写测试脚本,代替测试人员执行重复的手动测试任务。自动化脚本的开发工作通常由 QA 工程师或开发人员完成,这需要编写大量代码。 而低代码甚至无代码的理念也是在自动化测…

【ubuntu】ubuntu 20.04安装docker,使用nginx部署前端项目,nginx.conf文件配置

docker 官网:Install Docker Engine on Ubuntu 一、安装docker 1.将apt升级到最新 sudo apt update2.使用apt安装 docker 和 docker-compose (遇到提示输入y) sudo apt install docker.io docker-compose3.将当前用户添加到docker用户组 …

HNU-模式识别-作业2-面向应用分类系统

模式识别-作业2 计科210X 甘晴void 202108010XXX 【具体实现思路是按照去年数学建模国赛题来做的,就放个思路,完整不放全了】 题目: 查阅文献资料,构建一个面向应用的分类系统。 要求: 至少3页A4纸,文…

Leetcode951. 翻转等价二叉树

Every day a Leetcode 题目来源:951. 翻转等价二叉树 解法1:递归 存在三种情况: 如果 root1 或者 root2 是 nullptr,那么只有在他们都为 nullptr 的情况下这两个二叉树才等价。如果 root1,root2 的值不相等&#x…

人民网(人民日报官网)投稿方式解读

当今社交媒体风行的时代,人民网作为我国最大的新闻门户网站,扮演着传递信息、引导民意的重要角色。人民网积极鼓励社会各界朋友,尤其是普通民众,分享自己的见闻和观点,通过投稿的方式参与到新闻报道中来。那么&#xf…

已实现:JS如何根据视频的http(s)地址,来截取帧图片,并实现大图压缩的功能

现在&#xff0c;我们已经有了视频的http地址&#xff0c;我们怎么截取帧图片呢&#xff1f;我以Vue为基础架构&#xff0c;来写写代码。 1、先写布局&#xff0c;先得有video&#xff0c;然后得有canvas <video id"videoPlay" style"width: 100%; height:1…

查看centos的CPU、内存、磁盘空间等配置信息

目录 查看CPU/proc/cpuinfo中的信息 查看内存/proc/meminfo中的信息 查看磁盘空间df 命令du命令使用fdisk命令 查看CPU /proc/cpuinfo中的信息 前置&#xff1a; [ltkjltkj front]$ cat /proc/cpuinfo| grep "physical id" physical id : 0 physical id : 0 physi…