PyTorch框架学习九——网络模型的构建

PyTorch框架学习九——网络模型的构建

  • 一、概述
  • 二、nn.Module
  • 三、模型容器Container
    • 1.nn.Sequential
    • 2.nn.ModuleList
    • 3.nn.ModuleDict()
    • 4.总结

笔记二到八主要介绍与数据有关的内容,这次笔记将开始介绍网络模型有关的内容,首先我们不追求网络内部各层的具体内容,重点关注模型的构建,学会了如何构建模型,然后再开始一些具体网络层的学习。

一、概述

模型有关的内容主要如下图所示:
在这里插入图片描述
主要是模型的搭建权值的初始化两个问题,而模型的搭建里,首先需要构建单独的网络层,然后将这些网络层按顺序拼接起来,就构成了一个模型,然后进行某种权值初始化,就可以用于训练数据。

今天介绍PyTorch中是如何实现模型创建的,具体内部的卷积、池化、激活函数等知识下次笔记介绍。上述的所有内容,在PyTorch中都有一个叫nn.Module的模块来实现。

看一个LeNet模型的例子:
LeNet网络结构
从上图可以看出LeNet模型经过了这样一个网络层的流程:
在这里插入图片描述
那我们要来搭建这个模型的话,就要先单独构建卷积层Conv,池化层pool,全连接层fc,然后按照上面的顺序进行拼接,拼接后的整体才是一个构建好的网络模型。

看一下LeNet的模型构建的代码:

class LeNet(nn.Module):def __init__(self, classes):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, classes)def forward(self, x):out = F.relu(self.conv1(x))out = F.max_pool2d(out, 2)out = F.relu(self.conv2(out))out = F.max_pool2d(out, 2)out = out.view(out.size(0), -1)out = F.relu(self.fc1(out))out = F.relu(self.fc2(out))out = self.fc3(out)return out

可以看出__init__()函数实现了对每一个单独的网络层的构建,forward()函数实现了子网络层的拼接。

二、nn.Module

介绍nn.Module之前先看一下torch.nn里四个重要的模块:

  1. torch.nn.Parameter:张量的子类,表示可学习的参数,如weight、bias。
  2. torch.nn.Module:所有网络层的基类,管理网络属性。
  3. torch.nn.functional:函数具体实现,如卷积、池化、激活函数等。
  4. torch.nn.init:参数初始化的方法。

这里重点介绍nn.Parameter和nn.Module。

nn.Module来构建网络层时会创建8个字典管理它的不同属性,分别如下所示:

  • parameters:存储管理nn.Parameter类。
  • modules:存储管理nn.Module类。
  • buffers:存储管理缓冲属性,如BN层中的running_mean。
  • ×××_hooks(5个):存储管理钩子函数(目前不了解)。

下面的代码是创建一个module时对8个字典的初始化:

    def __init__(self):"""Initializes internal Module state, shared by both nn.Module and ScriptModule."""torch._C._log_api_usage_once("python.nn_module")self.training = Trueself._parameters = OrderedDict()self._buffers = OrderedDict()self._backward_hooks = OrderedDict()self._forward_hooks = OrderedDict()self._forward_pre_hooks = OrderedDict()self._state_dict_hooks = OrderedDict()self._load_state_dict_pre_hooks = OrderedDict()self._modules = OrderedDict()

注意:

  1. 一个module可以包含多个子module,如LeNet是一个module,它包含了conv、fc等子module。
  2. 一个module相当于一个运算,必须实现forward函数。
  3. 每个module都有8个字典管理它的属性。

三、模型容器Container

模型容器有三种,如下图所示:
在这里插入图片描述

1.nn.Sequential

功能:是nn.Module的容器,用于按顺序包装一组网络层。

还是以LeNet为例,我们将LeNet分成features和classifier两部分,每个部分都是一个sequential:
在这里插入图片描述
代码如下:

class LeNetSequential(nn.Module):def __init__(self, classes):super(LeNetSequential, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 6, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(16*5*5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, classes),)def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return x

但是,这种构建网络的方式有一个小问题,每一层网络层都会自动按顺序编一个号作为name,如features这个Sequential里每层网络层在module属性内部是这样的:

在这里插入图片描述
这里只有六个网络层,所以还可以在短时间内找到你需要的那一个,但是当层数非常多的时候,这种数字命名的方式就很不友好,而Sequential也有相应的应对方法,即为每一层网络命名,具体代码如下所示:

class LeNetSequentialOrderDict(nn.Module):def __init__(self, classes):super(LeNetSequentialOrderDict, self).__init__()self.features = nn.Sequential(OrderedDict({'conv1': nn.Conv2d(3, 6, 5),'relu1': nn.ReLU(inplace=True),'pool1': nn.MaxPool2d(kernel_size=2, stride=2),'conv2': nn.Conv2d(6, 16, 5),'relu2': nn.ReLU(inplace=True),'pool2': nn.MaxPool2d(kernel_size=2, stride=2),}))self.classifier = nn.Sequential(OrderedDict({'fc1': nn.Linear(16*5*5, 120),'relu3': nn.ReLU(),'fc2': nn.Linear(120, 84),'relu4': nn.ReLU(inplace=True),'fc3': nn.Linear(84, classes),}))def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return x

与原来不同的地方就是,构建了一个OrderedDict字典来存放键值对,key就是每一层网络的名字,value就是具体的网络层实现,看一下此时的module属性内部:

在这里插入图片描述
这样就很好寻找所需要的某一层网络。

综上,Sequential的特点:

  1. 顺序性:各网络层之间严格按照顺序构建。
  2. 自带forward():通过for循环依次执行前向传播运算。

2.nn.ModuleList

也是nn.module的容器,用于包装一组网络层,以迭代方式调用网络层。

主要方法:

  1. append():在ModuleList后面添加网络层
  2. extend():拼接两个ModuleList
  3. insert():指定在ModuleList中位置插入网络层

这种容器比较适合构建大量重复的网络层,因为利用了迭代的方法,下面就是构建20个线性层的例子

class ModuleList(nn.Module):def __init__(self):super(ModuleList, self).__init__()self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])def forward(self, x):for i, linear in enumerate(self.linears):x = linear(x)return x

在这里插入图片描述

3.nn.ModuleDict()

也是nn.module的容器,用于包装一组网络层,以索引方式调用网络层。

主要方法:

  1. clear():清空ModuleDict
  2. items():返回可迭代的键值对
  3. keys():返回字典的键
  4. values():返回字典的值
  5. pop():返回一对键值,并从字典中删除

这种容器的特点是,因为键值对可以索引的特性,可用于选择网络层:

class ModuleDict(nn.Module):def __init__(self):super(ModuleDict, self).__init__()self.choices = nn.ModuleDict({'conv': nn.Conv2d(10, 10, 3),'pool': nn.MaxPool2d(3)})self.activations = nn.ModuleDict({'relu': nn.ReLU(),'prelu': nn.PReLU()})def forward(self, x, choice, act):x = self.choices[choice](x)x = self.activations[act](x)return xnet = ModuleDict()fake_img = torch.randn((4, 10, 32, 32))output = net(fake_img, 'conv', 'relu')print(output)

我们构建了conv、pool以及relu、prelu,然后我们选择使用conv和relu。

4.总结

对于上述提及的三种容器,它们各自的特点以及适用范围如下所示:

  1. nn.Sequential:顺序性,各层之间按顺序执行,常用于block的构建。
  2. nn.ModuleList:迭代性,常用于大量重复网络层的构建,通过for循环实现重复构建。
  3. nn.ModuleDict:索引性,常用于可选择的网络层的构建,通过字典的键值对实现选择。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/491695.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

中国17种稀土有啥军事用途?没它们,美军技术优势将归零

来源:陶慕剑观察 稀土就是化学元素周期表中镧系元素——镧(La)、铈(Ce)、镨(Pr)、钕(Nd)、钷(Pm)、钐(Sm)、铕(Eu)、钆(Gd)、铽(Tb)、镝(Dy)、钬(Ho)、铒(Er)、铥(Tm)、镱(Yb)、镥(Lu),再加上钪(Sc)和钇(Y)共17种元素。中国稀土占据着众多的世界第一&…

PyTorch框架学习十——基础网络层(卷积、转置卷积、池化、反池化、线性、激活函数)

PyTorch框架学习十——基础网络层(卷积、转置卷积、池化、反池化、线性、激活函数)一、卷积层二、转置卷积层三、池化层1.最大池化nn.MaxPool2d2.平均池化nn.AvgPool2d四、反池化层最大值反池化nn.MaxUnpool2d五、线性层六、激活函数层1.nn.Sigmoid2.nn.…

PyTorch框架学习十一——网络层权值初始化

PyTorch框架学习十一——网络层权值初始化一、均匀分布初始化二、正态分布初始化三、常数初始化四、Xavier 均匀分布初始化五、Xavier正态分布初始化六、kaiming均匀分布初始化前面的笔记介绍了网络模型的搭建,这次将介绍网络层权值的初始化,适当的初始化…

W3C 战败:无权再制定 HTML 和 DOM 标准!

来源:CSDN历史性时刻!——近日,W3C正式宣告战败:HTML和DOM标准制定权将全权移交给浏览器厂商联盟WHATWG。由苹果、Google、微软和Mozilla四大浏览器厂商组成的WHATWG已经与万维网联盟(World Wide Web Consortium&#…

PyTorch框架学习十二——损失函数

PyTorch框架学习十二——损失函数一、损失函数的作用二、18种常见损失函数简述1.L1Loss(MAE)2.MSELoss3.SmoothL1Loss4.交叉熵CrossEntropyLoss5.NLLLoss6.PoissonNLLLoss7.KLDivLoss8.BCELoss9.BCEWithLogitsLoss10.MarginRankingLoss11.HingeEmbedding…

化合物半导体的机遇

来源:国盛证券半导体材料可分为单质半导体及化合物半导体两类,前者如硅(Si)、锗(Ge)等所形成的半导体,后者为砷化镓(GaAs)、氮化镓(GaN)、碳化硅(…

PyTorch框架学习十三——优化器

PyTorch框架学习十三——优化器一、优化器二、Optimizer类1.基本属性2.基本方法三、学习率与动量1.学习率learning rate2.动量、冲量Momentum四、十种常见的优化器(简单罗列)上次笔记简单介绍了一下损失函数的概念以及18种常用的损失函数,这次…

最全芯片产业报告出炉,计算、存储、模拟IC一文扫尽

来源:智东西最近几年, 半导体产业风起云涌。 一方面, 中国半导体异军突起, 另一方面, 全球产业面临超级周期,加上人工智能等新兴应用的崛起,中美科技摩擦频发,全球半导体现状如何&am…

python向CSV文件写内容

f open(r"D:\test.csv", w) f.write(1,2,3\n) f.write(4,5,6\n) f.close() 注意:上面例子中的123456这6个数字会分别写入不同的单元格里,即以逗号作为分隔符将字符串内容分开放到不同单元格 上面例子的图: 如果要把变量的值放入…

PyTorch框架学习十四——学习率调整策略

PyTorch框架学习十四——学习率调整策略一、_LRScheduler类二、六种常见的学习率调整策略1.StepLR2.MultiStepLR3.ExponentialLR4.CosineAnnealingLR5.ReduceLRonPlateau6.LambdaLR在上次笔记优化器的内容中介绍了学习率的概念,但是在整个训练过程中学习率并不是一直…

JavaScript数组常用方法

转载于:https://www.cnblogs.com/kenan9527/p/4926145.html

蕨叶形生物刷新生命史,动物界至少起源于5.7亿年前

来源 :newsweek.com根据发表于《古生物学》期刊(Palaeontology)的一项研究,动物界可能比科学界所知更加古老。研究人员发现,一种名为“美妙春光虫”(Stromatoveris psygmoglena)的海洋生物在埃迪…

PyTorch框架学习十五——可视化工具TensorBoard

PyTorch框架学习十五——可视化工具TensorBoard一、TensorBoard简介二、TensorBoard安装及测试三、TensorBoard的使用1.add_scalar()2.add_scalars()3.add_histogram()4.add_image()5.add_graph()之前的笔记介绍了模型训练中的数据、模型、损失函数和优化器,下面将介…

CNN、RNN、DNN的内部网络结构有什么区别?

来源:AI量化百科神经网络技术起源于上世纪五、六十年代,当时叫感知机(perceptron),拥有输入层、输出层和一个隐含层。输入的特征向量通过隐含层变换达到输出层,在输出层得到分类结果。早期感知机的推动者是…

L2级自动驾驶量产趋势解读

来源:《国盛计算机组》L2 级自动驾驶离我们比想象的更近。18 年下半年部分 L2 车型已面世,凯迪拉克、吉利、长城、长安、上汽等均已推出了 L2 自动驾驶车辆。国内目前在售2872个车型,L2级功能渗透率平均超过25%,豪华车甚至超过了6…

PyTorch框架学习十六——正则化与Dropout

PyTorch框架学习十六——正则化与Dropout一、泛化误差二、L2正则化与权值衰减三、正则化之Dropout补充:这次笔记主要关注防止模型过拟合的两种方法:正则化与Dropout。 一、泛化误差 一般模型的泛化误差可以被分解为三部分:偏差、方差与噪声…

HDU 5510 Bazinga 暴力匹配加剪枝

Bazinga Time Limit: 20 Sec Memory Limit: 256 MB 题目连接 http://acm.hdu.edu.cn/showproblem.php?pid5510 Description Ladies and gentlemen, please sit up straight.Dont tilt your head. Im serious.For n given strings S1,S2,⋯,Sn, labelled from 1 to n, you shou…

PyTorch框架学习十七——Batch Normalization

PyTorch框架学习十七——Batch Normalization一、BN的概念二、Internal Covariate Shift(ICS)三、BN的一个应用案例四、PyTorch中BN的实现1._BatchNorm类2.nn.BatchNorm1d/2d/3d(1)nn.BatchNorm1d(2)nn.Bat…

人工智能影响未来娱乐的31种方式

来源:资本实验室 技术改变生活,而各种新技术每天都在重新定义我们的生活状态。技术改变娱乐,甚至有了互联网时代“娱乐至死”的警语。当人工智能介入我们的生活,特别是娱乐的时候,一切又将大为不同。尽管很多时候我们很…

素数与量子物理的结合能带来解决黎曼猜想的新可能吗?

来源:中国科学院数学与系统科学研究院翻译:墨竹校对:杨璐1972年,物理学家弗里曼戴森(Freeman Dyson)写了一篇名为《错失的机会》(Missed Opportunities)的文章。在该文中&#xff0c…