基于RNN和Transformer的词级语言建模 代码分析 log_softmax

基于RNN和Transformer的词级语言建模 代码分析 log_softmax

flyfish

Word-level Language Modeling using RNN and Transformer

word_language_model

PyTorch 提供的 word_language_model 示例展示了如何使用循环神经网络RNN(GRU或LSTM)和 Transformer 模型进行词级语言建模 。默认情况下,训练使用Wikitext-2数据集,generate.py可以使用训练好的模型来生成新文本。

源码地址
https://github.com/pytorch/examples/tree/main/word_language_model

文件:model.py

F.log_softmax(output, dim=-1) 在 TransformerModel 的 forward 方法的最后一行中用于将模型的输出转换为对数概率分布,既提高了数值计算的稳定性,又与常用的损失函数(如NLLLoss)兼容.

数值计算的稳定性,请参考该文章的后半部分

https://flyfish.blog.csdn.net/article/details/106405099

1. 概率与对数概率

概率(Probability): 概率是某个事件发生的可能性,值在 [0, 1] 之间。比如,一个事件发生的概率为 P。
对数概率(Log Probability): 对数概率是将概率值

P 取对数后的结果,通常使用自然对数(ln)。对数概率可以表示为:log§,其中 log 是自然对数函数。

2. 对数概率的性质

范围: 因为概率 P 总是介于 0 和 1 之间,所以对数概率 log§ 总是小于或等于零。

当 P=1 时,log(1)=0。
当 0<P<1 时,log§<0。
当 P 趋近于 0 时,log§ 趋近于负无穷。
单调性: 对数函数是单调递增函数,这意味着如果两个概率
两个概率 P1和 P2, 满足 P1 > P2,则对应的对数概率也满足 log(P1) > log(P2)

3. 为什么使用对数概率

数值稳定性: 直接使用概率值进行计算时,若概率值非常小(接近于零),可能导致数值下溢问题。对概率取对数可以将乘法转化为加法,从而避免这种数值不稳定性。

例如,计算多个独立事件的联合概率

P(A∩B)=P(A)⋅P(B),使用对数概率可以转换为加法:
log(P(A∩B))=log(P(A))+log(P(B))。
简化计算: 在某些模型(如隐马尔可夫模型和深度学习模型)中,使用对数概率可以简化似然函数和损失函数的计算。

与损失函数的兼容性: 在深度学习中,常用的损失函数如负对数似然损失(Negative Log-Likelihood Loss, NLLLoss)需要对数概率作为输入。因此,模型输出对数概率是直接兼容这些损失函数的。

4. 对数概率在深度学习中的应用

在神经网络模型(特别是用于分类任务的模型)中,输出通常是一个概率分布。在语言建模任务中,我们希望输出每个词的概率。在训练过程中,为了计算损失,我们使用对数概率。
语言模型: 给定一个句子,模型输出每个词的对数概率。
损失函数: 使用 NLLLoss 来计算预测词与真实词之间的损失。

以下是一个示例,展示如何计算对数概率并使用负对数似然损失

import torch
import torch.nn as nn
import torch.nn.functional as F# 定义一个简单的模型
class SimpleModel(nn.Module):def __init__(self, ntoken, ninp):super(SimpleModel, self).__init__()self.embedding = nn.Embedding(ntoken, ninp)self.decoder = nn.Linear(ninp, ntoken)def forward(self, src):embedded = self.embedding(src)output = self.decoder(embedded)return F.log_softmax(output, dim=-1)# 超参数
ntoken = 10  # 词汇表大小
ninp = 512   # 嵌入维度# 创建模型实例
model = SimpleModel(ntoken, ninp)# 生成假数据
src = torch.randint(0, ntoken, (5, 2))  # 序列长度为5,批次大小为2# 前向传播
log_probs = model(src)# 计算损失
criterion = nn.NLLLoss()
target = torch.randint(0, ntoken, (5, 2))  # 生成目标序列
loss = criterion(log_probs.view(-1, ntoken), target.view(-1))print("Log probabilities shape:", log_probs.shape)
print("Log probabilities:", log_probs)
print("Loss:", loss.item())

加入原始代码

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as pltclass PositionalEncoding(nn.Module):def __init__(self, d_model, dropout=0.1, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):r"""Inputs of forward functionArgs:x: the sequence fed to the positional encoder model (required).Shape:x: [sequence length, batch size, embed dim]output: [sequence length, batch size, embed dim]Examples:>>> output = pos_encoder(x)"""x = x + self.pe[:x.size(0), :]return self.dropout(x)class TransformerModel(nn.Transformer):"""Container module with an encoder, a recurrent or transformer module, and a decoder."""def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):super(TransformerModel, self).__init__(d_model=ninp, nhead=nhead, dim_feedforward=nhid, num_encoder_layers=nlayers)self.model_type = 'Transformer'self.src_mask = Noneself.pos_encoder = PositionalEncoding(ninp, dropout)self.input_emb = nn.Embedding(ntoken, ninp)self.ninp = ninpself.decoder = nn.Linear(ninp, ntoken)self.init_weights()def _generate_square_subsequent_mask(self, sz):return torch.log(torch.tril(torch.ones(sz,sz)))def init_weights(self):initrange = 0.1nn.init.uniform_(self.input_emb.weight, -initrange, initrange)nn.init.zeros_(self.decoder.bias)nn.init.uniform_(self.decoder.weight, -initrange, initrange)def forward(self, src, has_mask=True):if has_mask:device = src.deviceif self.src_mask is None or self.src_mask.size(0) != len(src):mask = self._generate_square_subsequent_mask(len(src)).to(device)self.src_mask = maskelse:self.src_mask = Nonesrc = self.input_emb(src) * math.sqrt(self.ninp)src = self.pos_encoder(src)output = self.encoder(src, mask=self.src_mask)output = self.decoder(output)return F.log_softmax(output, dim=-1)# Hyperparameters
ntoken = 10   # size of vocabulary
ninp = 512    # embedding dimension
nhead = 8     # number of heads in the multiheadattention models
nhid = 512    # the dimension of the feedforward network model in nn.Transformer
nlayers = 2   # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
dropout = 0.2 # the dropout value# Create model
model = TransformerModel(ntoken, ninp, nhead, nhid, nlayers, dropout)# Example input (sequence length: 5, batch size: 2)
src = torch.randint(0, ntoken, (5, 2))# Forward pass
output = model(src)
print("Output shape:", output.shape)  # Should be (sequence length, batch size, ntoken)
print("Output:", output)# Visualize the output
output_np = output.detach().numpy()  # Convert to numpy for visualizationplt.figure(figsize=(12, 6))
for i in range(output_np.shape[1]):  # Iterate over batch elementsplt.subplot(1, output_np.shape[1], i+1)plt.imshow(output_np[:, i, :], aspect='auto', cmap='viridis')plt.colorbar()plt.title(f"Batch {i+1}")plt.xlabel("Token Index")plt.ylabel("Sequence Position")
plt.show()

在 TransformerModel 的 forward 方法中,return F.log_softmax(output, dim=-1) 的作用是将模型的最终输出转换为对数概率分布。为了更好地理解其意义和用途,我们需要详细解释以下几个方面:

1. output 的来源

在 forward 方法中,output 是经过嵌入层(embedding layer)、位置编码(positional encoding)、编码器(encoder)、和解码器(decoder)处理后的张量。假设输入 src 的形状为 (sequence_length, batch_size),则:

经过嵌入层后,形状为 (sequence_length, batch_size, ninp)。
经过位置编码后,形状保持不变。
经过编码器后,形状仍然保持不变。
最后经过解码器后,形状为 (sequence_length, batch_size, ntoken),其中 ntoken 是词汇表的大小。

2. F.log_softmax(output, dim=-1) 的作用

F.log_softmax 是 PyTorch 中的一个函数,用于计算张量的对数软最大值。
主要功能:
归一化:将输出转换为概率分布形式。
对数变换:取对数以提高数值稳定性。
为什么在 dim=-1 维度上应用:
dim=-1 表示在最后一个维度上应用 log_softmax,即在 ntoken 维度上。这意味着对于每个时间步和每个批次,模型输出的每个向量都被归一化为一个概率分布,并且这些概率值是通过取对数的方式表示的。

3. 应用

语言建模任务中,将模型输出转换为对数概率分布是非常常见的做法。这是因为:

数值稳定性:计算对数概率可以避免溢出或下溢的问题。
损失函数兼容性:在训练过程中,通常使用负对数似然损失(negative log-likelihood loss,NLLLoss)来优化模型参数。NLLLoss 需要对数概率作为输入,因此在前向传播中计算 log_softmax 是必要的。

4. 示例代码

假设我们有一个训练好的 TransformerModel 实例,并且我们输入一些假数据来运行前向传播。F.log_softmax(output, dim=-1) 的具体效果如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math# Define the TransformerModel and PositionalEncoding classes here (from previous messages)# Example input (sequence length: 5, batch size: 2)
src = torch.randint(0, ntoken, (5, 2))# Forward pass
output = model(src)
log_probs = F.log_softmax(output, dim=-1)# Output shape
print("Output shape:", log_probs.shape)  # Should be (sequence length, batch size, ntoken)# Output values (log probabilities)
print("Log probabilities:", log_probs)
Output shape: torch.Size([5, 2, 10])
Log probabilities: tensor([[[-1.8122, -1.6211, -2.8076, -2.8982, -3.4530, -1.0481, -3.1035,-3.0695, -6.4251, -3.0296],[-7.1732, -1.0471, -4.7220, -0.9092, -4.7615, -3.8586, -2.5530,-2.2406, -4.5940, -4.3775]],[[-1.8474, -2.3659, -4.0811, -3.3230, -1.9491, -1.2751, -2.2046,-3.1314, -3.8996, -2.3072],[-3.8504, -3.6711, -1.3957, -1.1146, -2.9621, -2.0949, -3.4236,-3.6456, -2.7213, -2.5475]],[[-3.3242, -3.7939, -3.4796, -4.5979, -1.9281, -1.7997, -3.2005,-1.9959, -3.0253, -1.0088],[-4.1581, -0.9709, -4.3932, -0.9737, -5.2762, -2.2979, -3.6968,-3.6558, -4.9326, -2.9538]],[[-2.0125, -1.7920, -2.7189, -3.7525, -2.9609, -2.4254, -3.8162,-2.4056, -4.5775, -1.0567],[-3.5996, -2.0966, -3.8215, -1.6972, -4.0127, -3.1117, -3.2421,-2.7508, -2.3296, -0.9629]],[[-1.7974, -2.0685, -2.4899, -2.8838, -1.9412, -1.7804, -4.3274,-4.6523, -1.6798, -3.0407],[-1.8604, -1.2988, -2.9066, -3.4268, -3.1218, -3.0153, -3.0892,-4.5497, -1.1045, -5.5781]]], grad_fn=<LogSoftmaxBackward0>)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/23409.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

三丰云免费虚拟主机及免费云服务器评测

三丰云是一家专业的云服务提供商&#xff0c;其免费虚拟主机和免费云服务器备受好评。三丰云提供稳定可靠的服务&#xff0c;完全免费的虚拟主机和云服务器让用户可以轻松搭建自己的网站或应用。自从开始使用三丰云的免费虚拟主机和免费云服务器后&#xff0c;我的网站访问速度…

thinkphp3.1中怎么使model查询以其中一个字段为key,另一个字段为值的数组?

在ThinkPHP 3.1中&#xff0c;如果你想要以一个字段作为键&#xff08;key&#xff09;&#xff0c;另一个字段作为值&#xff08;value&#xff09;来获取数组&#xff0c;可以通过查询结果集然后手动构建数组来实现。这里有一个简单的示例&#xff1a; // 假设我们有一个名为…

bash、zsh、fish三种流行Unix shell的区别

bash、zsh、fish三种流行Unix shell的区别 一、功能上的区别二、使用体验上的区别三、以下是每种 Shell 的常用命令行示例&#xff1a;BashZshFish 一、功能上的区别 bash&#xff1a;bash 是 Bourne Again SHell 的缩写&#xff0c;是 Linux 系统中默认的 Shell。bash 的特点是…

SQL性能优化 ——OceanBase SQL 性能调优实践分享(3)

相比较之前的两篇《连接调优》和《索引调优》&#xff0c;本篇文章主要是对先前两篇内容的整理与应用&#xff0c;这里不仅归纳了性能优化的策略&#xff0c;也通过具体的案例&#xff0c;详细展示了如何分析并定位性能瓶颈的步骤。 SQL 调优 先给出性能优化方法和分析性能瓶…

为什么基于 Django 和 Scrapy 的项目需要 @sync_to_async 装饰器

在现代 web 开发中&#xff0c;异步编程正变得越来越重要&#xff0c;特别是对于需要处理大量 I/O 操作的应用程序。Scrapy 是一个用于 web 抓取的异步框架&#xff0c;而 Django 是一个流行的 web 框架&#xff0c;主要采用同步编程模型。将这两个框架结合在一个项目中时&…

YT-DLP 超好用的开源视频下载工具

YT-DLP 是一个功能丰富的命令行音频/视频下载器&#xff0c;是 youtube-dl 的一个分支。由于 youtube-dl 已经停止更新&#xff0c;YT-DLP 不仅继承了其功能&#xff0c;还进行了多项改进和扩展。YT-DLP 不仅可以下载 YouTube 视频&#xff0c;还支持众多站点&#xff0c;包括国…

# RocketMQ 实战:模拟电商网站场景综合案例(二)

RocketMQ 实战&#xff1a;模拟电商网站场景综合案例&#xff08;二&#xff09; 一、SpringBoot 整合 Dubbo &#xff1a;dubbo 概述 1、dubbo 概述 Dubbo &#xff1a;是阿里巴巴公司开源的一款高性能、轻量级的 Java RPC 框架&#xff0c;它提供了三大核心能力&#xff1a…

Ubuntu系统本地搭建WordPress网站并发布公网实现远程访问

文章目录 前言1. 搭建网站&#xff1a;安装WordPress2. 搭建网站&#xff1a;创建WordPress数据库3. 搭建网站&#xff1a;安装相对URL插件4. 搭建网站&#xff1a;内网穿透发布网站4.1 命令行方式&#xff1a;4.2. 配置wordpress公网地址 5. 固定WordPress公网地址5.1. 固定地…

阿里云安装python依赖报错 Requirements should be satisfied by a PEP 517 installer.

Collecting basicsr1.4.2 (from -r requirements.txt (line 16))Downloading http://mirrors.cloud.aliyuncs.com/pypi/packages/86/41/00a6b000f222f0fa4c6d9e1d6dcc9811a374cabb8abb9d408b77de39648c/basicsr-1.4.2.tar.gz (172 kB)━━━━━━━━━━━━━━━━━━━━…

功能安全TSC

TSC 与 FSR 的基本概念 一、引言 在功能安全领域中,TSC(Technical Safety Concept,技术安全概念)和 FSR(Functional Safety Requirements,功能安全要求)是两个至关重要的概念。它们对于确保系统的安全性和可靠性起着关键作用。本文将详细阐述 TSC 和 FSR 的定义、内涵,…

QQ号码采集器

寅甲QQ号码采集软件, 一款采集QQ号、QQ邮件地址&#xff0c;采集QQ群成员、QQ好友的软件。可以按关键词采集&#xff0c;如可以按地区、年龄、血型、生日、职业等采集。采集速度非常快且操作很简单。

电能质量在线监测装置

安科瑞电气股份有限公司 祁洁 15000363176 一、装置概述 APView500电能质量在线监测装置采用了高性能多核平台和嵌入式操作系统&#xff0c;遵照IEC61000-4-30《测试和测量技术-电能质量测量方法》中规定的各电能质量指标的测量方法进行测量&#xff0c;集谐波分析、波形采…

如何应对Android面试官 -> 玩转 MVx(MVC、MVP、MVVM、MVI)

前言 本章主要基于以下几个方向进行 MVx 的讲解&#xff0c;带你玩转 MVx&#xff1b; MVC、MVP、MVVM、MVI 它们到底是什么&#xff1f; 分文件、分模块、分模式 一个文件打天下 为什么不要用一个页面打天下&#xff1f; 页面是给用户看的&#xff0c;随着版本的迭代&…

6.6小结

Problem - A - Codeforces 思路&#xff1a; 一次最多只能走一步或者两步&#xff0c;只需要判断后面两个是不是都是*就行 #include<bits/stdc.h> using namespace std; char a[1010]; int main() {int t;cin >> t;while (t--){int n, flag0;int ans 0;cin >…

kali扩容

通过wmware虚拟机–>设置–>添加40G容量的硬盘。 ──(root㉿kali)-[~/桌面] fdisk -lDisk /dev/sda: 40 GiB, 42949672960 bytes, 83886080 sectors …

DevOps的原理及应用详解(二)

本系列文章简介&#xff1a; 在当今快速变化的商业环境中&#xff0c;企业对于软件交付的速度、质量和安全性要求日益提高。传统的软件开发和运维模式已经难以满足这些需求&#xff0c;因此&#xff0c;DevOps&#xff08;Development和Operations的组合&#xff09;应运而生&a…

Qt5学习笔记

一、基础知识 1、基本控件类型 水平弹簧与垂直弹簧的父类都是QSpaceItem。关于PushButton相关的控件类型&#xff1a; QPushButton&#xff1a;最基础的按钮类型。QToolButton&#xff1a;可以控制图片、文字任意组合的显示方式的按钮类型。QRadioButton&#xff1a;就像rad…

Java多线程-初阶1

博主主页: 码农派大星. 数据结构专栏:Java数据结构 数据库专栏:MySQL数据库 JavaEE专栏:JavaEE 关注博主带你了解更多数据结构知识 1. 认识线程&#xff08;Thread&#xff09; 1.线程是什么 ⼀个线程就是⼀个 "执⾏流". 每个线程之间都可以按照顺序执⾏⾃⼰的代…

高并发数据处理中心服务器设计

涉及的相关框架Spring Cloud、RabbitMQ、Redis 和 MySQL&#xff1b; Spring Cloud&#xff1a;用于微服务的开发&#xff0c;确保服务间的通信和协作。 RabbitMQ&#xff1a;用于异步消息队列&#xff0c;确保系统的高可用性和扩展性。 Redis&#xff1a;用作缓存&#xff…

计算机图形学入门07:光栅化中的采样与走样

1.什么是光栅化&#xff1f; 在前面的章节里提过&#xff0c;光栅化(Rasterization)就是将物体投影在屏幕上的图形&#xff0c;依据像素打散&#xff0c;每一个像素中填充不同的颜色。 如下图中的老虎&#xff0c;可以看到屏幕上有各种多边形&#xff0c;这些多边形经过各种变换…