【Python】科研代码学习:三 PreTrainedModel, PretrainedConfig, PreTrainedTokenizer

【Python】科研代码学习:三 PreTrainedModel, PretrainedConfig, PreTrainedTokenizer

  • 前言
  • Models : PreTrainedModel
    • PreTrainedModel 中重要的方法
  • tensorflow & pytorch 简单对比
  • Configuration : PretrainedConfig
    • PretrainedConfig 中重要的方法
  • Tokenizer : PreTrainedTokenizer
    • PreTrainedTokenizer 中重要的方法

前言

  • HF 官网API
    本文主要从官网API与源代码中学习调用HF的关键模组

Models : PreTrainedModel

  • HF 提供的基础模型类有 PreTrainedModel, TFPreTrainedModel, and FlaxPreTrainedModel
  • 这三者有什么区别呢
    PreTrainedModel 指的是用 torch 的框架
    在这里插入图片描述
    TFPreTrainedModel 指的是用 tensorflow 框架
    在这里插入图片描述
    FlaxPreTrainedModel 指的是用 flax 框架,是用 jax 做的
    在这里插入图片描述
    (哈哈,搜了好久都没搜到,去看源码导包瞬间明白了,也可能是我比较笨)
  • Transformers的大部分模型都会继承PretrainedModel基类。PretrainedModel主要负责管理模型的配置,模型的参数加载、下载和保存。
  • PretrainedModel继承自 nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin
    在初始化时需要提供给它一个 config: PretrainedConfig
  • 所以,我们可以视为它是所有模型的基类
    可以看到很多其他代码在判断模型类型时,一般写 model: Union[PreTrainedModel, nn.Module]

PreTrainedModel 中重要的方法

  • push_to_hub:将模型传到HF hub
from transformers import AutoModelmodel = AutoModel.from_pretrained("google-bert/bert-base-cased")# Push the model to your namespace with the name "my-finetuned-bert".
model.push_to_hub("my-finetuned-bert")# Push the model to an organization with the name "my-finetuned-bert".
model.push_to_hub("huggingface/my-finetuned-bert")
  • from_pretrained:根据config实例化预训练pytorch模型(Instantiate a pretrained pytorch model from a pre-trained model configuration.)
    默认使用评估模式 .eval()
    可以打开训练模式 .train()

    看下面的例子,可以从官方加载,也可以从本地模型参数加载。如果本地参数是tf的,转pytorch需要设置 from_tf=True,并且会慢些;本地参数是flax的话类似同理。
from transformers import BertConfig, BertModel# Download model and configuration from huggingface.co and cache.
model = BertModel.from_pretrained("google-bert/bert-base-uncased")
# Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
model = BertModel.from_pretrained("./test/saved_model/")
# Update configuration during loading.
model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
assert model.config.output_attentions == True
# Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json")
model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
# Loading from a Flax checkpoint file instead of a PyTorch model (slower)
model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)

可以给 torch_dtype 设置数据类型。若不给,则默认为 torch.float16。也可以给 torch_dtype="auto"

  • get_input_embeddings:获得输入的词嵌入在这里插入图片描述
    对应还有 get_output_embeddings
  • init_weights:设置参数初始化
    如果需要自己调整参数初始化的,在 _init_weights_initialize_weights 中设置
  • save_pretrained:把模型和配置参数保存在文件夹中
    保存完后,便可以通过 from_pretrained 再次加载模型了
    在这里插入图片描述

tensorflow & pytorch 简单对比

  • 知乎:Tensorflow 到底比 Pytorch 好在哪里?
    下面截取了比较重要的图
    在这里插入图片描述
    在这里插入图片描述
    在这里插入图片描述
  • 里面还提到了一个内容叫做 Keras

Keras是一个由Python编写的开源人工神经网络库,可以作为Tensorflow、Microsoft-CNTK和Theano的高阶应用程序接口,进行深度学习模型的设计、调试、评估、应用和可视化

Configuration : PretrainedConfig

  • 刚才看了,对于 PretrainedModel 初始化提供的参数是 PretrainedConfig 类型的参数。
    它主要为不同的任务,提供了不同的重要参数
    HF官网:PretrainedConfig
  • 列一下对于NLP中比较重要的参数吧,所有的就看官方文档吧
返回信息
output_hidden_states (bool, optional, defaults to False) — Whether or not the model should return all hidden-states.
output_attentions (bool, optional, defaults to False) — Whether or not the model should returns all attentions.
return_dict (bool, optional, defaults to True) — Whether or not the model should return a ModelOutput instead of a plain tuple.
output_scores (bool, optional, defaults to False) — Whether the model should return the logits when used for generation.
return_dict_in_generate (bool, optional, defaults to False) — Whether the model should return a ModelOutput instead of a torch.LongTensor.序列生成
max_length (int, optional, defaults to 20) — Maximum length that will be used by default in the generate method of the model.
min_length (int, optional, defaults to 0) — Minimum length that will be used by default in the generate method of the model.
do_sample (bool, optional, defaults to False) — Flag that will be used by default in the generate method of the model. Whether or not to use sampling ; use greedy decoding otherwise.
num_beams (int, optional, defaults to 1) — Number of beams for beam search that will be used by default in the generate method of the model. 1 means no beam search.
diversity_penalty (float, optional, defaults to 0.0) — Value to control diversity for group beam search. that will be used by default in the generate method of the model. 0 means no diversity penalty. The higher the penalty, the more diverse are the outputs.
temperature (float, optional, defaults to 1.0) — The value used to module the next token probabilities that will be used by default in the generate method of the model. Must be strictly positive.
top_k (int, optional, defaults to 50) — Number of highest probability vocabulary tokens to keep for top-k-filtering that will be used by default in the generate method of the model.
top_p (float, optional, defaults to 1) — Value that will be used by default in the generate method of the model for top_p. If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation.
epetition_penalty (float, optional, defaults to 1) — Parameter for repetition penalty that will be used by default in the generate method of the model. 1.0 means no penalty.
length_penalty (float, optional, defaults to 1) — Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log likelihood of the sequence (i.e. negative), length_penalty > 0.0 promotes longer sequences, while length_penalty < 0.0 encourages shorter sequences.
bad_words_ids (List[int], optional) — List of token ids that are not allowed to be generated that will be used by default in the generate method of the model. In order to get the tokens of the words that should not appear in the generated text, use tokenizer.encode(bad_word, add_prefix_space=True).tokenizer相关
bos_token_id (int, optional) — The id of the beginning-of-stream token.
pad_token_id (int, optional) — The id of the padding token.
eos_token_id (int, optional) — The id of the end-of-stream token.PyTorch相关
torch_dtype (str, optional) — The dtype of the weights. This attribute can be used to initialize the model to a non-default dtype (which is normally float32) and thus allow for optimal storage allocation. For example, if the saved model is float16, ideally we want to load it back using the minimal amount of memory needed to load float16 weights. Since the config object is stored in plain text, this attribute contains just the floating type string without the torch. prefix. For example, for torch.float16 `torch_dtype is the "float16" string.常见参数
vocab_size (int) — The number of tokens in the vocabulary, which is also the first dimension of the embeddings matrix (this attribute may be missing for models that don’t have a text modality like ViT).
hidden_size (int) — The hidden size of the model.
num_attention_heads (int) — The number of attention heads used in the multi-head attention layers of the model.
num_hidden_layers (int) — The number of blocks in the model.

PretrainedConfig 中重要的方法

  • push_to_hub:依然是上传到 HF hub
  • from_dict:把一个 dict 类型转到 PretrainedConfig 类型
  • from_json_file:把一个 json 文件转到 PretrainedConfig 类型,传入的是文件路径
  • to_dict:转成 dict 类型
  • to_json_file:保存到 json 文件
  • to_json_string:转成 json 字符串
  • from_pretrained:从预训练模型配置文件中直接获取配置
    可以是HF模型,也可以是本地模型,见下方例子
# We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
# derived class: BertConfig
config = BertConfig.from_pretrained("google-bert/bert-base-uncased"
)  # Download configuration from huggingface.co and cache.
config = BertConfig.from_pretrained("./test/saved_model/"
)  # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
config = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False)
assert config.output_attentions == True
config, unused_kwargs = BertConfig.from_pretrained("google-bert/bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
)
assert config.output_attentions == True
assert unused_kwargs == {"foo": False}
  • save_pretrained:把配置文件保存到文件夹中,方便下次 from_pretrained 直接读取

Tokenizer : PreTrainedTokenizer

  • HF官网:PreTrainedTokenizer
    Tokenizer 是用来把输入的字符串,转成 id 数组用的
    先来看一下其中相关的类的继承关系
    在这里插入图片描述
  • PreTrainedTokenizer 的初始化方法是直接给了 **kwargs
    调几个重要的列在下面,可以看到大部分都是设置一些token的含义。
bos_token (str or tokenizers.AddedToken, optional) — A special token representing the beginning of a sentence. Will be associated to self.bos_token and self.bos_token_id.
eos_token (str or tokenizers.AddedToken, optional) — A special token representing the end of a sentence. Will be associated to self.eos_token and self.eos_token_id.
unk_token (str or tokenizers.AddedToken, optional) — A special token representing an out-of-vocabulary token. Will be associated to self.unk_token and self.unk_token_id.
sep_token (str or tokenizers.AddedToken, optional) — A special token separating two different sentences in the same input (used by BERT for instance). Will be associated to self.sep_token and self.sep_token_id.
pad_token (str or tokenizers.AddedToken, optional) — A special token used to make arrays of tokens the same size for batching purpose. Will then be ignored by attention mechanisms or loss computation. Will be associated to self.pad_token and self.pad_token_id.
cls_token (str or tokenizers.AddedToken, optional) — A special token representing the class of the input (used by BERT for instance). Will be associated to self.cls_token and self.cls_token_id.
mask_token (str or tokenizers.AddedToken, optional) — A special token representing a masked token (used by masked-language modeling pretraining objectives, like BERT). Will be associated to self.mask_token and self.mask_token_id.

PreTrainedTokenizer 中重要的方法

  • add_tokens:添加一些新的token
    它强调了,添加新token需要确保 token 嵌入矩阵与tokenizer是匹配的,即多调用一下 resize_token_embeddings 方法
    在这里插入图片描述
# Let's see how to increase the vocabulary of Bert model and tokenizer
tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased")
model = BertModel.from_pretrained("google-bert/bert-base-uncased")num_added_toks = tokenizer.add_tokens(["new_tok1", "my_new-tok2"])
print("We have added", num_added_toks, "tokens")
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))
  • add_special_tokens:添加特殊tokens,比如之前的 eos,pad 等,与之前普通的tokens是不大一样的,但要确保该token不在词汇表里
# Let's see how to add a new classification token to GPT-2
tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2")
model = GPT2Model.from_pretrained("openai-community/gpt2")special_tokens_dict = {"cls_token": "<CLS>"}num_added_toks = tokenizer.add_special_tokens(special_tokens_dict)
print("We have added", num_added_toks, "tokens")
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e., the length of the tokenizer.
model.resize_token_embeddings(len(tokenizer))assert tokenizer.cls_token == "<CLS>"
  • encode, decode:字符串转id数组,id数组转字符串,即词嵌入
    encodeself.convert_tokens_to_ids(self.tokenize(text)) 等价
    decodeself.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids)) 等价
  • tokenize:把字符串转成token序列,即分词 str → list[str]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/730614.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java基础面试题(day 01)

&#x1f4d1;前言 本文主要是【Java】——Java基础面试题的文章&#xff0c;如果有什么需要改进的地方还请大佬指出⛺️ &#x1f3ac;作者简介&#xff1a;大家好&#xff0c;我是听风与他&#x1f947; ☁️博客首页&#xff1a;CSDN主页听风与他 &#x1f304;每日一句&am…

C++ 篇 数组

数组是含有多个数据项的数据结构&#xff0c;并且这些数据项都具有相同的数据类型。这些数据项称为数组的元素&#xff0c;我们可以根据元素在数组中的位置来选取元素。 最简单的数组就是一维数组。数组元素在内存中是依次排列的&#xff0c;如下图所示&#xff1a; 声明一个…

C++之创建与使用dll

目录 1、创建dll test.h test.cpp Source.def 2、使用dll testdll.cpp DLL&#xff0c;全称“Dynamic Link Library”&#xff0c;中文名为“动态链接库”&#xff0c;是一种在Windows操作系统中常见的库文件格式。它包含了可以由多个程序同时使用的代码和数据。与静态链接…

09 函数和存储过程

文章目录 函数和存储过程函数创建存储过程创建存储过程和存储函数操作函数和存储过程区别 函数和存储过程 存储过程和函数是事先经过编译并存储在数据库中的一段sql语句集合&#xff0c;调用存储过程和函数可以简化应用开发工作&#xff0c;提高数据处理的效率。 函数创建 d…

人工智能|机器学习——k-近邻算法(KNN分类算法)

1.简介 k-最近邻算法&#xff0c;也称为 kNN 或 k-NN&#xff0c;是一种非参数、有监督的学习分类器&#xff0c;它使用邻近度对单个数据点的分组进行分类或预测。虽然它可以用于回归问题&#xff0c;但它通常用作分类算法&#xff0c;假设可以在彼此附近找到相似点。 对于分类…

五个与iOS基础开发相关的案例:

iOS是由苹果公司开发的移动操作系统&#xff0c;专为iPhone、iPad和iPod touch等设备设计。iOS系统以其流畅的用户体验、丰富的功能和强大的安全性而著称&#xff0c;成为全球最受欢迎的移动操作系统之一。iOS基础开发则是构建在这些设备上的应用程序的过程&#xff0c;涉及多个…

JavaScript—— 运算符总结(超全)

JavaScript—— 运算符总结(超全) 1.小括号运算符 ​ ()在我们js执行代码的过程中&#xff0c;一行代码内&#xff0c;优先执行小括号里面的内容; 2. 自增和自减运算符&#xff08;一元运算符&#xff09; 自增&#xff1a;让当前变量1的意思 let num 2 num // 3 num // …

Linux shell 列举当前所有网卡的IPV4地址及网卡名。

命令一&#xff1a; ip -4 addr show | grep inet | awk { printf "%s ", $2; for (i5; i<NF; i) { printf "%s ", $i }; printf "\n" } | awk {print $1, $NF} 命令二&#xff1a; 忽略 lo 环路网卡 ip -4 addr show | grep inet | awk …

标志寄存器

文章目录 标志寄存器是什么ZF标志PF标志SF标志CF标志OF标志adc指令sbb指令cmp指令有条件的转移指令DF标志和串传送指令pushf和popf 标志寄存器是什么 在8086CPU中标志寄存器是一个特殊的寄存器&#xff0c;具有以下3中功能&#xff1a; 1.用来存储相关指令的某些执行结果 2.用…

.SVN 信息泄露漏洞原理以及修复方法

漏洞名称&#xff1a;.SVN信息泄露、版本管理工具文件信息泄漏 漏洞描述&#xff1a;据介绍&#xff0c;SVN&#xff08;subversion&#xff09;是程序员常用的源代码版本管理软件。一旦网站出现SVN 漏洞&#xff0c;其危害远比SQL注入等其它常见网站漏洞更为致命&#xff0c;…

餐饮行业新风口:社区店的成功案例与经营秘诀

在竞争激烈的餐饮行业中&#xff0c;社区店正成为一个新的风口。作为一名90后的鲜奶吧创业者&#xff0c;我在社区开店已经5年时间&#xff0c;下面我将分享一些成功的社区店案例&#xff0c;并揭示其经营秘诀。 1、案例一&#xff1a;特色小吃店 这家小吃店以地方特色美食为…

MySQL安装与卸载

安装 1). 双击官方下来的安装包文件 2). 根据安装提示进行安装(全部默认就可以) 安装MySQL的相关组件&#xff0c;这个过程可能需要耗时几分钟&#xff0c;耐心等待。 输入MySQL中root用户的密码,一定记得记住该密码 配置 安装好MySQL之后&#xff0c;还需要配置环境变量&am…

平台总线--ID匹配和设备树匹配

一、ID匹配之框架代码 id匹配&#xff08;可想象成八字匹配&#xff09;&#xff1a;一个驱动可以对应多个设备 ------优先级次低 注意事项&#xff1a; device模块中&#xff0c;id的name成员必须与struct platform_device中的name成员内容一致&#xff0c;因此device模块中…

数据结构与算法-插值查找

引言 在计算机科学的广阔天地中&#xff0c;数据结构和算法扮演着至关重要的角色。它们优化了信息处理的方式&#xff0c;使得我们在面对海量数据时能够高效、准确地进行检索与分析。本文将聚焦于一种基于有序数组且利用元素分布规律的查找算法——插值查找&#xff08;Interpo…

C++面向对象程序设计-北京大学-郭炜【课程笔记(五)】

C面向对象程序设计-北京大学-郭炜【课程笔记&#xff08;五&#xff09;】 1、常量对象、常量成员函数1.1、常量对象1.2、常量成员函数1.3、常引用 2、友元&#xff08;friends&#xff09;2.1、友元函数2.2、友元类 3、运算符重载的基本概念3.1、运算符重载 4、赋值运算符的重…

二维码门楼牌管理系统应用场景:推动旅游与文化产业的智慧化升级

文章目录 前言一、二维码门楼牌管理系统在旅游领域的应用二、二维码门楼牌管理系统在文化产业的应用三、结语 前言 随着信息技术的不断发展&#xff0c;二维码门楼牌管理系统作为一种创新的信息化手段&#xff0c;正在逐渐渗透到旅游和文化领域。它通过为文化景点、旅游景点和…

ARM系统控制和管理接口System Control and Management Interface

本文档描述了一个可扩展的独立于操作系统的软件接口,用于执行各种系统控制和管理任务,包括电源和性能管理。 本文档描述了系统控制和管理接口(SCMI),它是一组操作系统无关的软件接口,用于系统管理。SCMI 是可扩展的,目前提供了以下接口: • 支持的接口的发现和自描述…

Java Map接口实现类之 HashMap

定义 public class HashMap<K,V> extends AbstractMap<K,V> implements Map<K,V>, Cloneable, Serializable{static final int DEFAULT_INITIAL_CAPACITY 1 << 4; //默认初始化容积&#xff0c;就是默认数组的长度为 16static final int MAXIMUM_CAP…

keep-alive 页面切换不触发onActivated和onDeactivated方法周期

<keep-alive :include"tagList"><component :is"Component" /></keep-alive>const tagList computed(() > {return $store.state.tagList })原因&#xff1a; store.state.app.tagList返回的是一个 Proxy&#xff0c; 代理了数组&am…

openxml获取xlsx的Excel.Validation

在 Open XML SDK 中&#xff0c;无法直接使用 Excel.Range 和 Excel.Validation&#xff0c;因为这是 VSTO (Visual Studio Tools for Office) 的概念&#xff0c;而不是 Open XML SDK 的概念。Open XML SDK 提供了对 Office Open XML (OOXML) 文件格式的低级访问&#xff0c;而…