copy is all you need前向绘图 和疑惑标记

疑惑的起因
在这里插入图片描述
简化前向图
在这里插入图片描述
GPT4解释

这段代码实现了一个神经网络模型,包含了BERT、GPT-2和MLP等模块。主要功能是给定一个文本序列和一个查询序列,预测查询序列中的起始和结束位置,使其对应文本序列中的一个短语。具体实现细节如下:

首先,使用BERT将文本序列编码成一个张量作为输入,并从GPT-2模型中提取查询序列的表示。然后,将查询序列的表示拆分为起始和结束两部分,并使用MLP将短语的起始和结束点表示投影到另一个空间中。
接下来,利用生成式模型计算每个查询起始和结束位置之间的所有令牌的概率分布,并根据此生成预测位置。使用交叉熵损失函数来计算这些预测和真实值之间的差距。
最后,计算起始和结束位置的准确性和整个短语的准确性,并返回损失和准确性的值来优化模型的性能。

import torch
from transformers import AutoModel, AutoTokenizer, GPT2LMHeadModel
from torch import nn
from torch.nn import functional as Fclass Copyisallyouneed(nn.Module):def __init__(self, **args):super(Copyisallyouneed, self).__init__()self.args = args# bert-encoder modelself.phrase_encoder = AutoModel.from_pretrained(self.args['phrase_encoder_model'][self.args['lang']])self.bert_tokenizer = AutoTokenizer.from_pretrained(self.args['phrase_encoder_tokenizer'][self.args['lang']])self.bert_tokenizer.add_tokens(['<|endoftext|>', '[PREFIX]'])self.prefix_token_id = self.bert_tokenizer.convert_tokens_to_ids('[PREFIX]')self.phrase_encoder.resize_token_embeddings(self.phrase_encoder.config.vocab_size + 2)# model and tokenizerself.tokenizer = AutoTokenizer.from_pretrained(self.args['prefix_encoder_tokenizer'][self.args['lang']])self.vocab_size = len(self.tokenizer)self.pad = self.tokenizer.pad_token_id if self.args['lang'] == 'zh' else self.tokenizer.bos_token_idself.model = GPT2LMHeadModel.from_pretrained(self.args['prefix_encoder_model'][self.args['lang']])self.token_embeddings = nn.Parameter(list(self.model.lm_head.parameters())[0])# MLP: mapping bert phrase start representationsself.s_proj = nn.Sequential(nn.Dropout(p=args['dropout']),nn.Tanh(),nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size // 2))# MLP: mapping bert phrase end representationsself.e_proj = nn.Sequential(nn.Dropout(p=args['dropout']),nn.Tanh(),nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size // 2))self.gen_loss_fct = nn.CrossEntropyLoss(ignore_index=self.pad)@torch.no_grad()def get_query_rep(self, ids):self.eval()output = self.model(input_ids=ids, output_hidden_states=True)['hidden_states'][-1][:, -1, :]return outputdef get_token_loss(self, ids, hs, ids_mask):# no pad tokenlabel = ids[:, 1:]logits = torch.matmul(hs[:, :-1, :],self.token_embeddings.t())# TODO: inner loss function remove the temperature factorlogits /= self.args['temp']loss = self.gen_loss_fct(logits.view(-1, logits.size(-1)), label.reshape(-1))chosen_tokens = torch.max(logits, dim=-1)[1]gen_acc = (chosen_tokens.reshape(-1) == label.reshape(-1)).to(torch.long)valid_mask = (label != self.pad).reshape(-1)valid_tokens = gen_acc & valid_maskgen_acc = valid_tokens.sum().item() / valid_mask.sum().item()return loss, gen_accdef forward(self, batch):## gpt2 query encoderids, ids_mask = batch['gpt2_ids'], batch['gpt2_mask']last_hidden_states = \self.model(input_ids=ids, attention_mask=ids_mask, output_hidden_states=True).hidden_states[-1]# get token lossloss_0, acc_0 = self.get_token_loss(ids, last_hidden_states, ids_mask)## encode the document with the BERT encoder modeldids, dids_mask = batch['bert_ids'], batch['bert_mask']output = self.phrase_encoder(dids, dids_mask, output_hidden_states=True)['hidden_states'][-1]  # [B, S, E]# collect the phrase start representations and phrase end representationss_rep = self.s_proj(output)e_rep = self.e_proj(output)s_rep = s_rep.reshape(-1, s_rep.size(-1))e_rep = e_rep.reshape(-1, e_rep.size(-1))  # [B_doc*S_doc, 768//2]# collect the query representationsquery = last_hidden_states[:, :-1].reshape(-1, last_hidden_states.size(-1))query_start = query[:, :self.model.config.hidden_size // 2]query_end = query[:, self.model.config.hidden_size // 2:]# training the representations of the start tokenscandidate_reps = torch.cat([self.token_embeddings[:, :self.model.config.hidden_size // 2],s_rep], dim=0)logits = torch.matmul(query_start, candidate_reps.t())logits /= self.args['temp']# build the padding mask for query sidequery_padding_mask = ids_mask[:, :-1].reshape(-1).to(torch.bool)# build the padding mask: 1 for valid and 0 for maskattention_mask = (dids_mask.reshape(1, -1).to(torch.bool)).to(torch.long)padding_mask = torch.ones_like(logits).to(torch.long)# Santiy check overpadding_mask[:, self.vocab_size:] = attention_mask# build the position mask: 1 for valid and 0 for maskpos_mask = batch['pos_mask']start_labels, end_labels = batch['start_labels'][:, 1:].reshape(-1), batch['end_labels'][:, 1:].reshape(-1)position_mask = torch.ones_like(logits).to(torch.long)query_pos = start_labels > self.vocab_size# ignore the padding maskposition_mask[query_pos, self.vocab_size:] = pos_maskassert padding_mask.shape == position_mask.shape# overall maskoverall_mask = padding_mask * position_mask## remove the position mask# overall_mask = padding_masknew_logits = torch.where(overall_mask.to(torch.bool), logits, torch.tensor(-1e4).to(torch.half).cuda())mask = torch.zeros_like(new_logits)mask[range(len(new_logits)), start_labels] = 1.loss_ = F.log_softmax(new_logits[query_padding_mask], dim=-1) * mask[query_padding_mask]loss_1 = (-loss_.sum(dim=-1)).mean()## split the token accuaracy and phrase accuracyphrase_indexes = start_labels > self.vocab_sizephrase_indexes_ = phrase_indexes & query_padding_maskphrase_start_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == start_labels[phrase_indexes_]phrase_start_acc = phrase_start_acc.to(torch.float).mean().item()phrase_indexes_ = ~phrase_indexes & query_padding_masktoken_start_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == start_labels[phrase_indexes_]token_start_acc = token_start_acc.to(torch.float).mean().item()# training the representations of the end tokenscandidate_reps = torch.cat([self.token_embeddings[:, self.model.config.hidden_size // 2:],e_rep], dim=0)logits = torch.matmul(query_end, candidate_reps.t())  # [Q, B*]  logits /= self.args['temp']new_logits = torch.where(overall_mask.to(torch.bool), logits, torch.tensor(-1e4).to(torch.half).cuda())mask = torch.zeros_like(new_logits)mask[range(len(new_logits)), end_labels] = 1.loss_ = F.log_softmax(new_logits[query_padding_mask], dim=-1) * mask[query_padding_mask]loss_2 = (-loss_.sum(dim=-1)).mean()# split the phrase and token accuracyphrase_indexes = end_labels > self.vocab_sizephrase_indexes_ = phrase_indexes & query_padding_maskphrase_end_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == end_labels[phrase_indexes_]phrase_end_acc = phrase_end_acc.to(torch.float).mean().item()phrase_indexes_ = ~phrase_indexes & query_padding_masktoken_end_acc = new_logits[phrase_indexes_].max(dim=-1)[1] == end_labels[phrase_indexes_]token_end_acc = token_end_acc.to(torch.float).mean().item()return (loss_0,  # token lossloss_1,  # token-head lossloss_2,  # token-tail lossacc_0,  # token accuracyphrase_start_acc,phrase_end_acc,token_start_acc,token_end_acc)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/54087.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Win系统设置开机自启项及自定义自启程序

Win系统设置开机自启项及自定义自启程序 分用户自启动和系统自启动两种形式&#xff1a; 1. 用户自启动目录&#xff1a;C:\Users\Administrator\AppData\Roaming\Microsoft\Windows\Start Menu\Programs\Startup 用快速键打开&#xff1a; Win键R键&#xff0c;输入shell:…

JavaScript 基础知识回顾与复习---闭包

当我们说到闭包&#xff0c;在JavaScript中闭包是一个让人难以理解甚至说是一个近乎神话的概念。闭包往往也是面试必考的题目&#xff0c;如果能够掌握闭包对我们自己来说那也是一种极大的提升。在学习的过程中不要害怕闭包&#xff0c;闭包并不是一个新的语法或者模式&#xf…

sql server 快速安装

目录标题 一、下载二、直接选择基本安装二、下载ssms&#xff08;数据库图形化操作页面&#xff09;三、开启sa账号认证&#xff08;一&#xff09;第一步&#xff1a;更改身份验证模式&#xff08;二&#xff09;第二步&#xff1a;启用 sa 登录四、开启tcp/ip 一、下载 下载…

低通滤波器和高通滤波器

应用于图像低通滤波器和高通滤波器的实现 需要用到傅里叶变换 #include <opencv2/opencv.hpp> #include <Eigen> #include <iostream> #include <vector> #include <cmath> #include <complex>#define M_PI 3.14159265358979323846…

QT5.12.12通过ODBC连接到GBase 8s数据库(CentOS)

本示例使用的环境如下&#xff1a; 硬件平台&#xff1a;x86_64&#xff08;amd64&#xff09;操作系统&#xff1a;CentOS 7.8 2003数据库版本&#xff08;含CSDK&#xff09;&#xff1a;GBase 8s V8.8 3.0.0_1 为什么使用QT 5.12.10&#xff1f;该版本包含QODBC。 1&#…

ES6中promise的使用

ES6中promise的使用 本文目录 ES6中promise的使用基础介绍箭头函数function函数状态 原型方法Promise.prototype.then()Promise.prototype.catch() 静态方法Promise.all()Promise.race()Promise.any() 链式回调 基础介绍 官网&#xff1a;https://promisesaplus.com/ window.…

最新docker多系统安装技术

在Ubuntu操作系统中安装Docker 在Ubuntu操作系统中安装Docker的步骤如下。 1&#xff0e;卸载旧版本Docker 卸载旧版本Docker的命令如下&#xff1a; $ sudo apt-get remove docker docker-engine docker.io 2&#xff0e;使用脚本自动安装 在测试或开发环境中&#xff0…

STM32 进不了main 函数

1. 我用的是STM32L151C8T6 的芯片&#xff0c;在github 上找了个别人的例程&#xff0c;拿来当模板改&#xff0c;由于他用的是HSE 外部晶振&#xff0c;我用的是内部晶振HSI&#xff0c;所以需要改系统时钟&#xff0c;改完后debug&#xff0c; 一直进不了main 函数&#xff0…

c语言练习题29:获得月份天数

获得月份天数 代码&#xff1a; //法一 #include<stdio.h> int main() {int y 0;int m 0;int days[13] { 0,31,28,31,30,31,30,31,31,30,31,30,31 };while (~scanf("%d%*c%d", &y, &m)) {int day days[m];if ((y % 4 0) && ((y % 400 …

图像线段检测几种方法

1、方法一 当我将OpenCV提升到4.1.0时&#xff0c;LineSegmentDetector&#xff08;LSD&#xff09;消失了。 OpenCV-contrib有一个名为FastLineDetector的东西&#xff0c;如果它被用作LSD的替代品似乎很好。如果你有点感动&#xff0c;你会得到与LSD几乎相同的结果。 2、方…

PHP“牵手”拼多多商品详情数据获取方法,拼多多API接口批量获取拼多多商品详情数据说明

拼多多商品详情接口 API 是开放平台提供的一种 API 接口&#xff0c;它可以帮助开发者获取拼多多商品的详细信息&#xff0c;包括商品的标题、描述、图片等信息。在拼多多电商平台的开发中&#xff0c;拼多多详情接口 API 是非常常用的 API&#xff0c;因此本文将详细介绍拼多多…

前端性能优化之js优化

文章目录 引言一、浏览器加载js文件过程二、浏览器加载js和图片的对比三、浏览器加载js资源占总资源加载时间的比例四、v8的编译原理概述五、代码层面优化&#xff0c;提高V8编译效率1. 函数优化1. 减少函数大小和复杂度2. 避免使用动态特性3. 避免使用eval()和with语句4. 使用…

【C++】C++ 引用详解 ⑤ ( 函数 “ 引用类型返回值 “ 当左值被赋值 )

文章目录 一、函数返回值不能是 " 局部变量 " 的引用或指针1、函数返回值常用用法2、分析函数 " 普通返回值 " 做左值的情况3、分析函数 " 引用返回值 " 做左值的情况 函数返回值 能作为 左值 , 是很重要的概念 , 这是实现 " 链式编程 &quo…

淘宝API技术解析,实现关键词搜索淘宝商品(商品详情接口等)

淘宝提供了开放平台接口&#xff08;API&#xff09;来实现按图搜索淘宝商品的功能。您可以通过以下步骤来实现&#xff1a; 获取开放平台的访问权限&#xff1a;首先&#xff0c;您需要在淘宝开放平台创建一个应用&#xff0c;获取访问淘宝API的权限。具体的申请步骤和要求可以…

Redis中 为什么Lua脚本可以保证原子性?

Redis中 为什么Lua脚本可以保证原子性&#xff1f;

LabVIEW开发灭火器机器人

LabVIEW开发灭火器机器人 如今&#xff0c;自主机器人在行业中有着巨大的需求。这是因为它们根据不同情况的适应性。由于消防员很难进入高风险区域&#xff0c;自主机器人出现了。该机器人具有自行检测火灾的能力&#xff0c;并通过自己的决定穿越路径。 由于消防安全是主要问…

java八股文面试[java基础]——如何实现不可变的类

知识来源&#xff1a; 【23版面试突击】如何实现不可变的类&#xff1f;_哔哩哔哩_bilibili 【2023年面试】怎样声明一个类不会被继承&#xff0c;什么场景下会用_哔哩哔哩_bilibili

python 把 易语言转成python 易语言和Python通信的实现

python 把 易语言转成python Python作为一种高效的编程语言&#xff0c;已经越来越受到开发者的欢迎。易语言是一种极为流行的编程语言&#xff0c;也有非常多的用户。然而&#xff0c;由于易语言语法比较简单&#xff0c;对于一些高级编程需求可能无法满足&#xff0c;对于需…

SystemVerilog中的Program的学习笔记

1、SystemVerilog中的Program的作用&#xff1f; 将验证部分与设计部分进行隔离&#xff08;实现方式就是将软件验证部分放置program中&#xff09;2、SystemVerilog中的Program结束方式&#xff1f; Program结束方式分为两种&#xff1a;1、隐式结束 2、显式结束 1、隐式结束…

JVM 之字节码(.class)文件

本文中的内容参考B站尚硅谷宋红康JVM全套教程 你将获得&#xff1a; 1、掌握字节码文件的结构 2、掌握Java源代码如何在JVM中执行 3、掌握一些虚拟机指令 4、回答一些面试题 课程介绍 通过几个面试题初始字节码文件为什么学习class字节码文件什么是class字节码文件分析c…