使用 Amazon SageMaker 微调 Llama 2 模型

1c3aea7b4c3c494fd29dcbc241c8a7fd.gif

本篇文章主要介绍如何使用 Amazon SageMaker 进行 Llama 2 模型微调的示例。

这个示例主要包括:

  1. Llama 2 总体介绍

  2. Llama 2 微调介绍

  3. Llama 2 环境设置

  4. Llama 2 微调训练

前言

随着生成式 AI 的热度逐渐升高,国内外各种基座大语言竞相出炉,在其基础上衍生出种类繁多的应用场景。训练优异的基座大语言模型在通用性方面表现较好,但模型可能并未涉及到特定领域的专业术语、领域内的特定用语或上下文等。采用微调技术可以通过在领域特定数据上进行训练,使模型更好地适应目标领域的特殊语言模式和结构;结合基座模型的通用性和领域特定性,使得模型更具实际应用价值。

Llama 2 总体介绍

Llama 2 是 META 最新开源的 LLM,包括 7B、13B 和 70B 三个版本,训练数据集超过了 Llama 2 的 40%,达到 2 万亿 token;上下文长度也提升到 4K,可以极大扩展多轮对话的轮数、提示词输入数据;与此同时,Llama 2 Chat 模型使用基于人类反馈的强化学习(Reinforcement Learning from Human Feedback,RLHF),针对对话场景进行了大幅优化,达到了非常出色的有用性和安全性基准。HuggingFace 的 TGI 和 vLLM 等框架均有针对 Llama 2 的推理优化,进一步强化了 Llama 2 的可用性。

Llama 2 被认为是开源界大语言模型的首选,众多的垂类大模型均采用 Llama 2 作为基座大模型,在此基础上添加行业数据进行模型的预训练或者微调,适配更多的行业场景。

Llama 2 微调介绍

模型微调主要分为 Full Fine-Tune 和 PEFT (Performance-Efficient Fine-Tune),前者模型全部参数都会进行更新,训练时间较长,训练资源较大;而后者会冻结大部分参数、微调训练网络结构,常见的方式是 LoRA 和 P-Tuning v2。

PEFT 微调方式由于参数更新较少,可能导致模型无法学习到全部领域知识,对于特定任务或领域来说会出现推理不稳定的情况,因此大多数生产系统均使用全参数方式进行模型的微调。基于上述原因,本文会以全参数微调方式介绍 Llama 2 在 Amazon SageMaker 上的微调。

Llama 2 环境设置

备注:项目中的示例代码均保存于代码仓库,地址如下: 

https://github.com/aws-samples/llm-workshop-on-amazon-sagemaker

1. 升级 Python SDK 

pip install -U sagemaker

2. 获取运行时资源,包括区域、角色、账号、S3 桶等 

import boto3
import sagemaker
from sagemaker import get_execution_rolesess                     = sagemaker.Session()
role                     = get_execution_role()
sagemaker_default_bucket = sess.default_bucket()account                  = sess.boto_session.client("sts").get_caller_identity()["Account"]
region                   = sess.boto_session.region_name

Llama 2 微调训练

微调准备

克隆代码

  • 采用 lm-sys 团队发布的 FastChat 平台进行 Llama 2 的微调,FastChat 也用于训练了知名的 Vicuna 模型,具有良好的代码规范和性能优化。

git clone https://github.com/lm-sys/FastChat.git
cd FastChat
git reset --hard 974537efbd82093b45e64d07904efe7728193a52

下载 Llama 2 原始模型

from huggingface_hub import snapshot_download
from pathlib import Pathlocal_cache_path = Path("./model")
local_cache_path.mkdir(exist_ok=True)model_name = "TheBloke/Llama-2-13B-fp16"# Only download pytorch checkpoint files
allow_patterns = ["*.json", "*.pt", "*.bin", "*.model", "*.py"]model_download_path = snapshot_download(repo_id=model_name,cache_dir=local_cache_path,allow_patterns=allow_patterns,revision='b2e65e8ad4bb35e5abaee0170ebd5fc2134a50bb'
)# Get the model files path
import os
from glob import globlocal_model_path = Nonepaths = os.walk(r'./model')
for root, dirs, files in paths:for file in files:if file == 'config.json':print(os.path.join(root,file))local_model_path = str(os.path.join(root,file))[0:-11]print(local_model_path)
if local_model_path == None:print("Model download may failed, please check prior step!")

拷贝模型和数据到 Amazon S3

chmod +x ./s5cmd
./s5cmd sync ${local_model_path} s3://${sagemaker_default_bucket}/llm/models/llama2/TheBloke/Llama-2-13B-fp16/ 
rm -rf model

模型微调

  • 模型的微调使用全参数模型,以实现微调后模型的稳定性。

  • 模型的微调使用开源框架 DeepSpeed 进行加速。

准备基础镜像

使用 Amazon SageMaker 定制的深度学习训练镜像作为基础镜像,再安装 Llama 2 训练所需的依赖包。Dockerfile 如下:

%%writefile Dockerfile
## You should change below region code to the region you used, here sample is use us-west-2
From 763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-training:1.13.1-transformers4.26.0-gpu-py39-cu117-ubuntu20.04 ENV LANG=C.UTF-8
ENV PYTHONUNBUFFERED=TRUE
ENV PYTHONDONTWRITEBYTECODE=TRUERUN pip3 uninstall -y deepspeed \&& pip3 install deepspeed==0.10.0 \&& pip3 install transformers==4.30.2## Make all local GPUs visible
ENV NVIDIA_VISIBLE_DEVICES="all"

模型微调代码

模型微调源代码较多,细节可以参考上述 git 仓库。

微调参数

  • 为了节省显存,采用 DeepSpeed Stage-3

  • 训练过程开启 bf16,实现整数范围和精度的平衡

  • 训练数据集采用官方提供的 dummy_conversation.json,也就是典型的 {"instruction"、"input"、"output"} 的格式,同时可以支持多轮对话

DEEPSPEED_OPTS="""FastChat/fastchat/train/train_mem.py --deepspeed ds.json --model_name_or_path "/tmp/llama_pretrain/" --data_path FastChat/data/dummy_conversation.json --output_dir "/tmp/llama_out" --num_train_epochs 1 --per_device_train_batch_size 1 --per_device_eval_batch_size  1 --gradient_accumulation_steps 4 --evaluation_strategy "no" --save_strategy "no" --save_steps 2000 --save_total_limit 1 --learning_rate 2e-5 --weight_decay 0. --warmup_ratio 0.03 --lr_scheduler_type "cosine" --logging_steps 1 --cache_dir '/tmp' --model_max_length 2048 --gradient_checkpointing True --lazy_preprocess True --bf16 True --tf32 True --report_to "none"
"""

微调脚本

  • 微调使用 torchrun + DeepSpeed 进行分布式训练

%%writefile ./src/ds-train-dist.sh
#!/bin/bash
CURRENT_HOST="${SM_CURRENT_HOST}"IFS=',' read -ra hosts_array <<< "${SM_HOSTS}"
NNODES=${#hosts_array[@]}
NODE_RANK=0for i in "${!hosts_array[@]}"; doif [[ "${hosts_array[$i]}" == *${CURRENT_HOST}* ]]; thenecho "host index:$i"NODE_RANK="$i" fi
doneMASTER_PORT="13579"
export NCCL_SOCKET_IFNAME="eth0"#Configure the distributed arguments for torch.distributed.launch.
GPUS_PER_NODE="$SM_NUM_GPUS"
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \--nnodes $NNODES \--node_rank $NODE_RANK \--master_addr $MASTER_ADDR \--master_port $MASTER_PORT"chmod +x ./s5cmd
./s5cmd sync s3://$MODEL_S3_BUCKET/llm/models/llama2/TheBloke/Llama-2-13B-fp16/* /tmp/llama_pretrain/CMD="torchrun ${DISTRIBUTED_ARGS} ${DEEPSPEED_OPTS}"
echo ${CMD}
${CMD} 2>&1 if [[ "${CURRENT_HOST}" == "${MASTER_ADDR}" ]]; then  ./s5cmd sync /tmp/llama_out s3://$MODEL_S3_BUCKET/llm/models/llama2/output/TheBloke/Llama-2-13B-fp16/$(date +%Y-%m-%d-%H-%M-%S)/
fi

启动微调

  • 全参数微调,需要使用至少一台 p4de.12xlarge(8 卡 A100 40GB)作为训练机器。

  • 当微调完成后,训练好的模型自动存储于指定的 S3 桶内,可用于后续的模型部署推理。

import time
from sagemaker.estimator import Estimatorenvironment = {'MODEL_S3_BUCKET': sagemaker_default_bucket # The bucket to store pretrained model and fine-tune model
}base_job_name = 'llama2-13b-finetune'instance_type = 'ml.p4d.24xlarge'estimator = Estimator(role=role,entry_point='ds-train-dist.sh',source_dir='./src',base_job_name=base_job_name,instance_count=1,instance_type=instance_type,image_uri=image_uri,environment=environment,disable_profiler=True,debugger_hook_config=False)estimator.fit()

总结

大语言模型方兴未艾,正在以各种方式改变和影响着整个世界。客户拥抱大语言模型,亚马逊云科技团队同样在深耕客户需求和大语言模型技术,可以在未来更好地协助客户实现需求,提升业务价值。

本篇作者

6ce443c21a564a6109595741d2da8d7c.jpeg

高郁

亚马逊云科技解决方案架构师,主要负责企业客户上云,帮助客户进行云架构设计和技术咨询,专注于智能湖仓、AI/ML 等技术方向。

3c7b3572aea34ef2ee8fc45077824270.gif

星标不迷路,开发更极速!

关注后记得星标「亚马逊云开发者」

952455b80801542c984978e00d9d8e6e.gif

听说,点完下面4个按钮

就不会碰到bug了!

c330e0b3208f52226bf51d8e36d2c4f1.gif

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/763903.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

测试开发工程师(QA)职业到底需要干些什么?part1:移动端QA

概述 移动端QA测试开发工作主要涉及对移动应用程序进行质量保证和测试的开发工作。以下是移动端QA测试开发人员的主要职责和工作内容&#xff1a; 测试计划和策略制定&#xff1a;参与制定移动应用程序的测试计划和策略&#xff0c;确定测试范围、测试目标和测试方法。考虑到…

Mysql---DML

文章目录 目录 一.DML概述 注入数据&#xff08; Insert&#xff09; 替换数据&#xff08;replace&#xff09; 删除数据 &#xff08;delete&#xff09; 修改数据 &#xff08;update&#xff09; 查询数据 &#xff08;select&#xff09; 二. 多表连接查询 内连接 子…

Linux环境JMeter脚本性能测试、easyNmon生成监控报告

一、下载JMeter安装包 Jmeter是Java开发的&#xff0c;需要依赖JDK环境&#xff0c;因此我们需提前安装好JDK。 Jmeter是开源的工具&#xff0c;我们直接到官网下载即可。 最新版本下载地址&#xff1a;Apache JMeter - Download Apache JMeter 二、安装JMeter #新建jmete…

【GIT】最好用的git可视化教程网站推荐

最好用可视化学习git 网站:https://learngitbranching.js.org/?demo&localezh_CN 玩遍所有关卡&#xff0c;花半天时间便能掌握git &#x1f603; 本地仓库 基础命令介绍 git commit 提交 git branch <分支名> 创建分支 git checkout <分支名> 切换分支 git…

鸿蒙Harmony应用开发—ArkTS(@State装饰器:组件内状态)

State装饰的变量&#xff0c;或称为状态变量&#xff0c;一旦变量拥有了状态属性&#xff0c;就和自定义组件的渲染绑定起来。当状态改变时&#xff0c;UI会发生对应的渲染改变。 在状态变量相关装饰器中&#xff0c;State是最基础的&#xff0c;使变量拥有状态属性的装饰器&a…

【C++ leetcode】双指针问题

1. 611. 有效三角形的个数 题目 给定一个包含非负整数的数组 nums &#xff0c;返回其中可以组成三角形三条边的三元组个数。 题目链接 . - 力扣&#xff08;LeetCode&#xff09; 画图 和 文字 分析 判断是否是三角形要得到三边&#xff0c;由于遍历三边要套三层循环&#x…

开发调试、远程访问、内网穿透解决方案集合

开发调试、远程访问、内网穿透解决方案集合 前言Cpolar下载安装配置与使用 Ngrok购买隧道下客户端配置与使用 ZeroTier注册与安装创建虚拟网络加入虚拟网络配置授权 花生壳Centos系统Ubuntu系统使用花生壳控制台SN码登录添加映射 Loophole通过CLI方式安装登录与注销简单使用身份…

网络安全是什么? 为什么要学网络安全 ?网络安全怎么学习?

网络安全是什么&#xff1f; 网络安全是指保护计算机网络、网络设备、应用程序、数据和用户免受非法访问、攻击、破坏或泄漏的过程和技术。网络安全包括多个领域&#xff0c;例如网络防御、漏洞管理、加密技术、身份验证和访问控制等等。 网络安全非常重要&#xff0c;因为现…

省一餐,是减瘦捷径?还是牺牲健康的换取?

肥胖从来不是靠短时间的&#xff0c;每天少吃一餐就能减掉的&#xff0c;需要长期坚持。但三餐不管哪一餐&#xff0c;长期不吃&#xff0c;都不会有好结果。为了瘦&#xff0c;失去健康值不值呢&#xff1f; 长期不吃早饭后果 1、消耗率、吸收率减慢&#xff1a;身体经过一整…

扩散模型零样本分类应用笔记

1 Title Your Diffusion Model is Secretly a Zero-Shot Classifier&#xff08;Alexander C. Li, Mihir Prabhudesai, Shivam Duggal, Ellis Brown, Deepak Pathak&#xff09;【ICCV 2023】 2 Conclusion This paper shows that the density estimates from large-scale tex…

阿里云4核8G服务器多少钱一年?

阿里云4核8G服务器优惠价格955元一年&#xff0c;配置为ECS通用算力型u1实例&#xff08;ecs.u1-c1m2.xlarge&#xff09;4核8G配置、1M到3M带宽可选、ESSD Entry系统盘20G到40G可选&#xff0c;CPU采用Intel(R) Xeon(R) Platinum处理器&#xff0c;阿里云活动链接 aliyunfuwuq…

SSH介绍及检测规则思路分析

一、SSH 1、定义 SSH是安全的加密协议&#xff0c;用于远程连接linux服务器。 2、ssh服务的主要功能&#xff1a; 1&#xff09;提供远程链接服务器的功能&#xff1b; 2&#xff09;对远程链接传输的数据进行加密 3、ssh与telnet的区别&#xff1a; 服务链接方式 服务数据…

DBO优化LSBoost回归预测(matlab代码)

DBO-LSBoost回归预测matlab代码 蜣螂优化算法(Dung Beetle Optimizer, DBO)是一种新型的群智能优化算法&#xff0c;在2022年底提出&#xff0c;主要是受蜣螂的的滚球、跳舞、觅食、偷窃和繁殖行为的启发。 数据为Excel股票预测数据。 数据集划分为训练集、验证集、测试集,比…

【国家计算机二级C语言】高分笔记

二叉树 参考 http://t.csdnimg.cn/ozVwT 数据库 SQL程序语言有四种类型&#xff0c;对数据库的基本操作都属于这四类&#xff0c;它们分别为&#xff1b;数据定义语言(DDL)、数据查询语言&#xff08;DQL&#xff09;、数据操纵语言&#xff08;DML&#xff09;、数据控制语言…

领夹麦配LDR6028,电力持久畅聊畅!

#无线麦克风#麦克风&#xff0c;对于大多数人来说&#xff0c;并不陌生。然而&#xff0c;领夹式麦克风&#xff0c;这个看似小巧的音频设备&#xff0c;或许在日常生活中并不常为我们所见。但在自媒体行业、新闻记者等领域&#xff0c;它却是不可或缺的好帮手。这款领夹式麦克…

管理类联考–复试–英文面试–问题--规划介绍原因做法--纯英文版

借鉴 https://www.bilibili.com/video/BV1Dk4y187zN/?p4&spm_id_from333.880.my_history.page.clickhttps://www.bilibili.com/video/BV1Dk4y187zN/?p4&spm_id_from333.880.my_history.page.click https://ttsreader.com/zh/https://ttsreader.com/zh/ 规划 视频版…

2024年NOC大赛创客智慧(西瓜创客)图形化编程真题模拟试卷包含答案

详细题目看顶部资源 答案解析 一、选择题 1、C 该段代码是将变量的值翻倍,运行之后变量的值是之前的两倍。变量的值是否改变取决于初始值是否为 0,所以船都不正确 2、C A 透项为让角色说话,不可以广播消息: B 选项为播放一段声音,不可以广播消息; C透项为广播消息,正确: …

OCP NVME SSD规范解读-14.Firmware固件升级要求

4.11节 Firmware Update Requirements 描述了数据中心NVMe SSD固件更新的具体要求&#xff0c;确保固件升级过程既安全又可靠&#xff0c;同时充分考虑了设备在升级过程中的可用性和功能性。 FWUP-1: 设备必须记录每一次固件激活过程。这意味着固件升级过程中&#xff0c;设备会…

使用远程工具连接Mysql

&#xff08;若想要远程连接Mysql需要下面解决四个问题&#xff09; 1、目标地址 直接查询 2、端口号 3306 3、防火墙关闭 [rootlocalhost date]# systemctl stop firewalld.service 4、授权mysql数据库root用户权限&#xff08;因为mysql开始不允许其他IP访问&#xff0…

时间减少90%以上!分布式系统的性能优化实战

1背景 分布式批量系统指的是采用分布式数据库架构&#xff0c;主体功能由批量程序实现的系统。分布式系统批量程序的性能测试&#xff0c;除了和联机交易性能测试一样关注服务器资源使用率是否合理、是否存在性能异常外&#xff0c;在测试执行阶段需要关注是否因数据分布不均衡…