Pytorch模型层简单介绍

模型层layers

深度学习模型一般由各种模型层组合而成。

torch.nn中内置了非常丰富的各种模型层。它们都属于nn.Module的子类,具备参数管理功能。

例如:

nn.Linear, nn.Flatten, nn.Dropout, nn.BatchNorm2d

nn.Conv2d,nn.AvgPool2d,nn.Conv1d,nn.ConvTranspose2d

nn.Embedding,nn.GRU,nn.LSTM

nn.Transformer

如果这些内置模型层不能够满足需求,我们也可以通过继承nn.Module基类构建自定义的模型层。

实际上,pytorch不区分模型和模型层,都是通过继承nn.Module进行构建。

因此,我们只要继承nn.Module基类并实现forward方法即可自定义模型层。

一,内置模型层

头文件:

import numpy as np 
import torch 
from torch import nn

基础层

nn.Linear:全连接层。参数个数 = 输入层特征数× 输出层特征数(weight)+ 输出层特征数(bias)

nn.Flatten:压平层,用于将多维张量样本压成一维张量样本。

nn.BatchNorm1d:一维批标准化层。通过线性变换将输入批次缩放平移到稳定的均值和标准差。可以增强模型对输入不同分布的适应性,加快模型训练速度,有轻微正则化效果。一般在激活函数之前使用。可以用afine参数设置该层是否含有可以训练的参数。

nn.BatchNorm2d:二维批标准化层。

nn.BatchNorm3d:三维批标准化层。

nn.Dropout:一维随机丢弃层。一种正则化手段。

nn.Dropout2d:二维随机丢弃层。

nn.Dropout3d:三维随机丢弃层。

nn.Threshold:限幅层。当输入大于或小于阈值范围时,截断之。

nn.ConstantPad2d: 二维常数填充层。对二维张量样本填充常数扩展长度。

nn.ReplicationPad1d: 一维复制填充层。对一维张量样本通过复制边缘值填充扩展长度。

nn.ZeroPad2d:二维零值填充层。对二维张量样本在边缘填充0值.

nn.GroupNorm:组归一化。一种替代批归一化的方法,将通道分成若干组进行归一。不受batch大小限制,据称性能和效果都优于BatchNorm。

nn.LayerNorm:层归一化。较少使用。

nn.InstanceNorm2d: 样本归一化。较少使用。

卷积网络相关层

nn.Conv1d:普通一维卷积,常用于文本。参数个数 = 输入通道数×卷积核尺寸(如3)×卷积核个数 + 卷积核尺寸(如3)

nn.Conv2d:普通二维卷积,常用于图像。参数个数 = 输入通道数×卷积核尺寸(如3乘3)×卷积核个数 + 卷积核尺寸(如3乘3) 通过调整dilation参数大于1,可以变成空洞卷积,增大卷积核感受野。 通过调整groups参数不为1,可以变成分组卷积。分组卷积中不同分组使用相同的卷积核,显著减少参数数量。 当groups参数等于通道数时,相当于tensorflow中的二维深度卷积层tf.keras.layers.DepthwiseConv2D。 利用分组卷积和1乘1卷积的组合操作,可以构造相当于Keras中的二维深度可分离卷积层tf.keras.layers.SeparableConv2D。

nn.Conv3d:普通三维卷积,常用于视频。参数个数 = 输入通道数×卷积核尺寸(如3乘3乘3)×卷积核个数 + 卷积核尺寸(如3乘3乘3) 。

nn.MaxPool1d: 一维最大池化。

nn.MaxPool2d:二维最大池化。一种下采样方式。没有需要训练的参数。

nn.MaxPool3d:三维最大池化。

nn.AdaptiveMaxPool2d:二维自适应最大池化。无论输入图像的尺寸如何变化,输出的图像尺寸是固定的。 该函数的实现原理,大概是通过输入图像的尺寸和要得到的输出图像的尺寸来反向推算池化算子的padding,stride等参数。

nn.FractionalMaxPool2d:二维分数最大池化。普通最大池化通常输入尺寸是输出的整数倍。而分数最大池化则可以不必是整数。分数最大池化使用了一些随机采样策略,有一定的正则效果,可以用它来代替普通最大池化和Dropout层。

nn.AvgPool2d:二维平均池化。

nn.AdaptiveAvgPool2d:二维自适应平均池化。无论输入的维度如何变化,输出的维度是固定的。

nn.ConvTranspose2d:二维卷积转置层,俗称反卷积层。并非卷积的逆操作,但在卷积核相同的情况下,当其输入尺寸是卷积操作输出尺寸的情况下,卷积转置的输出尺寸恰好是卷积操作的输入尺寸。在语义分割中可用于上采样。

nn.Upsample:上采样层,操作效果和池化相反。可以通过mode参数控制上采样策略为"nearest"最邻近策略或"linear"线性插值策略。

nn.Unfold:滑动窗口提取层。其参数和卷积操作nn.Conv2d相同。实际上,卷积操作可以等价于nn.Unfold和nn.Linear以及nn.Fold的一个组合。 其中nn.Unfold操作可以从输入中提取各个滑动窗口的数值矩阵,并将其压平成一维。利用nn.Linear将nn.Unfold的输出和卷积核做乘法后,再使用 nn.Fold操作将结果转换成输出图片形状。

nn.Fold:逆滑动窗口提取层。

循环网络相关层

nn.Embedding:嵌入层。一种比Onehot更加有效的对离散特征进行编码的方法。一般用于将输入中的单词映射为稠密向量。嵌入层的参数需要学习。

nn.LSTM:长短记忆循环网络层【支持多层】。最普遍使用的循环网络层。具有携带轨道,遗忘门,更新门,输出门。可以较为有效地缓解梯度消失问题,从而能够适用长期依赖问题。设置bidirectional = True时可以得到双向LSTM。需要注意的时,默认的输入和输出形状是(seq,batch,feature), 如果需要将batch维度放在第0维,则要设置batch_first参数设置为True。

nn.GRU:门控循环网络层【支持多层】。LSTM的低配版,不具有携带轨道,参数数量少于LSTM,训练速度更快。

nn.RNN:简单循环网络层【支持多层】。容易存在梯度消失,不能够适用长期依赖问题。一般较少使用。

nn.LSTMCell:长短记忆循环网络单元。和nn.LSTM在整个序列上迭代相比,它仅在序列上迭代一步。一般较少使用。

nn.GRUCell:门控循环网络单元。和nn.GRU在整个序列上迭代相比,它仅在序列上迭代一步。一般较少使用。

nn.RNNCell:简单循环网络单元。和nn.RNN在整个序列上迭代相比,它仅在序列上迭代一步。一般较少使用。

Transformer相关层

nn.Transformer:Transformer网络结构。Transformer网络结构是替代循环网络的一种结构,解决了循环网络难以并行,难以捕捉长期依赖的缺陷。它是目前NLP任务的主流模型的主要构成部分。Transformer网络结构由TransformerEncoder编码器和TransformerDecoder解码器组成。编码器和解码器的核心是MultiheadAttention多头注意力层。

nn.TransformerEncoder:Transformer编码器结构。由多个 nn.TransformerEncoderLayer编码器层组成。

nn.TransformerDecoder:Transformer解码器结构。由多个 nn.TransformerDecoderLayer解码器层组成。

nn.TransformerEncoderLayer:Transformer的编码器层。

nn.TransformerDecoderLayer:Transformer的解码器层。

nn.MultiheadAttention:多头注意力层。

二,自定义模型层

如果Pytorch的内置模型层不能够满足需求,我们也可以通过继承nn.Module基类构建自定义的模型层。

实际上,pytorch不区分模型和模型层,都是通过继承nn.Module进行构建。

因此,我们只要继承nn.Module基类并实现forward方法即可自定义模型层。

下面是Pytorch的nn.Linear层的源码,我们可以仿照它来自定义模型层。

import torch
from torch import nn
import torch.nn.functional as Fclass Linear(nn.Module):__constants__ = ['in_features', 'out_features']def __init__(self, in_features, out_features, bias=True):super(Linear, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = nn.Parameter(torch.Tensor(out_features, in_features))if bias:self.bias = nn.Parameter(torch.Tensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))if self.bias is not None:fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)bound = 1 / math.sqrt(fan_in)nn.init.uniform_(self.bias, -bound, bound)def forward(self, input):return F.linear(input, self.weight, self.bias)def extra_repr(self):return 'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias is not None)
linear = nn.Linear(20, 30)
inputs = torch.randn(128, 20)
output = linear(inputs)
print(output.size())

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/389353.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

有效沟通的技能有哪些_如何有效地展示您的数据科学或软件工程技能

有效沟通的技能有哪些What is the most important thing to do after you got your skills to be a data scientist? It has to be to show off your skills. Otherwise, there is no use of your skills. If you want to get a job or freelance or start a start-up, you ha…

java.net.SocketException: Software caused connection abort: socket write erro

场景:接口测试 编辑器:eclipse 版本:Version: 2018-09 (4.9.0) testng版本:TestNG version 6.14.0 执行testng.xml时报错信息: 出现此报错原因之一:网上有人说是testng版本与eclipse版本不一致造成的&#…

[博客..配置?]博客园美化

博客园搞定时间 -> 18年6月27日 [让我歇会儿 搞这个费脑子 代码一个都看不懂] 转载于:https://www.cnblogs.com/Steinway/p/9235437.html

使用K-Means对美因河畔法兰克福的社区进行聚类

介绍 (Introduction) This blog post summarizes the results of the Capstone Project in the IBM Data Science Specialization on Coursera. Within the project, the districts of Frankfurt am Main in Germany shall be clustered according to their venue data using t…

Pytorch损失函数losses简介

一般来说,监督学习的目标函数由损失函数和正则化项组成。(Objective Loss Regularization) Pytorch中的损失函数一般在训练模型时候指定。 注意Pytorch中内置的损失函数的参数和tensorflow不同,是y_pred在前,y_true在后,而Ten…

读取Mc1000的 唯一 ID 机器号

先引用Symbol.ResourceCoordination 然后引用命名空间 using System;using System.Security.Cryptography;using System.IO; 以下为类程序 /// <summary> /// 获取设备id /// </summary> /// <returns></returns> public static string GetDevi…

样本均值的抽样分布_抽样分布样本均值

样本均值的抽样分布One of the most important concepts discussed in the context of inferential data analysis is the idea of sampling distributions. Understanding sampling distributions helps us better comprehend and interpret results from our descriptive as …

玩转ceph性能测试---对象存储(一)

笔者最近在工作中需要测试ceph的rgw&#xff0c;于是边测试边学习。首先工具采用的intel的一个开源工具cosbench&#xff0c;这也是业界主流的对象存储测试工具。 1、cosbench的安装&#xff0c;启动下载最新的cosbench包wget https://github.com/intel-cloud/cosbench/release…

[BZOJ 4300]绝世好题

Description 题库链接 给定一个长度为 \(n\) 的数列 \(a_i\) &#xff0c;求 \(a_i\) 的子序列 \(b_i\) 的最长长度&#xff0c;满足 \(b_i\wedge b_{i-1}\neq 0\) &#xff08; \(\wedge\) 表示按位与&#xff09; \(1\leq n\leq 100000\) Solution 令 \(f_i\) 为二进制第 \(i…

因果关系和相关关系 大数据_数据科学中的相关性与因果关系

因果关系和相关关系 大数据Let’s jump into it right away.让我们马上进入。 相关性 (Correlation) Correlation means relationship and association to another variable. For example, a movement in one variable associates with the movement in another variable. For…

Pytorch构建模型的3种方法

这个地方一直是我思考的地方&#xff01;因为学的代码太多了&#xff0c;构建的模型各有不同&#xff0c;这里记录一下&#xff01; 可以使用以下3种方式构建模型&#xff1a; 1&#xff0c;继承nn.Module基类构建自定义模型。 2&#xff0c;使用nn.Sequential按层顺序构建模…

vue取数据第一个数据_我作为数据科学家的第一个月

vue取数据第一个数据A lot.很多。 I landed my first job as a Data Scientist at the beginning of August, and like any new job, there’s a lot of information to take in at once.我于8月初找到了数据科学家的第一份工作&#xff0c;并且像任何新工作一样&#xff0c;一…

Flask-SocketIO 简单使用指南

Flask-SocketIO 使 Flask 应用程序能够访问客户端和服务器之间的低延迟双向通信。客户端应用程序可以使用 Javascript&#xff0c;C &#xff0c;Java 和 Swift 中的任何 SocketIO 官方客户端库或任何兼容的客户端来建立与服务器的永久连接。 安装 直接使用 pip 来安装&#xf…

STL-开篇

基本概念 STL&#xff1a; Standard Template Library&#xff0c;标准模板库 定义&#xff1a; c引入的一个标准类库 特点&#xff1a;1&#xff09;数据结构和算法的 c实现&#xff08; 采用模板类和模板函数&#xff09;2&#xff09;数据的存储和算法的分离3&#xff09;高…

Symbol Mc1000 声音的设置以及播放

首先引用Symbol.Audio 加一命名空间using Symbol.Audio; /声音设备的设置 //Select Device from device list Symbol.Audio.Device MyDevice (Symbol.Audio.Device)Symbol.StandardForms.SelectDevice.Select( Symbol.Audio.Controller.Title, Symbol.Audio.Devic…

/bin/bash^M: 坏的解释器: 没有那个文件或目录

在win下编辑的时候&#xff0c;换行结尾是\n\r &#xff0c; 而在linux下 是\n&#xff0c;所以会多出来一个\r&#xff0c;这样会出现错误 此时执行 sed -i s/\r$// file.sh 将file.sh中的\r都替换为空白&#xff0c;问题解决转载于:https://www.cnblogs.com/zzdbullet/p/9890…

rcp rapido_为什么气流非常适合Rapido

rcp rapidoBack in 2019, when we were building our data platform, we started building the data platform with Hadoop 2.8 and Apache Hive, managing our own HDFS. The need for managing workflows whether it’s data pipelines, i.e. ETL’s, machine learning predi…

pandas处理丢失数据与数据导入导出

3.4pandas处理丢失数据 头文件&#xff1a; import numpy as np import pandas as pd丢弃数据部分&#xff1a; dates pd.date_range(20130101,periods6) df pd.DataFrame(np.random.randn(6,4),indexdates,columns[A,B,C,D]) df.iloc[0,1] np.nan df.iloc[1,2] np.nanp…

Mysql5.7开启远程

2019独角兽企业重金招聘Python工程师标准>>> 1.注掉bind-address #bind-address 127.0.0.1 2.开启远程访问权限 grant all privileges on *.* to root"xxx.xxx.xxx.xxx" identified by "密码"; 或 grant all privileges on *.* to root"%…

分类结果可视化python_可视化分类结果的另一种方法

分类结果可视化pythonI love good data visualizations. Back in the days when I did my PhD in particle physics, I was stunned by the histograms my colleagues built and how much information was accumulated in one single plot.我喜欢出色的数据可视化。 早在我获得…