【Python】科研代码学习:七 TrainingArguments,Trainer

【Python】科研代码学习:七 TrainingArguments,Trainer

  • TrainingArguments
    • 重要的方法
  • Trainer
    • 重要的方法
    • 使用 Trainer 的简单例子

TrainingArguments

  • HF官网API:Training
    众所周知,推理是一个大头,训练是另一个大头
    之前的很多内容,都是为训练这里做了一个小铺垫
    如何快速有效地调用代码,训练大模型,才是重中之重(不然学那么多HF库感觉怪吃苦的)
  • 首先看训练参数,再看训练器吧。
    首先,它的头文件是 transformers.TrainingArguments
    再看它源码的参数,我勒个去,太多了吧。
    ※ 我这里挑重要的讲解,全部请看API去。
  • output_dir (str)设置模型输出预测,或者中继点 (checkpoints) 的输出目录。模型训练到一半,肯定需要有中继点文件的嘛,就相当于游戏存档有很多一样,防止跑一半直接程序炸了,还要从头训练
  • overwrite_output_dir (bool, optional, defaults to False):把这个参数设置成 True,就会覆盖其中 output_dir 中的文档。一般在从中继点继续训练时需要这么用
  • do_train (bool, optional, defaults to False):指明我在做训练集的训练任务
  • do_eval (bool, optional):指明我在做验证集的评估任务
  • do_predict (bool, optional, defaults to False) :指明我在做测试集的预测任务
  • evaluation_strategy :评估策略:训练时不评估 / 每 eval_steps 步评估,或者每 epoch 评估
"no": No evaluation is done during training.
"steps": Evaluation is done (and logged) every eval_steps.
"epoch": Evaluation is done at the end of each epoch.
  • per_device_train_batch_size :训练时每张卡的batch大小,默认为8
    per_device_eval_batch_size :评估时每张卡的batch大小,默认为8
  • learning_rate (float, optional, defaults to 5e-5):学习率,里面使用的是 AdamW optimizer
    其他相应的 AdamW Optimizer 的参数还有:
    weight_decay adam_beta1adam_beta2adam_epsilon
  • num_train_epochs:训练的 epoch 个数,默认为3,可以设置小数。
  • lr_scheduler_type:具体作用要查看 transformers 里的 Scheduler 是干什么用的
  • warmup_ratiowarmup_steps :让一开始的学习率从0逐渐升到 learning_rate 用的
  • logging_dir :设置 logging 输出的文档
    除此之外还有一些和 logging相关的参数:
    logging_strategy ,logging_first_step ,logging_steps ,logging_nan_inf_filter 设置日志的策略
  • 与保存模型中继文件相关的参数:
    save_strategy :不保存中继文件 / 每 epoch 保存 / 每 save_steps 步保存
"no": No save is done during training.
"epoch": Save is done at the end of each epoch.
"steps": Save is done every save_steps.

save_steps :如果是整数,表示多少步保存一次;小数,则是按照总训练步,多少比例之后保存一次
save_total_limit :最多中继文件的保存上限,如果超过上限,会先把最旧的那个中继文件删了再保存新的
save_safetensors :使用 savetensor来存储和加载 tensors,默认为 True
push_to_hub :是否保存到 HF hub

  • use_cpu (bool, optional, defaults to False):是否用 cpu 训练
  • seed (int, optional, defaults to 42) :训练的种子,方便复现和可重复实验
  • data_seed :数据采样的种子
  • 数据精读相关的一些参数:
    FP32、TF32、FP16、BF16、FP8、FP4、NF4、INT8
    bf16 (bool, optional, defaults to False)fp16 (bool, optional, defaults to False)
    tf32 (bool, optional)
  • run_name :展示在 wandb and mlflow logging 中的描述
  • load_best_model_at_end (bool, optional, defaults to False):是否保存效果最好的中继点作为最终模型,与 save_total_limit 有些交互操作
    如果上述设置成 True 的话,考虑 metric_for_best_model ,即如何评估效果最好。默认为 loss 即损失最小
    如果你修改了 metric_for_best_model 的话,考虑 greater_is_better ,即指标越大越好还是越小越好
  • 一些加速相关的参数,貌似都比较麻烦
    fsdp
    fsdp_config
    deepspeed
    accelerator_config
  • optim :设置 optimizer,默认为 adamw_torch
    也可以设置成 adamw_hf, adamw_torch, adamw_torch_fused, adamw_apex_fused, adamw_anyprecision or adafactor.
  • resume_from_checkpoint :传入中继点文件的目录,从中继点继续训练

重要的方法

  • ※ 那我怎么访问或者修改上述参数呢?
    由于这个需要实例化,所以我们需要使用OO的方法修改
    下面讲一下其中重要的方法
  • set_dataloader:设置 dataloader
    在这里插入图片描述
from transformers import TrainingArgumentsargs = TrainingArguments("working_dir")
args = args.set_dataloader(train_batch_size=16, eval_batch_size=64)
args.per_device_train_batch_size
  • 设置 logging 相关的参数
    在这里插入图片描述
  • 设置 optimizer
    在这里插入图片描述
  • 设置保存策略
    在这里插入图片描述
  • 设置训练策略
    在这里插入图片描述
  • 设置评估策略
    在这里插入图片描述
  • 设置测试策略
    在这里插入图片描述

Trainer

  • 终于到大头了。Trainer 是主要用 pt 训练的,主要支持 GPUs (NVIDIA GPUs, AMD GPUs)/ TPUs
  • 看下源码,它要的东西不少,讲下重要参数:
  • model:要么是 transformers.PretrainedModel 类型的,要么是简单的 torch.nn.Module 类型的
  • argsTrainingArguments 类型的训练参数。如果不提供的话,默认使用 output_dir/tmp_trainer 里面的那个训练参数
  • data_collator DataCollator 类型参数,给训练集或验证集做数据分批和预处理用的,如果没有tokenizer默认使用 default_data_collator,否则默认使用 DataCollatorWithPadding (Will default to default_data_collator() if no tokenizer is provided, an instance of DataCollatorWithPadding otherwise.)
  • train_dataset (torch.utils.data.Dataset or torch.utils.data.IterableDataset, optional) :提供训练的数据集,当然也可以是 Datasets 类型的数据
  • eval_dataset :类似的验证集的数据集
  • tokenizer :提供 tokenizer 分词器
  • compute_metrics :验证集使用时候的计算指标,具体得参考 EvalPrediction 类型
  • optimizers :可以提供 Tuple(optimizer, scheduler)。默认使用 AdamW 以及 get_linear_schedule_with_warmup() controlled by args
    在这里插入图片描述

重要的方法

  • compute_loss:设置如何计算损失
    在这里插入图片描述
  • train:设置训练集训练任务,第一个参数可以设置是否从中继点开始训练
    在这里插入图片描述
  • evaluate:设置验证集评估任务,需要提供验证集
    在这里插入图片描述
  • predict:设置测试集任务
    在这里插入图片描述
  • save_model:保存模型参数到 output_dir在这里插入图片描述
  • training_step:设置每一个训练的 step,把一个batch的输入经过了何种操作,得到一个 torch.Tensor
    在这里插入图片描述

使用 Trainer 的简单例子

  • 主要就是加载一些参数,传进去即可
    模型、训练参数、训练集、验证集、计算指标
    调用训练方法 .train()
    最后保存模型即可 .save_model()
from transformers import (Trainer,)
trainer = Trainer(model=model,args=training_args,train_dataset=small_train_dataset,eval_dataset=small_eval_dataset,compute_metrics=compute_metrics,
)trainer.train()
trainer.save_model(outputdir="./xxx")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/736228.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

XSS-Labs靶场“11-13、15-16”关通关教程

君衍. 一、第十一关 referer参数注入二、第十二关 user-agent参数注入三、第十三关 cookie参数绕过四、第十五关 ng-include文件包含五、第十六关 回车代替空格 点击跳转: XSS-Labs靶场“1-5”关通关教程 XSS-Labs靶场“6-10”关通关教程 一、第十一关 referer参数…

[uni-app ] createAnimation锚点旋转 及 二次失效问题处理

记录一下: 锚点定位到左下角, 旋转动画 必须沿Z轴,转动 但是,此时会出现 后续动画在微信小程序失效问题 解决: 清空 this.animationData

201912青少年软件编程(Scratch)等级考试试卷(一级)

201912 青少年软件编程(Scratch)等级考试试卷(一级) 第1题:【 单选题】 关于造型和背景,下面说法不正确的是? A:造型编号从1开始 B:有四个背景,删除第二个背景,背景编…

11_Http

文章目录 HttpHttp协议网络模型Http协议的工作流程Http请求报文请求行请求方法请求资源协议版本 请求头空行请求体抓包软件:Fiddler Http响应报文响应行状态码 响应头响应体 请求完整的处理流程 Https 整体流程图: 前端:负责获取数据&#xf…

雷赛控制卡获取轴当前位置的值不正确问题处理

现像 从雷赛控制卡中获取当前轴位置值时发现轴在向零点的右边走时显示的值是负数。正常来就一般是要反馈正数的。一般轴零点右边是正方向,限位是正限位,反馈的位置也应该是正数。 如果雷赛软件中的【单轴参数】中的基本设置中的【脉冲模式】设置的是对的…

【C语言基础】:深入理解指针(终篇)

文章目录 深入理解指针一、函数指针变量4.1 函数指针变量的创建4.2 函数指针变量的使用4.3 typedef关键字 二、函数指针数组三、转移表四、回调函数4.1 什么是回调函数4.2 qsort使用举例4.2.1 使用qsort函数排序整形数据4.2.2 使用qsort排序结构数据4.2.3 qsort函数的模拟实现 …

elasticsearch 深度分页查询 Search_after(图文教程)

Search_after使用 一. 简介二. 不带PIT的search_after查询2.1 构造数据2.2 search_after分页查询2.2 问题 三. 带PIT的search_after查询3.1 构建第一次查询条件3.2 进行下一页查询3.3 删除PIT 四.参考文章 前言 这是我在这个网站整理的笔记,有错误的地方请指出,关注…

傅里叶变换pytorch使用

参考视频:1 傅里叶变换原理_哔哩哔哩_bilibili 傅里叶变换是干嘛的: 傅里叶得到低频、高频信息,针对低频、高频处理能够实现不同的目的。 傅里叶过程是可逆的,图像经过傅里叶变换、逆傅里叶变换后,能够恢复到原始图像…

【管理干部竞聘上岗】某星级酒店中层干部竞聘上岗管理咨询项目纪实

在这次项目合作中,我们的目的主要是设计一次公开、透明的竞聘活动,通过科学、公正的方法选拔出公司管理级岗位的最佳候选人。基于华恒智信的专业性,我们再次选择与其合作开展项目。在项目合作中,专家团队为我们进行了专业性的培训…

AIGC实战——GPT(Generative Pre-trained Transformer)

AIGC实战——GPT 0. 前言1. GPT 简介2. 葡萄酒评论数据集3. 注意力机制3.1 查询、键和值3.2 多头注意力3.3 因果掩码 4. Transformer4.1 Transformer 块4.2 位置编码 5. 训练GPT6. GPT 分析6.1 生成文本6.2 注意力分数 小结系列链接 0. 前言 注意力机制能够用于构建先进的文本…

【网络原理】TCP 协议中比较重要的一些特性(一)

目录 1、TCP 协议 2、确认应答 2.1、确认序号 3、超时重传 4、连接管理 4.1、建立连接(三次握手) 4.2、断开连接(四次挥手) 1、TCP 协议 TCP 是工作中最常用到的协议,也是面试中最常考的协议,具有面…

Electron程序如何在MacOS下获取相册访问权限

1.通过entitiment.plist,在electron-builder签名打包时,给app包打上签名。最后可以通过codesign命令进行验证。 TestPhotos.plist electron-builder配置文件中加上刚刚的plist文件。 通过codesign命令验证,若出现这个,则说明成…

Fortran语法介绍(三)

个人专栏—ABAQUS专栏 Abaqus2023的用法教程——与VS2022、oneAPI 2024子程序的关联方法 Abaqus2023的用法教程——与VS2022、oneAPI 2024子程序的关联方法Abaqus有限元分析——有限元网格划分基本原则 Abaqus有限元分析——有限元网格划分基本原则各向同性线弹性材料本构模型…

《手把手教你》系列技巧篇(二十七)-java+ selenium自动化测试- quit和close的区别(详解教程)

1.简介 尽管有的小伙伴或者童鞋们觉得很简单,不就是关闭退出浏览器,但是宏哥还是把两个方法的区别说一下,不然遇到坑后根本不会想到是这里的问题。 2.源码 本文介绍webdriver中关于浏览器退出操作。driver中有两个方法是关于浏览器关闭&…

SQL28 计算用户8月每天的练题数量

👨‍💻 大唐coding:个人主页 🎁 个人专栏: 《力扣高频刷题宝典》《SQL刷题记录》 ⛵ 既然选择远方,当不负青春,砥砺前行! 大家好,我是大唐,今天我们来做一道牛客题库SQL…

MySQL-----存储过程

▶ 介绍 存储过程是事先经过编译并存储在数据库中的一段SQL语句的集合,调用存储过程可以简化应用开发人员的很多工作,减少数据在数据库和应用服务器之间的传输,对于提高数据处理的效率是有好处的。 存储过程思想上很简单,…

C switch 语句

一个 switch 语句允许测试一个变量等于多个值时的情况。每个值称为一个 case,且被测试的变量会对每个 switch case 进行检查。 语法 C 语言中 switch 语句的语法: switch(expression){case constant-expression :statement(s);break; /* 可选的 */ca…

C语言中的UTF-8编码转换处理

C语言UTF-8编码的转换 1.C语言简介2.什么是UTF-8编码?2.1 UTF-8编码特点: 3.C语言中的UTF-8编码转换处理步骤1:获取UTF-8编码的字节流步骤2:解析UTF-8编码步骤3:Unicode码点转换为汉字 4.总结 1.C语言简介 C语言是一门…

【面试精讲】Java线程6种状态和工作原理详解,Java创建线程的4种方式

Java线程6种状态和工作原理详解,Java创建线程的4种方式 目录 一、Java线程的六种状态 二、Java线程是如何工作的? 三、BLOCKED 和 WAITING 的区别 四、start() 和 run() 源码分析 五、Java创建线程的所有方式和代码详解 1. 继承Thread类 2. 实现…

Node-RED在Linux二次开发网关中能源数据实时采集与优化

智能电网与分布式能源系统已成为推动绿色能源转型的重要载体。为了更好地应对多样化的能源供给与需求挑战,以及实现更高效的能源管理,Linux二次开发网关与Node-RED这一创新组合应运而生。 Linux二次开发网关作为高度定制化的硬件平台,其开源特…