前言
这玩意教程咋这么难找????
代码
参考:https://github.com/X-PLUG/mPLUG-Owl/blob/main/mPLUG-Owl2/mplug_owl2/train/mplug_owl2_trainer.py#L133C33-L133C33
先定义你的trainer,继承huggingface的trainer:
from transformers.trainer import (get_parameter_names,ALL_LAYERNORM_LAYERS
)
class MyTrainer(Trainer):xxxx
然后在MyTrainer中重写create_optimizer函数
class MyTrainer(Trainer):def create_optimizer(self):opt_model = self.modelif self.optimizer is None:decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)decay_parameters = set([name for name in decay_parameters if "bias" not in name])linear_parameters = set([name for name, _ in opt_model.named_parameters() if "mm_projector" in name])vision_parameters = set([name for name, _ in opt_model.named_parameters() if "video_tower" in name])llm_parameters = set([name for name, _ in opt_model.named_parameters() if "video_tower" not in name and "mm_projector" not in name])optimizer_grouped_parameters = [{"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in llm_parameters and p.requires_grad)],"weight_decay": self.args.weight_decay,"lr": self.args.learning_rate},{"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in linear_parameters and p.requires_grad)],"weight_decay": self.args.weight_decay,"lr": self.args.linear_lr},{"params": [p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in vision_parameters and p.requires_grad)],"weight_decay": self.args.weight_decay,"lr": self.args.vision_lr},{"params": [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)],"weight_decay": 0.0}]optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)else:raise ValueError("??? why here ???", self.optimizer)return self.optimizer