from transformers.modeling_utils import PreTrainedModel
是用于导入 Hugging Face Transformers 库中的 PreTrainedModel
类。这个类是所有预训练模型的基类,提供了许多通用功能和方法,适用于不同类型的模型(如BERT、GPT、Transformer-XL等)。下面是导入这个包的一些具体用途和功能:
主要功能和用途
-
通用功能:
- 加载和保存预训练模型:
PreTrainedModel
提供了from_pretrained()
和save_pretrained()
方法,可以方便地加载和保存预训练模型。 - 配置管理:管理模型的配置文件,确保模型初始化时使用正确的参数。
- 加载和保存预训练模型:
-
模型初始化:
- 权重初始化:帮助初始化模型的权重,并处理不同权重初始化策略。
- 模型架构定义:定义和初始化模型的架构,使得子类只需专注于具体模型的实现。
-
模型转换:
- 框架转换:支持将模型转换为不同框架(如 PyTorch 和 TensorFlow),使得模型可以在不同的深度学习框架之间无缝切换。
-
检查点管理:
- 断点续训:支持保存和加载模型的训练断点,方便训练过程的中断和恢复。
-
from transformers import BertModel, BertConfig# 初始化模型配置 config = BertConfig()# 从预训练模型加载 BERT model = BertModel.from_pretrained('bert-base-uncased')# 打印模型架构 print(model)# 保存模型 model.save_pretrained('./saved_model')# 加载模型 loaded_model = BertModel.from_pretrained('./saved_model')
-
继承模型并进行修改
from transformers import PreTrainedModel, BertConfig import torch.nn as nnclass MyCustomModel(PreTrainedModel):def __init__(self, config):super().__init__(config)self.bert = BertModel(config)self.classifier = nn.Linear(config.hidden_size, 2) # 假设二分类任务def forward(self, input_ids, attention_mask=None):outputs = self.bert(input_ids, attention_mask=attention_mask)logits = self.classifier(outputs.pooler_output)return logits# 初始化自定义模型 config = BertConfig() model = MyCustomModel(config)# 打印模型架构 print(model)