UniversalTransformer with Adaptive Computation Time(ACT)

在这里插入图片描述


原论文链接:https://arxiv.org/abs/1807.03819


Main code

import torch
import numpy as npclass PositionTimestepEmbedding(torch.nn.Module):def forward(self, x, t):device = x.devicesequence_length = x.size(1)d_model = x.size(2)position_embedding = np.array([[pos / np.power(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)] for pos in range(sequence_length)])position_embedding[:, 0::2] = np.sin(position_embedding[:, 0::2])position_embedding[:, 1::2] = np.cos(position_embedding[:, 1::2])timestep_embedding = np.array([[t / np.power(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)]])timestep_embedding[:, 0::2] = np.sin(timestep_embedding[:, 0::2])timestep_embedding[:, 1::2] = np.sin(timestep_embedding[:, 1::2])embedding = position_embedding + timestep_embeddingreturn x + torch.tensor(embedding, dtype=torch.float, requires_grad=False, device=device)class MultiHeadAttention(torch.nn.Module):def __init__(self, d_model, num_heads, dropout=0.):super().__init__()self.d_model = d_modelself.num_heads = num_headsself.head_dim = d_model // num_headsassert self.head_dim * num_heads == self.d_model, "d_model must be divisible by num_heads"self.query = torch.nn.Linear(d_model, d_model)self.key = torch.nn.Linear(d_model, d_model)self.value = torch.nn.Linear(d_model, d_model)self.dropout = torch.nn.Dropout(dropout)self.output = torch.nn.Linear(d_model, d_model)self.layer_norm = torch.nn.LayerNorm(d_model)def scaled_dot_product_attention(self, q, k, v, mask=None):scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)if mask is not None:scores = scores.masked_fill(mask, -np.inf)scores = scores.softmax(dim=-1)scores = self.dropout(scores)return torch.matmul(scores, v), scoresdef forward(self, q, k, v, mask=None):batch_size = q.size(0)residual = qif mask is not None:mask = mask.unsqueeze(1)q = self.query(q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)k = self.key(k).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)v = self.value(v).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)out, scores = self.scaled_dot_product_attention(q, k, v, mask)out = (out.transpose(1, 2).contiguous().view(batch_size, -1, self.num_heads * self.head_dim))out = self.output(out)out += residualreturn self.layer_norm(out)class TransitionFunction(torch.nn.Module):def __init__(self, d_model, dim_transition, dropout=0.):super().__init__()self.linear1 = torch.nn.Linear(d_model, dim_transition)self.relu = torch.nn.ReLU()self.linear2 = torch.nn.Linear(dim_transition, d_model)self.dropout = torch.nn.Dropout(dropout)self.layer_norm = torch.nn.LayerNorm(d_model)def forward(self, x):y = self.linear1(x)y = self.relu(y)y = self.linear2(y)y = self.dropout(y)y = y + xreturn self.layer_norm(y)class EncoderBasicLayer(torch.nn.Module):def __init__(self, d_model, dim_transition, num_heads, dropout=0.):super().__init__()self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)self.transition = TransitionFunction(d_model, dim_transition, dropout)def forward(self, block_inputs, enc_self_attn_mask=None):self_attention_outputs = self.self_attention(block_inputs, block_inputs, block_inputs, enc_self_attn_mask)block_outputs = self.transition(self_attention_outputs)return block_outputsclass DecoderBasicLayer(torch.nn.Module):def __init__(self, d_model, dim_transition, num_heads, dropout=0.):super().__init__()self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)self.attention_enc_dec = MultiHeadAttention(d_model, num_heads, dropout)self.transition = TransitionFunction(d_model, dim_transition, dropout)def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask=None, dec_enc_attn_mask=None):dec_query = self.self_attention(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)block_outputs = self.attention_enc_dec(dec_query, enc_outputs, enc_outputs, dec_enc_attn_mask)block_outputs = self.transition(block_outputs)return block_outputsclass RecurrentEncoderBlock(torch.nn.Module):def __init__(self, num_layers, d_model, dim_transition, num_heads, dropout=0.):super().__init__()self.layers = torch.nn.ModuleList([EncoderBasicLayer(d_model,dim_transition,num_heads,dropout) for _ in range(num_layers)])def forward(self, x, enc_self_attn_mask=None):for l in self.layers:x = l(x, enc_self_attn_mask)return xclass RecurrentDecoderBlock(torch.nn.Module):def __init__(self, num_layers, d_model, dim_transition, num_heads, dropout=0.):super().__init__()self.layers = torch.nn.ModuleList([DecoderBasicLayer(d_model,dim_transition,num_heads,dropout) for _ in range(num_layers)])def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):for l in self.layers:dec_inputs = l(dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)return dec_inputsclass AdaptiveNetwork(torch.nn.Module):def __init__(self, d_model, dim_transition, epsilon, max_hop):super().__init__()self.threshold = 1.0 - epsilonself.max_hop = max_hopself.halting_predict = torch.nn.Sequential(torch.nn.Linear(d_model, dim_transition),torch.nn.ReLU(),torch.nn.Linear(dim_transition, 1),torch.nn.Sigmoid())def forward(self, x, mask, pos_time_embed, recurrent_block, encoder_output=None):device = x.devicehalting_probability = torch.zeros((x.size(0), x.size(1)), device=device)remainders = torch.zeros((x.size(0), x.size(1)), device=device)n_updates = torch.zeros((x.size(0), x.size(1)), device=device)previous = torch.zeros_like(x, device=device)step = 0while (((halting_probability < self.threshold) & (n_updates < self.max_hop)).byte().any()):x = x + pos_time_embed(x, step)p = self.halting_predict(x).squeeze(-1)still_running = (halting_probability < 1.0).float()new_halted = (halting_probability + p * still_running > self.threshold).float() * still_runningstill_running = (halting_probability + p * still_running <= self.threshold).float() * still_runninghalting_probability = halting_probability + p * still_runningremainders = remainders + new_halted * (1 - halting_probability)halting_probability = halting_probability + new_halted * remaindersn_updates = n_updates + still_running + new_haltedupdate_weights = p * still_running + new_halted * remaindersif encoder_output is not None:x = recurrent_block(x, encoder_output, mask[0], mask[1])else:x = recurrent_block(x, mask)previous = ((x * update_weights.unsqueeze(-1)) + (previous * (1 - update_weights.unsqueeze(-1))))step += 1return previousclass Encoder(torch.nn.Module):def __init__(self, epsilon, max_hop, num_layers, d_model, dim_transition, num_heads, dropout=0.):super().__init__()assert 0 < epsilon < 1, "0 < epsilon < 1 !!!"self.pos_time_embedding = PositionTimestepEmbedding()self.recurrent_block = RecurrentEncoderBlock(num_layers,d_model,dim_transition,num_heads,dropout)self.adaptive_network = AdaptiveNetwork(d_model, dim_transition, epsilon, max_hop)def forward(self, x, enc_self_attn_mask=None):return self.adaptive_network(x, enc_self_attn_mask, self.pos_time_embedding, self.recurrent_block)class Decoder(torch.nn.Module):def __init__(self, epsilon, max_hop, num_layers, d_model, dim_transition, num_heads, dropout=0.):super().__init__()assert 0 < epsilon < 1, "0 < epsilon < 1 !!!"self.pos_time_embedding = PositionTimestepEmbedding()self.recurrent_block = RecurrentDecoderBlock(num_layers,d_model,dim_transition,num_heads,dropout)self.adaptive_network = AdaptiveNetwork(d_model, dim_transition, epsilon, max_hop)def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):return self.adaptive_network(dec_inputs, (dec_self_attn_mask, dec_enc_attn_mask),self.pos_time_embedding, self.recurrent_block, enc_outputs)class AdaptiveComputationTimeUniversalTransformer(torch.nn.Module):def __init__(self, d_model, dim_transition, num_heads, enc_attn_layers, dec_attn_layers, epsilon, max_hop, dropout=0.):super().__init__()self.encoder = Encoder(epsilon, max_hop, enc_attn_layers, d_model, dim_transition, num_heads, dropout)self.decoder = Decoder(epsilon, max_hop, dec_attn_layers, d_model, dim_transition, num_heads, dropout)def forward(self, src, tgt, enc_self_attn_mask=None, dec_self_attn_mask=None, dec_enc_attn_mask=None):enc_outputs = self.encoder(src, enc_self_attn_mask)return self.decoder(tgt, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)

Mask

# from https://zhuanlan.zhihu.com/p/403433120
def get_attn_subsequence_mask(seq):  # seq: [batch_size, tgt_len]attn_shape = [seq.size(0), seq.size(1), seq.size(1)]subsequence_mask = np.triu(np.ones(attn_shape), k=1)  # 生成上三角矩阵,[batch_size, tgt_len, tgt_len]subsequence_mask = torch.from_numpy(subsequence_mask).bool()  # [batch_size, tgt_len, tgt_len]return subsequence_maskdef get_attn_pad_mask(seq_q, seq_k):  # seq_q: [batch_size, seq_len] ,seq_k: [batch_size, seq_len]batch_size, len_q = seq_q.size()batch_size, len_k = seq_k.size()pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # 判断 输入那些含有P(=0),用1标记 ,[batch_size, 1, len_k]return pad_attn_mask.expand(batch_size, len_q, len_k)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/595245.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Qt界面篇:Qt停靠控件QDockWidget、树控件QTreeWidget及属性控件QtTreePropertyBrowser的使用

1、功能介绍 本篇主要使用Qt停靠控件QDockWidget、树控件QTreeWidget及Qt属性控件QtTreePropertyBrowser来搭建一个简单实用的主界面布局。效果如下所示。 2、控件使用详解 2.1 停靠控件QDockWidget QDockWidget可以停靠在 QMainWindow 内或作为桌面上的顶级窗口浮动。默认值…

基于OpenCV的透视变化

基本概念 透视变换(Perspective Transformation)是仿射变换的一种非线性扩展,是将图片投影到一个新的视平面(Viewing Plane)&#xff0c;也称作投影映射(Projective Mapping)。 原理&#xff1a;将二维的图片投影到一个三维视平面上&#xff0c;然后再转换到二维坐标下&#…

everything 本地文件搜索工具 完胜WIndows搜索 速度99% 超级给力

"Everything" 是一个 Windows 平台上的免费软件&#xff0c;它是一款功能强大的本地文件搜索工具。它允许用户在计算机上快速而准确地搜索文件和文件夹。以下是一些 "Everything" 的主要特点&#xff1a; 实时搜索&#xff1a; "Everything" 提供…

【小沐学NLP】Python实现TF-IDF算法(nltk、sklearn、jieba)

文章目录 1、简介1.1 TF1.2 IDF1.3 TF-IDF2.1 TF-IDF(sklearn)2.2 TF-IDF(nltk)2.3 TF-IDF(Jieba)2.4 TF-IDF(python) 结语 1、简介 TF-IDF&#xff08;term frequency–inverse document frequency&#xff09;是一种用于信息检索与数据挖掘的常用加权技术。TF是词频(Term Fr…

多台西门子PLC对接Oracle数据库,实现PLC与数据库双向数据通讯

智能网关IGT-DSER方便实现多台PLC与数据库之间的数据通讯&#xff0c;既可以读取PLC的数据上报到数据库&#xff0c;也可以从数据库查询数据后写入到PLC的寄存器。 网关安装在设备侧&#xff0c;与设备同时起停&#xff0c;不担心数据丢失&#xff1b;在断网、服务器维护上报数…

霹雳吧啦Wz《pytorch图像分类》-p5ResNet网络

《pytorch图像分类》p5ResNet网络结构 1 网络中的亮点1.1 超深的网络结构1.2 residual模块1.3 Batch Normalization1.4 迁移学习简介 2 模块类代码2.1 BasicBlock&#xff08;18 & 32 layers&#xff09;2.2 Bottleneck&#xff08;50 & 101 & 152layers&#xff0…

爬虫如何获取免费代理IP(二)

89ip代理爬取代码实现 一、代码实现 import requests import time import random from fake_useragent import UserAgent from lxml import etree import os import csv""" 89ip代理爬取 """class IPSipder(object):def __init__(self):self.u…

Python 操作 JMeter 探索:pymeter 实操指南

概要 JMeter 是一个流行的性能测试工具&#xff0c;用于测试 Web 应用程序的性能和负载。它通常与 GUI 一起使用&#xff0c;但如果您想在自动化测试中集成 JMeter&#xff0c;或者以编程方式创建和运行测试计划&#xff0c;那么 pymeter 库将是一个强大的工具。本文将介绍如何…

2023-12-26分割回文串和子集以及子集II

131. 分割回文串 思想&#xff1a;回溯三步骤&#xff01;① 传入参数 ② 回溯结束条件 ③ 单层搜索逻辑&#xff01;抽象成回溯树&#xff0c;树枝上是每次从头部穷举切分出的子串&#xff0c;节点上是待切分的剩余字符串【从头开始每次往后加一】 class Solution:def partiti…

JavaScript中实现页面跳转的多种方法【通俗易懂】

✨前言✨   本篇文章主要在于如何使用JavaScript中的各种实现页面跳转的方式 &#x1f352;欢迎点赞 &#x1f44d; 收藏 ⭐留言评论 &#x1f4dd;私信必回哟&#x1f601; &#x1f352;博主将持续更新学习记录收获&#xff0c;友友们有任何问题可以在评论区留言 在JavaScr…

Fortify漏洞之Sql Injection(sql注入)

Fortify漏洞之Sql Injection&#xff08;sql注入&#xff09; 前言 本篇先对Fortify做个简单的认识&#xff0c;同时总结一下sql注入的漏洞&#xff01; 一、Fortify软件介绍 Fortify是一款能扫描分析代码漏洞的强大工具&#xff0c;是由一家加州软件安全厂商开发而成&#…

为什么要为IP地址申请SSL证书?

在不断发展的互联网世界中&#xff0c;网络安全越来越受到重视&#xff0c;这不仅是因为相关法律法规政策的实施&#xff0c;还因为确保网络安全可以为企业减少财产损失。而确保企业在线业务安全的关键一点&#xff0c;就是SSL证书的部署&#xff0c;SSL证书不仅可以加密数据&a…

Unity中Shader雾效在场景中的调节技巧

文章目录 前言一、修改棋盘格Shader的Cull可以在属性面板控制1、在属性面板定义CullMode2、在SubShader中&#xff0c;使用CullMode3、这样就可以在不同剔除情况下使用棋盘格场景了 二、调节天际线颜色和雾融为一体1、在摄像机设置不渲染天空盒&#xff0c;渲染单一颜色2、采样…

如何解决大模型的「幻觉」问题?

如何解决大模型的「幻觉」问题&#xff1f; 如何解决大模型的「幻觉」问题&#xff1f;幻觉产生原因&#xff1f;模型原因数据层面 幻觉怎么评估&#xff1f;Reference-based&#xff08;基于参考信息&#xff09;基于模型的输入、预先定义的目标输出基于模型的输入 Reference-…

Elasticsearch基本操作之索引操作

本文说下Elasticsearch基本操作之索引操作 文章目录 概述创建索引创建索引示例重复创建索引示例 查看索引查看所有索引查看单个索引 删除索引删除索引 概述 由于是使用命令来操作Elasticsearch&#xff0c;可以使用kibana&#xff0c;postman和apifox等工具 我使用了apifox来执…

【bug】【VSCode】远程终端TERMINAL打不开

【bug】【VSCode】远程终端TERMINAL打不开 可能的原因现象分析解决 可能的原因 昨天晚上vscode在打开多个TERMINAL的情况下&#xff0c;挂了一晚上&#xff0c;今早上来看的时候全都lost connections…。然后关闭再打开就出现了如上现象。 早上一来到实验室就要debug… 现象…

西北工业大学计算机组成原理实验报告——verilog前两次

说明 为了有较好的可读性&#xff0c;报告仅仅粘贴关键代码。该PDF带有大纲功能&#xff0c;点击大纲中的对应标题&#xff0c;可以快速跳转。 实验目标 掌握单周期CPU执行指令的流程和原理&#xff1b;学习使用verilog HDL语言实现单周期CPU, 并通过功能仿真&#xff1b;提…

k8s之pod

pod是k8s中最小的资源管理组件 pod也是最小化运行容器化的应用的资源管理对象 pod是一个抽象的概念&#xff0c;可以理解成一个或者多个容器化应用的集合 pod可以是一个或者多个 在一个pod中运行一个容器&#xff08;最常用的方式&#xff09; 在一个pod中同时运行多个容器…

第二证券:长期布局重要窗口或至 险资看涨A股

新年伊始&#xff0c;稳妥资金对2024年权益商场出资更为达观。多家险资组织告诉上海证券报记者&#xff0c;在经历了2023年的震动调整行情后&#xff0c;2024年A股商场机遇大于危险&#xff0c;商场体现或将显着优于2023年。 详细来看&#xff0c;两方面要素支撑权益商场向好&…

总结MySQL 的一些知识点:MySQL 排序

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…