机器学习深度学习——torch.nn模块

机器学习&&深度学习——torch.nn模块

  • 卷积层
  • 池化层
  • 激活函数
  • 循环层
  • 全连接层

torch.nn模块包含着torch已经准备好的层,方便使用者调用构建网络。

卷积层

卷积就是输入和卷积核之间的内积运算,如下图:
在这里插入图片描述
容易发现,卷积神经网络中通过输入卷积核来进行卷积操作,使输入单元(图像或特征映射)和输出单元(特征映射)之间的连接时稀疏的,能够减少需要的训练参数的数量,从而加快网络计算速度。
卷积的分类如下所示,大体分为一维卷积、二维卷积、三位卷积以及转置卷积(简单理解为卷积操作的逆操作)

层对应的类功能作用
torch.nn.Conv1d()针对输入信号上应用1D卷积
torch.nn.Conv2d()针对输入信号上应用2D卷积
torch.nn.Conv3d()针对输入信号上应用3D卷积
torch.nn.ConvTranspose1d()在输入信号上应用1D转置卷积
torch.nn.ConvTranspose2d()在输入信号上应用2D转置卷积
torch.nn.ConvTranspose3d()在输入信号上应用3D转置卷积

以torch.nn.Conv2d()为例,介绍卷积在图像上的使用方法,其调用方式为:

torch.nn.Conv2d(in_channels,outchannels,kernel_size,stride=1,padding=0,dilation=1,groups=1,bias=True)

直接说前面三个参数吧,这三个是必选的参数,其他的参数作用可以看下面的这个文章:
torch.nn.Conv2d() 用法讲解
必选参数:

in_channels:输入的通道数目
out_channels:输出的通道数目
kernel_size:卷积核的大小,类型为int或元组,当卷积为方形时,只需要一个整形边长即可,否则要输入一个元组表示高和宽

现在我们针对一个二维图像来做卷积并观察结果:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image# 使用PIL包读取图像数据,使用matplotlib包来可视化图像和卷积后的结果
# 读取图像->转化为灰度图像->转化为Numpy数组
myim = Image.open("data/chap2/yier.jpg")
myimgray = np.array(myim.convert("L"), dtype=np.float32)
# 可视化图片
plt.figure(figsize=(6, 6))
plt.imshow(myimgray, cmap=plt.cm.gray)
plt.axis("off")
# plt.show()# 上述操作得到一个512×512的数组,在卷积前,需要转化为1×1×512×512的张量
imh, imw = myimgray.shape
myimgray_t = torch.from_numpy(myimgray.reshape(1, 1, imh, imw))
# print(myimgray_t.shape)# 卷积时需将图像转化为四维来表示[batch,channel,h,w],卷积后得到两个特征映射:
# 第一个特征映射使用图像轮廓提取卷积核获取,第二个特征映射使用的卷积核为随机数
# 卷积核大小为5×5,且不使用0填充,则卷积后输出特征映射的尺寸为508×508
# 下面进行卷积,且对卷积后的两个特征映射进行可视化kersize = 5  # 定义边缘检测卷积核,并将维度处理为1*1*5*5
ker = torch.ones(kersize, kersize, dtype=torch.float32) * -1
ker[2, 2] = 24
ker = ker.reshape((1, 1, kersize, kersize))
# 此时ker矩阵为:
# tensor([[[[-1., -1., -1., -1., -1.],
#           [-1., -1., -1., -1., -1.],
#           [-1., -1., 24., -1., -1.],
#           [-1., -1., -1., -1., -1.],
#           [-1., -1., -1., -1., -1.]]]])
#           用意还是很好理解的,如果不再边缘上,那么乘积之和就是0,否则看结果正负也容易知道边缘所在的大概位置# 进行卷积操作
conv2d = nn.Conv2d(1, 2, (kersize, kersize), bias=False)
# 设置卷积时使用的核,第一个核使用边缘检测核
conv2d.weight.data[0] = ker
# 对灰度图像进行卷积操作
imconv2dout = conv2d(myimgray_t)
# 对卷积后的输出进行维度压缩
imconv2dout_im = imconv2dout.data.squeeze()
# print(imconv2dout_im.shape)
# 可视化卷积后的图像
plt.figure(figsize=(12, 12))
plt.subplot(2, 2, 1)
plt.imshow(myim)
plt.axis("off")
plt.subplot(2, 2, 2)
plt.imshow(myimgray, cmap=plt.cm.gray)
plt.axis("off")
plt.subplot(2, 2, 3)
plt.imshow(imconv2dout_im[0], cmap=plt.cm.gray)
plt.axis("off")
plt.subplot(2, 2, 4)
plt.imshow(imconv2dout_im[1], cmap=plt.cm.gray)
plt.axis("off")
plt.show()

结果:
在这里插入图片描述
可以看出,使用的边缘特征提取卷积核很好地提取出了图像的边缘信息。而使用随机数的卷积核得到的卷积结果与原始图像很相似。

池化层

池化的一个重要目的是对卷积后得到的特征进行进一步处理(主要是降维),池化层可以对数据进一步浓缩,从而缓解内存压力。
池化会选取一定大小区域,将该区域内的像素值用一个代表元素表示,如下图表示滑动窗口2×2,且步长为2时的最大值池化和平均值池化:
在这里插入图片描述
在pytorch中有多种池化的类,分别是最大值池化(MaxPool)、最大值池化的逆过程(MaxUnPool)、平均值池化(AvgPool)与自适应池化(AdaptiveMaxPool、AdaptiveAvgPool)等,且都提供了一二三维的池化操作。
如果对上一个卷积后的图像进行池化,并且使用步长为2的最大值池化或平均值池化以后,所得到的尺寸将会变为254×254。
如果使用nn.AdaptiveAvgPool2d()函数,构造时可以指定其池化后的大小。
池化后,特征映射的尺寸变小,图像变得更模糊。

激活函数

下面的一些函数感觉也都是很基础的,一些关于双曲正余弦函数、双曲正切函数、梯度的概念给搞懂了就没什么问题。

层对应的类功能
torch.nn.SigmoidSigmoid激活函数
torch.nn.TanhTanh激活函数
torch.nn.ReLUReLU激活函数
torch.nn.SoftplusReLU激活函数的平滑近似

torch.nn.Sigmoid()
其对应的Sigmoid激活函数,又叫logistic激活函数:
f ( x ) = 1 1 + e − x f(x)=\frac{1}{1+e^{-x}} f(x)=1+ex1
其输出在(0,1)这个开区间内,该函数在神经网络早期也是很常用的激活函数之一,但是当输入远离坐标源点时,函数的梯度就会变得很小,几乎为0,因此会影响参数的更新速度
torch.nn.Tanh()
对应双曲正切函数:
f ( x ) = e x − e − x e x + e − x f(x)=\frac{e^x-e^{-x}}{e^x+e^{-x}} f(x)=ex+exexex
其输出区间在(-1,1)之间,整个函数以0为中心,虽然与Sigmoid一样,当输入很大或很小时,梯度很小,不利于权重的更新,但毕竟是以0为对称,使用效果会比Sigmoid好很多
torch.nn.ReLU()
其对应的ReLU函数又叫修正线性单元,计算方式为:
f ( x ) = m a x ( 0 , x ) f(x)=max(0,x) f(x)=max(0,x)
其只保留大于0的输出。而在输入正数时,不会存在梯度饱和的问题,计算速度会更快,而且因为ReLU函数只有线性关系,所以不管是前向传播还是反向传播都很快。
torch.nn.Softplus()
对应的平滑近似ReLU的激活函数,计算公式:
f ( x ) = 1 β l o g ( 1 + e β x ) f(x)=\frac{1}{β}log(1+e^{βx}) f(x)=β1log(1+eβx)
β默认为1。这个函数可以在任何位置求导数,且尽可能保留了ReLU函数的优点。

循环层

pytorch提供三种循环层实现:

层对应的类功能
torch.nn.RNN()多层RNN单元
torch.nn.LSTM()多层长短期记忆LSTM单元
torch.nn.GRU()多层门限循环GRU单元
torch.nn.RNNCell()一个RNN循环层单元
torch.nn.LSTMCell()一个长短期记忆LSTM单元
torch.nn.GRUCell()一个门限循环GRU单元

几个循环层函数的原理将在之后更新。

全连接层

指一个由多个神经元所组成的层,其所有的输出和该层的所有输入都有连接,即每个输入都会影响所有神经元的输出。
在pytorch中,nn.Linear()表示线性变换,全连接层可以看作是nn.Linear()表示线性变层再加上一个激活函数层所构成的结构。
nn.Linear()全连接操作及相关参数:
torch.nn.Linear(in_features,out_features,bias=True)
参数说明如下:
in_feature:每个输入样本的特征数量
out_feature:每个输出样本的特征数量
bias:若设置为False,则该层不会设置偏置,默认为True

torch.nn.Linear()的输入为(N,in_feature)的张量,输出为(N,out_feature)的张量。
全连接层的应用广泛,只有全连接层组成的网络是全连接神经网络,可用于数据的分类或回归预测,卷积神经网络和循环神经网络的末端,通常多个连接层组成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/7327.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

uniapp 微信小程序 placeholder字体、颜色自定义

效果图&#xff1a; 1、template <input type"text" placeholder"搜索标题" placeholder-class"placeholder-style"></input>2、style .placeholder-style{color: #2D94FF; }

微服务探索之路06篇k8s配置文件Yaml部署Redis使用Helm部署MongoDB和kafka

1 安装Redis 1.1创建配置文件redis.conf 切换到自己的目录下如本文是放在/home/ubuntu下 cd /home/ubuntuvim redis.conf bind 0.0.0.0 protected-mode yes port 6379 requirepass qwe123456 tcp-backlog 511 timeout 0 tcp-keepalive 300 daemonize no pidfile /var/run/r…

生产者消费者模型

生产者消费者模型 文章目录 生产者消费者模型概念原则优点 基于BlockingQueue的生产者消费者模型BlockingQueue模拟实现单生产者消费者模型基于计算任务和存储任务的生产者消费者模型 概念 生产者消费者模式就是通过一个容器来解决生产者和消费者的强耦合问题生产者和消费者彼…

代码随想录| 图论02●695岛屿最大面积 ●1020飞地的数量 ●130被围绕的区域 ●417太平洋大西洋水流问题

#695岛屿最大面积 模板题&#xff0c;很快.以下两种dfs&#xff0c;区别是看第一个点放不放到dfs函数中处理&#xff0c;那么初始化的area一个是1一个是0 int dir[4][2]{0,1,0,-1,1,0,-1,0};void dfs(int x, int y,int n, int m, int &area,vector<vector<bool>…

2023最新谷粒商城笔记之Sentinel概述篇(全文总共13万字,超详细)

Sentinel概述 服务流控、熔断和降级 什么是熔断 当扇出链路的某个微服务不可用或者响应时间太长时&#xff0c;会进行服务的降级&#xff0c;**进而熔断该节点微服务的调用&#xff0c;快速返回错误的响应信息。**检测到该节点微服务调用响应正常后恢复调用链路。A服务调用B服…

构建高效供应商管理体系,提升企业采购能力

随着企业采购规模的不断扩大和全球化竞争的加剧&#xff0c;供应商管理变得越来越重要。构建一个高效的供应商管理体系是企业提升采购能力、降低采购成本的关键一环。本文将重点探讨供应商管理体系的意义和作用&#xff0c;并介绍如何构建一个高效的供应商管理体系。 一、供应商…

SpringBoot复习:(1)常用的SpringApplication.run返回的容器的具体类型是哪个?

run方法中调用了createApplicationContext方法 createApplicationContext方法代码如下&#xff1a; 其中create代码如下&#xff1a; 可见返回的是AnnotationConfigServletWebServerApplicationContext()

【搜索引擎Solr】配置 Solr 以获得最佳性能

Apache Solr 是广泛使用的搜索引擎。有几个著名的平台使用 Solr&#xff1b;Netflix 和 Instagram 是其中的一些名称。我们在 tajawal 的应用程序中一直使用 Solr 和 ElasticSearch。在这篇文章中&#xff0c;我将为您提供一些关于如何编写优化的 Schema 文件的技巧。我们不会讨…

基于Python+WaveNet+CTC+Tensorflow智能语音识别与方言分类—深度学习算法应用(含全部工程源码)

目录 前言总体设计系统整体结构图系统流程图 运行环境Python 环境Tensorflow 环境 模块实现1. 方言分类数据下载及预处理模型构建模型训练及保存 2. 语音识别数据预处理模型构建模型训练及保存 3. 模型测试功能选择界面语言识别功能实现界面方言分类功能实现界面 系统测试1. 训…

【RabbitMQ(day1)】RabbitMQ的概述和安装

入门RabbitMQ 一、RabbitMQ的概述二、RabbitMQ的安装三、RabbitMQ管理命令行四、RabbitMQ的GUI界面 一、RabbitMQ的概述 MQ&#xff08;Message Queue&#xff09;翻译为消息队列&#xff0c;通过典型的【生产者】和【消费者】模型&#xff0c;生产者不断向消息队列中生产消息&…

【DDD】业务领域定义

文章目录 前言一、什么是业务子领域&#xff1f;二、子领域的类型有哪些&#xff1f;2.1、核心子领域2.2、通用子领域2.3、支撑子领域 三、子领域差异对比3.1、竞争优势比较3.2、复杂性比较3.3、易变性比较3.4、实时策略比较 总结 前言 一个业务领域是一个公司的主要活动领域的…

redis(11):springboot中使用redis

1 创建springboot项目 2 创建pom文件 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http:/…

vue3+Luckysheet实现表格的在线预览编辑(electron可用)

前言&#xff1a; 整理中 官方资料&#xff1a; 1、github 项目地址https://github.com/oy-paddy/luckysheet-vue-importAndExport/tree/master/https://github.com/oy-paddy/luckysheet-vue-importAndExport/tree/master/ 2、xlsx vue3 json数据导出excel_vue3导出excel_羊…

【SpirngCloud】分布式事务解决方案

【SpirngCloud】分布式事务解决方案 文章目录 【SpirngCloud】分布式事务解决方案1. 理论基础1.1 CAP 理论1.2 BASE 理论1.3 分布式事务模型 2. Seata 架构2.1 项目引入 Seata 3. 强一致性分布式事务解决方案3.1 XA 模式3.1.1 seata的XA模式3.1.2 XA 模式实践3.1.3 总结 4. 最终…

React AntDesign表批量操作时的selectedRowKeys回显选中

不知道大家是不是在AntDesign的某一个列表想要做一个批量导出或者操作的时候&#xff0c;发现只要选择下一页&#xff0c;即使选中的ids 都有记录下面&#xff0c;但是就是不回显 后来问了chatGPT&#xff0c;对方的回答是&#xff1a; 在Ant Design的DataTable组件中&#xf…

什么是框架?为什么要学框架?

一、什么是框架 框架是整个或部分应用的可重用设计&#xff0c;是可定制化的应用骨架。它可以帮开发人员简化开发过程&#xff0c;提高开发效率。 项目里有一部分代码&#xff1a;和业务无关&#xff0c;而又不得不写的代码>框架 项目里剩下的部分代码&#xff1a;实现业务…

基于C++的QT基础教程学习笔记

文章目录&#xff1a; 来源 教程社区 一&#xff1a;QT下载安装 二&#xff1a;注意事项 1.在哪里写程序 2.如何看手册 3.技巧 三&#xff1a;常用函数 1.窗口 2.相关 3.按钮 4.信号与槽函数 5.常用栏 菜单栏 工具栏 状态栏 6.铆接部件 7.文本编辑 8…

Docker Compose(九)

一、背景&#xff1a; 对于现代应用来说&#xff0c;大多数都是通过很多的微服务互相协同组成一个完整的应用。例如&#xff0c;订单管理、用户管理、品类管理、缓存服务、数据库服务等&#xff0c;他们构成了一个电商平台的应用。而部署和管理大量的服务容器是一件非常繁琐的事…

【时间复杂度】

旋转数组 题目 给定一个整数数组 nums&#xff0c;将数组中的元素向右轮转 k 个位置&#xff0c;其中 k 是非负数。 /* 解题思路&#xff1a;使用三次逆转法&#xff0c;让数组旋转k次 1. 先整体逆转 // 1,2,3,4,5,6,7 // 7 6 5 4 3 2 1 2. 逆转子数组[0, k - 1] // 5 6 7 4 3…

疲劳驾驶检测和识别2:Pytorch实现疲劳驾驶检测和识别(含疲劳驾驶数据集和训练代码)

疲劳驾驶检测和识别2&#xff1a;Pytorch实现疲劳驾驶检测和识别(含疲劳驾驶数据集和训练代码) 目录 疲劳驾驶检测和识别2&#xff1a;Pytorch实现疲劳驾驶检测和识别(含疲劳驾驶数据集和训练代码) 1.疲劳驾驶检测和识别方法 2.疲劳驾驶数据集 &#xff08;1&#xff09;疲…