pytorch04:网络模型创建

目录

  • 一、模型创建过程
    • 1.1 以LeNet网络为例
    • 1.2 LeNet结构
    • 1.3 nn.Module
  • 二、网络层容器(Containers)
    • 2.1 nn.Sequential
      • 2.1.1 常规方法实现
      • 2.1.2 OrderedDict方法实现
    • 2.2 nn.ModuleList
    • 2.3 nn.ModuleDict
    • 2.4 三种容器构建总结
  • 三、AlexNet网络构建

一、模型创建过程

在这里插入图片描述

1.1 以LeNet网络为例

在这里插入图片描述

网络代码如下:

class LeNet(nn.Module):def __init__(self, classes):super(LeNet, self).__init__()  # 调用父类方法,作用是调用nn.Module类的构造函数,# 确保LeNet类被正确地初始化,并继承了nn.Module 的所有属性和方法self.conv1 = nn.Conv2d(3, 6, 5) # 卷积层self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120) # 全连接层self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, classes)def forward(self, x):out = F.relu(self.conv1(x))out = F.max_pool2d(out, 2)out = F.relu(self.conv2(out))out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out

1.2 LeNet结构

在这里插入图片描述

LeNet:conv1–>pool1–>conv2–>pool2–>fc1–>fc2–>fc3
在这里插入图片描述

1.3 nn.Module

Module是nn模块中的功能,nn模块还有Parameter、functional等模块。
在这里插入图片描述
nn.Module主要有以下参数:
• parameters : 存储管理nn.Parameter类
• modules : 存储管理nn.Module类
• buffers:存储管理缓冲属性,如BN层中的running_mean

二、网络层容器(Containers)

在这里插入图片描述

2.1 nn.Sequential

nn.Sequential 是 nn.module的容器,也是最常用的容器,用于按顺序包装一组网络层
• 顺序性:各网络层之间严格按照顺序构建
• 自带forward():自带的forward里,通过for循环依次执行前向传播运算

2.1.1 常规方法实现

LeNet网络由两部分构成,中间的卷积池化特征提取部分(features),以及最后的分类部分(classifier)。
在这里插入图片描述
具体代码如下:

class LeNetSequential(nn.Module):def __init__(self, classes):super(LeNetSequential, self).__init__()self.features = nn.Sequential(  #特征提取部分nn.Conv2d(3, 6, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(  #分类部分nn.Linear(16*5*5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, classes),)def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return x

打印网络层:
在这里插入图片描述

2.1.2 OrderedDict方法实现

使用有序字典的方法构建Sequential
代码如下:

class LeNetSequentialOrderDict(nn.Module):def __init__(self, classes):super(LeNetSequentialOrderDict, self).__init__()self.features = nn.Sequential(OrderedDict({'conv1': nn.Conv2d(3, 6, 5),'relu1': nn.ReLU(inplace=True),'pool1': nn.MaxPool2d(kernel_size=2, stride=2),'conv2': nn.Conv2d(6, 16, 5),'relu2': nn.ReLU(inplace=True),'pool2': nn.MaxPool2d(kernel_size=2, stride=2),}))self.classifier = nn.Sequential(OrderedDict({'fc1': nn.Linear(16 * 5 * 5, 120),'relu3': nn.ReLU(),'fc2': nn.Linear(120, 84),'relu4': nn.ReLU(inplace=True),'fc3': nn.Linear(84, classes),}))def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return x

先看一下Sequential函数中init初始化的两种方法,当我们使用OrderedDict方法时,会进行判断,使用self.add_module(key, module)方法将字典中的key和value取出来添加到Sequential中。

class Sequential(Module):def __init__(self, *args):super().__init__()if len(args) == 1 and isinstance(args[0], OrderedDict):for key, module in args[0].items():self.add_module(key, module)else:for idx, module in enumerate(args):self.add_module(str(idx), module)

通过这种方法构建可以给每一网络层添加一个名称,网络输出结果如下:
在这里插入图片描述

2.2 nn.ModuleList

nn.ModuleList是 nn.module的容器,用于包装一组网络层,以迭代方式调用网络层
主要方法:
• append():在ModuleList后面添加网络层
• extend():拼接两个ModuleList
• insert():指定在ModuleList中位置插入网络层

使用列表生成式,通过一行代码就能构建20个网络层。
代码演示:

class ModuleList(nn.Module):def __init__(self):super(ModuleList, self).__init__()# 使用列表生成式构建20个全连接层,每个全连接层10个神经元的网络self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])def forward(self, x):for i, linear in enumerate(self.linears):x = linear(x)return xnet = ModuleList()

2.3 nn.ModuleDict

nn.ModuleDict是 nn.module的容器,用于包装一组网络层,以索引方式调用网络层,可以用过参数的形式选取想要调用的网络层。
主要方法:
• clear():清空ModuleDict
• items():返回可迭代的键值对(key-value pairs)
• keys():返回字典的键(key)
• values():返回字典的值(value)
• pop():返回一对键值,并从字典中删除

代码展示,只选取conv和relu两个网络层:

class ModuleDict(nn.Module):def __init__(self):super(ModuleDict, self).__init__()self.choices = nn.ModuleDict({'conv': nn.Conv2d(10, 10, 3),'pool': nn.MaxPool2d(3)})# 激活函数self.activations = nn.ModuleDict({'relu': nn.ReLU(),'prelu': nn.PReLU()})def forward(self, x, choice, act):  # 传入两个参数 用来选择网络层x = self.choices[choice](x)x = self.activations[act](x)return x
net = ModuleDict()
fake_img = torch.randn((4, 10, 32, 32))
output = net(fake_img, 'conv', 'relu')  #只选取conv和relu两个网络层。
print(output)

2.4 三种容器构建总结

• nn.Sequential:顺序性,各网络层之间严格按顺序执行,常用于block构建
• nn.ModuleList:迭代性,常用于大量重复网构建,通过for循环实现重复构建
• nn.ModuleDict:索引性,常用于可选择的网络层

三、AlexNet网络构建

AlexNet:2012年以高出第二名10多个百分点的准确率获得ImageNet分类任务冠军,开创了卷积神经网络的新时代
AlexNet特点如下:

  1. 采用ReLU:替换饱和激活函数,减轻梯度消失
  2. 采用LRN(Local Response Normalization):对数据归一化,减轻梯度消失
  3. Dropout:提高全连接层的鲁棒性,增加网络的泛化能力
  4. Data Augmentation:TenCrop,色彩修改

网络结构图如下:
在这里插入图片描述
构建代码:

import torch.nn as nn
import torch
from torchsummary import summary
# 定义一个名为AlexNet的神经网络模型,继承自nn.Module基类
class AlexNet(nn.Module):# 构造函数,初始化网络的参数def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:# 调用父类的构造函数super().__init__()# 定义神经网络的特征提取部分,包含多个卷积层和池化层self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),  # 输入通道3,输出通道64,卷积核大小11x11,步长4,填充2nn.ReLU(inplace=True),  # 使用ReLU激活函数,inplace=True表示原地操作,节省内存nn.MaxPool2d(kernel_size=3, stride=2),  # 最大池化层,核大小3x3,步长2nn.Conv2d(64, 192, kernel_size=5, padding=2),  # 输入通道64,输出通道192,卷积核大小5x5,填充2nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(192, 384, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)# 定义自适应平均池化层,将输入的任意大小的特征图池化为固定大小6x6self.avgpool = nn.AdaptiveAvgPool2d((6, 6))# 定义分类器部分,包含全连接层和Dropout层self.classifier = nn.Sequential(nn.Dropout(p=dropout),  # 使用Dropout进行正则化,随机丢弃一部分神经元以防止过拟合nn.Linear(256 * 6 * 6, 4096),  # 输入大小为256*6*6,输出大小为4096nn.ReLU(inplace=True),nn.Dropout(p=dropout),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes),  # 最后的全连接层输出类别数)# 前向传播函数,定义数据在网络中的传播过程def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.features(x)  # 特征提取x = self.avgpool(x)  # 平均池化x = torch.flatten(x, 1)  # 将特征图展平成一维向量x = self.classifier(x)  # 分类器return xif __name__ == '__main__':net = AlexNet().cuda()summary(net, (3, 256, 256))

打印出的网络结构图如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/594117.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何使用内网穿透工具实现远程SSH访问Deepin系统

文章目录 前言1. 开启SSH服务2. Deppin安装Cpolar3. 配置ssh公网地址4. 公网远程SSH连接5. 固定连接SSH公网地址6. SSH固定地址连接测试 前言 Deepin操作系统是一个基于Debian的Linux操作系统,专注于使用者对日常办公、学习、生活和娱乐的操作体验的极致&#xff0…

leetcode递归算法题总结

递归本质是找重复的子问题 本章目录 1.汉诺塔2.合并两个有序链表3.反转链表4.两两交换链表中的节点5.Pow(x,n) 1.汉诺塔 汉诺塔 //面试写法 class Solution { public:void hanota(vector<int>& a, vector<int>& b, vector<int>& c) {dfs(a,b…

踩坑记录-安装nuxt3报错:Error: Failed to download template from registry: fetch failed;

报错复现 安装nuxt3报错&#xff1a;Error: Failed to download template from registry: fetch failednpx nuxi init nuxt-demo 初始化nuxt 项目 报错 Error: Failed to download template from registry: fetch faile 解决方法 配置hosts Mac电脑&#xff1a;/etc/hostswin电…

众和策略:中一签最多赚超2万元!又有新股大涨

A股商场整体低位轰动 今天上午A股商场整体低位轰动。 板块和赛道方面&#xff0c;煤炭板块再度领涨&#xff0c;板块涨幅逾越1%&#xff0c;云煤动力涨停。 公用事业、钢铁、传媒、石油石化等板块涨幅居前。电子、计算机、通讯等板块跌幅居前。 概念板块方面&#xff0c;化…

深入探索小红书笔记详情API:解锁内容创新的无尽潜力

一、引言 在当今信息爆炸的时代&#xff0c;内容创新已经成为品牌和个人脱颖而出的关键。小红书&#xff0c;作为全球最大的消费类口碑库之一&#xff0c;每天产生大量的用户生成内容。而小红书笔记详情API&#xff0c;作为一个强大的工具&#xff0c;能够为内容创作者提供深入…

fpga xvc 调试实现,支持多端口同时调试多颗FPGA芯片

xilinx 推荐的实现结构方式如下&#xff1a; 通过一个ZYNQ运行xvc服务器&#xff0c;然后通过zynq去配置其他的FPGA&#xff0c;具体参考设计可以参考手册xapp1251&#xff0c;由于XVC运行的协议是标准的TCP协议&#xff0c;这种方式需要ZYNQ运行TCP协议&#xff0c;也就需要运…

【普中开发板】基于51单片机音乐盒LCD1602显示( proteus仿真+程序+设计报告+讲解视频)

【普中开发板】基于51单片机音乐盒LCD1602显示( proteus仿真程序设计报告讲解视频&#xff09; 仿真图proteus7.8及以上 程序编译器&#xff1a;keil 4/keil 5 编程语言&#xff1a;C语言 设计编号&#xff1a;P08 1. 主要功能&#xff1a; 基于51单片机AT89C51/52&#…

LLVM(简介)

历史 LLVM(low level virtual machine)起源于伊利诺伊大学的一个编译器实验项目&#xff0c;目前已经发展成一个集编译器和工具链为一体的商业开源项目&#xff0c;因此其英文名称的含义被扩大&#xff0c;不再仅仅是字面意思。其创始人为 Chris Lattner。LLVM项目遵循的开源许…

Go语言命令行参数及cobra使用教程

Go语言命令行参数及cobra使用教程 1.原生命令行参数2.使用CIL框架Cobra创建 rootCmd创建你的 main.go创建其他命令子命令返回和处理错误 3.cobra使用标志4.Cobra位置参数和自定义参数5.Cobra PreRun和PostRun钩子 1.原生命令行参数 os 包以跨平台的方式&#xff0c;提供了一些…

Spring Boot 整合 MinIO自建对象存储服务

GitHub 地址&#xff1a;GitHub - minio/minio: The Object Store for AI Data Infrastructure 另外&#xff0c;MinIO 可以用来作为云原生应用的主要存储服务&#xff0c;因为云原生应用往往需要更高的吞吐量和更低的延迟&#xff0c;而这些都是 MinIO 的优势。安装过程跳过。…

Numpy基础

目录&#xff1a; 一、简介:二、array数组ndarray&#xff1a;1.array( )创建数组&#xff1a;2.数组赋值和引用的区别&#xff1a;3.arange( )创建区间数组&#xff1a;4.linspace( )创建等差数列&#xff1a;5.logspace( )创建等比数列&#xff1a;6.zeros( )创建全0数组&…

半导体设备系列:半导体制造产能扩张,设备零部件需求旺盛

近年来国内半导体制造产能不断扩张&#xff0c;半导体设备厂商加速成长。我们认为下游发展将拉动上游本地化配套需求&#xff0c;半导体设备零部件迎来高增长阶段。 摘要 半导体设备零部件包含密封圈、EFEM、射频电源、静电吸盘、硅电极、真空泵、气体流量计、喷淋头等产品&a…

JVM虚拟机:各种JVM报错总结

错误 java.lang.StackOverflowError java.lang.OutOfMemoryError:java heap space java.lang.OutOfMemoryError:GC overhead limit exceeded java.lang.OutOfMemoryError:Direct buffer memory java.lang.OutOfMemoryError:unable to create new native thread java.lang.OutOf…

线程的深入学习(二)

前言 上一篇讲了线程池的相关知识&#xff0c;这篇文章主要讲解一个 1.并发工具类如CountDownLatch、CyclicBarrier等。 2.线程安全和并发集合&#xff1a; 3.学习如何使用Java提供的线程安全的集合类&#xff0c;如ConcurrentHashMap、CopyOnWriteArrayList等。 并发工具类 …

Linux学习记录——삼십삼 http协议

文章目录 1、URL2、http协议的宏观构成3、详细理解http协议1、http请求2、http响应1、有效载荷格式2、有效载荷长度3、客户端要访问的资源类型4、修改响应写法5、处理不同的请求6、跳转 3、请求方法&#xff08;GET/POST&#xff09;4、HTTP状态码&#xff08;实现3和4开头的&a…

uniapp中用户登录数据的存储方法探究

Hello大家好&#xff01;我是咕噜铁蛋&#xff01;作为一个博主&#xff0c;我们经常需要在应用程序中实现用户登录功能&#xff0c;并且需要将用户的登录数据进行存储&#xff0c;以便在多次使用应用程序时能够方便地获取用户信息。铁蛋通过科技手段帮大家收集整理了些知识&am…

每天五分钟计算机视觉:揭秘迁移学习

本文重点 随着人工智能的迅速发展,深度学习已经成为了许多领域的关键技术。然而,深度学习模型的训练需要大量的标注数据,这在很多情况下是不现实的。迁移学习作为一种有效的方法,可以在已有的数据和模型上进行训练,然后将其应用于新的任务。这种方法大大降低了对新任务的…

书香之家 国学启智——学夫堂幼儿国学托管永嘉上塘实验店启航

在教育创新的道路上&#xff0c;学夫堂幼儿国学托管永嘉上塘实验店迎来了一个重要的时刻。经过三个多月的精心筹备和试运营&#xff0c;今天正式宣布学夫堂幼儿国学托管在永嘉县城北街道景和佳苑8幢105号开门迎客。 学夫堂深信&#xff0c;国学智慧不仅是中华文化的瑰宝&#x…

阿赵UE学习笔记——7、导入资源

阿赵UE学习笔记目录 大家好&#xff0c;我是阿赵。   继续学习虚幻引擎的使用。这次将会把一个带动作和贴图的钢铁侠模型&#xff0c;导入的UE的项目中。 1、准备的资源 这里有2个fbx文件&#xff0c;都是带着网格和动画的&#xff0c;模型网格和骨骼是一样的&#xff0c;只…

MySQL是如何做到可以恢复到半个月内任意一秒的状态的?

MySQL的逻辑架构图 MySQL中两个重要的日志模块&#xff1a;redo log&#xff08;重做日志&#xff09;和binlog&#xff08;归档日志&#xff09; 我们先来看redo log&#xff1a; 介绍一个MySQL里经常说到的WAL技术&#xff0c;即Write-Ahead-Logging&#xff0c;它的关键点…