深入理解Softmax函数及其在PyTorch中的实现

Softmax函数简介

Softmax函数在机器学习和深度学习中,被广泛用于多分类问题的输出层。它将一个实数向量转换为概率分布,使得每个元素介于0和1之间,且所有元素之和为1。

Softmax函数的定义

给定一个长度为 K K K的输入向量 z = [ z 1 , z 2 , … , z K ] \boldsymbol{z} = [z_1, z_2, \dots, z_K] z=[z1,z2,,zK],Softmax函数 σ ( z ) \sigma(\boldsymbol{z}) σ(z)定义为:

σ ( z ) i = e z i ∑ j = 1 K e z j , 对于所有  i = 1 , 2 , … , K \sigma(\boldsymbol{z})_i = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}, \quad \text{对于所有 } i = 1, 2, \dots, K σ(z)i=j=1Kezjezi,对于所有 i=1,2,,K

其中:

  • e e e是自然对数的底数,约为2.71828。
  • σ ( z ) i \sigma(\boldsymbol{z})_i σ(z)i是输入向量第 i i i个分量对应的Softmax输出。

Softmax函数的特点

  1. 将输出转换为概率分布:Softmax的输出向量中的每个元素都在 ( 0 , 1 ) (0, 1) (0,1)之间,并且所有元素的和为1,这使得输出可以视为各类别的概率。

  2. 强调较大的值:Softmax函数会放大输入向量中较大的元素对应的概率,同时压缩较小的元素对应的概率。这种特性有助于突出模型认为更有可能的类别。

  3. 可微性:Softmax函数是可微的,这对于基于梯度的优化算法(如反向传播)非常重要。


数值稳定性的问题

在实际计算中,为了防止指数函数计算过程中可能出现的数值溢出,通常会对输入向量进行调整。常见的做法是在计算Softmax之前,从输入向量的每个元素中减去向量的最大值:

σ ( z ) i = e z i − z max ∑ j = 1 K e z j − z max \sigma(\boldsymbol{z})_i = \frac{e^{z_i - z_{\text{max}}}}{\sum_{j=1}^{K} e^{z_j - z_{\text{max}}}} σ(z)i=j=1Kezjzmaxezizmax

其中, z max = max ⁡ { z 1 , z 2 , … , z K } z_{\text{max}} = \max\{z_1, z_2, \dots, z_K\} zmax=max{z1,z2,,zK}。这种调整不会改变Softmax的输出结果,但能提高计算的数值稳定性。


Softmax函数的应用场景

  1. 多分类问题:在神经网络的最后一层,Softmax函数常用于将模型的线性输出转换为概率分布,以进行多分类预测。

  2. 注意力机制:在深度学习中的注意力模型中,Softmax用于计算注意力权重,以突显重要的输入特征。

  3. 语言模型:在自然语言处理任务中,Softmax函数用于预测下一个词的概率分布。


Softmax函数的示例计算

假设有一个三类别分类问题,神经网络的输出为一个长度为3的向量:

z = [ z 1 , z 2 , z 3 ] = [ 2.0 , 1.0 , 0.1 ] \boldsymbol{z} = [z_1, z_2, z_3] = [2.0, 1.0, 0.1] z=[z1,z2,z3]=[2.0,1.0,0.1]

我们想使用Softmax函数将其转换为概率分布。

步骤1:计算每个元素的指数

e z 1 = e 2.0 = 7.3891 e z 2 = e 1.0 = 2.7183 e z 3 = e 0.1 = 1.1052 \begin{align*} e^{z_1} &= e^{2.0} = 7.3891 \\ e^{z_2} &= e^{1.0} = 2.7183 \\ e^{z_3} &= e^{0.1} = 1.1052 \end{align*} ez1ez2ez3=e2.0=7.3891=e1.0=2.7183=e0.1=1.1052

步骤2:计算指数和

sum = e z 1 + e z 2 + e z 3 = 7.3891 + 2.7183 + 1.1052 = 11.2126 \text{sum} = e^{z_1} + e^{z_2} + e^{z_3} = 7.3891 + 2.7183 + 1.1052 = 11.2126 sum=ez1+ez2+ez3=7.3891+2.7183+1.1052=11.2126

步骤3:计算Softmax输出

σ 1 = e z 1 sum = 7.3891 11.2126 = 0.6590 σ 2 = e z 2 sum = 2.7183 11.2126 = 0.2424 σ 3 = e z 3 sum = 1.1052 11.2126 = 0.0986 \begin{align*} \sigma_1 &= \frac{e^{z_1}}{\text{sum}} = \frac{7.3891}{11.2126} = 0.6590 \\ \sigma_2 &= \frac{e^{z_2}}{\text{sum}} = \frac{2.7183}{11.2126} = 0.2424 \\ \sigma_3 &= \frac{e^{z_3}}{\text{sum}} = \frac{1.1052}{11.2126} = 0.0986 \end{align*} σ1σ2σ3=sumez1=11.21267.3891=0.6590=sumez2=11.21262.7183=0.2424=sumez3=11.21261.1052=0.0986

因此,经过Softmax函数后,输出概率分布为:

σ ( z ) = [ 0.6590 , 0.2424 , 0.0986 ] \sigma(\boldsymbol{z}) = [0.6590, 0.2424, 0.0986] σ(z)=[0.6590,0.2424,0.0986]

这表示模型预测第一个类别的概率约为65.9%,第二个类别约为24.24%,第三个类别约为9.86%。


使用PyTorch实现Softmax函数

在PyTorch中,可以通过多种方式实现Softmax函数。以下将通过示例演示如何使用torch.nn.functional.softmaxtorch.nn.Softmax

创建输入数据

首先,创建一个示例输入张量:

import torch
import torch.nn as nn
import torch.nn.functional as F# 创建一个输入张量,形状为 (batch_size, features)
input_tensor = torch.tensor([[2.0, 1.0, 0.1],[1.0, 3.0, 0.2]])
print("输入张量:")
print(input_tensor)

输出:

输入张量:
tensor([[2.0000, 1.0000, 0.1000],[1.0000, 3.0000, 0.2000]])

方法一:使用torch.nn.functional.softmax

利用PyTorch中torch.nn.functional.softmax函数直接对输入数据应用Softmax。

# 在维度1上(即特征维)应用Softmax
softmax_output = F.softmax(input_tensor, dim=1)
print("\nSoftmax输出:")
print(softmax_output)

输出:

Softmax输出:
tensor([[0.6590, 0.2424, 0.0986],[0.1065, 0.8726, 0.0209]])

方法二:使用torch.nn.Softmax模块

也可以使用torch.nn中的Softmax模块。

# 创建一个Softmax层实例
softmax = nn.Softmax(dim=1)# 对输入张量应用Softmax层
softmax_output_module = softmax(input_tensor)
print("\n使用nn.Softmax模块的输出:")
print(softmax_output_module)

输出:

使用nn.Softmax模块的输出:
tensor([[0.6590, 0.2424, 0.0986],[0.1065, 0.8726, 0.0209]])

在神经网络模型中应用Softmax

构建一个简单的神经网络模型,在最后一层使用Softmax激活函数。

class SimpleNetwork(nn.Module):def __init__(self, input_size, num_classes):super(SimpleNetwork, self).__init__()self.layer1 = nn.Linear(input_size, 5)self.layer2 = nn.Linear(5, num_classes)# 使用LogSoftmax提高数值稳定性self.softmax = nn.LogSoftmax(dim=1)def forward(self, x):x = F.relu(self.layer1(x))x = self.layer2(x)x = self.softmax(x)return x# 定义输入大小和类别数
input_size = 3
num_classes = 3# 创建模型实例
model = SimpleNetwork(input_size, num_classes)# 查看模型结构
print("\n模型结构:")
print(model)

输出:

模型结构:
SimpleNetwork((layer1): Linear(in_features=3, out_features=5, bias=True)(layer2): Linear(in_features=5, out_features=3, bias=True)(softmax): LogSoftmax(dim=1)
)

前向传播:

# 将输入数据转换为浮点型张量
input_data = input_tensor.float()# 前向传播
output = model(input_data)
print("\n模型输出(对数概率):")
print(output)

输出:

模型输出(对数概率):
tensor([[-1.2443, -0.7140, -1.2645],[-1.3689, -0.6535, -1.5142]], grad_fn=<LogSoftmaxBackward0>)

转换为概率:

# 取指数,转换为概率
probabilities = torch.exp(output)
print("\n模型输出(概率):")
print(probabilities)

输出:

模型输出(概率):
tensor([[0.2882, 0.4898, 0.2220],[0.2541, 0.5204, 0.2255]], grad_fn=<ExpBackward0>)

预测类别:

# 获取每个样本概率最大的类别索引
predicted_classes = torch.argmax(probabilities, dim=1)
print("\n预测的类别:")
print(predicted_classes)

输出:

预测的类别:
tensor([1, 1])

torch.nn.functional.softmaxtorch.nn.Softmax的区别

函数式API与模块化API的设计理念

PyTorch提供了两种API:

  1. 函数式API (torch.nn.functional)

    • 特点:无状态(Stateless),不包含可学习的参数。
    • 使用方式:直接调用函数。
    • 适用场景:需要在forward方法中灵活应用各种操作。
  2. 模块化API (torch.nn.Module)

    • 特点:有状态(Stateful),可能包含可学习的参数,即使某些模块没有参数(如Softmax),但继承自nn.Module
    • 使用方式:需要先实例化,再在前向传播中调用。
    • 适用场景:构建模型时,统一管理各个层和操作。

具体到Softmax的实现

  • torch.nn.functional.softmax(函数)

    • 使用示例

      import torch.nn.functional as F
      output = F.softmax(input_tensor, dim=1)
      
    • 特点:直接调用,简洁灵活。

  • torch.nn.Softmax(模块)

    • 使用示例

      import torch.nn as nn
      softmax = nn.Softmax(dim=1)
      output = softmax(input_tensor)
      
    • 特点:作为模型的一层,便于与其他层组合,保持代码结构一致。

为什么存在两个实现?

提供两种实现方式是为了满足不同开发者的需求和编程风格。

  • 使用nn.Softmax的优势

    • 在模型定义阶段明确各层,结构清晰。
    • 便于使用nn.Sequential构建顺序模型。
    • 统一管理模型的各个部分。
  • 使用F.softmax的优势

    • 代码简洁,直接调用函数。
    • 适用于需要在forward中进行灵活操作的情况。

使用示例

使用nn.Softmax
import torch
import torch.nn as nn# 定义模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.layer = nn.Linear(10, 5)self.softmax = nn.Softmax(dim=1)def forward(self, x):x = self.layer(x)x = self.softmax(x)return x# 实例化和使用
model = MyModel()
input_tensor = torch.randn(2, 10)
output = model(input_tensor)
print(output)
使用F.softmax
import torch
import torch.nn as nn
import torch.nn.functional as F# 定义模型
class MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.layer = nn.Linear(10, 5)def forward(self, x):x = self.layer(x)x = F.softmax(x, dim=1)return x# 实例化和使用
model = MyModel()
input_tensor = torch.randn(2, 10)
output = model(input_tensor)
print(output)

总结

Softmax函数在深度学习中起着关键作用,尤其在多分类任务中。PyTorch为了满足不同的开发需求,提供了torch.nn.functional.softmaxtorch.nn.Softmax两种实现方式。

  • F.softmax:函数式API,灵活简洁,适合在forward方法中直接调用。

  • nn.Softmax:模块化API,便于模型结构的统一管理,适合在模型初始化时定义各个层。

在实际开发中,选择适合你的项目和团队的方式。如果更喜欢模块化的代码结构,使用nn.Softmax;如果追求简洁和灵活,使用F.softmax。同时,要注意数值稳定性的问题,尤其是在计算损失函数时,建议使用nn.LogSoftmaxnn.NLLLoss结合使用。


参考文献

  • PyTorch官方文档 - Softmax函数
  • PyTorch官方文档 - nn.Softmax
  • PyTorch官方教程 - 构建神经网络
  • PyTorch论坛 - Softmax激活函数:nn.Softmax vs F.softmax

希望本文能帮助读者深入理解Softmax函数及其在PyTorch中的实现和应用。如有任何疑问,欢迎交流讨论!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/77608.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Vue 3 响应式更新问题解析

在 Vue 3 中&#xff0c;即使使用 reactive 或 ref 创建的响应式数据&#xff0c;当数据量很大时也可能出现更新不及时的情况。以下是原因和解决方案&#xff1a; 核心原因 ​​响应式系统优化机制​​&#xff1a; Vue 3 使用 Proxy 实现响应式&#xff0c;比 Vue 2 更高效但为…

异形遮罩之QML中的 `OpacityMask` 实战

文章目录 &#x1f327;️ 传统实现的问题&#x1f449; 效果图 &#x1f308; 使用 OpacityMask 的理想方案&#x1f449;代码如下&#x1f3af; 最终效果&#xff1a; ✨ 延伸应用&#x1f9e0; 总结 在 UI 设计中&#xff0c;经常希望实现一些“异形区域”拥有统一透明度或颜…

数据可视化 —— 堆形图应用(大全)

一、案例一&#xff1a;温度堆积图 # 导入 matplotlib 库中的 pyplot 模块&#xff0c;这个模块提供了类似于 MATLAB 的绘图接口&#xff0c; # 方便我们创建各种类型的可视化图表&#xff0c;比如折线图、柱状图、散点图等 import matplotlib.pyplot as plt # 导入 numpy 库&…

python工程中的包管理(requirements.txt)

pip install -r requirements.txtpython工程通过requirements.txt来管理依赖库版本&#xff0c;上述命令&#xff0c;可以一把安装依赖库&#xff0c;类似java中maven的pom.xml文件。 参考 [](

操作系统 3.4-段页结合的实际内存管理

段与页结合的初步思路 虚拟内存的引入&#xff1a; 为了结合段和页的优势&#xff0c;操作系统引入了虚拟内存的概念。虚拟内存是一段地址空间&#xff0c;它映射到物理内存上&#xff0c;但对用户程序是透明的。 段到虚拟内存的映射&#xff1a; 用户程序中的段首先映射到虚…

【Amazon EC2】为何基于浏览器的EC2 Instance Connect 客户端连接不上EC2实例

文章目录 前言&#x1f4d6;一、报错先知❌二、问题复现&#x1f62f;三、解决办法&#x1f3b2;四、验证结果&#x1f44d;五、参考链接&#x1f517; 前言&#x1f4d6; 这篇文章将讲述我在 Amazon EC2 上使用 RHEL9 AMI 时无法连接到 EC2 实例时所遇到的麻烦&#x1f616; …

Python学习笔记(二)(字符串)

文章目录 编写简单的程序一、标识符 (Identifiers)及关键字命名规则&#xff1a;命名惯例&#xff1a;关键字 二、变量与赋值 (Variables & Assignment)变量定义&#xff1a;多重赋值&#xff1a;变量交换&#xff1a;&#xff08;很方便哟&#xff09; 三、输入与输出 (In…

Hydra Columnar:一个开源的PostgreSQL列式存储引擎

Hydra Columnar 是一个 PostgreSQL 列式存储插件&#xff0c;专为分析型&#xff08;OLAP&#xff09;工作负载设计&#xff0c;旨在提升大规模分析查询和批量更新的效率。 Hydra Columnar 以扩展插件的方式提供&#xff0c;主要特点包括&#xff1a; 采用列式存储&#xff0c…

es的告警信息

Elasticsearch&#xff08;ES&#xff09;是一个开源的分布式搜索和分析引擎&#xff0c;在运行过程中可能会产生多种告警信息&#xff0c;以提示用户系统中存在的潜在问题或异常情况。以下是一些常见的 ES 告警信息及其含义和处理方法&#xff1a; 集群健康状态告警 信息示例…

健康与好身体笔记

文章目录 保证睡眠饭后百步走&#xff0c;活到九十九补充钙质一副好肠胃肚子咕咕叫 健康和工作的取舍 以前对健康没概念&#xff0c;但是随着年龄增长&#xff0c;健康问题凸显出来。 持续维护该文档&#xff0c;健康是个永恒的话题。 保证睡眠 一是心态要好&#xff0c;沾枕…

vue实现在线进制转换

vue实现在线进制转换 主要功能包括&#xff1a; 1.支持2-36进制之间的转换。 2.支持整数和浮点数的转换。 3.输入验证&#xff08;虽然可能存在不严格的情况&#xff09;。 4.错误提示。 5.结果展示&#xff0c;包括大写字母。 6.用户友好的界面&#xff0c;包括下拉菜单、输…

智体知识库:poplang编程语言是什么?

问&#xff1a;poplang语言是什么 Poplang 语言简介 Poplang&#xff08;OPCode-Oriented Programming Language&#xff09;是一种面向操作码&#xff08;Opcode&#xff09;的轻量级编程语言&#xff0c;主要用于智体&#xff08;Agent&#xff09;系统中的自动化任务处理、…

二分查找5:852. 山脉数组的峰顶索引

链接&#xff1a;852. 山脉数组的峰顶索引 - 力扣&#xff08;LeetCode&#xff09; 题解&#xff1a; 事实证明&#xff0c;二分查找不局限于有序数组&#xff0c;非有序的数组也同样适用 二分查找主要思想在于二段性&#xff0c;即将数组分为两段。本体就可以将数组分为ar…

下列软件包有未满足的依赖关系: python3-catkin-pkg : 冲突: catkin 但是 0.8.10-

下列软件包有未满足的依赖关系: python3-catkin-pkg : 冲突: catkin 但是 0.8.10- 解决&#xff1a; 1. 确认当前的包状态 首先&#xff0c;运行以下命令来查看当前安装的catkin和python3-catkin-pkg版本&#xff0c;以及它们之间的依赖关系&#xff1a; dpkg -l | grep ca…

深度学习:AI 大模型时代的智能引擎

当 Deepspeek 以逼真到难辨真假的语音合成和视频生成技术横空出世&#xff0c;瞬间引发了全球对 AI 伦理与技术边界的激烈讨论。从伪造名人演讲、制造虚假新闻&#xff0c;到影视行业的特效革新&#xff0c;这项技术以惊人的速度渗透进大众视野。但在 Deepspeek 强大功能的背后…

医学分割新标杆!双路径PGM-UNet:CNN+Mamba实现病灶毫厘级捕捉

一、引言&#xff1a;医学图像分割的挑战与机遇 医学图像分割是辅助疾病诊断和治疗规划的关键技术&#xff0c;但传统方法常受限于复杂病理特征和微小结构。现有深度学习模型&#xff08;如CNN和Transformer&#xff09;虽各有优势&#xff0c;但CNN难以建模长距离依赖&…

CV - 目标检测

物体检测 目标检测和图片分类的区别&#xff1a; 图像分类&#xff08;Image Classification&#xff09; 目的&#xff1a;图像分类的目的是识别出图像中主要物体的类别。它试图回答“图像是什么&#xff1f;”的问题。 输出&#xff1a;通常输出是一个标签或一组概率值&am…

高并发秒杀系统设计:关键技术解析与典型陷阱规避

电商、在线票务等众多互联网业务场景中&#xff0c;高并发秒杀活动屡见不鲜。这类活动往往在短时间内会涌入海量的用户请求&#xff0c;对系统架构的性能、稳定性和可用性提出了极高的挑战。曾经&#xff0c;高并发秒杀架构设计让许多开发者望而生畏&#xff0c;然而&#xff0…

蓝桥杯--结束

冲刺题单 基础 一、简单模拟&#xff08;循环数组日期进制&#xff09; &#xff08;一&#xff09;日期模拟 知识点 1.把月份写为数组&#xff0c;二月默认为28天。 2.写一个判断闰年的方法&#xff0c;然后循环年份的时候判断并更新二月的天数 3.对于星期数的计算&#…

13、nRF52xx蓝牙学习(GPIOTE组件方式的任务配置)

下面再来探讨下驱动库如何实现任务的配置&#xff0c;驱动库的实现步骤应该和寄存器方式对应&#xff0c;关 键点就是如何调用驱动库的函数。 本例里同样的对比寄存器方式编写两路的 GPOITE 任务输出&#xff0c;一路配置为输出翻转&#xff0c;一路设 置为输出低电平。和 …