3.PyTorch——常用神经网络层

import numpy as np
import pandas as pd
import torch as t
from PIL import Image
from torchvision.transforms import ToTensor, ToPILImaget.__version__
'2.1.1'

3.1 图像相关层

图像相关层主要包括卷积层(Conv)、池化层(Pool)等,这些层在实际使用中可分为一维(1D)、二维(2D)、三维(3D),池化方式又分为平均池化(AvgPool)、最大值池化(MaxPool)、自适应池化(AdaptiveAvgPool)等。而卷积层除了常用的前向卷积之外,还有逆卷积(TransposeConv)。

除了这里的使用,图像的卷积操作还有各种变体,具体可以参照此处动图[^2]介绍。 [^2]: https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md

to_tensor = ToTensor()
to_pil = ToPILImage()
lena = Image.open('imgs/lena.png')
lena

在这里插入图片描述

# layer对输入形状都有假设:输入的不是单个数据,而是一个batch。
# 这里输入一个数据,就必须调用tensor.unsqueeze(0)增加一个维度,伪装成batch_size=1的batch
input = to_tensor(lena).unsqueeze(0)# 锐化卷积核
kernel = t.ones(3, 3) / -9
kernel[1][1] = 1
conv = t.nn.Conv2d(1, 1, (3, 3), 1, bias=False)
conv.weight.data = kernel.view(1, 1, 3, 3)out = conv(input)
to_pil(out.data.squeeze(0))

在这里插入图片描述

池化层:可视为一种特殊的卷积层,用来下采样。注意池化层是没有可学习参数的,其weight是固定的。

pool = t.nn.AvgPool2d(2, 2)
out = pool(input)
to_pil(out.data.squeeze(0))

在这里插入图片描述

除了卷积层和池化层,深度学习中还将常用到以下几个层:

  • Linear:全连接层。
  • BatchNorm:批规范化层,分为1D、2D和3D。除了标准的BatchNorm之外,还有在风格迁移中常用到的InstanceNorm层。
  • Dropout:dropout层,用来防止过拟合,同样分为1D、2D和3D。
    下面通过例子来说明它们的使用。
# 输入batch_size=2, 维度3
input = t.rand(2,3)
linear = t.nn.Linear(3, 4)
h = linear(input)
h
tensor([[-0.2314, -0.2245,  0.0966,  0.7610],[-0.2679, -0.2403,  0.0086,  0.5799]], grad_fn=<AddmmBackward0>)
# 4 channel,初始化标准差4,均值0
bn = t.nn.BatchNorm1d(4)
bn.weight.data = t.ones(4) * 4
bn.bias.data = t.zeros(4)bn_out = bn(h)
print(bn_out)
bn_out.mean(), bn_out.var(0, unbiased=False)       # 由于计算无偏方差分母会减1, 使用unbiased=1分母不减一
tensor([[ 3.9415,  3.7136,  3.9897,  3.9976],[-3.9415, -3.7136, -3.9897, -3.9976]],grad_fn=<NativeBatchNormBackward0>)(tensor(-1.8775e-06, grad_fn=<MeanBackward0>),tensor([15.5355, 13.7908, 15.9179, 15.9805], grad_fn=<VarBackward0>))
# 每个元素以0.5的概率舍弃
dropout = t.nn.Dropout(0.5)
o = dropout(bn_out)
o           
tensor([[ 7.8830,  7.4272,  0.0000,  7.9951],[-7.8830, -7.4272, -7.9794, -7.9951]], grad_fn=<MulBackward0>)

3.2 激活函数

PyTorch实现了常见的激活函数,其具体的接口信息可参见官方文档1,这些激活函数可作为独立的layer使用。这里将介绍最常用的激活函数ReLU,其数学表达式为:
R e L U ( x ) = m a x ( 0 , x ) ReLU(x)=max(0,x) ReLU(x)=max(0,x)

ReLU函数有个inplace参数,如果设为True,它会把输出直接覆盖到输入中,这样可以节省内存/显存。之所以可以覆盖是因为在计算ReLU的反向传播时,只需根据输出就能够推算出反向传播的梯度。但是只有少数的autograd操作支持inplace操作(如tensor.sigmoid_()),除非你明确地知道自己在做什么,否则一般不要使用inplace操作。

relu = t.nn.ReLU(inplace=True)
input = t.randn(2, 3)
print(input)
output = relu(input)
print(output)        # 负数都被截断为0
tensor([[-0.4064, -0.1886,  0.4812],[ 0.8996, -0.3606,  0.6127]])
tensor([[0.0000, 0.0000, 0.4812],[0.8996, 0.0000, 0.6127]])

对于此类网络如果每次都写复杂的forward函数会有些麻烦,在此就有两种简化方式,ModuleList和Sequential。其中Sequential是一个特殊的module,它包含几个子Module,前向传播时会将输入一层接一层的传递下去。ModuleList也是一个特殊的module,可以包含几个子module,可以像用list一样使用它,但不能直接把输入传给ModuleList。下面举例说明。

# Sequential
net = t.nn.Sequential(t.nn.Conv2d(3, 3, 3),t.nn.BatchNorm2d(3),t.nn.ReLU()
)
print('net:', net)
net: Sequential((0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))(1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU()
)
# 可根据名字或序号取出module
net[2]
ReLU()
input = t.randn(1, 3, 4, 4)
output = net(input)
output
tensor([[[[1.2239, 0.0000],[0.0000, 0.6354]],[[0.1855, 0.0000],[0.7218, 0.7777]],[[1.3686, 0.0000],[0.4861, 0.0000]]]], grad_fn=<ReluBackward0>)
# modellist
modellist = t.nn.ModuleList([t.nn.Linear(3, 4), t.nn.ReLU(), t.nn.Linear(4, 2)])
input = t.randn(1, 3)for model in modellist:input = model(input)print(input)
tensor([[-0.1817,  0.3852,  1.3656, -0.5643]], grad_fn=<AddmmBackward0>)
tensor([[0.0000, 0.3852, 1.3656, 0.0000]], grad_fn=<ReluBackward0>)
tensor([[-0.0151, -0.0309]], grad_fn=<AddmmBackward0>)

3.3 RNN循环神经网络

关于RNN的基础知识,推荐阅读colah的文章2入门。PyTorch中实现了如今最常用的三种RNN:RNN(vanilla RNN)、LSTM和GRU。此外还有对应的三种RNNCell。

RNN和RNNCell层的区别在于前者一次能够处理整个序列,而后者一次只处理序列中一个时间点的数据,前者封装更完备更易于使用,后者更具灵活性。实际上RNN层的一种后端实现方式就是调用RNNCell来实现的。

t.manual_seed(1000)
# 输入:batch_size=3, 序列长度为2,序列中每个元素占4维
input = t.randn(2, 3, 4)
# lstm输入向量4维,隐藏元3. 1层
lstm = t.nn.LSTM(4, 3, 1)
# 初始状态:1层,batch_size=3, 3个隐藏元
h0 = t.randn(1, 3, 3)
c0 = t.randn(1, 3, 3)
out, hn = lstm(input, (h0, c0))
out
tensor([[[-0.3610, -0.1643,  0.1631],[-0.0613, -0.4937, -0.1642],[ 0.5080, -0.4175,  0.2502]],[[-0.0703, -0.0393, -0.0429],[ 0.2085, -0.3005, -0.2686],[ 0.1482, -0.4728,  0.1425]]], grad_fn=<MkldnnRnnLayerBackward0>)
t.manual_seed(1000)
input = t.randn(2, 3, 4)
# 一个LSTMCell对应的层数只能是一层
lstm = t.nn.LSTMCell(4, 3)
hx = t.randn(3, 3)
cx = t.randn(3, 3)
out = []
for i_ in input:hx, cx=lstm(i_, (hx, cx))out.append(hx)
t.stack(out)
tensor([[[-0.3610, -0.1643,  0.1631],[-0.0613, -0.4937, -0.1642],[ 0.5080, -0.4175,  0.2502]],[[-0.0703, -0.0393, -0.0429],[ 0.2085, -0.3005, -0.2686],[ 0.1482, -0.4728,  0.1425]]], grad_fn=<StackBackward0>)

3.4 损失函数

损失函数可看作是一种特殊的layer,PyTorch也将这些损失函数实现为nn.Module的子类。然而在实际使用中通常将这些loss function专门提取出来,和主模型互相独立。详细的loss使用请参照文档3,这里以分类中最常用的交叉熵损失CrossEntropyloss为例说明。

# batch_size = 3, 计算对应每个类别的分数(只有两个类别)
score = t.randn(3, 2)
# 三个样本分别属于1, 0, 1类,label必须是LongTensor
label = t.Tensor([1, 0, 1]).long()# loss与普通的layer无差异
criterion = t.nn.CrossEntropyLoss()
loss = criterion(score, label)
loss
tensor(1.8772)

  1. http://pytorch.org/docs/nn.html#non-linear-activations ↩︎

  2. http://colah.github.io/posts/2015-08-Understanding-LSTMs/ ↩︎

  3. http://pytorch.org/docs/nn.html#loss-functions ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/206333.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

车联网通信安全之DDS

前面笔者写过一篇文章简单介绍了下工业级通信协议DDS&#xff0c;DDS在车联网领域特别是智能驾驶领域也有着十分广泛的应用&#xff0c;如&#xff1a; 1. 车辆实时数据传输 DDS可以用于车辆之间和车辆与云端之间的实时数据传输&#xff0c;包括车辆状态、传感器数据、位置信息…

node.js和浏览器之间的区别

node.js是什么 Node.js是一种基于Chrome V8引擎的JavaScript运行环境&#xff0c;可以在服务器端运行JavaScript代码 Node.js 在浏览器之外运行 V8 JavaScript 引擎。 这使得 Node.js 非常高效。 浏览器如何运行js代码 nodejs运行环境 在浏览器中&#xff0c;大部分时间你所…

26.Python 网络爬虫

目录 1.网络爬虫简介2.使用urllib3.使用request4.使用BeautifulSoup 1.网络爬虫简介 网络爬虫是一种按照一定的规则&#xff0c;自动爬去万维网信息的程序或脚本。一般从某个网站某个网页开始&#xff0c;读取网页的内容&#xff0c;同时检索页面包含的有用链接地址&#xff0…

Linux系统调试课:USB 常用调试方法

文章目录 一、USB调试工具有哪些二、USB相关节点2.1、USB枚举成功标志2.2、USB speed查询2.3、USB 查询PID、VID沉淀、分享、成长,让自己和他人都能有所收获!😄 📢本篇章主要 介绍 USB 常用调试方法。 一、USB调试工具有哪些

Qt开发 之 记一次安装 Qt5.12.12 安卓环境的失败案例

文章目录 1、安装Qt2、安卓开发的组合套件2.1、CSDN地址2.2、官网地址2.3、发现老方法不适用了3、尝试用新方法解决3.1、先安装JDK,搞定JDK环境变量3.1.1、安装jdk3.1.2、确定jdk安装路径3.1.3、打开系统环境变量配置3.1.4、配置系统环境变量3.1.5、验证JDK环境变量是否配置成…

Educational Codeforces Round 158 (Rated for Div. 2) A-D

文章目录 A. Line TripB. Chip and RibbonC. Add, Divide and FloorD. Yet Another Monster Fight A. Line Trip 签到 #include <bits/stdc.h>using namespace std; const int N 2e5 5; typedef long long ll; typedef pair<ll, ll> pll; typedef array<ll,…

Sanic:一个极速Python Web框架

更多Python学习内容&#xff1a;ipengtao.com 大家好&#xff0c;我是彭涛&#xff0c;今天为大家分享 Sanic&#xff1a;一个极速Python Web框架&#xff0c;全文3500字&#xff0c;阅读大约12分钟。 随着 Web 应用的日益复杂&#xff0c;选择一个高性能的 Web 框架变得尤为…

扫描器的使用

漏扫器 注意事项 扫描器会给客户的业务造成影响。比如&#xff0c;如果存在sql注入漏洞&#xff08;重大的漏洞&#xff09;的话&#xff0c;会给客户的数据库插入脏数据&#xff0c;后果很严重 主机漏扫 针对IP地址和网段的漏洞扫描&#xff0c;例如&#xff1a;22端口弱口…

LCM-LoRA:a universal stable-diffusion acceleration module

Consistency is All You Need - wrong.wang什么都不用做生成却快了十倍其实也并非完全不可能https://wrong.wang/blog/20231111-consistency-is-all-you-need/ 1.Stable diffusion实在预训练VAE空间训练diffusion model的结果。 2.consistency decoder是用consistency model技…

ISIS默认路由下发的各种机制

作者简介&#xff1a;大家好&#xff0c;我是Asshebaby&#xff0c;热爱网工&#xff0c;有网络方面不懂的可以加我一起探讨 :1125069544 个人主页&#xff1a;Asshebaby博客 当前专栏&#xff1a; 网络HCIP内容 特色专栏&#xff1a; 常见的项目配置 本文内容&am…

017 OpenCV 向量机SVM

目录 一、环境 二、SVM原理 三、完整代码 一、环境 本文使用环境为&#xff1a; Windows10Python 3.9.17opencv-python 4.8.0.74 二、SVM原理 OpenCV中的向量机&#xff08;SVM&#xff09;是一种监督学习算法&#xff0c;用于分类和回归分析。它通过找到一个最优的超平…

振弦采集仪在岩土工程中的探索与应用

振弦采集仪在岩土工程中的探索与应用 振弦采集仪是一种常用的测量仪器&#xff0c;在岩土工程中具有重要的应用价值。它主要利用振弦原理&#xff0c;通过测量振动信号的特征参数来分析地下土体的力学特性以及工程中的变形情况。 振弦采集仪早期主要用于建筑物、桥梁、堤坝等…

手机拍照的图片,如何传到电脑上?

手机受性能和屏幕限制&#xff0c;其应用功能也多少会因此而受到影响&#xff0c;比如在金鸣识别的电脑客户端&#xff0c;用户可一次性提交100张的图片进行识别&#xff0c;而在移动端&#xff0c;则最多只能一次三张&#xff0c;如何破这个“局”呢&#xff1f; 一、有扫描仪…

唯创知音WTN6040F语音芯片声音提示IC在家用雾化器中的应用

近年来&#xff0c;随着空气质量的恶化和呼吸道疾病的增多&#xff0c;家用雾化器成为了越来越多家庭的必备健康设备。然而&#xff0c;对于许多用户来说&#xff0c;正确操作雾化器并准确掌握药物的使用时机和方式是一个挑战。为了解决这一问题&#xff0c;唯创知音WTN6040F语…

RT-DETR优化:Backbone改进 | UniRepLKNet,通用感知大内核卷积网络,RepLK改进版本 | 2023.11

🚀🚀🚀本文改进: UniRepLKNet,通用感知大内核卷积网络,ImageNet-22K预训练,精度和速度SOTA,ImageNet达到88%, COCO达到56.4 box AP,ADE20K达到55.6 mIoU 🚀🚀🚀RT-DETR改进创新专栏:http://t.csdnimg.cn/vuQTz 学姐带你学习YOLOv8,从入门到创新,轻轻松松…

C++入门【4-C++ 变量作用域】

C 变量作用域 一般来说有三个地方可以定义变量&#xff1a; 在函数或一个代码块内部声明的变量&#xff0c;称为局部变量。在函数参数的定义中声明的变量&#xff0c;称为形式参数。在所有函数外部声明的变量&#xff0c;称为全局变量。 作用域是程序的一个区域&#xff0c;…

Linux centos8安装JDK1.8、tomcat

一、安装jdk 1.如果之前安装过jdk&#xff0c;先卸载掉旧的 rpm -qa | grep -i jdk 2.检查yum中有没有java1.8的包 yum list java-1.8* 3.yum安装jdk yum install java-1.8.0-openjdk* -y 4.验证 二、安装tomcat Index of /tomcat 可以在这里选择你想要安装的tomcat版本…

假设检验(三)(单侧假设检验)

在 《假设检验&#xff08;二&#xff09;&#xff08;正态总体参数的假设检验&#xff09;》中我们讨论了形如 H 0 : θ θ 0 ↔ H 1 : θ ≠ θ 0 H_0:\theta\theta_0 \leftrightarrow H_1:\theta \neq \theta_0 H0​:θθ0​↔H1​:θθ0​ 的假设检验问题&#xff0c;其…

Centos7部署Graylog5.2日志系统

Graylog5.2部署 Graylog 5.2适配MongoDB 5.x~6.x&#xff0c;MongoDB5.0要求CPU支持AVX指令集。 主机说明localhost部署Graylog&#xff0c;需要安装mongodb-org-6.0、 Elasticsearch7.10.2 参考&#xff1a; https://blog.csdn.net/qixiaolinlin/article/details/129966703 …

洛谷(md版)

小知识点 1.printf()一行一个双引号“” 2.double->%lf 3.例题 ​​​​​​​​​​​​​​ ​​​4. 这两者不一样 上行&#xff1a;先转化成了浮点数&#xff0c;再运算 下行&#xff1a;先运算的整数&#xff0c;得到结果&#xff0c;再转化成浮点数 no1 no / (…