文章目录
- 参数量计算
本文是 LLMBox 和 YuLan-Chat 的预训练示例代码。此示例基于 Transformers 和 DeepSpeed 进行训练。在下面的示例代码中,train() 函数涵盖了预训练过程中的主要步骤,包括模型与分词器的初始化、训练数据的准备等;然后调用 Trainer 类来执行模型训练并保存训练状态。
参数量计算
1 from dataclasses import dataclass
2 from dataset.pt_dataset import PTDataset
3 from transformers import (
4 AutoModelForCausalLM,
5 AutoTokenizer,
6 HfArgumentParser,
7 TrainingArguments,
8 Trainer,
9 )
10 from transformers.hf_argparser import HfArg
11
12
13 # 用户输入超参数
14 @dataclass
15 class Arguments(TrainingArguments):
16 # 模型结构
17 model_name_or_path: str = HfArg(
18 default=None,
19 help="The model name or path, e.g., `meta-llama/Llama-2-7b-hf`",
20 )
21 # 训练数据集
22 dataset: str = HfArg(
23 default="",
24 help="Setting the names of data file.",
25 )
26 # 上下文窗口大小
27 model_max_length: int = HfArg(
28 default=2048,
29 help&