理解深度学习pytorch框架中的线性层

在神经网络或机器学习的线性层(Linear Layer / Fully Connected Layer)中,经常会见到两种形式的公式:

  • 数学文献或传统线性代数写法: y = W x + b \displaystyle y = W\,x + b y=Wx+b
  • 一些深度学习代码中写法: y = x W T + b \displaystyle y = x\,W^T + b y=xWT+b

初次接触时,很多人会觉得两者“方向”不太一样,不知该如何对照理解;再加上矩阵维度 ( in_features , out_features ) (\text{in\_features},\, \text{out\_features}) (in_features,out_features) ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features) 的各种写法常常让人疑惑不已。本文将从数学角度和编程实现角度剖析它们的关系,并结合实际示例指出一些常见的坑与需要特别留意的下标对应问题。

1. 数学角度: y = W x + b \displaystyle y = W\,x + b y=Wx+b

在线性代数中,如果我们假设输入 x x x 是一个列向量,通常会写作 x ∈ R ( in_features ) \displaystyle x\in\mathbb{R}^{(\text{in\_features})} xR(in_features)(或者在更严格的矩阵形状记法下写作 ( in_features , 1 ) (\text{in\_features},\,1) (in_features,1))。那么一个最常见的全连接层可以表示为:

y = W x + b , y = W\,x + b, y=Wx+b,

其中:

  • W W W 是一个大小为 ( out_features , in_features ) \bigl(\text{out\_features},\,\text{in\_features}\bigr) (out_features,in_features) 的矩阵;
  • b b b 是一个 out_features \text{out\_features} out_features-维的偏置向量(形状 ( out_features , 1 ) (\text{out\_features},\,1) (out_features,1));
  • y y y 则是输出向量,大小为 out_features \text{out\_features} out_features

示例

假设 in_features = 3 \text{in\_features}=3 in_features=3 out_features = 2 \text{out\_features}=2 out_features=2。那么:
W ∈ R 2 × 3 , x ∈ R 3 × 1 , b ∈ R 2 × 1 . W \in \mathbb{R}^{2\times 3},\quad x \in \mathbb{R}^{3\times 1},\quad b \in \mathbb{R}^{2\times 1}. WR2×3,xR3×1,bR2×1.

矩阵写开来就是:

W = [ w 11 w 12 w 13 w 21 w 22 w 23 ] , x = [ x 1 x 2 x 3 ] , b = [ b 1 b 2 ] . W = \begin{bmatrix} w_{11} & w_{12} & w_{13} \\[5pt] w_{21} & w_{22} & w_{23} \end{bmatrix},\quad x = \begin{bmatrix} x_{1}\\ x_{2}\\ x_{3} \end{bmatrix},\quad b = \begin{bmatrix} b_{1}\\ b_{2} \end{bmatrix}. W=[w11w21w12w22w13w23],x= x1x2x3 ,b=[b1b2].

那么线性变换结果 W x + b Wx + b Wx+b 可以展开为:

W x + b = [ w 11 x 1 + w 12 x 2 + w 13 x 3 w 21 x 1 + w 22 x 2 + w 23 x 3 ] + [ b 1 b 2 ] = [ w 11 x 1 + w 12 x 2 + w 13 x 3 + b 1 w 21 x 1 + w 22 x 2 + w 23 x 3 + b 2 ] . \begin{aligned} Wx + b &= \begin{bmatrix} w_{11}x_1 + w_{12}x_2 + w_{13}x_3 \\ w_{21}x_1 + w_{22}x_2 + w_{23}x_3 \end{bmatrix} + \begin{bmatrix} b_1 \\ b_2 \end{bmatrix} \\ &= \begin{bmatrix} w_{11}x_1 + w_{12}x_2 + w_{13}x_3 + b_1 \\ w_{21}x_1 + w_{22}x_2 + w_{23}x_3 + b_2 \end{bmatrix}. \end{aligned} Wx+b=[w11x1+w12x2+w13x3w21x1+w22x2+w23x3]+[b1b2]=[w11x1+w12x2+w13x3+b1w21x1+w22x2+w23x3+b2].

这就是最传统、在数学文献或线性代数课程中最常见的表示方法。


2. 编程实现角度: y = x W T + b \displaystyle y = x\,W^T + b y=xWT+b

在实际的深度学习代码(例如 PyTorch、TensorFlow)中,经常看到的却是下面这种写法:

y = x @ W.T + b

注意这里 W.shape 通常被定义为 ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features),而 x.shape 在批量处理时则是 ( batch_size , in_features ) (\text{batch\_size},\, \text{in\_features}) (batch_size,in_features)。于是 (x @ W.T) 的结果是 ( batch_size , out_features ) (\text{batch\_size},\, \text{out\_features}) (batch_size,out_features)

为什么会出现转置?
因为在数学里我们通常把 x x x 当作“列向量”放在右边,于是公式变成 y = W x + b y = Wx + b y=Wx+b
但在编程里,尤其是处理批量输入时,x 常写成“行向量”的形式 ( batch_size , in_features ) (\text{batch\_size},\, \text{in\_features}) (batch_size,in_features),这就造成了在进行矩阵乘法时,需要将 W(大小 ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features))转置成 ( in_features , out_features ) (\text{in\_features},\, \text{out\_features}) (in_features,out_features),才能满足「行×列」的匹配关系。

从结果上来看,

( batch_size , in_features ) × ( in_features , out_features ) = ( batch_size , out_features ) . (\text{batch\_size}, \text{in\_features}) \times (\text{in\_features}, \text{out\_features}) = (\text{batch\_size}, \text{out\_features}). (batch_size,in_features)×(in_features,out_features)=(batch_size,out_features).

所以,在代码里就写成 x @ W.T,再加上偏置 b(通常会广播到 batch_size \text{batch\_size} batch_size 那个维度)。

本质上这和数学公式里 y = W x + b y = W\,x + b y=Wx+b 并无冲突,只是一个“列向量”和“行向量”的转置关系。只要搞清楚最终你想让输出 y y y 的 shape 是多少,就能明白在代码里为什么要写 .T


3. 常见错误与易混点解析

有些教程或文档,会不小心写成:“如果我们有一个形状为 ( in_features , out_features ) (\text{in\_features},\text{out\_features}) (in_features,out_features) 的权重矩阵 W W W……”——然后又要做 W x Wx Wx,想得到一个 out_features \text{out\_features} out_features-维的结果。但按照线性代数的常规写法,行数必须和输出维度匹配、列数必须和输入维度匹配。所以 正确 的说法应该是

W ∈ R ( out_features ) × ( in_features ) . W\in\mathbb{R}^{(\text{out\_features}) \times (\text{in\_features})}. WR(out_features)×(in_features).

否则从矩阵乘法次序来看就对不上。
但这又可能让人迷惑:为什么深度学习框架 torch.nn.Linear(in_features, out_features) 却给出 weight.shape == (out_features, in_features) 其实正是同一个道理,它和上面“数学文献里”用到的 W W W 形状完全一致。


4. 小结

  1. 从数学角度
    最传统的记号是
    y = W x + b , W ∈ R ( out_features ) × ( in_features ) , x ∈ R ( in_features ) , y ∈ R ( out_features ) . y = W\,x + b, \quad W \in \mathbb{R}^{(\text{out\_features})\times(\text{in\_features})},\, x \in \mathbb{R}^{(\text{in\_features})},\, y \in \mathbb{R}^{(\text{out\_features})}. y=Wx+b,WR(out_features)×(in_features),xR(in_features),yR(out_features).

  2. 从深度学习代码角度

    • 由于批量数据常被视为行向量,每一行代表一个样本特征,因此形状通常是 ( batch_size , in_features ) (\text{batch\_size},\, \text{in\_features}) (batch_size,in_features)
    • 对应的权重 W 定义为 ( out_features , in_features ) (\text{out\_features},\, \text{in\_features}) (out_features,in_features)。为了完成行乘以列的矩阵运算,需要对 W 做转置:
      y = x @ W.T + b
      
    • 得到的 y.shape ( batch_size , out_features ) (\text{batch\_size},\, \text{out\_features}) (batch_size,out_features)
  3. 避免踩坑

    • 写公式时,仔细确认 in_features \text{in\_features} in_features out_features \text{out\_features} out_features 的位置以及矩阵行列顺序。
    • 编程实践中理解“为什么要 .T”非常重要:那只是为了匹配「行×列」的矩阵乘法规则,本质上还是和 y = W x + b y = Wx + b y=Wx+b 相同。

通过理解并区分“列向量”与“行向量”的不同惯例,避免因为矩阵维度或转置不当而导致莫名其妙的错误或 bug。


参考链接

  • PyTorch 文档:torch.nn.Linear
  • 深度学习中的矩阵运算初步 —— batch_size 与矩阵乘法
  • 常见线性代数符号:行向量与列向量

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/66895.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C#Object类型的索引,序列化和反序列化

前言 最近在编写一篇关于标准Mes接口框架的文章。其中有一个非常需要考究的内容时如果实现数据灵活和可使用性强。因为考虑数据灵活性,所以我一开始选取了Object类型作为数据类型,Object作为数据Value字段,String作为数据Key字段&#xff0c…

大模型应用与部署 技术方案

大模型应用与部署 技术方案 一、引言 人工智能蓬勃发展,Qwen 大模型在自然语言处理领域地位关键,其架构优势尽显,能处理文本创作等多类复杂任务,提供优质交互。Milvus 向量数据库则是向量数据存储检索利器,有高效索引算法(如 IVF_FLAT、HNSWLIB 等)助力大规模数据集相似…

【Prometheus】Prometheus如何监控Haproxy

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

C# 控制打印机:从入门到实践

在开发一些涉及打印功能的应用程序时,使用 C# 控制打印机是一项很实用的技能。这篇文章就来详细介绍下如何在 C# 中实现对打印机的控制。 一、准备工作 安装相关库:在 C# 中操作打印机,我们可以借助System.Drawing.Printing命名空间&#x…

Go语言中的值类型和引用类型特点

一、值类型 值类型的数据直接包含值,当它们被赋值给一个新的变量或者作为参数传递给函数时,实际上是创建了原值的一个副本。这意味着对新变量的修改不会影响原始变量的值。 Go中的值类型包括: 基础类型:int,float64…

GPT 结束语设计 以nanogpt为例

GPT 结束语设计 以nanogpt为例 目录 GPT 结束语设计 以nanogpt为例 1、简述 2、分词设计 3、结束语断点 1、简述 在手搓gpt的时候,可能会遇到一些性能问题,即关于是否需要全部输出或者怎么节约资源。 在输出语句被max_new_tokens 限制&#xff0c…

《探秘:人工智能如何为鸿蒙Next元宇宙网络传输与延迟问题破局》

在元宇宙的宏大愿景中,流畅的网络传输和低延迟是保障用户沉浸式体验的关键。鸿蒙Next结合人工智能技术,为解决这些问题提供了一系列创新思路和方法。 智能网络监测与预测 人工智能可以实时监测鸿蒙Next元宇宙中的网络状况,包括带宽、延迟、…

深入MapReduce——计算模型设计

引入 通过引入篇,我们可以总结,MapReduce针对海量数据计算核心痛点的解法如下: 统一编程模型,降低用户使用门槛分而治之,利用了并行处理提高计算效率移动计算,减少硬件瓶颈的限制 优秀的设计&#xff0c…

macOS安装Gradle环境

文章目录 说明安装JDK安装Gradle 说明 gradle8.5最高支持jdk21,如果使用jdk22建议使用gradle8.8以上版本 安装JDK mac系统安装最新(截止2024.9.13)Oracle JDK操作记录 安装Gradle 下载Gradle,解压将其存放到资源java/env目录…

五国十五校联合巨献!仿人机器人运动与操控:控制、规划与学习的最新突破与挑战

作者: Zhaoyuan Gu, Junheng Li, Wenlan Shen, Wenhao Yu, Zhaoming Xie, Stephen McCrory, Xianyi Cheng, Abdulaziz Shamsah, Robert Griffin, C. Karen Liu, Abderrahmane Kheddar, Xue Bin Peng, Yuke Zhu, Guanya Shi, Quan Nguyen, Gordon Cheng, Huijun Gao,…

CVPR 2024 无人机/遥感/卫星图像方向总汇(航空图像和交叉视角定位)

1、UAV、Remote Sensing、Satellite Image(无人机/遥感/卫星图像) Unleashing Unlabeled Data: A Paradigm for Cross-View Geo-Localization ⭐codeRethinking Transformers Pre-training for Multi-Spectral Satellite Imagery ⭐codeAerial Lifting: Neural Urban Semantic …

【BQ3568HM开发板】如何在OpenHarmony上通过校园网的上网认证

引言 前面已经对BQ3568HM开发板进行了初步测试,后面我要实现MQTT的工作,但是遇到一个问题,就是开发板无法通过校园网的认证操作。未认证的话会,学校使用的深澜软件系统会屏蔽所有除了认证用的流量。好在我们学校使用的认证系统和…

(Java版本)基于JAVA的网络通讯系统设计与实现-毕业设计

源码 论文 下载地址: ​​​​c​​​​​​c基于JAVA的网络通讯系统设计与实现(源码系统论文)https://download.csdn.net/download/weixin_39682092/90299782https://download.csdn.net/download/weixin_39682092/90299782 第1章 绪论 1.1 课题选择的…

kafka学习笔记4-TLS加密 —— 筑梦之路

1. 准备证书文件 mkdir /opt/kafka/pkicd !$# 生成CA证书 openssl req -x509 -nodes -days 3650 -newkey rsa:4096 -keyout ca.key -out ca.crt -subj "/CNKafka-CA"# 生成私钥 openssl genrsa -out kafka.key 4096# 生成证书签名请求 (CSR) openssl req -new -key …

Node.js NativeAddon 构建工具:node-gyp 安装与配置完全指南

Node.js NativeAddon 构建工具:node-gyp 安装与配置完全指南 node-gyp Node.js native addon build tool [这里是图片001] 项目地址: https://gitcode.com/gh_mirrors/no/node-gyp 项目基础介绍及主要编程语言 Node.js NativeAddon 构建工具(node-gyp…

SpringCloud微服务Gateway网关简单集成Sentinel

Sentinel是阿里巴巴开源的一款面向分布式服务架构的轻量级流量控制、熔断降级组件。Sentinel以流量为切入点,从流量控制、熔断降级、系统负载保护等多个维度来帮助保护服务的稳定性。 官方文档:https://sentinelguard.io/zh-cn/docs/introduction.html …

vscode环境中用仓颉语言开发时调出覆盖率的方法

在vscode中仓颉语言想得到在idea中利用junit和jacoco的覆盖率,需要如下几个步骤: 1.在vscode中搭建仓颉语言开发环境; 2.在源代码中右键运行[cangjie]coverage. 思路1:编写了测试代码的情况(包管理工具) …

pikachu靶场-敏感信息泄露概述

敏感信息泄露概述 由于后台人员的疏忽或者不当的设计,导致不应该被前端用户看到的数据被轻易的访问到。 比如: ---通过访问url下的目录,可以直接列出目录下的文件列表; ---输入错误的url参数后报错信息里面包含操作系统、中间件、开发语言的版…

安卓动态设置Unity图形API

命令行方式 Unity图像api设置为自动,安卓动态设置Vulkan、OpenGLES Unity设置 安卓设置 创建自定义活动并将其设置为应用程序入口点。 在自定义活动中,覆盖字符串UnityPlayerActivity。updateunitycommandlineararguments (String cmdLine)方法。 在该方法中,将cmdLine…

CICD集合(五):Jenkins+Git+Allure实战(自动化测试)

CICD集合(五):Jenkins+Git+Allure实战(自动化测试) 前提: 已安装好Jenkins安装好git,maven,allure报告插件配置好Git,Maven,allure参考:CICD集合(一至四) https://blog.csdn.net/fen_fen/article/details/131476093 https://blog.csdn.net/fen_fen/article/details/1213…