Pytorch nn.Linear()的基本用法与原理详解及全连接层简介

主要引用参考:
https://blog.csdn.net/zhaohongfei_358/article/details/122797190
https://blog.csdn.net/weixin_43135178/article/details/118735850

nn.Linear的基本定义

nn.Linear定义一个神经网络的线性层,方法签名如下:

torch.nn.Linear(in_features, # 输入的神经元个数out_features, # 输出神经元个数bias=True # 是否包含偏置)

Linear其实就是对输入 X n × i X_{n\times i} Xn×i执行了一个线性变换,即:
Y n × o = X n × i W i × o + b Y_{n\times o}=X_{n\times i}W_{i\times o}+b Yn×o=Xn×iWi×o+b
其中 W W W是模型想要学习的参数, W W W的维度为 W i × o W_{i\times o} Wi×o,b是o维的向量偏置,n为输入向量的行数(例如,你想一次输入10个样本,即batch_size为10,则n=10),i为输入神经元的个数(例如你的样本特征为5,则i=5),o为输出神经元的个数。

示例:

from torch import nn
import torchmodel = nn.Linear(2, 1) # 输入特征数为2,输出特征数为1
input = torch.Tensor([1, 2]) # 给一个样本,该样本有2个特征(这两个特征的值分别为1和2)
output = model(input)
output
tensor([-1.4166], grad_fn=<AddBackward0>)

我们的输入为[1,2],输出了[-1.4166]。可以查看模型参数验证一下上述的式子:

# 查看模型参数
for param in model.parameters():print(param)
Parameter containing:
tensor([[ 0.1098, -0.5404]], requires_grad=True)
Parameter containing:
tensor([-0.4456], requires_grad=True)

可以看到,模型有3个参数,分别为两个权重和一个偏执。计算可得:
y = [ 1 , 2 ] ∗ [ 0.1098 , − 0.5404 ] T − 0.4456 = − 1.4166 y=[1,2]*[0.1098,-0.5404]^T-0.4456=-1.4166 y=[1,2][0.1098,0.5404]T0.4456=1.4166


实战

假设我们的一次输入三个样本A,B,C(即batch_size为3),每个样本的特征数量为5:

A: [0.1,0.2,0.3,0.3,0.3]
B: [0.4,0.5,0.6,0.6,0.6]
C: [0.7,0.8,0.9,0.9,0.9]

则我们的输入向量 X 3 × 5 X_{3\times 5} X3×5为:

X = torch.Tensor([[0.1,0.2,0.3,0.3,0.3],[0.4,0.5,0.6,0.6,0.6],[0.7,0.8,0.9,0.9,0.9],
])
X
tensor([[0.1000, 0.2000, 0.3000, 0.3000, 0.3000],[0.4000, 0.5000, 0.6000, 0.6000, 0.6000],[0.7000, 0.8000, 0.9000, 0.9000, 0.9000]])

定义线性层,我们的输入特征为5,所以in_feature=5,我们想让下一层的神经元个数为10,所以out_feature=10,则模型参数为: W 5 × 10 W_{5\times 10} W5×10

model = nn.Linear(in_features=5, out_features=10, bias=True)

经过线性层,其实就是做了一件事,即:
Y 3 × 10 = X 3 × 5 W 5 × 10 + b Y_{3\times 10}=X_{3\times 5}W_{5\times 10}+b Y3×10=X3×5W5×10+b
具体表示为:
[ Y 00 Y 01 ⋯ Y 08 Y 09 Y 10 Y 11 ⋯ Y 18 Y 19 Y 20 Y 21 ⋯ Y 28 Y 29 ] = [ X 00 X 01 X 02 X 03 X 04 X 10 X 11 X 12 X 13 X 14 X 20 X 21 X 22 X 23 X 23 ] [ W 00 W 01 ⋯ W 08 W 09 W 10 W 11 ⋯ W 18 W 19 W 20 W 21 ⋯ W 28 W 29 W 30 W 31 ⋯ W 38 W 39 W 40 W 41 ⋯ W 48 W 49 ] + b \begin{equation} \left[ \begin{array}{ccc} Y_{00} & Y_{01} &\cdots & Y_{08} &Y_{09} \\ Y_{10} & Y_{11} &\cdots & Y_{18} &Y_{19} \\ Y_{20} & Y_{21} &\cdots & Y_{28} &Y_{29} \end{array} \right] =\left[ \begin{array}{ccc} X_{00} & X_{01} &X_{02} & X_{03} &X_{04} \\ X_{10} & X_{11} &X_{12} & X_{13} &X_{14} \\ X_{20} & X_{21} &X_{22} & X_{23} &X_{23} \end{array}\nonumber \right] \left[ \begin{array}{ccc} W_{00} & W_{01} &\cdots & W_{08} &W_{09} \\ W_{10} & W_{11} &\cdots & W_{18} &W_{19} \\ W_{20} & W_{21} &\cdots & W_{28} &W_{29} \\ W_{30} & W_{31} &\cdots & W_{38} &W_{39} \\ W_{40} & W_{41} &\cdots & W_{48} &W_{49} \\ \end{array} \right] +b \end{equation}\nonumber Y00Y10Y20Y01Y11Y21Y08Y18Y28Y09Y19Y29 = X00X10X20X01X11X21X02X12X22X03X13X23X04X14X23 W00W10W20W30W40W01W11W21W31W41W08W18W28W38W48W09W19W29W39W49 +b

个人的理解:比如 X X X第一行和 W W W矩阵的第一列相乘就相当于对样本A做了全局卷积,最后得到了1个特征,因为 W W W有10列,所以最后得到10个特征,也就是把5个特征转变为了10个特征。

其中 X i . X_i. Xi.就表示第i个样本, W . j W_{.j} W.j表示所有输入神经元到第j个输出神经元的权重。
在这里插入图片描述

注意:这里图有点问题,应该是 W 00 , W 01 , W 02 , . . . , W 07 , W 08 , W 09 W_{00}, W_{01}, W_{02}, ..., W_{07}, W_{08},W_{09} W00,W01,W02,...,W07,W08,W09(我没觉得图有问题)

因为有三个样本,所以相当于依次进行了三次 Y 3 × 10 = X 3 × 5 W 5 × 10 + b Y_{3\times 10}=X_{3\times 5}W_{5\times 10}+b Y3×10=X3×5W5×10+b,然后再将三个 Y 1 × 10 Y_{1\times 10} Y1×10叠在一起

经过线性层后,我们最终的到了3×10维的矩阵,即 输入3个样本,每个样本维度为5,输出为3个样本,将每个样本扩展成了10维

model(X).size()
torch.Size([3, 10])

全连接层

概述
全连接层 Fully Connected Layer 一般位于整个卷积神经网络的最后,负责将卷积输出的二维特征图转化成一维的一个向量,由此实现了端到端的学习过程(即:输入一张图像或一段语音,输出一个向量或信息)。全连接层的每一个结点都与上一层的所有结点相连因而称之为全连接层。由于其全相连的特性,一般全连接层的参数也是最多的。

主要作用
全连接层的主要作用就是将前层(卷积、池化等层)计算得到的特征空间映射样本标记空间。简单的说就是将特征表示整合成一个值,其优点在于减少特征位置对于分类结果的影响,提高了整个网络的鲁棒性。
全连接在整个网络卷积神经网络中起到“分类器”的作用,如果说卷积层、池化层和激活函数等操作是将原始数据映射到隐层特征空间的话,全连接层则起到将学到的特征表示映射到样本的标记空间的作用。其实,就是把特征整合到一起,方便交给最后的分类器或者回归。

实际操作
在实际使用中,全连接层可由卷积操作实现:对前层是全连接的全连接层可以转换为卷积核为1*1的卷积;而前层是卷积层的全连接层可以转换为卷积核为前层卷积输出结果的高和宽一样大小的全局卷积。

一个通俗的例子:
以VGG-16为例,对224x224x3的输入,最后一层卷积可得输出为7x7x512,如后层是一层含4096个神经元的FC,则可用卷积核为7x7x512x4096的全局卷积来实现这一全连接运算过程,其中该卷积核参数如下:“filter size = 7, padding = 0, stride = 1, D_in = 512, D_out = 4096”经过此卷积操作后可得输出为1x1x4096。如需再次叠加一个2048的FC,则可设定参数为“filter size = 1, padding = 0, stride = 1, D_in = 4096, D_out = 2048”的卷积层操作(个人理解就是用1×1×4096×2048的卷积核来卷积)。(参考:https://www.zhihu.com/question/41037974/answer/150522307)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/232342.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

生日蜡烛C语言

分析&#xff1a;假设这个人只能活到100岁&#xff0c;如果不这样规定的话&#xff0c;那么这个人就可以假设活到老236岁&#xff0c;直接一次吹236个蜡烛&#xff0c;我们就枚举出所以情况&#xff0c;从一岁开始。 #include <stdio.h> int f(int a,int b){//计算从a到…

视频素材网站全新上线,海量高清视频等你来探索~

亲爱的视频制作爱好者们&#xff0c;好消息来啦&#xff01;我们的视频素材网站全新上线啦&#xff01;这次我们为大家带来了海量的高清视频素材&#xff0c;无论是风景、城市、人物、动物还是各种特效、背景等&#xff0c;应有尽有&#xff0c;满足您在视频制作过程中的各种需…

【神器】wakatime代码时间追踪工具

文章目录 wakatime简介支持的IDE安装步骤API文档插件费用写在最后 wakatime简介 wakatime就是一个IDE插件&#xff0c;一个代码时间追踪工具。可自动获取码编码时长和度量指标&#xff0c;以产生很多的coding图形报表。这些指标图形可以为开发者统计coding信息&#xff0c;比如…

【MySQL】:复合查询

复合查询 一.多表查询二.自连接三.子查询1.单行子查询2.多行子查询3.多列子查询4.在from语句里使用子查询5.合并查询 准备三张表 emp表 dept表 salgrade表 一.多表查询 实际开发中往往数据来自不同的表&#xff0c;所以需要多表查询。我们用一个简单的公司管理系统&#xff0c…

HPM6750系列--第十一篇 Uart讲解(轮询模式)

一、目的 在介绍完GPIO的相关内容下一个必须介绍的就是uart了&#xff0c;因为串口一个主要用途就是用于调试信息打印。 HPM6750在uart的配置上也是相当炸裂&#xff0c;有17个串口&#xff1b;结合HPM6750的高主频高内存&#xff0c;完全可以作为一个串口服务器。 ​​​​​​…

智能优化算法应用:基于天牛须算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于天牛须算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于天牛须算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.天牛须算法4.实验参数设定5.算法结果6.参考文…

MLOps在极狐GitLab 的现状和前瞻

什么是 MLOps 首先我们可以这么定义机器学习&#xff08;Machine Learning&#xff09;&#xff1a;通过一组工具和算法&#xff0c;从给定数据集中提取信息以进行具有一定程度不确定性的预测&#xff0c;借助于这些预测增强用户体验或推动内部决策。 同一般的软件研发流程比…

【lesson17】MySQL表的基本操作--表去重、聚合函数和group by

文章目录 MySQL表的基本操作介绍插入结果查询&#xff08;表去重&#xff09;建表插入数据操作 聚合函数建表插入数据操作 group by&#xff08;分组&#xff09;建表插入数据操作 MySQL表的基本操作介绍 CRUD : Create(创建), Retrieve(读取)&#xff0c;Update(更新)&#x…

【TB作品】STM32 PWM之实现呼吸灯,STM32F103RCT6,晨启

文章目录 完整工程参考资料实验过程 实验任务&#xff1a; 1&#xff1a;实现PWM呼吸灯&#xff0c;定时器产生PWM&#xff0c;控制实验板上的LED灯亮灭&#xff1b; 2&#xff1a;通过任意两个按键切换PWM呼吸灯输出到两个不同的LED灯&#xff0c;实现亮灭效果&#xff1b; 3&…

Axure的案例演示

增删改查&#xff1a; 在中继器里面展示照片

创建型模式之抽象工厂模式

一、概述 1、抽象工厂模式&#xff1a;提供一个创建一系列相关或相互依赖对象的接口&#xff0c;而无需指定它们具体的类。 2、抽象工厂模式&#xff1a;一个工厂可以生产一系列产品&#xff08;一族产品&#xff09;&#xff0c;极大减少了工厂类的数量 3、抽象工厂模式&am…

众和策略:加强经济监测预测预警 加大宏观调控力度

12月17日至18日&#xff0c;全国展开和革新作业会议在京举行&#xff0c;整理总结2023年展开革新作业&#xff0c;组织布置2024年展开革新关键使命。会议指出&#xff0c;中心经济作业会议对本年经济作业作了全面体系总结&#xff0c;侧重我国经济全体上升向好&#xff0c;全年…

选择合适教育管理软件:必须考虑的10个关键问题

随着教育行业的迅速数字化&#xff0c;学校要能够提供最新的管理和教育方法。大家逐渐意识到技术让运营变得更容易、更有效率。 不过首先我们需要找到一个能满足需求的应用程序。面对众多的选择&#xff0c;你该如何选择一个合适的平台呢&#xff1f;当然&#xff0c;没有人想…

MYSQL中使用IN,在xml文件中怎么写?

MYSQL&#xff1a; Spring中&#xff1a; mysql中IN后边的集合&#xff0c;在后端中使用集合代替&#xff0c;其他的没有什么注意的&#xff0c;还需要了解foreach 语法即可。

Spark编程实验一:Spark和Hadoop的安装使用

目录 一、目的与要求 二、实验内容 三、实验步骤 1、安装Hadoop和Spark 2、HDFS常用操作 3、Spark读取文件系统的数据 四、结果分析与实验体会 一、目的与要求 1、掌握在Linux虚拟机中安装Hadoop和Spark的方法&#xff1b; 2、熟悉HDFS的基本使用方法&#xff1b; 3、掌…

SCADA助力食品加工数字化变革:未来产业的智慧引擎

一、背景介绍 当前&#xff0c;在国际市场竞争加剧、消费者个性化需求突出的背景下&#xff0c;我国食品加工行业面临着诸多挑战&#xff1a;越发严苛的食品安全标准、追求供应链的透明度和效率、进一步提高产品质量和降低成本等等。 为了应对上述挑战&#xff0c;我国食品加…

亚马逊,速卖通,shein卖家如何准确有效的测评补单

一、合理规划测评时间和数量 卖家需要合理规划测评的时间和数量。如果卖家过于频繁地进行测评&#xff0c;或者在短时间内进行大量的测评&#xff0c;这可能会被视为恶意行为&#xff0c;从而触犯风控机制。因此&#xff0c;卖家需要根据自己的销售情况和市场需求&#xff0c;…

如何确保对称密钥管理的存储安全?

确保对称密钥管理的存储安全是保障信息安全的重要一环。以下是一些建议&#xff0c;以确保对称密钥管理的存储安全&#xff1a; 使用安全存储设备&#xff1a;选择使用经过验证的安全存储设备来存储对称密钥。这些设备通常具有高度的物理安全性&#xff0c;可以防止未经授权的访…

vp与vs联合开发-通过CogAcqFifoTool工具连接相机

1.完成相机硬件配置后 2.完成vp与vs联合开发配置功能后 1.创建winform 项目 目的 : 搭建 界面应用 2. 1. vpp文件存入 项目的debug 目录中 目的&#xff1a; 在项目中加载本地vpp文件 读取相机工具 1.控件CogRecordDisplay 用于显示相机拍摄照片和实施显示的窗口 2和3 …

aidd【人工智能技术及在生物分子活性预测、药物发现中的应用】

人工智能技术在生物分子活性预测和药物发现中具有广泛的应用。以下是一些具体的应用方式&#xff1a; 生物分子活性预测&#xff1a;利用机器学习算法&#xff0c;可以对生物分子的活性进行预测。这些算法可以学习并识别与生物分子活性相关的模式&#xff0c;并基于这些模式对…