samout 最新版本state 逐层控制加速收敛

代码

import torch
import numpy as npclass MaxState(torch.nn.Module):def __init__(self, hidden_dim, heads, win):super(MaxState, self).__init__()assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."self.head_size = hidden_dim // headsself.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.head_num = headsself.win = winself.hidden = hidden_dimself.mask = torch.triu(torch.ones([win, win])).to("cuda")self.layer_nor = torch.nn.LayerNorm(hidden_dim)def forward(self, input_data, state=None):# self.head.to("cuda")b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.winwindow = torch.ones([1, w]).to("cuda")out = self.head(input_data)out = out.unsqueeze(-1) @ windowout = out.permute([0, 2, 1, 3])one_list = []if state is None:state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")state = state.to("cuda")for i in range(0, s, w):state.reshape([state.shape[0], -1])j = w + ione = out[:, :, i:j]_, _, r, c = one.shapeif r != self.win:one = torch.where(self.mask[:r, :] == 1, one, torch.Tensor([-float('inf')]).to("cuda"))else:one = torch.where(self.mask == 1, one, torch.Tensor([-float('inf')]).to("cuda"))if i == 0:one = torch.concat([one, state @ window], axis=2)state, _ = torch.max(one, axis=2, keepdim=True)else:state1, _ = torch.max(one, axis=2, keepdim=True)# state = torch.sin(self.state(state1.reshape([state1.shape[0], -1]))*state.reshape([state.shape[0], -1]))state1 = self.state(state1.permute([0, 3, 1, 2]).reshape([state1.shape[0], -1, state1.shape[1]]))state = state1.permute([0, 2, 1]).unsqueeze(-2) + state# state = state.reshape(state1.shape)one = torch.concat([one, state], axis=2)state, _ = torch.max(one, axis=2, keepdim=True)one = state.reshape([b, k, h, w])state = state[..., -1:]if r != self.win:one = one[..., :r]one = one.permute([0, 3, 1, 2])one_list.append(one)out = torch.concat(one_list, 1)out = out.reshape([b, s, -1])return out, stateclass FeedForward(torch.nn.Module):def __init__(self, hidden_size):super(FeedForward, self).__init__()self.ffn1 = torch.nn.Linear(hidden_size, hidden_size * 2)self.ffn2 = torch.nn.Linear(hidden_size * 2, hidden_size)self.gate = torch.nn.Linear(hidden_size, hidden_size * 2)self.relu = torch.nn.ReLU()def forward(self, x):x1 = self.ffn1(x)x2 = self.relu(self.gate(x))x = x1 * x2x = self.ffn2(x)return xclass DecoderLayer(torch.nn.Module):def __init__(self, hidden_size, num_heads):super(DecoderLayer, self).__init__()# self.self_attention = MaskMultiHeadAttention(hidden_size, num_heads)self.self_attention = MaxState(hidden_size, num_heads, 8)self.ffn = FeedForward(hidden_size)self.layer_norm = torch.nn.LayerNorm(hidden_size)def forward(self, x, state=None, seq_len=None):x1, state = self.self_attention(x, state)x = self.layer_norm(self.ffn(x1) + x)  # Feed-Forward with residual connectionreturn x, stateclass SamOut(torch.nn.Module):def __init__(self, voc_size, hidden_size, num_heads, num_layers):super(SamOut, self).__init__()self.em = torch.nn.Embedding(voc_size, hidden_size, padding_idx=3)self.pos = torch.nn.Embedding(1024, hidden_size)self.decoder_layers = torch.nn.ModuleList([DecoderLayer(hidden_size, num_heads) for _ in range(num_layers)])self.head = torch.nn.Linear(hidden_size, voc_size)self.head_state = torch.nn.Linear(hidden_size, num_layers)def forward(self, x, state=None, seq_len=None):x = self.em(x)if x.shape[1] >= 1024:pos = self.pos(torch.range(0, x.shape[1] - 1).long() // 1024).unsqueeze(0)pos = self.pos(torch.range(0, x.shape[1] - 1).long() % 1024).unsqueeze(0) + poselse:pos = self.pos(torch.range(0, x.shape[1] - 1).long().to("cuda")).unsqueeze(0)if state is None:state = [None] * len(self.decoder_layers)i = 0for decoder_layer in self.decoder_layers:x1, state[i] = decoder_layer(pos + x, state[i])x = x1 + xi += 1state_data = self.head_state((torch.concat(state, -1).squeeze(-2)).permute([0, 2, 1]))return self.head(x), state, state_dataif __name__ == '__main__':net = SamOut(235, 256, 16, 4)net(torch.randint(0, 200, [2, 3000]))

解析

import torch
import numpy as np

这两行代码导入了PyTorch库和NumPy库,它们分别用于深度学习和数值计算。

class MaxState(torch.nn.Module):def __init__(self, hidden_dim, heads, win):super(MaxState, self).__init__()

这里定义了一个名为MaxState的PyTorch模块。它继承自torch.nn.Module,这是所有自定义模型的基类。

    assert hidden_dim % heads == 0, "Hidden size must be divisible by the number of heads."

这里使用assert语句确保hidden_dim能够被heads整除,这是多头注意力机制的一个要求。

    self.head_size = hidden_dim // headsself.head = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.state = torch.nn.Linear(hidden_dim, hidden_dim, bias=False)self.head_num = headsself.win = winself.hidden = hidden_dimself.mask = torch.triu(torch.ones([win, win])).to("cuda")self.layer_nor = torch.nn.LayerNorm(hidden_dim)

这里初始化了一些类的属性,包括线性层、头数、窗口大小、隐藏层大小、上三角矩阵掩码以及层归一化。

    def forward(self, input_data, state=None):

定义了forward方法,这是模型的前向传播过程。

        b, s, k, h, w = input_data.shape[0], input_data.shape[1], self.head_num, self.head_size, self.win

从输入数据中提取批次大小、序列长度、头数、头大小和窗口大小。

        window = torch.ones([1, w]).to("cuda")

创建一个窗口大小的一维张量,并将其移动到GPU上。

        out = self.head(input_data)

对输入数据进行线性变换。

        out = out.unsqueeze(-1) @ window

将输出数据与窗口张量进行矩阵乘法。

        out = out.permute([0, 2, 1, 3])

调整输出数据的维度顺序。

        one_list = []if state is None:state = torch.ones([out.shape[0], out.shape[1], 1, 1]) * float("-inf")state = state.to("cuda")

如果状态为空,则初始化状态张量,并将其移动到GPU上。

        for i in range(0, s, w):# ... (省略中间代码)

对序列进行迭代,每次迭代处理一个窗口大小的数据。

        return out, state

返回最终输出和状态。
接下来是FeedForwardDecoderLayerSamOut类的定义,它们的结构和MaxState类类似,都是自定义的PyTorch模块。

if __name__ == '__main__':net = SamOut(235, 256, 16, 4)net(torch.randint(0, 200, [2, 3000]))

最后,这段代码实例化了SamOut类,并使用随机生成的输入数据进行了一次前向传播。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/869843.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C++深度学习】多态(概念虚函数抽象类)

✨ 疏影横斜水清浅,暗香浮动月黄昏 🌏 📃个人主页:island1314 🔥个人专栏:C学习 🚀 欢迎关注:👍点赞 &…

第2章 大话 ASP.NET Core 入门

第1章 框架学习的基石与实战策略 链接 第2章 大话 ASP.NET Core 入门 1.什么是ASP.NET Core框架 ASP.NET Core是一个超级棒的框架,它是免费的,你可以在任何主流的系统上,比如Windows、Linux或macOS上使用它,而且它是完全开放源…

appium环境准备

前言: 本系列教程会从软件的基本安装开始,最终目的是通过完成几个案例后, 大家实现自由抓取App中想要的资源。 本系列以后会更的: Appium基本使用及控制真机及安卓模拟器Mitmproxy抓包工具的基本使用Fiddler抓包软件的基本使用 了解了以上的基本操作,我们就可进行手机资源…

Splunk Enterprise路径遍历漏洞风险通告

今日&#xff0c;亚信安全CERT监控到安全社区研究人员发布安全通告&#xff0c;披露了Splunk Enterprise 路径遍历漏洞(CVE-2024-36991)。该漏洞发生在9.2.0<version<9.2.2&#xff0c;9.1.0<version<9.1.5&#xff0c;以及9.0.0<version<9.0.10的windows版本…

3102.力扣每日一题7/9 Java(TreeMap)

博客主页&#xff1a;音符犹如代码系列专栏&#xff1a;算法练习关注博主&#xff0c;后期持续更新系列文章如果有错误感谢请大家批评指出&#xff0c;及时修改感谢大家点赞&#x1f44d;收藏⭐评论✍ 目录 TreeMap详解 解题思路 解题方法 时间复杂度 空间复杂度 Code T…

【Python】 已解决:ModuleNotFoundError: No module named…

文章目录 一、分析问题背景二、可能出错的原因三、错误代码示例四、正确代码示例五、注意事项 已解决&#xff1a;ModuleNotFoundError: No module named… 一、分析问题背景 在使用Python进行开发时&#xff0c;有时会遇到“ModuleNotFoundError: No module named…”这样的…

android gradle开发基础

Android Gradle开发基础涉及多个方面&#xff0c;包括Gradle的基本概念、环境配置、构建脚本的编写、任务与插件的使用等。以下是对这些方面的详细介绍&#xff1a; 一、Gradle基础 1. Gradle简介 Gradle是一个开源的构建自动化系统&#xff0c;专注于灵活性和性能。它支持多…

洞察与理解:自闭症儿童的典型行为特征解析

作为星贝育园自闭症儿童康复中心的一名专业教师&#xff0c;我深知理解自闭症儿童的行为特征对于早期识别、干预和提供恰当支持至关重要。自闭症&#xff0c;或称孤独症谱系障碍&#xff08;Autism Spectrum Disorder, ASD&#xff09;&#xff0c;是一组影响个体社交互动、沟通…

创新设计策略:提升大屏幕可视化设计效果的关键方法

随着科技的不断发展和数据量的快速增长&#xff0c;数据可视化大屏在各个行业中的应用越来越广泛&#xff0c;可以帮助人们更好地理解和分析数据&#xff0c;可视化大屏设计也因此成了众多企业的需求。但很多设计师对可视化大屏设计并不了解&#xff0c;也不知道如何制作可视化…

谁说forEach不支持异步代码,只是你拿不到异步结果而已

在前面探讨 forEach 中异步请求后端接口时&#xff0c;很多人都知道 forEach 中 async/await 实际是无效的&#xff0c;很多文章也说&#xff1a;forEach 不支持异步&#xff0c;forEach 只能同步运行代码&#xff0c;forEach 会忽略 await 直接进行下一次循环… 当时我的理解…

dify/api/models/tool.py文件中的数据表

源码位置&#xff1a;dify/api/models/tool.py ToolProvider 表结构 字段英文名数据类型字段中文名字备注idStringUUIDIDUUID生成tenant_idStringUUID租户ID非空tool_nameString工具名称非空encrypted_credentialsText加密凭证可为空is_enabledBoolean是否启用默认值为 false…

[GICv3] 1.引言Introduction

基本概念 通用中断控制器 (GIC) 从外设获取中断&#xff0c;确定它们的优先级&#xff0c;然后将它们传送到适当的处理器内核。 下图了为一个 GIC 从 n 个不同的外设获取中断&#xff0c;并将它们分配给两个不同的处理器。 ​​ GCI(Generic Interrupt Controller)&#xff0c…

Caused by: java.lang.NoSuchMethodError: com.squareup.javapoet.MethodSpec

导入第三方module运行项目报&#xff1a; Caused by: java.lang.NoSuchMethodError: com.squareup.javapoet.MethodSpec$Builder.addComment(Ljava/lang/String;[Ljava/lang/Object;)Lcom/squareup/javapoet/MethodSpec$Builder; Caused by: java.lang.RuntimeException: Cann…

IPython的交互式命令行:交互式命令行界面

IPython的交互式命令行&#xff1a;交互式命令行界面 介绍 IPython是一款功能强大的交互式命令行工具&#xff0c;它极大地增强了Python编程的体验。通过提供即时反馈和动态探索功能&#xff0c;IPython帮助初学者更快速、更直观地掌握Python编程技能。本指南将详细介绍IPyth…

AI Agent 的发展现状、行业结构与趋势分析

Agent 来自一种哲学概念&#xff0c;是个很古老的哲学术语&#xff0c;从哲学意义上讲&#xff0c;“代理”的概念涉及实体的自主性&#xff0c;具有行使意志、做出选择和采取行动的能力&#xff0c;而不是被动地对外部刺激做出反应。后来人们将这一概念引入计算机科学领域&…

ApiFox或postman怎么用params类型传输json或集合+json的String类型

你是否碰见过这样的接口? post请求然后传输的参数都要和查询时一样以param形式传参数,那String什么的都好说,传就直接进后台了,那json呢,集合呢,是不是直接给你返400呢. 1.传json如何处理 那我们看看怎么实现,如果你要传json数据,那需要将特殊字符转义,也叫url转码,否则传不…

【HarmonyOS】关于官方推荐的组件级路由Navigation的心得体会

前言 最近因为之前的630版本有点忙&#xff0c;导致断更了几天&#xff0c;现在再补上。换换脑子。 目前内测系统的华为应用市场&#xff0c;各种顶级APP陆续都放出来beta版本了&#xff0c;大体上都完成了主流程的开发。欣欣向荣的气息。 学习思路 关于学习HarmonyOS的问题…

热点解读 | 小红书「县城生活」趋势前瞻

“县城婆罗门”、“月薪两万不如县城贵妇”、“北漂打工人回县城被穷笑”……继中产之后&#xff0c;县城成为又一个全网热议的焦点。 县城叙事重返舆论场&#xff0c;本期千瓜将进一步解构「县城」语境下的个体表现&#xff0c;帮助品牌沉淀用户心智&#xff0c;塑造新时代竞争…

​​​防御第一次作业

1、拓扑图及实验要求&#xff1a; 2、配置&#xff1a; 配置终端及服务器IP地址&#xff1a; Pc2&#xff1a; Client1&#xff1a; Pc4&#xff1a; Client2&#xff1a; PC1&#xff1a; Server1&#xff1a; Server2&#xff1a; 防火墙基础配置&#xff1a; [fw1]int g …

maven 依赖冲突

依赖冲突 1、对于 Maven 而言&#xff0c;同一个 groupId 同一个 artifactId 下&#xff0c;只能使用一个 version。 <!-- https://mvnrepository.com/artifact/org.apache.commons/commons-math3 --><dependency><groupId>org.apache.commons</groupId&…