LLM PreTraining from scratch -- 大模型从头开始预训练指北

最近做了一些大模型训练相关的训练相关的技术储备,在内部平台上完成了多机多卡的llm 预训练的尝试,具体的过程大致如下:

数据准备:

大语言模型的训练依赖于与之匹配的语料数据,在开源社区有一群人在自发的整理高质量的语料数据,可以通过 以下的一些链接获取

liwu/MNBVC at main

Skywork/SkyPile-150B · Datasets at Hugging Face

预训练框架:

选用了百川智能的开源框架

原始版本代码训练准备:

根据README 里面的介绍,需要准备以下几样东西:

  • 训练数据,按照训练的卡的数目分成多个文件,每个文件的每一行为一整句的语料,类似这样的文件

添加图片注释,不超过 140 字(可选)

  • 分词器(tokenizer) ,下载 分词器 到当前目录下。

  • 修改hostfile训练脚本,单机训练情况下,不依赖于多机多卡的hostfile, 修改启动脚本 添加启动项 --num_nodes 即可完成单机多卡的训练

 

#!/bin/bash

deepspeed --hostfile config/hostfile --num_nodes=1 \ train.py \ --deepspeed \ --deepspeed_config config/deepspeed.json

原版的训练中几个要处理的问题

  • deepspeed.zero.Init. 错误

这个错误是发生在 deepspeed 设置优化等级不为3的情况下,调用deepspeed.zero.Init 函数会报错,需要在初始化的时候判断一下优化等级是不是3,因此修改代码如下:

 
//train.py
def prepare_model():ds_config = json.load(open(args.deepspeed_config))# print(type(ds_config["zero_optimization"]['stage']))with deepspeed.zero.Init(config_dict_or_path=args.deepspeed_config,enabled=ds_config["zero_optimization"]['stage']==3,mem_efficient_linear=False,mpu=None):model = BaiChuanForCausalLM(smallconfig)
  • 数据不可无限读取

    def get_data(self):# todo 循环读取#data = self.data.pop(0)

在原版本的实现中,数据是从队列中pop 出来的,导致当数据读完了之后会报一个错误导致训练中断

此外原版代码中所有的语料是读取到内存中再进行操作,但是随着语料的量级达到T级别,基本无法全部用内存hold 住所有的语料,另外读取语料的到内存的时间也会很长,基于以上几点考虑,重新选择tfrecord 作为新的数据的存储方式

TFRecord

tfrecord 是一种在tensorflow 中常用的数据格式,数据基于protobuf 完成序列化存储,配合对应的index 可以实现高速的数据读取从而减少数据读取造成的性能瓶颈,原版的训练代码基于的pytorch的框架,可以pip 安装

pip install tfrecord

来使用这个数据结构, 注意这个库可能会遇到protobuf 版本库的问题,通过pip 重新安装 protobuf==3.19 可以解决,

编写对应的代码完成将原来的jsonl 数据转换成tfrecord

//tools/jsonlmutiltfrecord.py
import tfrecord
import os
from tqdm import tqdm
import json
import torch
from tfrecord.torch.dataset import MultiTFRecordDataset, TFRecordDatasetori_path = "/workspace/mnt/storage/zhaozhijian/silk-debug/Baichuan-7B/data_dir_ori"
out_path = "/workspace/mnt/storage/zhaozhijian/silk-debug/Baichuan-7B/data_dir_mutil_test"if not os.path.exists(out_path):os.mkdir(out_path)numgpu = 16 
fidlist = []
for i in range(numgpu):writer = tfrecord.TFRecordWriter(os.path.join(out_path,"data" +str(i)+".tfrecord"))fidlist.append(writer)if 1:files = os.listdir(ori_path)count = 0for file in files:with open(os.path.join(ori_path, file)) as f:for line in tqdm(f.readlines()):dict_ = json.loads(line)fidlist[count%numgpu].write({"text":(dict_["text"].encode('utf-8'), "byte")})count +=1for writer in fidlist:writer.close()os.system("python3 -m tfrecord.tools.tfrecord2idx " + os.path.join(out_path))tfrecord_path = os.path.join(out_path,"data{}.tfrecord")
index_path = os.path.join(out_path,"data{}.tfindex")
splits = {"1": 1,
}
description = {"text": "byte"}
dataset = MultiTFRecordDataset(tfrecord_path, index_path, splits, description, infinite=False)
loader = torch.utils.data.DataLoader(dataset, batch_size=1)for item in loader:print(item['text'].decode('utf-8'))

这里需要注意,tfrecord 的写入的数据只有int,float, byte 3种形式,因此string 格式的数据数据需要通过utf-8的编码写入到tfrecord 中,再读取的时候通过utf-8的解码才能还原为写入的string数据,对应修改train.py 文件,

from tfrecord.torch.dataset import TFRecordDataset, MultiTFRecordDataset...
class DataEngine():...def load_tfrecode_data_mutil(self):splits = {}for file_path in self.local_input_paths:   splits[file_path.replace('.tfrecord', '')] = 1.0/len(self.local_input_paths)tfrecord_path = "{}.tfrecord"index_path = "{}.tfindex"description = {"text": "byte"}dataset = MultiTFRecordDataset(tfrecord_path, index_path, splits, description, infinite=False)self.loader = torch.utils.data.DataLoader(dataset, batch_size=1)return
...
def prepare_data():data_dir = args.data_dir....#    data_engine.load_data()data_engine.load_tfrecode_data_mutil()return data_engine
...
def train(data_engine, model_engine):model_engine.train()step = 0data =[]for item in data_engine.loader:while 1:line = item['text'].decode('utf-8')cc = data_engine.sp.EncodeAsIds(line.strip()) + [data_engine.EOS_TOKEN_ID]if len(cc) < data_engine.MIN_TEXT_LEN:continuedata.extend(cc)if len(data) >= data_engine.micro_batch_size * (data_engine.max_length + 1):index = data_engine.micro_batch_size * (data_engine.max_length + 1)data = data[:index]breakseq = np.asarray(data).reshape(data_engine.micro_batch_size, data_engine.max_length + 1)data = torch.LongTensor(seq)data = data.cuda(non_blocking=True)loss = model_engine(data, labels=data).lossmodel_engine.backward(loss)model_engine.step()step += 1data =[]return
成功解决数据加载中内存和读取速度的问题

多机多卡训练

原版本使用的hostfile 做为启动器,这个有一个前提条件需要各个机器之间可以通过ssh协议互相通信,但是在我们的内部ATOM的环境中无法做到这个,所以启动多机多卡的训练的时候会出现启动两个单机训练和无法启动训练两种情况,这些和我们的多机多卡训练不符

经过摸索后,我们采用了torchrun的启动方式,利用master_addr 等环境变量,用torchstyle 的方式启动多机多卡训练,解决了deepspeed 启动器对于ssh 通信的依赖

NUM_GPUS=8
torchrun --nnodes=$WORLD_SIZE --nproc-per-node=$NUM_GPUS --master-addr=$MASTER_ADDR \--master-port=$MASTER_PORT --node-rank=$RANK \train.py \--deepspeed \--deepspeed_config config/deepspeed.json >log$RANK.txt

对应的修改train.py 中的一些内容:

//train.py
###
deepspeed.init_distributed()
args.local_rank=int(os.environ['LOCAL_RANK'])
###
def prepare_data():...model = BaiChuanForCausalLM(smallconfig)torch.cuda.set_device(args.local_rank)...def train(data_engine, model_engine):model_engine.train()local_rank = int(os.environ['LOCAL_RANK'])...data = data.cuda(non_blocking=True).to(local_rank)...

一些遗留的BUG:

启动训练会卡住: 原因特别傻,就是现在在数据目录下会有tfrecord 和 index 两种后缀的文件,在按照radnk分的时候由于不够随机,会有loader 读取不到文件,导致计算loss 时候卡住,修改 DataEngine

files = [x for _, x in enumerate(self.global_input_paths)if x.find('.tfrecord') != -1]self.local_input_paths = [x for i, x inenumerate(files)if i % dist.get_world_size() == dist.get_rank()]

即可。

最终的训练loss 如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/731207.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

jeecgboot 新建子模块 使用@EXCEL实现实现导入导出功能

一&#xff0c;用框架生成增删改查模块 二&#xff0c;在实体类entity 需要导入导出的字段上加上注解Excel 三&#xff0c;在controller类上继承jeecgboot通用controller JeecgController 并且在JeecgController里增加导出模板的方法 /*** 导出excel空模板** param req…

专业140+总430+电子科技大学858信号与系统考研经验成电电子信息与通信工程,电科大,真题,大纲,参考书。

今年考研成绩出来&#xff0c;初试专业课858信号与系统140&#xff0c;总分430&#xff0c;其余各门分数都比较平稳&#xff0c;总分好于自己估分&#xff0c;应群里很多同学要求&#xff0c;我总结一下自己的复习经验。首先我是一个大冤种&#xff0c;专业课资料学长给了一套&…

挑战杯 基于深度学习的视频多目标跟踪实现

文章目录 1 前言2 先上成果3 多目标跟踪的两种方法3.1 方法13.2 方法2 4 Tracking By Detecting的跟踪过程4.1 存在的问题4.2 基于轨迹预测的跟踪方式 5 训练代码6 最后 1 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 基于深度学习的视频多目标跟踪实现 …

软考高级:系统工程生命周期方法(计划驱动方法、渐进迭代式方法等)概念和例子

作者&#xff1a;明明如月学长&#xff0c; CSDN 博客专家&#xff0c;大厂高级 Java 工程师&#xff0c;《性能优化方法论》作者、《解锁大厂思维&#xff1a;剖析《阿里巴巴Java开发手册》》、《再学经典&#xff1a;《Effective Java》独家解析》专栏作者。 热门文章推荐&am…

如何在Ubuntu系统部署DbGate数据库管理工具并结合cpolar内网穿透远程访问

文章目录 1. 安装Docker2. 使用Docker拉取DbGate镜像3. 创建并启动DbGate容器4. 本地连接测试5. 公网远程访问本地DbGate容器5.1 内网穿透工具安装5.2 创建远程连接公网地址5.3 使用固定公网地址远程访问 本文主要介绍如何在Linux Ubuntu系统中使用Docker部署DbGate数据库管理工…

web组态

演示地址 &#xff1a;by组态[web组态插件] 这是一款可以嵌入到任何项目组态插件&#xff0c;功能全面&#xff0c;可根据自己的项目需要进行二次开发&#xff0c;能大大的节省在组态上的开发时间&#xff0c;代码简单易懂。 一、数据流向图及嵌入原理 数据流向 嵌入原理 …

深度神经网络 基本知识 记录

资料&#xff1a;https://www.bilibili.com/video/BV1K94y1Z7wn/?spm_id_from333.337.search-card.all.click&vd_source14a476de9132ba6b2c3cbc2221750b99 计划&#xff1a;3~4天 注&#xff1a;网课讲的内容比较糅杂&#xff0c;记录的内容可能会出现重复 杂 人工智能…

<商务世界>《第8课 Leads——MQL——SQL——商机——成交》

1 各种概念 英文缩写概念Traffic流量Leads潜在客户&#xff0c;销售线索&#xff1b;简称潜在线索MQLMarketing-Qualified Leads市场认可线索SQLSales-Qualified Leads销售认可线索OPPOpportunity商机Account成单客户 2 线索到商机 一般企业会把自身线索进行如下的划分&…

【电工学笔记】上册第一、二章

电工学 上次考试败在了单位&#xff0c;这次单位 一定要记熟。 第一章 电源或信号源的电压或电流称为激励,它推动电路工作; 由激励所产生的电压和电流称为响应。 复杂电路中,一般无法事先判断某个支路电流的 实际方向或者某个电路元件电压的实际方向 140V/4算不出总电阻的 …

数据结构面试常见问题

数据结构面试常见问题 什么是 AVL 树&#xff1f;什么是红黑树&#xff1f;AVL 树和红黑树的区别&#xff1f;B 树和B 树的区别&#xff1f;排序有哪些分类&#xff1f;直接插入排序的原理&#xff1f;希尔排序的原理&#xff1f;直接选择排序的原理&#xff1f;堆排序的原理&a…

vue3的开发小技巧

「总之岁月漫长&#xff0c;然而值得等待。」 目录 父组件调用子组件函数如何访问全局api 父组件调用子组件函数 ref, defineExpose //父组件 代码 <child ref"ch">this.$refs.ch.fn();//子组件 函数抛出 const fn () > { }; defineExpose({ fn });如何…

考研复习C语言初阶(3)

目录 一.函数是什么? 二.C语言中函数的分类 2.1库函数 2.2自定义函数 三.函数的参数 3.1实际参数&#xff08;实参&#xff09; 3.2 形式参数&#xff08;形参&#xff09; 四.函数的调用 4.1 传值调用 4.2 传址调用 五. 函数的嵌套调用和链式访问 5.1 嵌套调用 5…

瑞芯微 | I2S-音频基础分享

1. 音频常用术语 名称含义ADC&#xff08;Analog to Digit Conversion&#xff09;模拟信号转换为数字信号AEC&#xff08;Acoustic Echo Cancellor&#xff09;回声消除AGC&#xff08;Automatic Gain Control&#xff09;自动增益补偿&#xff0c;调整MIC收音量ALSA&#xf…

Jmeter常用组件的使用场景

一.在一段时间内持续发送请求 此场景可以用于稳定性测试&#xff0c;在稳定性测试中&#xff0c;通常需要持续压测几个小时甚至几天时间&#xff0c;查看接口是否有报错&#xff0c;或者cpu、内存会上涨&#xff0c;此时就需要通过控制持续时间来达到此目的。 1.创建线程组&am…

基于SSM的校园疫情管理系统的设计与实现(有报告)。Javaee项目。ssm项目。

演示视频&#xff1a; 基于SSM的校园疫情管理系统的设计与实现&#xff08;有报告&#xff09;。Javaee项目。ssm项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系结构&#xff0c;…

【网络层】IP多播技术的相关基本概念(湖科大慕课自学笔记)

IP多播 1&#xff1a;IP多播技术的相关基本概念 我们简单举例&#xff0c;如下图所示&#xff1a; 一共有60个主机要接受来自视频服务器的同一个节目&#xff0c;如果采用单播方式&#xff0c;则视频服务器要发送60份&#xff0c;这些视频节目通过路由器的转发&#xff0c;最…

CentOS7 利用remi yum源安装php8.1

目录 前言remi yum源remi yum源 支持的操作系统remi yum源 支持的php版本 安装epel源安装remi源安装 php8.1查看php版本查看php-fpm服务启动php-fpm服务查看php-fpm服务运行状态查看php-fpm服务占用的端口查看 php8.1 相关的应用 前言 CentOS Linux release 7.9.2009 (Core) …

[Angular 基础] - Observable

[Angular 基础] - Observable 之前的笔记&#xff1a; [Angular 基础] - service 服务[Angular 基础] - routing 路由(上)[Angular 基础] - routing 路由(下) 我以前对 Observable 的理解是 Promise 的一个超集&#xff0c;重新了解了一下&#xff0c;感觉这个说法不太对。更…

2024最新版CleanMyMac X 4.15.1 Crack+激活码下载

CleanMyMac X 为您喜爱的事物腾出空间。 CleanMyMac 具有一系列巧妙的新功能&#xff0c;可让您安全、智能地扫描和清理整个系统、删除大型未使用的文件、减小 iPhoto 图库的大小、卸载不需要的应用程序或修复开始工作不正常的应用程序、管理所有应用程序您可以从一个地方进行扩…

【牛客】HJ87 密码强度等级 CM62 井字棋

题目一:密码强度等级 题目链接&#xff1a;密码强度等级_牛客题霸_牛客网 (nowcoder.com) 本题主要考察C语言中逻辑分支语句&#xff0c;基本语句以及对各种特殊字符 &#xff0c;ASCII值以及条件表达中的逻辑运算符关系运算符各自功能的理解&#xff0c;以及基本使用&#x…