huggingface的self.state与self.control来源(TrainerState与TrainerControl)

文章目录

  • 前言
  • 一、huggingface的trainer的self.state与self.control初始化调用
  • 二、TrainerState源码解读(self.state)
    • 1、huggingface中self.state初始化参数
    • 2、TrainerState类的Demo
  • 三、TrainerControl源码解读(self.control)
  • 总结


前言

在 Hugging Face 中,self.state 和 self.control 这两个对象分别来源于 TrainerState 和 TrainerControl,它们提供了对训练过程中状态和控制流的访问和管理。通过这些对象,用户可以在训练过程中监视和调整模型的状态,以及控制一些重要的决策点。


一、huggingface的trainer的self.state与self.control初始化调用

trainer函数初始化调用代码如下:

# 定义Trainer对象
trainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,)

在Trainer()类的初始化的self.state与self.control初始化调用,其代码如下:

class Trainer:def __init__(self,model: Union[PreTrainedModel, nn.Module] = None,args: TrainingArguments = None,data_collator: Optional[DataCollator] = None,train_dataset: Optional[Dataset] = None,eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,tokenizer: Optional[PreTrainedTokenizerBase] = None,model_init: Optional[Callable[[], PreTrainedModel]] = None,compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,callbacks: Optional[List[TrainerCallback]] = None,optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,):...self.state = TrainerState(is_local_process_zero=self.is_local_process_zero(),is_world_process_zero=self.is_world_process_zero(),)self.control = TrainerControl()...

二、TrainerState源码解读(self.state)

1、huggingface中self.state初始化参数

这里多解读一点huggingface的self.state初始化调用参数方法,

 self.state = TrainerState(is_local_process_zero=self.is_local_process_zero(),is_world_process_zero=self.is_world_process_zero(),)

而TrainerState的内部参数由trainer的以下2个函数提供,可知道这里通过self.args.local_process_index与self.args.process_index的值来确定TrainerState方法的参数。

 def is_local_process_zero(self) -> bool:"""Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on severalmachines) main process.这个过程是否是本地主进程(例如,如果在多台机器上以分布式方式进行训练,则是在一台机器上)。"""return self.args.local_process_index == 0def is_world_process_zero(self) -> bool:"""Whether or not this process is the global main process (when training in a distributed fashion on severalmachines, this is only going to be `True` for one process).这个过程是否是全局主进程(在多台机器上以分布式方式进行训练时,只有一个进程会返回True)。"""# Special case for SageMaker ModelParallel since there process_index is dp_process_index, not the global# process index.if is_sagemaker_mp_enabled():return smp.rank() == 0else:return self.args.process_index == 0

self.args.local_process_index与self.args.process_index来源self.args

2、TrainerState类的Demo

介于研究state,我写了一个Demo来探讨使用方法,class TrainerState来源huggingface。该类实际就是一个存储变量的方式,变量包含epoch: Optional[float] = None, global_step: int = 0, max_steps: int = 0等内容,也进行了默认参数赋值,其Demo如下:

from dataclasses import dataclass
import dataclasses
import json
from typing import Dict, List, Optional, Union
@dataclass
class TrainerState:epoch: Optional[float] = Noneglobal_step: int = 0max_steps: int = 0num_train_epochs: int = 0total_flos: float = 0log_history: List[Dict[str, float]] = Nonebest_metric: Optional[float] = Nonebest_model_checkpoint: Optional[str] = Noneis_local_process_zero: bool = Trueis_world_process_zero: bool = Trueis_hyper_param_search: bool = Falsetrial_name: str = Nonetrial_params: Dict[str, Union[str, float, int, bool]] = Nonedef __post_init__(self):if self.log_history is None:self.log_history = []def save_to_json(self, json_path: str):"""Save the content of this instance in JSON format inside `json_path`."""json_string = json.dumps(dataclasses.asdict(self), indent=2, sort_keys=True) + "\n"with open(json_path, "w", encoding="utf-8") as f:f.write(json_string)@classmethoddef load_from_json(cls, json_path: str):"""Create an instance from the content of `json_path`."""with open(json_path, "r", encoding="utf-8") as f:text = f.read()return cls(**json.loads(text))if __name__ == '__main__':state = TrainerState()state.save_to_json('state.json')state_new = state.load_from_json('state.json')

我这里使用state = TrainerState()方法对TrainerState()类实例化,使用state.save_to_json('state.json')进行json文件保存(如下图),若修改里面参数,使用state_new = state.load_from_json('state.json')方式载入会得到新的state_new实例化。
在这里插入图片描述

三、TrainerControl源码解读(self.control)

该类实际就是一个存储变量的方式,变量包含 should_training_stop: bool = False, should_epoch_stop: bool = False, should_save: bool = False, should_evaluate: bool = False, should_log: bool = False内容,也进行了默认参数赋值,其源码如下:

@dataclass
class TrainerControl:"""A class that handles the [`Trainer`] control flow. This class is used by the [`TrainerCallback`] to activate someswitches in the training loop.Args:should_training_stop (`bool`, *optional*, defaults to `False`):Whether or not the training should be interrupted.If `True`, this variable will not be set back to `False`. The training will just stop.should_epoch_stop (`bool`, *optional*, defaults to `False`):Whether or not the current epoch should be interrupted.If `True`, this variable will be set back to `False` at the beginning of the next epoch.should_save (`bool`, *optional*, defaults to `False`):Whether or not the model should be saved at this step.If `True`, this variable will be set back to `False` at the beginning of the next step.should_evaluate (`bool`, *optional*, defaults to `False`):Whether or not the model should be evaluated at this step.If `True`, this variable will be set back to `False` at the beginning of the next step.should_log (`bool`, *optional*, defaults to `False`):Whether or not the logs should be reported at this step.If `True`, this variable will be set back to `False` at the beginning of the next step."""should_training_stop: bool = Falseshould_epoch_stop: bool = Falseshould_save: bool = Falseshould_evaluate: bool = Falseshould_log: bool = Falsedef _new_training(self):"""Internal method that resets the variable for a new training."""self.should_training_stop = Falsedef _new_epoch(self):"""Internal method that resets the variable for a new epoch."""self.should_epoch_stop = Falsedef _new_step(self):"""Internal method that resets the variable for a new step."""self.should_save = Falseself.should_evaluate = Falseself.should_log = False

总结

本文主要介绍huggingface的trainer中的self.control与self.state的来源。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/19092.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言实现十进制转任意进制(详解)

主要思路:运用一个数组,通过数字每次取任意进制模,存在数组中, 再通过倒取数组中的数值,来实现进制转换,如果遇到十六进制,利用ASCII码值 数字字符和大写字母 相差55的特性来解决 int main() {…

【芯片验证方法】

术语——中文术语 大陆与台湾的一些术语存在差别: 验证常用的英语术语: 验证:尽量模拟实际应用场景,比对芯片的所需要的目标功能和实现的功能 影响验证的要素:应用场景、目标功能、比对应用场景、目标功能&#xff…

内存泄露和内存溢出有什么不同

内存泄露和内存溢出是两个常见的内存问题,它们在程序运行过程中可能导致性能下降、系统不稳定甚至应用崩溃。虽然这两个术语经常被混用,但它们描述的是两种不同的问题: 内存泄露(Memory Leak) 内存泄露是指程序在分配…

华发股份:加强业务协同 新政下项目热销

“5.17”楼市政策出台后,各地密集落地执行。5月27—28日,上海、广州、深圳三个一线城市跟进落地“517”新政。上海发布《关于优化本市房地产市场平稳健康发展政策措施的通知》,共计9条调整政策,涵盖外地户籍、人才、单身、婚否、企…

一个生动的例子——通过ERC20接口访问Tether合约

生动的例子 USDT:符合ERC20标准的美元稳定币,Tether合约获得测试网上Tether合约地址通过自己写的ERC20接口访问这个合约 Tether合约地址:0xdAC17F958D2ee523a2206206994597C13D831ec7 IERC20.sol // SPDX-License-Identifier: GPL-3.0pra…

今日分享站

同志们,字符函数和字符串函数已经全部学习完啦,笔记也已经上传完毕,大家可以去看啦。字符函数和字符串函数and模拟函数 加油!!!!!

Unix环境高级编程--7-进程环境--7.1-7.2main函数-7.3进程退出

1、几个问题 ①main函数如何被调用? ②命令行参数如何传递给新程序?; ③典型储存空间布局是什么样的?; ④进程如何使用环境变量 ?; ⑤进程的各种终止方式? 2、main函数 当内核…

列表推导式(解析式)python

Python中的列表推导式(list comprehension)是一种简洁且强大的语法,用于创建新的列表。它允许你通过对现有列表中的元素进行操作或筛选来快速生成新列表。以下是列表推导式的基本语法和一些示例: 基本语法: new_list…

vue3的组件通信v-model使用

一、组件通信 1.props 》 父向子传值 props 主要用于父组件向子组件通信。再父组件中通过使用:msgmsg绑定需要传给子组件的属性值&#xff0c;然后再在子组件中用props接收该属性值 方法一 普通方式:// 父组件 传值<child :msg1"msg1" :list"list">…

Dinky MySQLCDC 整库同步到 Doris

资源&#xff1a;flink 1.17.0、dinky 1.0.2、doris-2.0.1-rc04 问题&#xff1a;Cannot deserialize value of type int from String &#xff0c;detailMessageunknowndatabases &#xff0c;not a valid int value 2024-05-29 16:52:20.136 ERROR org.apache.doris.flink.…

最长公共子序列问题的求解

假设有两个字符串A和B&#xff0c;A字符串的组成为 A A 0 A 1 A 2 . . . . . . A n − 1 A A_0A_1A_2......A_{n-1} AA0​A1​A2​......An−1​ B B 0 B 1 B 2 . . . . . . B m − 1 BB_0B_1B_2......B_{m-1} BB0​B1​B2​......Bm−1​ 要寻找这两个字符串的公共子序列还…

MS Excel: 高亮当前行列 - 保持原有格式不被改变

本文使用条件格式VBA的方法实现高亮当前行列&#xff0c;因为纯VBA似乎会清除原有的高亮格式。效果如下&#xff1a;本文图省事就使用同一种颜色了。 首先最重要的&#xff0c;【选中你期望高亮的单元格区域】&#xff0c;比如可以全选当前sheet的全部区域 然后点击【开始】-【…

06.深入学习Java 线程

1 线程的状态/生命周期 Java 的 Thread 类对线程状态进行了枚举&#xff1a; public class Thread implements Runnable {public enum State {NEW,RUNNABLE,BLOCKED,WAITING,TIMED_WAITING,TERMINATED;} } 初始(NEW)&#xff1a;新创建了一个线程对象&#xff0c;但还没有调用…

数据库学习笔记1-数据库实验1

文章目录 创建表格的时候出现的一些错误查询所有的表格实验一查询单个表格分块修改大学数据库表格创建大学数据库表格系课程教师课程段授课学生选课注意吐槽 修改大学数据库表格2&#xff08;英文版本&#xff09;abcde 自建项目-在线书店数据库 创建表格的时候出现的一些错误 …

子集和问题(回溯法)

目录 ​​​​ 前言 一、算法思路 二、分析过程 三、代码实现 伪代码&#xff1a; C&#xff1a; 总结 前言 【问题描述】考虑定义如下的PARTITION问题中的一个变型。给定一个n个整数的集合X{x1,x2,…,xn}和整数y&#xff0c;找出和等于y的X的子集Y。 一、算法思路 基本思想&am…

【STL】C++ stack(栈) 基本使用

目录 一 stack常见构造 1 空容器构造函数&#xff08;默认构造函数&#xff09; 2. 使用指定容器构造 3 拷贝构造函数 二 其他操作 1 empty 2 size 3 top 4 push && pop 5 emplace 6 swap 三 总结 一 stack常见构造 1 空容器构造函数&#xff08;默认构造…

云计算OpenStack基础

1.什么是虚拟化&#xff1f; •虚拟化是云计算的基础。 •虚拟化是指计算元件在虚拟的而不是真实的硬件基础上运行。 •虚拟化将物理资源转变为具有可管理性的逻辑资源&#xff0c;以消除物理结构之间的隔离&#xff0c;将物理资源融为一个整体。虚拟化是一种简化管理和优化…

探秘AI艺术:揭开Midjourney绘画的神秘面纱

在当今这个数字化迅速发展的时代&#xff0c;AI技术已经深入到我们生活的方方面面&#xff0c;而最令人着迷的莫过于它在艺术创作领域的应用。“Midjourney绘画”就是这样一个令人惊叹的例子&#xff0c;它通过高级AI技术&#xff0c;能够帮助用户生成独一无二的艺术作品。但是…

如何知道自己电脑的 Shell类型是什么?

在macOS中&#xff0c;你可以通过以下几种方法来确定当前正在使用的shell类型&#xff0c;并了解相关的配置文件&#xff1a; 1. 使用终端命令确定shell类型 打开终端应用程序&#xff08;Terminal&#xff09;。输入以下命令并按回车键&#xff1a;echo $SHELL。该命令会输出…

最长递增子序列,交错字符串

第一题&#xff1a; 代码如下&#xff1a; int lengthOfLIS(vector<int>& nums) {//dp[i]表示以第i个元素为结尾的最长子序列的长度int n nums.size();int res 1;vector<int> dp(n, 1);for (int i 1; i < n; i){for (int j 0; j < i; j){if (nums[i]…