copy is all you need前向绘图 和疑惑标记

疑惑的起因
在这里插入图片描述
简化前向图
在这里插入图片描述
GPT4解释

这段代码实现了一个神经网络模型,包含了BERT、GPT-2和MLP等模块。主要功能是给定一个文本序列和一个查询序列,预测查询序列中的起始和结束位置,使其对应文本序列中的一个短语。具体实现细节如下:

首先,使用BERT将文本序列编码成一个张量作为输入,并从GPT-2模型中提取查询序列的表示。然后,将查询序列的表示拆分为起始和结束两部分,并使用MLP将短语的起始和结束点表示投影到另一个空间中。
接下来,利用生成式模型计算每个查询起始和结束位置之间的所有令牌的概率分布,并根据此生成预测位置。使用交叉熵损失函数来计算这些预测和真实值之间的差距。
最后,计算起始和结束位置的准确性和整个短语的准确性,并返回损失和准确性的值来优化模型的性能。

import torch
from transformers import AutoModel, AutoTokenizer, GPT2LMHeadModel
from torch import nn
from torch.nn import functional as Fclass Copyisallyouneed(nn.Module):def __init__(self, **args):super(Copyisallyouneed, self).__init__()self.args = args# bert-encoder modelself.phrase_encoder = AutoModel.from_pretrained(self.args['phrase_encoder_model'][self.args['lang']])self.bert_tokenizer = AutoTokenizer.from_pretrained(self.args['phrase_encoder_tokenizer'][self.args['lang']])self.bert_tokenizer.add_tokens(['<|endoftext|>', '[PREFIX]'])self.prefix_token_id = self.bert_tokenizer.convert_tokens_to_ids('[PREFIX]')self.phrase_encoder.resize_token_embeddings(self.phrase_encoder.config.vocab_size + 2)# model and tokenizerself.tokenizer = AutoTokenizer.from_pretrained(self.args['prefix_encoder_tokenizer'][self.args['lang']])self.vocab_size = len(self.tokenizer)self.pad = self.tokenizer.pad_token_id if self.args['lang'] == 'zh' else self.tokenizer.bos_token_idself.model = GPT2LMHeadModel.from_pretrained(self.args['prefix_encoder_model'][self.args['lang']])self.token_embeddings = nn.Parameter(list(self.model.lm_head.parameters())[0])# MLP: mapping bert phrase start representationsself.s_proj = nn.Sequential(nn.Dropout(p=args['dropout']),nn.Tanh(),nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size // 2))# MLP: mapping bert phrase end representationsself.e_proj = nn.Sequential(nn.Dropout(p=args['dropout']),nn.Tanh(),nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size // 2))self.gen_loss_fct = nn.CrossEntropyLoss(ignore_index=self.pad)@torch.no_grad()def get_query_rep(self, ids):self.eval()output = self.model(input_ids=ids, output_hidden_states=True)['hidden_states'][-1][:, -1, :]return outputdef get_token_loss(self, ids, hs, ids_mask):# no pad tokenlabel = ids[:, 1:]logits = torch.matmul(hs[:, :-1, :],self.token_embeddings.t())# TODO: inner loss function remove the temperature factorlogits /= self.args['temp']loss = self.gen_loss_fct(logits.view(-1, logits.size(-1)), label.reshape(-1))chosen_tokens = torch.max(logits, dim=-1)[1]gen_acc = (chosen_tokens.reshape(-1) == label.reshape(-1)).to(torch.long)valid_mask = (label != self.pad).reshape(-1)valid_tokens = gen_acc & valid_maskgen_acc = valid_tokens.sum().item() / valid_mask.sum().item()return loss, gen_accdef forward(self, batch):## gpt2 query encoderids, ids_mask = batch['gpt2_ids'], batch['gpt2_mask']last_hidden_states = \self.model(input_ids=ids, attention_mask=ids_mask, output_hidden_states=True).hidden_states[-1]# get token lossloss_0, acc_0 = self.get_token_loss(ids, last_hidden_states, ids_mask)## encode the document with the BERT encoder modeldids, dids_mask = batch['bert_ids'], batch['bert_mask']output = self.phrase_encoder(dids, dids_mask, output_hidden_states=True)['hidden_states'][-1]  # [B, S, E]# collect the phrase start representations and phrase end representationss_rep = self.s_proj(output)e_rep = self.e_proj(output)s_rep = s_rep.reshape(-1, s_rep.size(-1))e_rep = e_rep.reshape(-1, e_rep.size(-1))  # [B_doc*S_doc, 768//2]# collect the query representationsquery = last_hidden_states[:, :-1].reshape(-1, last_hidden_states.size(-1))query_start = query[:, :self.model.config.hidden_size // 2]query_end = query[:, self.model.config.hidden_size // 2:]# training the representations of the start tokenscandidate_reps = torch.cat([self.token_embeddings[:, :self.model.config.hidden_size // 2],s_rep], dim=0)logits = torch.matmul(query_start, candidate_reps.t())logits /= self.args['temp']# build the padding mask for query sidequery_padding_mask = ids_mask[:, :-1].reshape(-1).to(torch.bool)# build the padding mask: 1 for valid and 0 for maskattention_mask = (dids_mask.reshape(1, -1).to(torch.bool)).to(torch.long)padding_mask = torch.ones_like(logits).to(torch.long)# Santiy check overpadding_mask[:, self.vocab_size:] = attention_mask# build the position mask: 1 for valid and 0 for maskpos_mask = batch['pos_mask']start_labels, end_labels = batch['start_labels'][:, 1:].reshape(-1), batch['end_labels'][:, 1:].reshape(-1)position_mask = torch.ones_like(logits).to(torch.long)query_pos = start_labels > self.vocab_size# ignore the padding maskposition_mask[query_pos, self.vocab_size:] = pos_maskassert padding_mask.shape == position_mask.shape# overall maskoverall_mask = padding_mask * position_mask## remove the position mask# overall_mask = padding_masknew_logits = torch.where(overall_mask.to(torch.bool), logits, torch.tensor(-1e4).to(torch.half).cuda())mask = torch.zeros_like(new_logits)mask[range(len(new_logits)), start_labels] = 1.loss_ = F.log_softmax(new_logits[query_padding_mask], dim=-1) * mask[query_padding_mask]loss_1 = (-loss_.sum(dim=-1)).mean()## split the token accuaracy and phrase accuracyphrase_indexes = start_labels > self.vocab_sizephrase_indexes_ = phrase_indexes & query_padding_maskphrase_start_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == start_labels[phrase_indexes_]phrase_start_acc = phrase_start_acc.to(torch.float).mean().item()phrase_indexes_ = ~phrase_indexes & query_padding_masktoken_start_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == start_labels[phrase_indexes_]token_start_acc = token_start_acc.to(torch.float).mean().item()# training the representations of the end tokenscandidate_reps = torch.cat([self.token_embeddings[:, self.model.config.hidden_size // 2:],e_rep], dim=0)logits = torch.matmul(query_end, candidate_reps.t())  # [Q, B*]  logits /= self.args['temp']new_logits = torch.where(overall_mask.to(torch.bool), logits, torch.tensor(-1e4).to(torch.half).cuda())mask = torch.zeros_like(new_logits)mask[range(len(new_logits)), end_labels] = 1.loss_ = F.log_softmax(new_logits[query_padding_mask], dim=-1) * mask[query_padding_mask]loss_2 = (-loss_.sum(dim=-1)).mean()# split the phrase and token accuracyphrase_indexes = end_labels > self.vocab_sizephrase_indexes_ = phrase_indexes & query_padding_maskphrase_end_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == end_labels[phrase_indexes_]phrase_end_acc = phrase_end_acc.to(torch.float).mean().item()phrase_indexes_ = ~phrase_indexes & query_padding_masktoken_end_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == end_labels[phrase_indexes_]token_end_acc = token_end_acc.to(torch.float).mean().item()return (loss_0,  # token lossloss_1,  # token-head lossloss_2,  # token-tail lossacc_0,  # token accuracyphrase_start_acc,phrase_end_acc,token_start_acc,token_end_acc)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/54087.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Win系统设置开机自启项及自定义自启程序

Win系统设置开机自启项及自定义自启程序 分用户自启动和系统自启动两种形式&#xff1a; 1. 用户自启动目录&#xff1a;C:\Users\Administrator\AppData\Roaming\Microsoft\Windows\Start Menu\Programs\Startup 用快速键打开&#xff1a; Win键R键&#xff0c;输入shell:…

sql server 快速安装

目录标题 一、下载二、直接选择基本安装二、下载ssms&#xff08;数据库图形化操作页面&#xff09;三、开启sa账号认证&#xff08;一&#xff09;第一步&#xff1a;更改身份验证模式&#xff08;二&#xff09;第二步&#xff1a;启用 sa 登录四、开启tcp/ip 一、下载 下载…

低通滤波器和高通滤波器

应用于图像低通滤波器和高通滤波器的实现 需要用到傅里叶变换 #include <opencv2/opencv.hpp> #include <Eigen> #include <iostream> #include <vector> #include <cmath> #include <complex>#define M_PI 3.14159265358979323846…

QT5.12.12通过ODBC连接到GBase 8s数据库(CentOS)

本示例使用的环境如下&#xff1a; 硬件平台&#xff1a;x86_64&#xff08;amd64&#xff09;操作系统&#xff1a;CentOS 7.8 2003数据库版本&#xff08;含CSDK&#xff09;&#xff1a;GBase 8s V8.8 3.0.0_1 为什么使用QT 5.12.10&#xff1f;该版本包含QODBC。 1&#…

ES6中promise的使用

ES6中promise的使用 本文目录 ES6中promise的使用基础介绍箭头函数function函数状态 原型方法Promise.prototype.then()Promise.prototype.catch() 静态方法Promise.all()Promise.race()Promise.any() 链式回调 基础介绍 官网&#xff1a;https://promisesaplus.com/ window.…

最新docker多系统安装技术

在Ubuntu操作系统中安装Docker 在Ubuntu操作系统中安装Docker的步骤如下。 1&#xff0e;卸载旧版本Docker 卸载旧版本Docker的命令如下&#xff1a; $ sudo apt-get remove docker docker-engine docker.io 2&#xff0e;使用脚本自动安装 在测试或开发环境中&#xff0…

STM32 进不了main 函数

1. 我用的是STM32L151C8T6 的芯片&#xff0c;在github 上找了个别人的例程&#xff0c;拿来当模板改&#xff0c;由于他用的是HSE 外部晶振&#xff0c;我用的是内部晶振HSI&#xff0c;所以需要改系统时钟&#xff0c;改完后debug&#xff0c; 一直进不了main 函数&#xff0…

PHP“牵手”拼多多商品详情数据获取方法,拼多多API接口批量获取拼多多商品详情数据说明

拼多多商品详情接口 API 是开放平台提供的一种 API 接口&#xff0c;它可以帮助开发者获取拼多多商品的详细信息&#xff0c;包括商品的标题、描述、图片等信息。在拼多多电商平台的开发中&#xff0c;拼多多详情接口 API 是非常常用的 API&#xff0c;因此本文将详细介绍拼多多…

【C++】C++ 引用详解 ⑤ ( 函数 “ 引用类型返回值 “ 当左值被赋值 )

文章目录 一、函数返回值不能是 " 局部变量 " 的引用或指针1、函数返回值常用用法2、分析函数 " 普通返回值 " 做左值的情况3、分析函数 " 引用返回值 " 做左值的情况 函数返回值 能作为 左值 , 是很重要的概念 , 这是实现 " 链式编程 &quo…

淘宝API技术解析,实现关键词搜索淘宝商品(商品详情接口等)

淘宝提供了开放平台接口&#xff08;API&#xff09;来实现按图搜索淘宝商品的功能。您可以通过以下步骤来实现&#xff1a; 获取开放平台的访问权限&#xff1a;首先&#xff0c;您需要在淘宝开放平台创建一个应用&#xff0c;获取访问淘宝API的权限。具体的申请步骤和要求可以…

LabVIEW开发灭火器机器人

LabVIEW开发灭火器机器人 如今&#xff0c;自主机器人在行业中有着巨大的需求。这是因为它们根据不同情况的适应性。由于消防员很难进入高风险区域&#xff0c;自主机器人出现了。该机器人具有自行检测火灾的能力&#xff0c;并通过自己的决定穿越路径。 由于消防安全是主要问…

java八股文面试[java基础]——如何实现不可变的类

知识来源&#xff1a; 【23版面试突击】如何实现不可变的类&#xff1f;_哔哩哔哩_bilibili 【2023年面试】怎样声明一个类不会被继承&#xff0c;什么场景下会用_哔哩哔哩_bilibili

JVM 之字节码(.class)文件

本文中的内容参考B站尚硅谷宋红康JVM全套教程 你将获得&#xff1a; 1、掌握字节码文件的结构 2、掌握Java源代码如何在JVM中执行 3、掌握一些虚拟机指令 4、回答一些面试题 课程介绍 通过几个面试题初始字节码文件为什么学习class字节码文件什么是class字节码文件分析c…

2022年03月 C/C++(四级)真题解析#中国电子学会#全国青少年软件编程等级考试

第1题&#xff1a;拦截导弹 某国为了防御敌国的导弹袭击&#xff0c; 发展出一种导弹拦截系统。 但是这种导弹拦截系统有一个缺陷&#xff1a; 虽然它的第一发炮弹能够到达任意的高度&#xff0c;但是以后每一发炮弹都不能高于前一发的高度。 某天&#xff0c; 雷达捕捉到敌国的…

Vue3.0极速入门- 目录和文件说明

目录结构 以下文件均为npm create helloworld自动生成的文件目录结构 目录截图 目录说明 目录/文件说明node_modulesnpm 加载的项目依赖模块src这里是我们要开发的目录&#xff0c;基本上要做的事情都在这个目录里assets放置一些图片&#xff0c;如logo等。componentsvue组件…

SFM structure from motion

struction就是空间三维点的位置 motion 就是相机每帧的位移 https://www.youtube.com/watch?vUhkb8Zq-dnM&listPL2zRqk16wsdoYzrWStffqBAoUY8XdvatV&index9

西部AI小镇-构建自主虚拟世界

背景 未来曜文有接入市场上所有面向chatGPT开发的应用&#xff0c;例如开源聊天组件&#xff0c;西部小镇等 内容介绍 生成代理起床&#xff0c;做早餐&#xff0c;然后去上班&#xff1b;艺术家作画&#xff0c;作家写作&#xff1b;他们形成意见、互相关注并发起对话&…

Linux线程 --- 生产者消费者模型(C语言)

在学习完线程相关的概念之后&#xff0c;本节来认识一下Linux多线程相关的一个重要模型----“ 生产者消费者模型” 本文参考&#xff1a; Linux多线程生产者与消费者_红娃子的博客-CSDN博客 Linux多线程——生产者消费者模型_linux多线程生产者与消费者_两片空白的博客-CSDN博客…

基于Java+SpringBoot+Vue前后端分离党员教育和管理系统设计和实现

博主介绍&#xff1a;✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专…

ServiceManager接收APP的跨进程Binder通信流程分析

现在一起来分析Server端接收&#xff08;来自APP端&#xff09;Binder数据的整个过程&#xff0c;还是以ServiceManager这个Server为例进行分析,这是一个至下而上的分析过程。 在分析之前先思考ServiceManager是什么&#xff1f;它其实是一个独立的进程&#xff0c;由init解析i…