huggingface的self.state与self.control来源(TrainerState与TrainerControl)

文章目录

  • 前言
  • 一、huggingface的trainer的self.state与self.control初始化调用
  • 二、TrainerState源码解读(self.state)
    • 1、huggingface中self.state初始化参数
    • 2、TrainerState类的Demo
  • 三、TrainerControl源码解读(self.control)
  • 总结


前言

在 Hugging Face 中,self.state 和 self.control 这两个对象分别来源于 TrainerState 和 TrainerControl,它们提供了对训练过程中状态和控制流的访问和管理。通过这些对象,用户可以在训练过程中监视和调整模型的状态,以及控制一些重要的决策点。


一、huggingface的trainer的self.state与self.control初始化调用

trainer函数初始化调用代码如下:

# 定义Trainer对象
trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,)

在Trainer()类的初始化的self.state与self.control初始化调用,其代码如下:

class Trainer:def __init__(self,model: Union[PreTrainedModel, nn.Module] = None,args: TrainingArguments = None,data_collator: Optional[DataCollator] = None,train_dataset: Optional[Dataset] = None,eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,tokenizer: Optional[PreTrainedTokenizerBase] = None,model_init: Optional[Callable[[], PreTrainedModel]] = None,compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,callbacks: Optional[List[TrainerCallback]] = None,optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,):...self.state = TrainerState(is_local_process_zero=self.is_local_process_zero(),is_world_process_zero=self.is_world_process_zero(),)self.control = TrainerControl()...

二、TrainerState源码解读(self.state)

1、huggingface中self.state初始化参数

这里多解读一点huggingface的self.state初始化调用参数方法,

 self.state = TrainerState(is_local_process_zero=self.is_local_process_zero(),is_world_process_zero=self.is_world_process_zero(),)

而TrainerState的内部参数由trainer的以下2个函数提供,可知道这里通过self.args.local_process_index与self.args.process_index的值来确定TrainerState方法的参数。

 def is_local_process_zero(self) -> bool:"""Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on severalmachines) main process.这个过程是否是本地主进程(例如,如果在多台机器上以分布式方式进行训练,则是在一台机器上)。"""return self.args.local_process_index == 0def is_world_process_zero(self) -> bool:"""Whether or not this process is the global main process (when training in a distributed fashion on severalmachines, this is only going to be `True` for one process).这个过程是否是全局主进程(在多台机器上以分布式方式进行训练时,只有一个进程会返回True)。"""# Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global# process index.if is_sagemaker_mp_enabled():return smp.rank() == 0else:return self.args.process_index == 0

self.args.local_process_index与self.args.process_index来源self.args

2、TrainerState类的Demo

介于研究state,我写了一个Demo来探讨使用方法,class TrainerState来源huggingface。该类实际就是一个存储变量的方式,变量包含epoch: Optional[float] = None, global_step: int = 0, max_steps: int = 0等内容,也进行了默认参数赋值,其Demo如下:

from dataclasses import dataclass
import dataclasses
import json
from typing import Dict, List, Optional, Union
@dataclass
class TrainerState:epoch: Optional[float] = Noneglobal_step: int = 0max_steps: int = 0num_train_epochs: int = 0total_flos: float = 0log_history: List[Dict[str, float]] = Nonebest_metric: Optional[float] = Nonebest_model_checkpoint: Optional[str] = Noneis_local_process_zero: bool = Trueis_world_process_zero: bool = Trueis_hyper_param_search: bool = Falsetrial_name: str = Nonetrial_params: Dict[str, Union[str, float, int, bool]] = Nonedef __post_init__(self):if self.log_history is None:self.log_history = []def save_to_json(self, json_path: str):"""Save the content of this instance in JSON format inside `json_path`."""json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"with open(json_path, "w", encoding="utf-8") as f:f.write(json_string)@classmethoddef load_from_json(cls, json_path: str):"""Create an instance from the content of `json_path`."""with open(json_path, "r", encoding="utf-8") as f:text = f.read()return cls(**json.loads(text))if __name__ == '__main__':state = TrainerState()state.save_to_json('state.json')state_new = state.load_from_json('state.json')

我这里使用state = TrainerState()方法对TrainerState()类实例化,使用state.save_to_json('state.json')进行json文件保存(如下图),若修改里面参数,使用state_new = state.load_from_json('state.json')方式载入会得到新的state_new实例化。
在这里插入图片描述

三、TrainerControl源码解读(self.control)

该类实际就是一个存储变量的方式,变量包含 should_training_stop: bool = False, should_epoch_stop: bool = False, should_save: bool = False, should_evaluate: bool = False, should_log: bool = False内容,也进行了默认参数赋值,其源码如下:

@dataclass
class TrainerControl:"""A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate someswitches in the training loop.Args:should_training_stop (`bool`, *optional*, defaults to `False`):Whether or not the training should be interrupted.If `True`, this variable will not be set back to `False`. The training will just stop.should_epoch_stop (`bool`, *optional*, defaults to `False`):Whether or not the current epoch should be interrupted.If `True`, this variable will be set back to `False` at the beginning of the next epoch.should_save (`bool`, *optional*, defaults to `False`):Whether or not the model should be saved at this step.If `True`, this variable will be set back to `False` at the beginning of the next step.should_evaluate (`bool`, *optional*, defaults to `False`):Whether or not the model should be evaluated at this step.If `True`, this variable will be set back to `False` at the beginning of the next step.should_log (`bool`, *optional*, defaults to `False`):Whether or not the logs should be reported at this step.If `True`, this variable will be set back to `False` at the beginning of the next step."""should_training_stop: bool = Falseshould_epoch_stop: bool = Falseshould_save: bool = Falseshould_evaluate: bool = Falseshould_log: bool = Falsedef _new_training(self):"""Internal method that resets the variable for a new training."""self.should_training_stop = Falsedef _new_epoch(self):"""Internal method that resets the variable for a new epoch."""self.should_epoch_stop = Falsedef _new_step(self):"""Internal method that resets the variable for a new step."""self.should_save = Falseself.should_evaluate = Falseself.should_log = False

总结

本文主要介绍huggingface的trainer中的self.control与self.state的来源。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/19092.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【芯片验证方法】

术语——中文术语 大陆与台湾的一些术语存在差别: 验证常用的英语术语: 验证:尽量模拟实际应用场景,比对芯片的所需要的目标功能和实现的功能 影响验证的要素:应用场景、目标功能、比对应用场景、目标功能&#xff…

华发股份:加强业务协同 新政下项目热销

“5.17”楼市政策出台后,各地密集落地执行。5月27—28日,上海、广州、深圳三个一线城市跟进落地“517”新政。上海发布《关于优化本市房地产市场平稳健康发展政策措施的通知》,共计9条调整政策,涵盖外地户籍、人才、单身、婚否、企…

一个生动的例子——通过ERC20接口访问Tether合约

生动的例子 USDT:符合ERC20标准的美元稳定币,Tether合约获得测试网上Tether合约地址通过自己写的ERC20接口访问这个合约 Tether合约地址:0xdAC17F958D2ee523a2206206994597C13D831ec7 IERC20.sol // SPDX-License-Identifier: GPL-3.0pra…

今日分享站

同志们,字符函数和字符串函数已经全部学习完啦,笔记也已经上传完毕,大家可以去看啦。字符函数和字符串函数and模拟函数 加油!!!!!

Dinky MySQLCDC 整库同步到 Doris

资源:flink 1.17.0、dinky 1.0.2、doris-2.0.1-rc04 问题:Cannot deserialize value of type int from String ,detailMessageunknowndatabases ,not a valid int value 2024-05-29 16:52:20.136 ERROR org.apache.doris.flink.…

MS Excel: 高亮当前行列 - 保持原有格式不被改变

本文使用条件格式VBA的方法实现高亮当前行列,因为纯VBA似乎会清除原有的高亮格式。效果如下:本文图省事就使用同一种颜色了。 首先最重要的,【选中你期望高亮的单元格区域】,比如可以全选当前sheet的全部区域 然后点击【开始】-【…

06.深入学习Java 线程

1 线程的状态/生命周期 Java 的 Thread 类对线程状态进行了枚举: public class Thread implements Runnable {public enum State {NEW,RUNNABLE,BLOCKED,WAITING,TIMED_WAITING,TERMINATED;} } 初始(NEW):新创建了一个线程对象,但还没有调用…

数据库学习笔记1-数据库实验1

文章目录 创建表格的时候出现的一些错误查询所有的表格实验一查询单个表格分块修改大学数据库表格创建大学数据库表格系课程教师课程段授课学生选课注意吐槽 修改大学数据库表格2(英文版本)abcde 自建项目-在线书店数据库 创建表格的时候出现的一些错误 …

子集和问题(回溯法)

目录 ​​​​ 前言 一、算法思路 二、分析过程 三、代码实现 伪代码: C: 总结 前言 【问题描述】考虑定义如下的PARTITION问题中的一个变型。给定一个n个整数的集合X{x1,x2,…,xn}和整数y,找出和等于y的X的子集Y。 一、算法思路 基本思想&am…

【STL】C++ stack(栈) 基本使用

目录 一 stack常见构造 1 空容器构造函数(默认构造函数) 2. 使用指定容器构造 3 拷贝构造函数 二 其他操作 1 empty 2 size 3 top 4 push && pop 5 emplace 6 swap 三 总结 一 stack常见构造 1 空容器构造函数(默认构造…

云计算OpenStack基础

1.什么是虚拟化? •虚拟化是云计算的基础。 •虚拟化是指计算元件在虚拟的而不是真实的硬件基础上运行。 •虚拟化将物理资源转变为具有可管理性的逻辑资源,以消除物理结构之间的隔离,将物理资源融为一个整体。虚拟化是一种简化管理和优化…

最长递增子序列,交错字符串

第一题&#xff1a; 代码如下&#xff1a; int lengthOfLIS(vector<int>& nums) {//dp[i]表示以第i个元素为结尾的最长子序列的长度int n nums.size();int res 1;vector<int> dp(n, 1);for (int i 1; i < n; i){for (int j 0; j < i; j){if (nums[i]…

Spring-注解

Spring 注解分类 Spring 注解驱动模型 Spring 元注解 Documented Retention() Target() // 可以继承相关的属性 Inherited Repeatable()Spirng 模式注解 ComponentScan 原理 ClassPathScanningCandidateComponentProvider#findCandidateComponents public Set<BeanDefin…

动态规划part03 Day43

LC343整数拆分&#xff08;未掌握&#xff09; 未掌握分析&#xff1a;dp数组的含义没有想清楚&#xff0c;dp[i]表示分解i能够达到的最大乘积&#xff0c;i能够如何分解呢&#xff0c;从1开始遍历&#xff0c;直到i-1&#xff1b;每次要不是j和i-j两个数&#xff0c;要不是j和…

【传知代码】自监督高效图像去噪(论文复现)

前言&#xff1a;在数字化时代&#xff0c;图像已成为我们生活、工作和学习的重要组成部分。然而&#xff0c;随着图像获取方式的多样化&#xff0c;图像质量问题也逐渐凸显出来。噪声&#xff0c;作为影响图像质量的关键因素之一&#xff0c;不仅会降低图像的视觉效果&#xf…

串口通信问题排查总结

串口通信问题排查 排查原则&#xff1a; 软件从发送处理到接收处理&#xff0c;核查驱动、控制及发送接收数据是否正常。硬件从发送到接收&#xff0c;针对信号经过的各段&#xff0c;分段核对信号是否正常。示波器、逻辑分析仪。用万用表、示波器、逻辑分析仪等工具&#xf…

JRT性能演示

演示视频 君生我未生&#xff0c;我生君已老&#xff0c;这里是java信创频道JRT&#xff0c;真信创-不糊弄。 基础架构决定上层建筑&#xff0c;和给有些品种的植物种植一样&#xff0c;品种不对&#xff0c;施肥浇水再多&#xff0c;也是不可能长成参天大树的。JRT吸收了各方…

基于文本来推荐相似酒店

基于文本来推荐相似酒店 查看数据集基本信息 import pandas as pd import numpy as np from nltk.corpus import stopwords from sklearn.metrics.pairwise import linear_kernel from sklearn.feature_extraction.text import CountVectorizer from sklearn.feature_extrac…

C++:类和对象

一、前言 C是面向对象的语言&#xff0c;本文将通过上、中、下三大部分&#xff0c;带你深入了解类与对象。 目录 一、前言 二、部分&#xff1a;上 1.面向过程和面向对象初步认识 2.类的引入 3.类的定义 4.类的访问限定符及封装 5.类的作用域 6.类的实例化 7.类的…

【busybox记录】【shell指令】readlink

目录 内容来源&#xff1a; 【GUN】【readlink】指令介绍 【busybox】【readlink】指令介绍 【linux】【readlink】指令介绍 使用示例&#xff1a; 打印符号链接或规范文件名的值 - 默认输出 打印符号链接或规范文件名的值 - 打印规范文件的全路径 打印符号链接或规范文…