pytorch版本的bert模型代码(MLM)

魔改bert就必须要知道Bert的结构:

主要解答与BertForMaskedLM(MLM)有关的类:

下面是MLM的分类头:

class BertLMPredictionHead(nn.Module):def __init__(self, config, bert_model_embedding_weights):super(BertLMPredictionHead, self).__init__()self.transform = BertPredictionHeadTransform(config)# The output weights are the same as the input embeddings, but there is# an output-only bias for each token.self.decoder = nn.Linear(bert_model_embedding_weights.size(1),bert_model_embedding_weights.size(0),bias=False)self.decoder.weight = bert_model_embedding_weightsself.bias = nn.Parameter(torch.zeros(bert_model_embedding_weights.size(0)))"""上面是创建一个线性映射层, 把transformer block输出的[batch_size, seq_len, embed_dim]映射为[batch_size, seq_len, vocab_size], 也就是把最后一个维度映射成字典中字的数量, 获取MaskedLM的预测结果, 注意这里其实也可以直接矩阵成embedding矩阵的转置, 但一般情况下我们要随机初始化新的一层参数"""def forward(self, hidden_states):hidden_states = self.transform(hidden_states)hidden_states = self.decoder(hidden_states) + self.biasreturn hidden_statesclass BertOnlyMLMHead(nn.Module):def __init__(self, config, bert_model_embedding_weights):super(BertOnlyMLMHead, self).__init__()self.predictions = BertLMPredictionHead(config, bert_model_embedding_weights)def forward(self, sequence_output):prediction_scores = self.predictions(sequence_output)return prediction_scores

另一个类:

class BertModel(BertPreTrainedModel):"""BERT model ("Bidirectional Embedding Representations from a Transformer").Params:config: a BertConfig class instance with the configuration to build a new modelInputs:`input_ids`: a torch.LongTensor of shape [batch_size, sequence_length]with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts`extract_features.py`, `run_classifier.py` and `run_squad.py`)`token_type_ids`: an optional torch.LongTensor of shape [batch_size, sequence_length] with the tokentypes indices selected in [0, 1]. Type 0 corresponds to a `sentence A` and type 1 corresponds toa `sentence B` token (see BERT paper for more details).`attention_mask`: an optional torch.LongTensor of shape [batch_size, sequence_length] with indicesselected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the maxinput sequence length in the current batch. It's the mask that we typically use for attention whena batch has varying length sentences.`output_all_encoded_layers`: boolean which controls the content of the `encoded_layers` output as described below. Default: `True`.Outputs: Tuple of (encoded_layers, pooled_output)`encoded_layers`: controled by `output_all_encoded_layers` argument:- `output_all_encoded_layers=True`: outputs a list of the full sequences of encoded-hidden-states at the endof each attention block (i.e. 12 full sequences for BERT-base, 24 for BERT-large), eachencoded-hidden-state is a torch.FloatTensor of size [batch_size, sequence_length, hidden_size],- `output_all_encoded_layers=False`: outputs only the full sequence of hidden-states correspondingto the last attention block of shape [batch_size, sequence_length, hidden_size],`pooled_output`: a torch.FloatTensor of size [batch_size, hidden_size] which is the output of aclassifier pretrained on top of the hidden state associated to the first character of theinput (`CLS`) to train on the Next-Sentence task (see BERT's paper).Example usage:```python# Already been converted into WordPiece token idsinput_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])config = modeling.BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072)model = modeling.BertModel(config=config)all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask)```"""def __init__(self, config):super(BertModel, self).__init__(config)self.embeddings = BertEmbeddings(config)self.encoder = BertEncoder(config)self.pooler = BertPooler(config)self.apply(self.init_bert_weights)def forward(self, input_ids, positional_enc, token_type_ids=None, attention_mask=None,output_all_encoded_layers=True, get_attention_matrices=False):if attention_mask is None:# torch.LongTensor# attention_mask = torch.ones_like(input_ids)attention_mask = (input_ids > 0)# attention_mask [batch_size, length]if token_type_ids is None:token_type_ids = torch.zeros_like(input_ids)# We create a 3D attention mask from a 2D tensor mask.# Sizes are [batch_size, 1, 1, to_seq_length]# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]# this attention mask is more simple than the triangular masking of causal attention# used in OpenAI GPT, we just need to prepare the broadcast dimension here.extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)# 注意力矩阵mask: [batch_size, 1, 1, seq_length]# Since attention_mask is 1.0 for positions we want to attend and 0.0 for# masked positions, this operation will create a tensor which is 0.0 for# positions we want to attend and -10000.0 for masked positions.# Since we are adding it to the raw scores before the softmax, this is# effectively the same as removing these entirely.extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibilityextended_attention_mask = (1.0 - extended_attention_mask) * -10000.0# 给注意力矩阵里padding的无效区域加一个很大的负数的偏置, 为了使softmax之后这些无效区域仍然为0, 不参与后续计算# embedding层embedding_output = self.embeddings(input_ids, positional_enc, token_type_ids)# 经过所有定义的transformer block之后的输出encoded_layers, all_attention_matrices = self.encoder(embedding_output,extended_attention_mask,output_all_encoded_layers=output_all_encoded_layers,get_attention_matrices=get_attention_matrices)# 可输出所有层的注意力矩阵用于可视化if get_attention_matrices:return all_attention_matrices# [-1]为最后一个transformer block的隐藏层的计算结果sequence_output = encoded_layers[-1]# pooled_output为隐藏层中#CLS#对应的token的一条向量pooled_output = self.pooler(sequence_output)if not output_all_encoded_layers:encoded_layers = encoded_layers[-1]return encoded_layers, pooled_output

参考:https://blog.51cto.com/u_15060462/4254056

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/835636.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C++】-------反向迭代器的模拟实现(补充)

目录 前言 一、反向迭代器接口(用户层) 二、模拟实现 三、以vector模拟实现为例 四、总结 前言 在vector和list的接口中我们实际上有说明过反向迭代器的用法,这里就有个问题,并不是只有这两个容器存在反向迭代器的。那么对于他…

点云DBSCAN聚类,同时获取最多点数量的类,同时删除其他的类并显示

代码的主要目的是处理一个点云文件(从某个巷道或类似环境中获取的),并尝试识别并可视化其中的主要结构(比如墙壁),同时去除可能的噪声和异常点。它首先读取一个点云文件,进行降采样和异常点移除,然后使用DBSCAN聚类算法对剩余的点云进行聚类,最后选择并可视化包含最多…

力扣爆刷第133天之动态规划收尾(距离编辑与回文子串)

力扣爆刷第133天之动态规划收尾(距离编辑与回文子串) 文章目录 力扣爆刷第133天之动态规划收尾(距离编辑与回文子串)一、72. 编辑距离二、647. 回文子串三、516. 最长回文子序列 一、72. 编辑距离 题目链接:https://l…

应用案例 | 商业电气承包商借助Softing NetXpert XG2节省网络验证时间

一家提供全方位服务的电气承包商通过使用Softing NetXpert XG2顺利完成了此次工作任务——简化了故障排查的同时,还在很大程度上减少了不必要的售后回访。 对已经安装好的光纤或铜缆以太网网络进行认证测试可能会面临不同的挑战,这具体取决于网络的规模、…

示例五、气敏传感器

通过以下几个示例来具体展开学习,了解气敏传感器原理及特性,学习气敏传感器的应用: 示例五、气敏传感器 一、基本原理:随着人们生活水平的不断提高,人们对环境和健康问题越来越重视。各种燃气的广泛使用,使生产效率和…

socket编程 学习笔记 理解

在使用socket(也就是套接字)编程的时候,其实是工作于应用层和传输层之间 如果使用的是基于TCP的socket,那每个数据包的发送的过程大致为: 数据通过socket套接字构造符合TCP协议的数据包在屏蔽底层协议的情况下&#…

多模态CLIP和BLIP

一、CLIP 全称为Contrastive Language-Image Pre-Training用于做图-文匹配,部署在预训练阶段,最终理解为图像分类器。 1.背景 以前进行分类模型时,存在类别固定和训练时要进行标注。因此面对这两个问题提出CLIP,通过这个预训练…

显式Intent

activity.xml <?xml version"1.0" encoding"utf-8"?> <androidx.constraintlayout.widget.ConstraintLayout xmlns:android"http://schemas.android.com/apk/res/android"xmlns:app"http://schemas.android.com/apk/res-auto&q…

B 站评论系统架构设计难点

更多大厂面试内容可见 -> http://11come.cn B 站评论系统架构设计难点 这里整理一下在哔哩哔哩技术公众号看到的 B 站评论系统的架构设计文章&#xff0c;自己在学习过程中&#xff0c;对其中感觉比较有帮助的点做一下整理&#xff0c;方便以后查阅&#xff0c;详细版可以点…

你要顺着毛撸Rust——简评LogLogGames放弃Rust游戏开发

庄晓立/LIIGO&#xff0c;2024年5月11日。 上个月底&#xff0c;游戏开发工作室LogLogGames发文《Leaving Rust gamedev after 3 years》&#xff0c;声明在经历3年磨难后决定放弃用Rust语言开发游戏。万字长文&#xff0c;开启吐槽模式&#xff0c;引发国内外大量争论。 我尊…

Portforge:一款功能强大的轻量级端口混淆工具

关于Portforge Portforge是一款功能强大的轻量级端口混淆工具&#xff0c;该工具使用Crystal语言开发&#xff0c;可以帮助广大研究人员防止网络映射&#xff0c;这样一来&#xff0c;他人就无法查看到你设备正在运行&#xff08;或没有运行&#xff09;的服务和程序了。简而言…

XML属性

XML属性是XML元素的附加信息&#xff0c;它们提供有关元素的更多细节或定义其行为的方式。属性通常被包含在开始标签中&#xff0c;并使用关键字“键值对”的形式表示。 下面是几个示例以说明XML属性的使用&#xff1a; HTML元素的属性&#xff1a; <p id"paragraph…

邂逅Linux--常见指令,万物为文件(一)

引子&#xff1a;在之前&#xff0c;我们经常听到Linux&#xff0c;那什么是Linux呢&#xff1f;Linux是一种免费使用和自由传播的类UNIX操作系统&#xff0c;其内核由林纳斯本纳第克特托瓦兹&#xff08;Linus Benedict Torvalds&#xff09;于1991年10月5日首次发布&#xff…

力扣每日一题-统计已测试设备-2024.5.10

力扣题目&#xff1a;统计已测试设备 题目链接: 2960.统计已测试设备 题目描述 代码思路 根据题目内容&#xff0c;第一感是根据题目模拟整个过程&#xff0c;在每一步中修改所有设备的电量百分比。但稍加思索&#xff0c;发现可以利用已测试设备的数量作为需要减少的设备电…

Spring底层入门(十)

1、内嵌tomcat boot框架是默认内嵌tomcat的&#xff0c;不需要手动安装和配置外部的 Servlet 容器。 简单的介绍一下tomcat服务器的构成&#xff1a; Catalina&#xff1a; Catalina 是 Tomcat 的核心组件&#xff0c;负责处理 HTTP 请求、响应以及管理 Servlet 生命周期。它包…

Vitis HLS 学习笔记--理解串流Stream(1)

目录 1. 介绍 2. 示例 2.1 代码解析 2.2 串流数据类型 2.3 综合报告 3. 总结 1. 介绍 在Vitis HLS中&#xff0c;hls::stream是一个用于在C/C中进行高级合成的关键数据结构。它类似于C标准库中的std::stream&#xff0c;但是专门设计用于硬件描述语言&#xff08;如Veri…

基于springboot实现贸易行业crm系统项目【项目源码+论文说明】计算机毕业设计

基于springboot实现贸易行业crm系统演示 摘要 随着信息技术在管理上越来越深入而广泛的应用&#xff0c;管理信息系统的实施在技术上已逐步成熟。本文介绍了基于springboot的贸易行业crm系统的开发全过程。通过分析基于springboot的贸易行业crm系统管理的不足&#xff0c;创建…

【栈】Leetcode 字符串解码

题目讲解 394. 字符串解码 算法讲解 这道题有四种情况&#xff1a;1.遍历的时候遇到数字&#xff0c;我们计算并保存数字&#xff0c;将它加入到数字栈中&#xff1b;2.遍历的时候遇到[&#xff0c;我们就把字符保存&#xff0c;加入到字符栈中&#xff1b;3.当遇到]&#x…

如何远程控制另一部手机:远程控制使用方法

在现今高科技的社会中&#xff0c;远程控制手机的需求在某些情境下变得越来越重要。不论是为了协助远在他乡的家人解决问题&#xff0c;还是为了确保孩子的在线安全&#xff0c;了解如何实现这一功能都是有益的。本文将为您简要介绍几种远程控制手机的方法及其使用要点。 KKVi…

探索前端技术的未来:新兴工具与框架的引领

随着互联网的迅速发展&#xff0c;前端技术也在不断演进。作为前端开发者&#xff0c;我们时刻都要保持对新兴工具和框架的关注&#xff0c;以便跟上技术的脚步&#xff0c;同时也为自己的职业发展做好准备。在这篇文章中&#xff0c;我们将探索前端技术的未来趋势&#xff0c;…