Transformer实现的一个Demo

RT,直接上代码,可以跑通:

#encoding:utf-8

import torch

import torch.nn as nn

import numpy as np

import math


 

class Config(object):

    def __init__(self):

        self.vocab_size = 6

        self.d_model = 512

        self.n_heads = 4

        assert self.d_model % self.n_heads == 0

        self.dim_k = self.d_model // self.n_heads

        self.dim_v = self.d_model // self.n_heads

        self.padding_size = 30

        self.UNK = 5

        self.PAD = 4

        self.N = 6

        self.p = 0.1

config = Config()


 

class Embedding(nn.Module):

    def __init__(self, vocab_size):

        super(Embedding, self).__init__()

        self.embedding = nn.Embedding(vocab_size, config.d_model, padding_idx=config.PAD)


 

    def forward(self, x):

        # print("type(x):", type(x))

        # for i in range(len(x)):

        #     if len(x[i]) < config.padding_size:

        #         x[i].extend([config.UNK] * (config.padding_size - len(x[i])))

        #     else:

        #         tmp = x[i][:config.padding_size]

        #         print("tmp.shape:", tmp.shape)

        #         print("type(x[i]):", type(x[i]))

        #         print("x[i].shape:", x[i].shape)

        #         x[i,:] = x[i][:config.padding_size]

        x = self.embedding(torch.tensor(x))

        return x

   

class Positional_Encoding(nn.Module):

    def __init__(self, d_model):

        super(Positional_Encoding, self).__init__()

        self.d_model = d_model

    def forward(self, seq_len, embedding_dim):

        positional_encoding = np.zeros((seq_len, embedding_dim))

        for pos in range(positional_encoding.shape[0]):

            for i in range(positional_encoding.shape[1]):

                positional_encoding[pos][i] = math.sin(pos / (10000**(2*i/self.d_model))) if i%2==0 else math.cos(pos/(10000**(2*i/self.d_model)))

        return torch.from_numpy(positional_encoding)

   

class Multihead_Attention(nn.Module):

    def __init__(self, d_model, dim_k, dim_v, n_heads):

        super(Multihead_Attention, self).__init__()

        self.dim_v = dim_v

        self.dim_k = dim_k

        self.n_heads = n_heads

        self.q = nn.Linear(d_model, dim_k)

        self.k = nn.Linear(d_model, dim_k)

        self.v = nn.Linear(d_model, dim_v)

        self.o = nn.Linear(dim_v, d_model)

        self.norm_fact = 1 / math.sqrt(d_model)

    def generate_mask(self, dim):

        matrix = np.ones((dim, dim))

        mask = torch.BoolTensor(np.tril(matrix).astype(np.bool_))

        return mask


 

    def forward(self, x, y, requires_mask=False):

        Q = self.q(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k//self.n_heads)

        K = self.k(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k//self.n_heads)

        V = self.v(y).reshape(-1, x.shape[0], x.shape[1], self.dim_v//self.n_heads)

        attention_score = torch.matmul(Q, K.permute(0,1,3,2)) * self.norm_fact

        if requires_mask:

            mask = self.generate_mask(x.shape[1])

            attention_score.masked_fill(mask, value=float("-inf"))

        output = torch.matmul(attention_score, V).reshape(y.shape[0], y.shape[1], -1)

        output = self.o(output)

        return output

   

class Feed_Forward(nn.Module):

    def __init__(self, input_dim, hidden_dim=2048):

        super(Feed_Forward, self).__init__()

        self.L1 = nn.Linear(input_dim, hidden_dim)

        self.L2 = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):

        output = nn.ReLU()(self.L1(x))

        output = self.L2(output)

        return output

   

class Add_Norm(nn.Module):

    def __init__(self):

        super(Add_Norm, self).__init__()

        self.dropout = nn.Dropout(config.p)

    def forward(self, x, sub_layer, **kwargs):

        sub_output = sub_layer(x, **kwargs)

        x = self.dropout(x+sub_output)

        layer_norm = nn.LayerNorm(x.size()[1:])

        out = layer_norm(x)

        return out

   

class Encoder(nn.Module):

    def __init__(self):

        super(Encoder, self).__init__()

        self.positional_encoding = Positional_Encoding(config.d_model)

        self.muti_atten = Multihead_Attention(config.d_model, config.dim_k, config.dim_v, config.n_heads)

        self.feed_forward = Feed_Forward(config.d_model)

        self.add_norm = Add_Norm()


 

    def forward(self, x):

        x += self.positional_encoding(x.shape[1], config.d_model)

        output = self.add_norm(x, self.muti_atten, y=x)

        output = self.add_norm(output, self.feed_forward)

        return output

   

class Decoder(nn.Module):

    def __init__(self):

        super(Decoder, self).__init__()

        self.positional_encoding = Positional_Encoding(config.d_model)

        self.muti_atten = Multihead_Attention(config.d_model, config.dim_k, config.dim_v, config.n_heads)

        self.feed_forward = Feed_Forward(config.d_model)

        self.add_norm = Add_Norm()

    def forward(self, x, encoder_output):

        x += self.positional_encoding(x.shape[1], config.d_model)

        output = self.add_norm(x, self.muti_atten, y=x, requires_mask=True)

        output = self.add_norm(output, self.muti_atten, y=encoder_output, requires_mask=True)

        output = self.add_norm(output, self.feed_forward)

        return output

   

class Transformer_layer(nn.Module):

    def __init__(self):

        super(Transformer_layer, self).__init__()

        self.encoder = Encoder()

        self.decoder = Decoder()

    def forward(self, x):

        x_input, x_output = x

        encoder_output = self.encoder(x_input)

        decoder_output = self.decoder(x_output, encoder_output)

        return (encoder_output, decoder_output)

   

class Transformer(nn.Module):

    def __init__(self, N, vocab_size, output_dim):

        super(Transformer, self).__init__()

        self.embedding_input = Embedding(vocab_size=vocab_size)

        self.embedding_output = Embedding(vocab_size=vocab_size)

        self.output_dim = output_dim

        self.linear = nn.Linear(config.d_model, output_dim)

        self.softmax = nn.Softmax(dim=-1)

        self.model = nn.Sequential(*[Transformer_layer() for _ in range(N)])

    def forward(self, x):

        x_input, x_output = x, x

        x_input = self.embedding_input(x_input)

        x_output = self.embedding_output(x_output)

        _, output = self.model((x_input, x_output))

        output = self.linear(output)

        output = self.softmax(output)

        return output

   

def main():

    transformer = Transformer(4, 1024, 512)

    input = (np.random.rand(8, 512) * 512).astype(np.int64)

    output = transformer(input)

    print("output.shape:", output.shape)


 

if __name__ == '__main__':

    main()

---------------------

输出:

output.shape: torch.Size([8, 512, 512])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/582085.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

UI自动化Selenium 元素定位之Xpath

一、元素定位方式 selenium中定位元素&#xff0c;通常有几种方式&#xff1a; 1、通过id定位&#xff1a;By.ID 2、通过Name定位&#xff1a;By.Name 3、通过元素其他属性定位&#xff0c;如class、type、text文本。。。。。。等等&#xff0c;如果要用属性定位那就需要使…

图论 经典例题

1 拓扑排序 对有向图的节点排序&#xff0c;使得对于每一条有向边 U-->V U都出现在V之前 *有环无法拓扑排序 indegree[], nxs[];//前者表示节点 i 的入度&#xff0c;后者表示节点 i 指向的节点 queue [] for i in range(n):if indege[i] 0: queue.add(i)// 入度为0的节…

虚析构和纯虚析构

多态使用时&#xff0c;如果子类中有属性开辟到堆区&#xff0c;那么父类的指针在释放时无法调用到子类的析构代码 解决方式&#xff1a;将父类中的析构代码函数改为虚析构或者纯虚析构 虚析构和纯虚析构共性&#xff1a; 可以解决父类指针释放子类对象 都需要有具体的函数…

[SWPUCTF 2021 新生赛]finalrce

[SWPUCTF 2021 新生赛]finalrce wp 注&#xff1a;本文参考了 NSSCTF Leaderchen 师傅的题解&#xff0c;并修补了其中些许不足。 此外&#xff0c;参考了 命令执行(RCE)面对各种过滤&#xff0c;骚姿势绕过总结 题目代码&#xff1a; <?php highlight_file(__FILE__); …

【算法练习】leetcode链表算法题合集

链表总结 增加表头元素倒数节点&#xff0c;使用快慢指针环形链表&#xff08;快慢指针&#xff09;合并有序链表&#xff0c;归并排序LRU缓存 算法题 删除链表元素 删除链表中的节点 LeetCode237. 删除链表中的节点 复制后一个节点的值&#xff0c;删除后面的节点&#x…

verilog 通过DPI-C调用C 流水灯模拟

verilog 通过DPI-C调用C简单示例&#xff0c; verillator模拟 ledloop.v module ledloop(input wire clk,output wire[3:0] LED );reg[31:0] cnt 32h00000000;always (posedge clk)cnt < cnt 1;assign LED 4b0001 << cnt[21:20]; endmodule电脑模拟较慢&#xff…

如何解决服务器CA证书过期的问题

一、问题的提出 最近在学习VPS&#xff0c;在Linux系统里给服务器安装某项服务时&#xff0c;在服务的log里看到下面的错误信息&#xff1a; failed to verify certificate: x509: certificate has expired or is not yet valid: current time 2023-12-25T04:42:38-05:00 is a…

java球队信息管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目

一、源码特点 java Web球队信息管理系统是一套完善的java web信息管理系统&#xff0c;对理解JSP java编程开发语言有帮助&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发&#xff0c;数据库为Mysql5…

深度学习之RNN

1.循环神经网络 在时间t的时候&#xff0c;对于单个神经元来讲它的输出y(t)如下 wx是对于输入x的权重&#xff0c;wy是对于上一时刻输出的权重 所以循环神经网络有两个权重。 如果有很多这样的神经元并排在一起 则在t时刻的输出y为 这时输入输出都是向量 2.记忆单元 由于循…

java系列-CountDownLatch

CountDownLatch 不是一种锁&#xff0c;而是一种同步工具类&#xff0c;用于协调多个线程之间的操作。它并不是像 ReentrantLock 或 synchronized 关键字那样实现了锁定机制&#xff0c;而是通过一个计数器来实现线程的等待和通知。 具体来说&#xff0c;CountDownLatch 维护了…

车队试验的远程实时显示方案

风丘科技推出的数据远程实时显示方案更好地满足了客户对于试验车队远程实时监控的需求&#xff0c;并真正实现了试验车队的远程管理。随着新的数据记录仪软件IPEmotion RT和相应的跨平台显示解决方案的引入&#xff0c;让我们的客户端不仅可在线访问记录器系统状态&#xff0c;…

LaTeX 不同章的图片放在不同的文件夹

需求&#xff1a;在写长文档的时候&#xff0c;比如学位论文&#xff0c;每一章都有好几张图片&#xff0c;整个文档一共几十张甚至上百张图片&#xff0c;如果不分开放&#xff0c;想修改某一张图片的时候&#xff0c;找起来比较困难。所以&#xff0c;想把每一章的图片单独放…

git unable to create temporary file: No space left on device(git报错)

1.问题 1.1 vscode中npm run serve跑项目的时候&#xff0c;进度达到95%的时候一直卡着无进度&#xff1b; 1.2 git命令提交代码报错&#xff1b; 2.具体解决 这个错误通常表示你的磁盘空间已经满了&#xff0c;导致 Git 无法在临时目录中创建文件。2.1 清理磁盘空间&#xf…

LeetCode75| 区间集合

目录 435 无重叠区间 452 用最少的箭引爆气球 435 无重叠区间 class Solution { public:static bool cmp(vector<int>&a,vector<int>&b){return a[0] < b[0];}int eraseOverlapIntervals(vector<vector<int>>& intervals) {int res …

低代码平台在金融银行中的应用场景

随着数字化转型的推进&#xff0c;商业银行越来越重视技术在业务发展中的作用。在这个背景下&#xff0c;白码低代码平台作为一种新型的开发方式&#xff0c;正逐渐受到广大商业银行的关注和应用。白码低代码平台能够快速构建各类应用程序&#xff0c;提高开发效率&#xff0c;…

WebGoat 指定端口号

文章目录 新版本的 WebGoat旧版本 WebGoat 新版本的 WebGoat 使用 WEBGOAT_PORT 指定 WebGoat 的端口号 使用 WEBWOLF_PORT 指定 WebWolf 的端口号 java -DWEBGOAT_PORT8081 -jar webgoat-2023.8.jar java -DWEBGOAT_PORT8081 -DWEBWOLF_PORT9091 -jar webgoat-2023.8.jar W…

跨境电商引流真的很难吗?了解一下这些技巧!

随着全球电商市场的不断扩大&#xff0c;越来越多的企业开始涉足跨境电商领域&#xff0c;然而&#xff0c;与国内电商相比&#xff0c;跨境电商面临着诸多挑战&#xff0c;其中最大的难题之一就是如何有效地吸引潜在客户。 很多卖家觉得跨境电商引流非常困难&#xff0c;但实…

解析数据时代----驱动变革与重塑商业的力量

随着科技的飞速发展&#xff0c;我们正处在一个信息爆炸的时代。数据&#xff0c;作为这个时代的核心要素&#xff0c;已经渗透到各个领域&#xff0c;深刻影响着我们的生活、工作和商业模式。本文将深入解析数据时代的特点、影响以及如何应对数据带来的挑战&#xff0c;以适应…

springBoot整合redis做缓存

一、Redis介绍 Redis是当前比较热门的NOSQL系统之一&#xff0c;它是一个开源的使用ANSI c语言编写的key-value存储系统&#xff08;区别于MySQL的二维表格的形式存储。&#xff09;。和Memcache类似&#xff0c;但很大程度补偿了Memcache的不足。和Memcache一样&#xff0c;R…

BAT log-yyyy-mm-dd.log

日志文件 文件名 日期格式化 https://download.csdn.net/download/spencer_tseng/88673832 https://download.csdn.net/download/spencer_tseng/88673716