block-recurrent-transformer-pytorch 学习笔记

目录

有依赖项1:

没有依赖项,没有使用例子

没有依赖项2:


有依赖项1:

GitHub - dashstander/block-recurrent-transformer: Pytorch implementation of "Block Recurrent Transformers" (Hutchins & Schlag et al., 2022)

没有依赖项,没有使用例子

GitHub - jskinn/pytorch-block-recurrent-transformer: Pytorch implementation of the Block-Recurrent Transformer. Official JAX implementation here: https://github.com/google-research/meliad

没有依赖项2:

GitHub - lucidrains/block-recurrent-transformer-pytorch: Implementation of Block Recurrent Transformer - Pytorch

import math
from random import random
from functools import wraps, partial
from itertools import zip_longest
from collections import namedtuple, defaultdict
from packaging import versionimport torch
import torch.nn.functional as F
from torch import nn, einsumfrom einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrangefrom beartype import beartype
from beartype.typing import Optional, List, Tuple# helpersdef exists(val):return val is not Nonedef default(val, d):return val if exists(val) else ddef is_empty(t: torch.Tensor):return t.numel() == 0def cast_tuple(t, length = 1):return t if isinstance(t, tuple) else ((t,) * length)def all_unique(arr):return len(arr) == len(set(arr))def eval_decorator(fn):def inner(self, *args, **kwargs):was_training = self.trainingself.eval()out = fn(self, *args, **kwargs)self.train(was_training)return outreturn innerdef once(fn):called = False@wraps(fn)def inner(x):nonlocal calledif called:returncalled = Truereturn fn(x)return innerprint_once = once(print)def compact(arr):return [*filter(exists, arr)]def and_reduce(arr: List[torch.Tensor]):if len(arr) == 0:return Nonehead, *rest = arrfor t in rest:head = head & treturn headdef safe_cat(*args, dim = 1):args = compact(args)if len(args) == 0:return Nonereturn torch.cat(args, dim = dim)def divisible_by(numer, denom):return (numer % denom) == 0def l2norm(t):return F.normalize(t, dim = -1)def pack_one(t, pattern):return pack([t], pattern)def unpack_one(t, ps, pattern):return unpack(t, ps, pattern)[0]def pad_at_dim(t, pad, dim = -1, value = 0.):dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1)zeros = ((0, 0) * dims_from_right)return F.pad(t, (*zeros, *pad), value = value)# bias-less layernormclass LayerNorm(nn.Module):def __init__(self, dim):super().__init__()self.gamma = nn.Parameter(torch.ones(dim))self.register_buffer("beta", torch.zeros(dim))def forward(self, x):return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)# sampling helpersdef log(t, eps = 1e-20):return torch.log(t.clamp(min = eps))def gumbel_noise(t):noise = torch.zeros_like(t).uniform_(0, 1)return -log(-log(noise))def gumbel_sample(t, temperature = 1., dim = -1):return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)def top_k(logits, thres = 0.9):k = math.ceil((1 - thres) * logits.shape[-1])val, ind = torch.topk(logits, k)probs = torch.full_like(logits, float('-inf'))probs.scatter_(1, ind, val)return probs# rotary positional embedding w/ xpos
# https://arxiv.org/abs/2104.09864
# https://arxiv.org/abs/2212.10554v1class RotaryEmbedding(nn.Module):def __init__(self,dim,width,scale_base = 512,theta = 10000):super().__init__()self.width = widthinv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))self.register_buffer("inv_freq", inv_freq, persistent = False)self.scale_base = scale_basescale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)self.register_buffer('scale', scale, persistent = False)self.register_buffer('cached_freqs', None, persistent = False)self.register_buffer('cached_scales', None, persistent = False)@propertydef device(self):return next(self.buffers()).devicedef forward(self):device, seq_len = self.device, self.widthif exists(self.cached_freqs):cached_seq_len = self.cached_freqs.shape[-2]if cached_seq_len >= seq_len:return self.cached_freqs[:seq_len], self.cached_scales[:seq_len]t = torch.arange(seq_len, device = device).type_as(self.inv_freq)freqs = torch.einsum('i , j -> i j', t, self.inv_freq)freqs = torch.cat((freqs, freqs), dim = -1)power = (t - (seq_len // 2)) / self.scale_basescale = self.scale ** rearrange(power, 'n -> n 1')scale = torch.cat((scale, scale), dim = -1)self.register_buffer('cached_freqs', freqs, persistent = False)self.register_buffer('cached_scales', scale, persistent = False)return freqs, scaledef rotate_half(x):x1, x2 = x.chunk(2, dim=-1)return torch.cat((-x2, x1), dim=-1)def apply_rotary_pos_emb(t, pos, scale = 1.):scale = default(scale, 1.)seq_len = t.shape[-2]assert pos.shape[-2] >= seq_lenpos = pos[-seq_len:]if isinstance(scale, torch.Tensor):assert scale.shape[-2] >= seq_lenscale = scale[-seq_len:]return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)# memory managementclass MemoryManager(nn.Module):def __init__(self,dim,*,layers = 1,mem_lengths = 512,compress_factors = 1):super().__init__()mem_lengths = cast_tuple(mem_lengths)compress_factors = cast_tuple(compress_factors)assert all([mem_length > 0 for mem_length in mem_lengths])assert len(mem_lengths) == len(compress_factors)assert layers >= 1self.mem_lengths = mem_lengthsself.compress_factors = compress_factorsself.layers = nn.ModuleList([])for _ in range(layers):compress_fns = nn.ModuleList([])for compress_factor in compress_factors:compress_fn = nn.Identity()if compress_factor > 1:compress_fn = nn.Sequential(Rearrange('b n d -> b d n'),nn.Conv1d(dim * 2,dim * 2,compress_factor,stride = compress_factor,groups = 2),Rearrange('b d n -> b n d'),)compress_fns.append(compress_fn)self.layers.append(compress_fns)def forward(self,past_memories: List[torch.Tensor],new_memories: List[torch.Tensor]):next_memories = []for past_memory, new_memory, compress_fns in zip_longest(past_memories, new_memories, self.layers):# edge case if neither memories existif not (exists(past_memory) or exists(new_memory)):next_memories.append(None)continuenext_memory = Nonefor mem_length, compress_factor, compress_fn in zip(self.mem_lengths, self.compress_factors, compress_fns):# first get the memories for the given compression factor "current_memory"current_memory = Noneif exists(past_memory):past_memory, current_memory = past_memory[..., :-mem_length, :], past_memory[..., -mem_length:, :]# compress the new memories coming in, based on the compression factors set at initif (not is_empty(new_memory)) and compress_factor > 1:# make sure memory length is divisible by compression factornew_mem_length = new_memory.shape[-2]curtailed_length = (new_mem_length // compress_factor) * compress_factorcurtailed_slice = slice(-curtailed_length, None) if curtailed_length > 0 else slice(0, 0)new_memory = new_memory[..., curtailed_slice, :]# compress the memory pushed to the next stageif new_memory.shape[-2] > 0:new_memory = rearrange(new_memory, 'm b n d -> b n (m d)')new_memory = compress_fn(new_memory)new_memory = rearrange(new_memory, 'b n (m d) -> m b n d', m = 2)# fifo memory queue# add the new memory on the rightcurrent_memory = safe_cat(current_memory, new_memory, dim = -2)# "new" memory is new with respect to the next compressed segmentnew_memory, current_memory = current_memory[..., :-mem_length, :], current_memory[..., -mem_length:, :]# concat the new memory to the left into the pastnext_memory = safe_cat(current_memory, next_memory, dim = -2)next_memories.append(next_memory)return next_memories# maybe flash attention, if using pytorch 2.0# constantsConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])# state containerclass StateContainer(nn.Module):def __init__(self,dim,*,num_state_vectors,dim_head = 64,heads = 8,qk_rmsnorm = False,qk_rmsnorm_scale = 8,use_flash_attn = False):super().__init__()assert num_state_vectors > 0self.heads = headsinner_dim = dim_head * headsself.state_norm = LayerNorm(dim)self.q_to_state = nn.Linear(dim, inner_dim, bias = False)self.q_from_state = nn.Linear(dim, inner_dim, bias = False)self.state_to_q = nn.Linear(dim, inner_dim, bias = False)self.state_to_kv = nn.Linear(dim, dim_head * 2, bias = False)self.init_state = nn.Parameter(torch.randn(num_state_vectors, dim))self.state_pos_ids = nn.Parameter(torch.randn(num_state_vectors, dim))self.to_state_out = nn.Linear(inner_dim * 2, dim, bias = False)self.to_state_cross_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)self.state_self_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)self.from_state_cross_attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)# gating related parameters - using the fixed simple configself.state_out_to_gate = nn.Linear(dim, dim)self.learned_ema_beta = nn.Parameter(torch.randn(dim))# since each read should be followed by a write, just store cache in the containerself.cache = Noneself.next_read_state = Nonedef set_next_read_state(self,states):if not exists(states):states = self.init_stateself.next_read_state = (states,)def read(self, x):assert exists(self.next_read_state), 'states to be read must be set with .set_next_read_state'states, = self.next_read_stateself.next_read_state = None# pre norm state for attentionnormed_states = self.state_norm(states)# add the positional ids, as stated in the paper critical for it to worknormed_states = normed_states + self.state_pos_ids# get queries for cross attention, which they do not share, although they share key / values. another intriguing detailq_to_state = self.q_to_state(x)q_to_state = rearrange(q_to_state, '... n (h d) -> ... h n d', h = self.heads)# self attention qkv for statesstate_k, state_v = self.state_to_kv(normed_states).chunk(2, dim = -1)# cross attend to the past states key valuesto_state_out = self.to_state_cross_attn(q_to_state, state_k, state_v)to_state_out = rearrange(to_state_out, 'b h n d -> b n (h d)')# cache for next writeself.cache = (states, normed_states, state_k, state_v)return to_state_outdef write(self,*,memories):assert exists(self.cache)k, v = memoriesbatch = k.shape[0]# get cached values from the previous readstates, normed_states, state_k, state_v = self.cacheself.cache = None# derive queriesq_from_state = self.q_from_state(normed_states)q_from_state = rearrange(q_from_state, '... n (h d) -> ... h n d', h = self.heads)state_q = self.state_to_q(normed_states)state_q_einsum = 'n (h d)' if state_q.ndim == 2 else 'b n (h d)'state_q = repeat(state_q, f'{state_q_einsum} -> b h n d', h = self.heads, b = batch)# states must also undergo self attentionif q_from_state.ndim == 3:q_from_state = repeat(q_from_state, '... -> b ...', b = batch)state_out = self.state_self_attn(state_q, state_k, state_v)from_state_out = self.from_state_cross_attn(q_from_state, k, v)state_out = torch.cat((state_out, from_state_out), dim = -1)state_out = rearrange(state_out, 'b h n d -> b n (h d)')state_out = self.to_state_out(state_out)# use the best performing configuration# fixed simple gate - nothing more than a learned EMA with some resemblance to highway networksz = self.state_out_to_gate(state_out)learned_ema_decay = self.learned_ema_beta.sigmoid()# set new state with the learned EMA gatingreturn learned_ema_decay * z + (1 - learned_ema_decay) * statesdef forward(self, x):raise NotImplementedError# main classclass Attend(nn.Module):def __init__(self,causal = False,use_flash_attn = False):super().__init__()self.causal = causalself.register_buffer("mask", None, persistent=False)self.use_flash_attn = use_flash_attnassert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'# determine efficient attention configs for cuda and cpuself.cpu_config = Config(True, True, True)self.cuda_config = Noneif not torch.cuda.is_available() or not use_flash_attn:returndevice_properties = torch.cuda.get_device_properties(torch.device('cuda'))if device_properties.major == 8 and device_properties.minor == 0:print_once('A100 GPU detected, using flash attention if input tensor is on cuda')self.cuda_config = Config(True, False, False)else:print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')self.cuda_config = Config(False, True, True)def get_mask(self, n, device):if exists(self.mask) and self.mask.shape[-1] >= n:return self.mask[:n, :n]mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)self.register_buffer("mask", mask, persistent=False)return maskdef flash_attn(self, q, k, v, mask = None):_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda# Recommended for multi-query single-key-value attention by Tri Dao# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])if k.ndim == 3:k = repeat(k, 'b ... -> b h ...', h = q.shape[1])if v.ndim == 3:v = repeat(v, 'b ... -> b h ...', h = q.shape[1])# Check if mask exists and expand to compatible shape# The mask is B L, so it would have to be expanded to B H N Lmasks = []if self.causal:i, j = q_len, k_lencausal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1)masks.append(~causal_mask)if exists(mask):if mask.ndim != 2:mask = repeat(mask, 'w ... -> (b w) ...', b = q.shape[0] // mask.shape[0])masks.append(mask)attn_mask = and_reduce(masks)# Check if there is a compatible device for flash attentionconfig = self.cuda_config if is_cuda else self.cpu_config# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scalewith torch.backends.cuda.sdp_kernel(**config._asdict()):out = F.scaled_dot_product_attention(q, k, v,attn_mask = attn_mask)return outdef forward(self, q, k, v, mask = None, use_flash_attn = None):use_flash_attn = default(use_flash_attn, self.use_flash_attn)b, n, device = q.shape[0], q.shape[-2], q.deviceq, ps = pack_one(q, '* h n d')k, _ = pack_one(k, '* n d')v, _ = pack_one(v, '* n d')if use_flash_attn:out = self.flash_attn(q, k, v, mask = mask)return unpack_one(out, ps, '* h n d')scale = q.shape[-1] ** -0.5k_einsum = 'b j d' if k.ndim == 3 else 'b h j d'v_einsum = 'b j d' if v.ndim == 3 else 'b h j d'# similaritysim = einsum(f"b h i d, {k_einsum} -> b h i j", q, k) * scale# key padding maskif exists(mask):if mask.ndim != 2:mask = repeat(mask, 'w ... -> (b w) ...', b = b)sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)# causal maskif self.causal:i, j = sim.shape[-2:]causal_mask = torch.ones((i, j), dtype = torch.bool, device = q.device).triu(j - i + 1)sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)# attentionattn = sim.softmax(dim=-1)# aggregate valuesout = einsum(f"b h i j, {v_einsum} -> b h i d", attn, v)return unpack_one(out, ps, '* h n d')# geglu feedforwardclass GEGLU(nn.Module):def forward(self, x):x, gate = x.chunk(2, dim = -1)return F.gelu(gate) * xdef FeedForward(dim, mult = 4):inner_dim = int(dim * mult * 2 / 3)return nn.Sequential(LayerNorm(dim),nn.Linear(dim, inner_dim * 2, bias = False),GEGLU(),nn.Linear(inner_dim, dim, bias = False))# attentionclass Attention(nn.Module):def __init__(self,dim_head,causal = False,qk_rmsnorm = False,qk_rmsnorm_scale = 8,use_flash_attn = False):super().__init__()self.causal = causalself.qk_rmsnorm = qk_rmsnormself.qk_rmsnorm_scale = qk_rmsnorm_scaleself.attend = Attend(causal = causal, use_flash_attn = use_flash_attn)if qk_rmsnorm:self.q_scale = nn.Parameter(torch.ones(dim_head))self.k_scale = nn.Parameter(torch.ones(dim_head))def forward(self,q, k, v,mask = None,rotary_pos_emb = None,xpos_scale = None):scale = q.shape[-1] ** -0.5if self.qk_rmsnorm:q, k = map(l2norm, (q, k))scale = self.qk_rmsnorm_scaleif self.qk_rmsnorm:q = q * self.q_scalek = k * self.k_scale# rotary positional embedding with xpos for length extrapolationif exists(rotary_pos_emb):q = apply_rotary_pos_emb(q, rotary_pos_emb, xpos_scale)k = apply_rotary_pos_emb(k, rotary_pos_emb, xpos_scale ** -1)# attentionout = self.attend(q, k, v, mask = mask)return outclass AttentionBlock(nn.Module):def __init__(self,dim,block_width,dim_head = 64,heads = 8,qk_rmsnorm = False,qk_rmsnorm_scale = 8,use_flash_attn = False,num_state_vectors = 0,num_external_state_reads = 0,state_read_before_write = True  # this will be defaulted to on as in the paper, but will be turned off in the case the researcher wants to test out reading the state at a lower layer):super().__init__()inner_dim = dim_head * headsself.heads = headsself.norm = LayerNorm(dim)self.to_q = nn.Linear(dim, inner_dim, bias = False)self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)self.attn = Attention(dim_head, qk_rmsnorm = qk_rmsnorm, qk_rmsnorm_scale = qk_rmsnorm_scale, use_flash_attn = use_flash_attn)self.block_width = block_widthself.is_recurrent_layer = num_state_vectors > 0# decide how many states this attention layer is going to read fromnum_state_reads = int(self.is_recurrent_layer and state_read_before_write) + num_external_state_readsself.to_out = nn.Linear(inner_dim * (1 + num_state_reads), dim, bias = False)if not self.is_recurrent_layer:returnself.state_read_before_write = state_read_before_writeself.state_container = StateContainer(dim,dim_head = dim_head,heads = heads,num_state_vectors = num_state_vectors,qk_rmsnorm = qk_rmsnorm,qk_rmsnorm_scale = qk_rmsnorm_scale,use_flash_attn = use_flash_attn)@propertydef device(self):return next(self.parameters()).devicedef forward(self,x,rotary_pos_emb = None,xpos_scale = None,attn_mask = None,xl_memories: Optional[torch.Tensor] = None,read_from_state_containers: List[StateContainer] = []):batch, seq_len, _, width, device = *x.shape, self.block_width, self.device# pre normalizationx = self.norm(x)# queries, keys, values and split out headsq, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))split_head = partial(rearrange, pattern = 'b n (h d) -> b h n d', h = self.heads)q = split_head(q)# save the last key / values as memories for recurrencememories = torch.stack((k, v))mem_len = 0if exists(xl_memories):# if past memories are passed in, concat as the first bucketmem_len = xl_memories.shape[-2]past_k, past_v = xl_memoriesk = torch.cat((past_k, k), dim = 1)v = torch.cat((past_v, v), dim = 1)# handle cropping of attention mask and positional embeddingsif exists(attn_mask):attn_mask = attn_mask[:seq_len, :seq_len]attn_mask = F.pad(attn_mask, (mem_len, 0), value = True)# attention, but of courseout = self.attn(q, k, v,rotary_pos_emb = rotary_pos_emb,xpos_scale = xpos_scale,mask = attn_mask)# merge headsout = rearrange(out, 'b h n d -> b n (h d)')# early return if not a recurrent layerif not self.is_recurrent_layer and len(read_from_state_containers) == 0:return self.to_out(out), memories, None# whether to read from own state container, default to on, but may pass in moreif self.is_recurrent_layer and self.state_read_before_write:read_from_state_containers = [self.state_container, *read_from_state_containers]for read_state_container in read_from_state_containers:# read from the states ...to_state_out = read_state_container.read(x)# and concat it to the output of self-attentionout = torch.cat((out, to_state_out), dim = -1)new_states = Noneif self.is_recurrent_layer:# then write to the states as well if need benew_states = self.state_container.write(memories = memories)return self.to_out(out), memories, new_states# classes@beartype
class BlockRecurrentTransformer(nn.Module):def __init__(self,*,num_tokens,dim,depth,dim_head = 64,heads = 8,all_layers_qk_rmsnorm = False,ff_mult = 4,max_seq_len = 1024,block_width = 512,recurrent_layers: Optional[Tuple[int, ...]] = None,read_recurrent_layers: Optional[Tuple[int, ...]] = None,num_state_vectors = None,ignore_index = -100,use_flash_attn = False,use_compressed_mem = False,compressed_mem_factor = 4):super().__init__()num_state_vectors = default(num_state_vectors, block_width)# set recurrent layersrecurrent_layers = default(recurrent_layers, (depth // 2,)) # default to one recurent layer at middle of the networkassert all([0 < layer <= depth for layer in recurrent_layers]), f'recurrent layers must range from 1 to the depth {depth}'assert all_unique(recurrent_layers), 'recurrent layers must be all unique. no duplicate layers'self.recurrent_layers = recurrent_layers# set read recurrent layersread_recurrent_layers = default(read_recurrent_layers, recurrent_layers)assert all([read_layer <= write_layer for read_layer, write_layer in zip(read_recurrent_layers, recurrent_layers)]), 'the recurrent read layer must be always less than or equal to the write layer'assert all([0 < layer <= depth for layer in read_recurrent_layers])assert len(read_recurrent_layers) == len(recurrent_layers)self.read_recurrent_layers = read_recurrent_layers# token embeddingself.token_emb = nn.Embedding(num_tokens, dim)self.rotary_pos_emb = RotaryEmbedding(dim = dim_head, width = (2 if not use_compressed_mem else 3) * block_width)self.layers = nn.ModuleList([])self.write_to_read_map = {write_layer: read_layer for write_layer, read_layer in zip(recurrent_layers, read_recurrent_layers)}self.read_state_router = defaultdict(list)for layer in range(1, depth + 1):is_recurrent_layer = layer in self.recurrent_layerslayer_num_state_vectors = num_state_vectors if is_recurrent_layer else 0num_external_state_reads = sum([int(layer == read_layer) for read_layer in read_recurrent_layers])# only layers with xl memories# or has recurrence in horizontal direction# use qk rmsnorm (in paper, they use cosine sim attention, but i think qk rmsnorm is more proven given Vit 22B paper)# one can also override to use all qk rmsnorm by setting all_layers_qk_rmsnorm = Trueqk_rmsnorm = all_layers_qk_rmsnorm or is_recurrent_layerattn_block = AttentionBlock(dim,block_width = block_width,dim_head = dim_head,heads = heads,qk_rmsnorm = qk_rmsnorm,num_state_vectors = layer_num_state_vectors,use_flash_attn = use_flash_attn,num_external_state_reads = num_external_state_reads,state_read_before_write = False,)ff_block = FeedForward(dim, mult = ff_mult)if is_recurrent_layer:read_layer = self.write_to_read_map[layer]self.read_state_router[read_layer].append(attn_block.state_container)self.layers.append(nn.ModuleList([attn_block,ff_block]))# (compressed) memory managementself.mem_manager = MemoryManager(dim = dim_head,layers = depth,mem_lengths = block_width if not use_compressed_mem else (block_width, block_width // 2),compress_factors = 1 if not use_compressed_mem else (1, compressed_mem_factor))# to logitsself.to_logits = nn.Sequential(LayerNorm(dim),nn.Linear(dim, num_tokens, bias = False))self.max_seq_len = max_seq_lenself.block_width = block_widthassert divisible_by(max_seq_len, block_width)self.ignore_index = ignore_indexself.register_buffer('cached_causal_attn_mask', None, persistent = False)@propertydef device(self):return next(self.parameters()).devicedef get_causal_attn_mask(self, width):if exists(self.cached_causal_attn_mask):cached_mask = self.cached_causal_attn_maskcached_width = cached_mask.shape[-2]padding = (width - cached_width) // 2j_slice = Ellipsis if padding == 0 else slice(padding, -padding)return cached_mask[:cached_width, j_slice]device = self.devicecausal_mask = torch.ones((width, width), device = device, dtype = torch.bool).triu(1)return ~causal_mask@torch.no_grad()@eval_decoratordef generate(self,prime,length = None,xl_memories: List[torch.Tensor] = [],states: List[torch.Tensor] = [],temperature = 1.,filter_thres = 0.9,return_memories_and_states = False):length = default(length, self.max_seq_len + 1)start_len = prime.shape[-1]assert start_len < self.max_seq_lenassert length <= (self.max_seq_len + 1)assert start_len < lengthoutput = primememories = []for ind in range(length - start_len):logits, next_memories, next_states = self.forward(output,xl_memories = xl_memories,states = states)logits = logits[:, -1]filtered_logits = top_k(logits, thres = filter_thres)sampled = gumbel_sample(filtered_logits, temperature = temperature)sampled = rearrange(sampled, 'b -> b 1')output = torch.cat((output, sampled), dim = -1)if divisible_by(output.shape[-1] - 1, self.max_seq_len): # on the sampling of the last token in the current window, set new memories and statesmemories = next_memoriesstates = next_statesoutput = output[:, start_len:]if return_memories_and_states:return output, memories, statesreturn outputdef forward(self,x,return_loss = False,xl_memories: List[torch.Tensor] = [],states: List[torch.Tensor] = [],return_memories_and_states = None  # can force to either return memory + state or not. by default will only return when number of tokens == max_seq_len):device = x.deviceif return_loss:x, labels = x[:, :-1], x[:, 1:]# get sequence length i and j for dynamic pos biasassert x.shape[-1] <= self.max_seq_lenw = self.block_width# token embeddingx = self.token_emb(x)# dynamic pos biasattn_mask = self.get_causal_attn_mask(w)rotary_pos_emb, xpos_scale = self.rotary_pos_emb()# only return memories and state if at the full block width, but can be overriddenreturn_memories_and_states = default(return_memories_and_states, self.max_seq_len == x.shape[-2])# ready output tensor, to be concatted to block by blockbatch, _, dim = x.shapeout = torch.empty(batch, 0, dim, dtype = x.dtype, device = self.device)# split input into blocks of width winput_blocks = x.split(w, dim = -2)# process each block at a timefor input_block in input_blocks:input_block_length = input_block.shape[-2]# ready xl memories and statesiter_xl_memories = iter(xl_memories)iter_states = iter(states)next_xl_memories = []next_states = []# set the states on the appropriate state containersfor attn, _ in self.layers:if not attn.is_recurrent_layer:continueattn.state_container.set_next_read_state(next(iter_states, None))# go through layersfor ind, (attn, ff) in enumerate(self.layers):# determine if the layer requires transformer xl memorieslayer = ind + 1# whether to pass in xl memoriesattn_kwargs = dict(rotary_pos_emb = rotary_pos_emb,xpos_scale = xpos_scale,attn_mask = attn_mask,xl_memories = next(iter_xl_memories, None),read_from_state_containers = self.read_state_router[layer])# attention layerresidual = input_blockattn_branch_out, layer_xl_memories, layer_next_states = attn(input_block, **attn_kwargs)if exists(layer_xl_memories):next_xl_memories.append(layer_xl_memories)if exists(layer_next_states):next_states.append(layer_next_states)input_block = attn_branch_out + residual# feedforward layerinput_block = ff(input_block) + input_block# concat to outputout = torch.cat((out, input_block), dim = -2)# set new xl memories and statesstates = next_statesif input_block_length == w:xl_memories = self.mem_manager(xl_memories, next_xl_memories)# project to logitslogits = self.to_logits(out)# detach the states and memoriesreturned_next_states = list(map(torch.detach, states)) if return_memories_and_states else Nonereturned_next_xl_memories = list(map(torch.detach, xl_memories)) if return_memories_and_states else None# whether to return logitsif not return_loss:return logits, returned_next_xl_memories, returned_next_states# cross entropy losslogits = rearrange(logits, 'b n c -> b c n')loss = F.cross_entropy(logits, labels, ignore_index = self.ignore_index)return loss, returned_next_xl_memories, returned_next_states# recurrent trainer wrapper@beartype
class RecurrentTrainerWrapper(nn.Module):def __init__(self,transformer: BlockRecurrentTransformer,xl_memories_dropout = 0.,state_dropout = 0.):super().__init__()self.transformer = transformerself.seq_len = transformer.max_seq_lenself.xl_memories_dropout = xl_memories_dropoutself.state_dropout = state_dropout@eval_decorator@torch.no_grad()def generate(self,prime,length,**kwargs):seq_len = self.seq_lenstart_len = prime.shape[-1]assert start_len < lengthoutput = primecurrent_len = start_lenmemories = []states = []# determine lengthshas_remainder = not divisible_by(length, seq_len)remainder_amount = length % seq_lentotal_segments = math.ceil(length / seq_len)if not has_remainder:lengths = (*((seq_len + 1,) * (total_segments - 1)), seq_len)elif remainder_amount == 1:lengths = (seq_len + 1,) * (total_segments - 1)else:lengths = (*((seq_len + 1,) * (total_segments - 1)), remainder_amount)# loop through lengthsfor next_length in lengths:segment_output, memories, states = self.transformer.generate(output[:, -current_len:],length = next_length,xl_memories = memories,states = states,return_memories_and_states = True,**kwargs)output = torch.cat((output, segment_output), dim = -1)current_len = 1return output[:, start_len:]def forward(self,x,return_memories_and_states = False):total_seq_len, seq_len = x.shape[1], self.seq_lenassert divisible_by(total_seq_len - 1, seq_len), f'length of sequence ({total_seq_len}) must be equal to a multiple of {seq_len} + 1 (one extra token) during training'segments = total_seq_len // seq_lentotal_loss = 0.memories = []states = []for ind in range(segments):start = ind * seq_lenend = start + seq_len + 1if self.training and random() < self.xl_memories_dropout:memories.clear()if self.training and random() < self.state_dropout:states.clear()loss, memories, states = self.transformer(x[:, start:end],xl_memories = memories,states = states,return_loss = True)total_loss = total_loss + (loss / segments)if return_memories_and_states:return total_loss, memories, statesreturn total_lossif __name__ == '__main__':model = BlockRecurrentTransformer(num_tokens=20000,  # vocab sizedim=512,  # model dimensionsdepth=6,  # depthdim_head=64,  # attention head dimensionsheads=8,  # number of attention headsmax_seq_len=1024,  # the total receptive field of the transformer, in the paper this was 2 * block sizeblock_width=512,# block size - total receptive field is max_seq_len, 2 * block size in paper. the block furthest forwards becomes the new cached xl memories, which is a block size of 1 (please open an issue if i am wrong)num_state_vectors=512,  # number of state vectors, i believe this was a single block size in the paper, but can be any amountrecurrent_layers=(4,),  # where to place the recurrent layer(s) for states with fixed simple gatinguse_compressed_mem=False,  # whether to use compressed memories of a single block width, from https://arxiv.org/abs/1911.05507compressed_mem_factor=4,  # compression factor of compressed memoriesuse_flash_attn=False  # use flash attention, if on pytorch 2.0)seq = torch.randint(0, 2000, (1, 512))out, mems1, states1 = model(seq)out, mems2, states2 = model(seq, xl_memories=mems1, states=states1)out, mems3, states3 = model(seq, xl_memories=mems2, states=states2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/211564.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

gd32和stm32的区别

gd32和stm32的区别 现在的市场上有很多种不同类型的微控制器&#xff0c;其中比较常见的有两种&#xff0c;即gd32和stm32。两种微控制器都是中国和欧洲的两个公司分别推出的&#xff0c;但是它们之间有很多区别&#xff0c;本文将会深入探讨这些区别。 1.起源和历史 gd32是…

2024年网络安全竞赛-Web安全应用

Web安全应用 (一)拓扑图 任务环境说明: 1.获取PHP的版本号作为Flag值提交;(例如:5.2.14) 2.获取MySQL数据库的版本号作为Flag值提交;(例如:5.0.22) 3.获取系统的内核版本号作为Flag值提交;(例如:2.6.18) 4.获取网站后台管理员admin用户的密码作为Flag值提交…

udp多播组播

import socket ,struct,time# 组播地址和端口号 MCAST_GRP 239.0.0.1 MCAST_PORT 8888 # 创建UDP socket对象 sock socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP) # 绑定socket对象到本地端口号 # sock.bind((MCAST_GRP, MCAST_PORT)) …

【4】PyQt输入框

1. 单行文本输入框 QLineEdit控件可以输入单行文本 from PyQt5.QtWidgets import QApplication, QWidget, QLineEdit, QVBoxLayout from PyQt5.QtCore import * from PyQt5.QtGui import QIcon import sysdef init_widget(w: QWidget):# 修改窗口标题w.setWindowTitle(单行输…

前端面试——CSS面经(持续更新)

1. CSS选择器及其优先级 !important > 行内样式 > id选择器 > 类/伪类/属性选择器 > 标签/伪元素选择器 > 子/后台选择器 > *通配符 2. 重排和重绘是什么&#xff1f;浏览器的渲染机制是什么&#xff1f; 重排(回流)&#xff1a;当增加或删除dom节点&…

【面试经典150 | 二叉树】从中序与后序遍历序列构造二叉树

文章目录 写在前面Tag题目来源题目解读解题思路方法一&#xff1a;递归 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法&#xff0c;两到三天更新一篇文章&#xff0c;欢迎催更…… 专栏内容以分析题目为主&#xff0c;并附带一些对于本题涉及到的数据结构等内容…

Android : Room 数据库的基本用法 —简单应用

1.Room介绍&#xff1a; Android Room 是 Android 官方提供的一个持久性库&#xff0c;用于在 Android 应用程序中管理数据库。它提供了一个简单的 API 层&#xff0c;使得使用 SQLite 数据库变得更加容易和方便。 以下是 Android Room 的主要特点&#xff1a; 对象关系映射…

9.MySQL 索引

目录 ​​​​​​​概述 概念&#xff1a; 单列索引 普通索引 创建索引 查看索引 删除索引 唯一索引 创建唯一索引 删除唯一索引 主键索引 组合索引 创建索引 全文索引 概述 使用全文索引 空间索引 内部原理 相关算法&#xff1a; hash算法 二叉树算法 …

Spring基于XML文件配置AOP

AOP AOP&#xff0c;面向切面编程&#xff0c;是对面向对象编程OOP的升华。OOP是纵向对一个事物的抽象&#xff0c;一个对象包括静态的属性信息&#xff0c;包括动态的方法信息等。而AOP是横向的对不同事物的抽象&#xff0c;属性与属性、方法与方法、对象与对象都可以组成一个…

12.10多种编码方式,编码方案选择策略(递归级联),PDE,RLE代码

作者如何选择和设计编码方案&#xff0c;以实现高效的解压缩和高压缩比&#xff1f;BtrBlocks是否适用于所有类型的数据&#xff1f; 选择和设计编码方案&#xff1a; 结合多种高效编码方案&#xff1a;BtrBlocks 通过选择一组针对不同数据分布的高效编码方案&#xff0c;实现…

js判断是否对象自身为空

文章目录 一、前言二、JSON.stringify三、for in 配合 hasOwnProperty四、Object.keys五、Object.getOwnPropertyNames六、Object.getOwnPropertyNames 结合 Object.getOwnPropertySymbols七、Reflect.ownKeys八、最后 一、前言 如何判断一个对象为空&#xff1f; 先上结论&a…

MySql复习笔记03(小滴课堂) 事务,视图,触发器,存储过程

mysql 必备核心知识之事务的详细解析&#xff1a; 创建一个数据库表&#xff1a; 添加数据并开启事务。 添加数据并查询。 登录另一台服务器发现查不到这个表中的数据。 这是因为事务开启了&#xff0c;但是没有提交&#xff0c;只是把数据存到了内存中&#xff0c;还没有写入…

以为回调函数是同步的(js的问题)

回调函数可以用来处理 JavaScript 的异步操作&#xff0c;但是选用 Promise、async/await 更好&#xff0c;因为多重回调函数会导致回调地狱。 回调函数不是**同步的**&#xff0c;它是延时操作执行完毕后会被调用的一个函数。 比如全局方法 "setTimeout" &#xf…

CString 的 Replace 函数

Replace 使用测试 CString mSectNameNew L"槽a*b*c*d";CString mSectNameNew2 L"Ca*b*c*d";CString mSectNameNew3 L"[a*b*c*d";mSectNameNew.Replace(_T("M"), _T("C")); // 不会替换mSectNameNew.Re…

JOSEF 冲击继电器 ZC-23A DC48V 柜内安装,板前带座

系列型号 ZC-23冲击继电器&#xff1b;ZC-23A冲击继电器&#xff1b; ZC-23B冲击继电器 一、用途 冲击继电器ZC-23A DC48V 柜内安装板前带座 (以下简称继电器)&#xff0c;广泛用于直流操作的继电器保护及自动控制回路中&#xff0c;作为集中控制信号元件。 二、主要技术参…

C#动态调用C++DLL中的函数

DLL中导出的函数 typedef void (*HQ_MSG_CALLBACK)(void *h, int nMsg, int nMsgType, int nReqNo, const char *szData, int nSize); void SetMsgFunc(void *h, HQ_MSG_CALLBACK pmsgCallBack);C#动态调用上述函数 public delegate void CALLBACK(IntPtr h, int nMsg, int n…

信息处理技术员

目录 信息处理技术员工作内容 信息处理技术员岗位面试试题举例 信息处理技术员考试 信息处理技术员工作内容 信息处理技术员是负责处理和管理信息系统的专业人员。他们的主要工作内容包括以下几个方面&#xff1a; 1.系统维护和管理&#xff1a;信息处理技术员负责维护和管…

大数据股票简单分析

目录标题 内容说明解题量化金融的含义量化交易策略 点击直接资料领取 内容 1解释量化金融的含义&#xff0c;调研并给出至少 5种量化交易的策略或方法 2.完成Tushare Pro 的安装、注册&#xff0c;获取自己的 Token&#xff0c;查阅网站内的接口讲解和示例; 3通过Python 编程完…

力扣刷题总结 字符串(2)【KMP】

&#x1f525;博客主页&#xff1a; A_SHOWY&#x1f3a5;系列专栏&#xff1a;力扣刷题总结录 数据结构 云计算 数字图像处理 28.找出字符串中第一个匹配项的下标mid经典KMP4593重复的子字符串mid可以使用滑动窗口或者KMP KMP章节难度较大&#xff0c;需要深入理解其中…

Flink 本地单机/Standalone集群/YARN模式集群搭建

准备工作 本文简述Flink在Linux中安装步骤&#xff0c;和示例程序的运行。需要安装JDK1.8及以上版本。 下载地址&#xff1a;下载Flink的二进制包 点进去后&#xff0c;选择如下链接&#xff1a; 解压flink-1.10.1-bin-scala_2.12.tgz&#xff0c;我这里解压到soft目录 [ro…