SwissArmyTransformer瑞士军刀工具箱使用手册

Introduction sat(SwissArmyTransformer)是一个灵活而强大的库,用于开发您自己的Transformer变体。
sat是以“瑞士军刀”命名的,这意味着所有型号(例如BERT、GPT、T5、GLM、CogView、ViT…)共享相同的backone代码,并通过一些超轻量级的mixin满足多种用途。
sat由deepspeed ZeRO和模型并行性提供支持,旨在为大模型(100M\~20B参数)的预训练和微调提供最佳实践。

从 SwissArmyTransformer 0.2.x 迁移到 0.3.x

  1. 导入时将包名称从 SwissArmyTransformer 更改为 sat,例如从 sat 导入 get_args。
  2. 删除脚本中的所有--sandwich-ln,使用layernorm-order='sandwich'。
  3. 更改顺序 from_pretrained(args, name) => from_pretrained(name, args)。
  4. 我们可以直接使用 from sat.model import AutoModel;model, args = AutoModel.from_pretrained('roberta-base') 以 仅模型模式 加载模型,而不是先初始化 sat。

安装

pip install SwissArmyTransformer

特征

添加与模型无关的组件,例如前缀调整,只需一行!

前缀调整(或 P 调整)通过在每个注意力层中添加可训练参数来改进微调。使用我们的库可以轻松地将其应用于 GLM 分类(或任何其他)模型。

class ClassificationModel(GLMModel): # can also be BertModel, RobertaModel, etc. def __init__(self, args, transformer=None, **kwargs):super().__init__(args, transformer=transformer, **kwargs)self.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))# Arm an arbitrary model with Prefix-tuning with this line!self.add_mixin('prefix-tuning', PrefixTuningMixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.prefix_len))

GPT 和其他自回归模型在训练和推理过程中的行为有所不同。在推理过程中,文本是逐个令牌生成的,我们需要缓存以前的状态以提高效率。使用我们的库,您只需要考虑训练期间的行为(教师强制),并通过添加 mixin 将其转换为缓存的自回归模型:

model, args = AutoModel.from_pretrained('glm-10b-chinese', args)
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
# Generate a sequence with beam search
from sat.generation.autoregressive_sampling import filling_sequence
from sat.generation.sampling_strategies import BeamSearchStrategy
output, *mems = filling_sequence(model, input_seq,batch_size=args.batch_size,strategy=BeamSearchStrategy(args.batch_size))

使用最少的代码构建基于 Transformer 的模型。我们提到了 GLM,它与标准转换器(称为 BaseModel)仅在位置嵌入(和训练损失)上有所不同。我们在编码的时候只需要关注相关的部分就可以了。

扩展整个定义:

class BlockPositionEmbeddingMixin(BaseMixin):# Here define parameters for the mixindef __init__(self, max_sequence_length, hidden_size, init_method_std=0.02):super(BlockPositionEmbeddingMixin, self).__init__()self.max_sequence_length = max_sequence_lengthself.hidden_size = hidden_sizeself.block_position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)torch.nn.init.normal_(self.block_position_embeddings.weight, mean=0.0, std=init_method_std)# Here define the method for the mixindef position_embedding_forward(self, position_ids, **kwargs):position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1]position_embeddings = self.transformer.position_embeddings(position_ids)block_position_embeddings = self.block_position_embeddings(block_position_ids)return position_embeddings + block_position_embeddingsclass GLMModel(BaseModel):def __init__(self, args, transformer=None, parallel_output=True):super().__init__(args, transformer=transformer, parallel_output=parallel_output)self.add_mixin('block_position_embedding', BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size)) # Add the mixin for GLM

全方位的培训支持。 sat 旨在提供预训练和微调的最佳实践,您只需要完成forward_step 和 create_dataset_function,但可以使用超参数来更改有用的训练配置。
通过指定 --num_nodes、--num_gpus 和一个简单的主机文件,将训练扩展到多个 GPU 或节点。
DeepSpeed 和模型并行性。
ZeRO-2 和激活检查点的更好集成。
自动扩展和改组训练数据和内存映射。
成功支持CogView2和CogVideo的训练。
目前唯一支持在 GPU 上微调 T5-10B 的开源代码库。

快速浏览

在 sat 中使用 Bert(用于推理)的最典型的 python 文件如下:

# @File: inference_bert.py
from sat import get_args, get_tokenizer, AutoModel
# Parse args, initialize the environment. This is necessary.
args = get_args() 
# Automatically download and load model. Will also dump model-related hyperparameters to args.
model, args = AutoModel.from_pretrained('bert-base-uncased', args) 
# Get the BertTokenizer according to args.tokenizer_type (automatically set).
tokenizer = get_tokenizer(args) 
# Here to use bert as you want!
# ...

然后我们可以通过以下方式运行代码

SAT_HOME=/path/to/download python inference_bert.py --mode inference

所有官方支持的模型名称都在 urls.py 中。

# @File: finetune_bert.py
from sat import get_args, get_tokenizer, AutoModel
from sat.model.mixins import MLPHeadMixindef create_dataset_function(path, args):# Here to load the dataset# ...assert isinstance(dataset, torch.utils.data.Dataset)return datasetdef forward_step(data_iterator, model, args, timers):inputs = next(data_iterator) # from the dataset of create_dataset_function.loss, *others = model(inputs)return loss# Parse args, initialize the environment. This is necessary.
args = get_args() 
model, args = AutoModel.from_pretrained('bert-base-uncased', args) 
tokenizer = get_tokenizer(args) 
# Here to use bert as you want!
model.del_mixin('bert-final')
model.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))
# ONE LINE to train! 
# args already includes hyperparams such as lr, train-iters, zero-stage ...
training_main(args, model_cls=model, forward_step_function=forward_step, # user definecreate_dataset_function=create_dataset_function # user define
)

然后我们可以通过以下方式运行代码

deepspeed --include localhost:0,1 finetune_bert.py \--experiment-name ftbert \--mode finetune --train-iters 1000 --save /path/to/save \--train-data /path/to/train --valid-data /path/to/valid \--lr 0.00002 --batch-size 8 --zero-stage 1 --fp16

这里我们在 GPU 0,1 上使用数据并行。我们还可以通过 --hostfile/path/to/hostfile 在许多互连的机器上启动训练。请参阅教程了解更多详细信息。
要编写自己的模型,您只需要考虑与标准 Transformer 的差异。例如,如果你有一个改进注意力操作的想法:

from sat.model import BaseMixin
class MyAttention(BaseMixin):def __init__(self, hidden_size):super(MyAttention, self).__init__()# MyAttention may needs some new params, e.g. a learnable alpha.self.learnable_alpha = torch.nn.Parameter(torch.ones(hidden_size))# This is a hook function, the name `attention_fn` is special.def attention_fn(q, k, v, mask, dropout=None, **kwargs):# Code for my attention.# ...return attention_results

这里的attention_fn是一个钩子函数,用新函数替换默认动作。所有可用的钩子都在transformer_defaults.py中。现在我们可以使用 add_mixin 将更改应用到所有转换器,例如 BERT、Vit 和 CogView。请参阅教程了解更多详细信息。

教程

  • How to use pretrained models collected in sat?
  • Why and how to train models in sat?

Citation

Currently we don't have a paper, so you don't need to formally cite us!~

If this project helps your research or engineering, use \footnote{https://github.com/THUDM/SwissArmyTransformer} to mention us and recommend SwissArmyTransformer to others.

The tutorial for contributing sat is on the way!

The project is based on (a user of) DeepSpeed, Megatron-LM and Huggingface transformers. Thanks for their awesome work.

训练指导

The Training API

我们提供了一个简单但功能强大的训练APItraining_main(),它不仅限于我们的Transformer模型,还适用于任何torch.nn.Module

from sat import get_args, training_main
from sat.model import AutoModel, BaseModel
args = get_args()
# to pretrain from scratch, give a class obj
model = BaseModel
# to finetuned from a given model, give a torch.nn.Module
model = AutoModel.from_pretrained('bert-base-uncased', args)training_main(args, model_cls=model,forward_step_function=forward_step,create_dataset_function=dataset_func,handle_metrics_function=None,init_function=None
)

以上是使用 sat 的标准训练计划的(不完整)示例。 Training_main 接受 5 个参数:(必需)model_cls:继承 torch.nn.Module 的类型对象,或我们训练的 torch.nn.Module 对象。
(必需)forward_step_function:一个自定义函数,输入 data_iterator、model、args、timers、returns loss、{'metric0': m0, ...}。
(必填)create_dataset_function:返回一个torch.utils.data.Dataset用于加载。我们的库会自动将数据分配给多个worker,并将数据迭代器交给forward_step_function。
(可选)handle_metrics_function:在评估过程中处理特殊指标。
(可选)init_function:在训练之前更改模型的钩子,对于继续训练很有用。
有关完整示例,请参阅 Finetune BERT 示例。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/136968.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

[git] cherry pick 将某个分支的某次提交应用到当前分支

功能:将某个分支的某次提交应用到当前分支 应用场景: 在合并分支时,是将源分支的所有内容都合并到目标分支上,有的时候我们可能只需要合并源分支的某次或某几次的提交,这个时候我们就需要使用到git的cherry-pick操作…

边缘计算多角色智能计量插座:用电监测和资产管理的未来智能化引擎

目前主流的智能插座涵盖了红外遥控(控制空调和电视等带有红外标准的电器),配备着测温、测湿等仓库应用场景,配备了人体红外或者毫米波雷达作为联动控制,但是大家有没有思考一个问题,就是随着对接的深入&…

易点易动固定资产管理系统:实现全生命周期闭环式管理和快速盘点

固定资产管理对于企业来说至关重要,它涉及到资产的采购、领用、使用、维护和报废等各个环节。然而,传统的固定资产管理方式往往繁琐、耗时,容易导致信息不准确和资源浪费。为了解决这些问题,我们引入易点易动固定资产管理系统&…

PHP+Swoole应用示例

**Swoole是一个C编写的基于异步事件驱动和协程的并行网络通信引擎,为PHP提供高性能网络编程支持** ## ⚙️ 快速启动 可以直接使用 [Docker](https://github.com/swoole/docker-swoole) 来执行Swoole的代码,例如: bash docker run --rm php…

DevOps简介

DevOps简介 1、DevOps的起源2、什么是DevOps3、DevOps的发展现状4、DevOps与虚拟化、容器 1、DevOps的起源 上个世纪40年代,世界上第一台计算机诞生。计算机离不开程序(Program)驱动,而负责编写程序的人,被称为程序员&…

Kotlin基础数据类型和运算符

原文链接 Kotlin Types and Operators Kotlin是新一代的基于JVM的静态多范式编程语言,功能强大,语法简洁,前面已经做过Kotlin的基本的介绍,今天就来深入的学习一下它的数据类型和运算操作符。 数据类型 与大部分语言不同的是&am…

socket编程中的EINTR是什么?

socket编程中的EINTR是什么? 在socket编程中&#xff0c;我们时常在accept/read/write等接口调用的异常处理的部分看到对于EINTR的处理&#xff0c;例如下面这样的语句&#xff1a; repeat: if(read(fd, buff, size) < 0) {if(errno EINTR)goto repeat;elseprintf("…

三菱FX3U系列-定位指令

目录 一、简介 二、指令形式 1、相对定位[DRVI、DDRVI] 2、绝对定位[DRVA、DDRVA] 三、总结 一、简介 定位指令用于控制伺服电机或步进电机的位置移动。可以通过改变脉冲频率和脉冲数量来控制电机的移动速度和移动距离&#xff0c;同时还可以指定移动的方向。 二、指令形…

Linux下的环境变量【详解】

Linux下的环境变量 一&#xff0c;环境变量的概念1 概述2 环境变量的分类3 常见的环境变量4 查看环境变量4.1 shell变量4.2 查看环境变量 5 添加和删除环境变量5.1 添加环境变量5.2 删除环境变量 6. 通过代码如何获取环境变量6.1 命令行的第三个参数6.2 通过第三方变量environ获…

Linux 的热插拔机制通过 Udev(用户空间设备)实现、守护进程

一、Udev作用概述 udev机制简介udev工作流程图 二、Linux的热拔插UDEV机制 三、守护进程 守护进程概念守护进程在后台运行基本特点 四、守护进程和后台进程的区别 一、Udev作用概述 udev机制简介 Udev&#xff08;用户空间设备&#xff09;是一个 Linux 系统中用于动态管…

如何防止IP和账户关联?

在当今信息时代&#xff0c;个人隐私安全变得尤为重要。保护个人IP地址和账户的隐私是防止隐私泄露、信息泄漏以及支付安全等问题的关键。VMLogin虚拟浏览器作为一种隐私工具&#xff0c;可以帮助您解决问题。本文将为您介绍如何使用它来保护隐私安全和防止IP和账户关联。 一、…

微信号绑定50个开发者小程序以后超额如何删除不用的

我们在开发微信小程序的时候&#xff0c;当前开发者工具登录的必须是该小程序的开发者才能进行小程序的开发&#xff0c;添加开发者的步骤是&#xff1a; 添加开发者 1、进入微信开放平台&#xff0c;然后扫码进入管理平台 2、找到下图所示位置 3:、输入要添加的微信账号&am…

LCD英文字模库(16x8)模拟测试程序

字模 字模&#xff0c;就是把文字符号转换为LCD能识别的像素点阵信息。 电子发烧友可能都熟悉字模的用途。就是调用者通过向LCD模块发送字模数据&#xff0c;LCD根据字模数据在LCD面板上相应的像素描绘出图形或文字。 现在&#xff0c;大部分的LCD都内置了字模库&#xff0c…

Web学习笔记-Vue3(用户动态页面设计)

笔记内容转载自 AcWing 的 Web 应用课讲义&#xff0c;课程链接&#xff1a;AcWing Web 应用课。 CONTENTS 1. 实现用户信息模块2. 实现关注用户功能3. 实现历史动态模块4. 实现发动态模块 1. 实现用户信息模块 用户动态页面可以划分为三个模块&#xff1a;用户信息部分、发动…

11-09 周四 CNN 卷积神经网络基础知识

11-09 周四 CNN 卷积神经网络 时间版本修改人描述2023年11月9日09:38:12V0.1宋全恒新建文档 简介 学习一下CNN&#xff0c;卷积神经网络。使用的视频课程。视觉相关的任务&#xff1a; 人脸识别 卷积网络与传统网络的区别&#xff1a; <img altimage-20231109094400591 s…

电脑怎么录制视频,录制的视频怎么剪辑?

在现今数字化的时代&#xff0c;视频成为了人们日常生活中不可或缺的一部分。因此&#xff0c;对于一些需要制作视频教程、录制游戏或者是进行视频演示的人来说&#xff0c;电脑录屏已经成为了一个必不可少的工具。那么&#xff0c;对于这些人来说&#xff0c;如何选择一个好用…

Zigbee—网络层地址分配机制

&#x1f3ac;慕斯主页&#xff1a;修仙—别有洞天 ♈️今日夜电波&#xff1a;孤雏 0:21━━━━━━️&#x1f49f;──────── 4:14 &#x1f504; ◀️ ⏸ ▶️ ☰ &#x1f497;关注…

maven配置自定义下载路径,以及阿里云下载镜像

1.配置文件 <?xml version"1.0" encoding"UTF-8"?> <settings xmlns"http://maven.apache.org/SETTINGS/1.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation"http://maven.apache.org…

Vue真实技术面试题解析【兄弟组件、vue-router、增量部署】

兄弟组件的传值方式&#xff0c;有两种方式&#xff0c;把你尽可能知道的告诉我 我的答案&#xff1a;使用父组件传值 和 状态管理传值 使用事件总线&#xff08;Event Bus&#xff09;&#xff1a;创建一个空的 Vue 实例作为事件总线&#xff0c;在其中定义事件和对应的处理函…

大厂面试题-MySQL中的RR隔离级别,到底有没有解决幻读问题?

就MySQL中的RR(Repeatable Reads)事务隔离级别&#xff0c;到底有没有解决幻读问题发起了激烈的讨论。 一部分人说有&#xff0c;一部分人说没有。 结论&#xff0c;MySQL中的RR事务隔离级别&#xff0c;在特定的情况下会出现幻读的问题。 所谓的幻读&#xff0c;表示在同一…