0基础学会在亚马逊云科技AWS上利用SageMaker、PEFT和LoRA高效微调AI大语言模型(含具体教程和代码)

项目简介:

小李哥今天将继续介绍亚马逊云科技AWS云计算平台上的前沿前沿AI技术解决方案,帮助大家快速了解国际上最热门的云计算平台亚马逊云科技AWS上的AI软甲开发最佳实践,并应用到自己的日常工作里。本次介绍的是如何在Amazon SageMaker上微调(Fine-tune)大语言模型dolly-v2-3b,满足日常生活中不同的场景需求,并将介分享如何在SageMaker上优化模型性能并节省计算资源实现成本控制,最后将部署后的大语言模型URL集成到自己云上的软件应用中。

本方案包括通过Amazon Cloudfront和S3托管前端页面,并通过Amazon API Gateway和AWS Lambda将应用程序与AI模型集成,调用大模型实现推理。本方案的解决方案架构图如下:

利用微调模型创建的对话机器人前端UI

利用本方案小李哥用微调后的模型搭建了一个Q&A对话机器人助手,可以生成代码、文字总结、回答问题。

在开始分享案例之前,我们来了解一下本方案的技术背景,帮助大家更好的理解方案架构。

什么是Amazon SageMaker?

Amazon SageMaker 是一个完全托管的机器学习服务(大家可以理解为Serverless的Jupyter Notebook),专为应用开发和数据科学家设计,帮助他们快速构建、训练和部署机器学习模型。使用 SageMaker,您无需担心底层基础设施的管理,可以专注于模型的开发和优化。它提供了一整套工具和功能,包括数据准备、模型训练、超参数调优、模型部署和监控,简化了整个机器学习工作流程。

本方案将介绍以下内容:

1. 使用 SageMaker Jupyter Notebook进行dolly-v2-3b模型开发和微调

2. 在SageMaker部署微调后的大语言模型LLM并基于数据进行推理

3. 使用多场景的测试案例验证推理结果表现,并将部署的模型API节点集成进云端应用

项目搭建具体步骤:

下面跟着小李哥手把手微调一个亚马逊云科技AWS上的生成式AI模型(dolly-v2-3b)的软件应用,并将AI大模型部署与应用集成。

1. 在控制台进入Amazon SageMaker, 点击Notebook

2. 打开Jupyter Notebook

3. 创建一个新的Notebook:“lab-notebook.ipynb”并打开

4. 接下来我们在单元格内一步一步运行代码,检查CUDA的内存状态

!nvidia-smi

5.接下来,我们安装必要依赖并导入

%%capture!pip3 install -r requirements.txt --quiet
!pip install sagemaker --quiet --upgrade --force-reinstall
%%captureimport os
import numpy as np
import pandas as pd
from typing import Any, Dict, List, Tuple, Union
from datasets import Dataset, load_dataset, disable_caching
disable_caching() ## disable huggingface cachefrom transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
from transformers import TextDatasetimport torch
from torch.utils.data import Dataset, random_split
from transformers import TrainingArguments, Trainer
import accelerate
import bitsandbytesfrom IPython.display import Markdown

6. 导入提前准备好的FAQs数据集

sagemaker_faqs_dataset = load_dataset("csv", data_files='data/amazon_sagemaker_faqs.csv')['train']
sagemaker_faqs_dataset
sagemaker_faqs_dataset[0]

7. 我们定义用于模型推理的提示词格式

from utils.helpers import INTRO_BLURB, INSTRUCTION_KEY, RESPONSE_KEY, END_KEY, RESPONSE_KEY_NL, DEFAULT_SEED, PROMPT
'''
PROMPT = """{intro}{instruction_key}{instruction}{response_key}{response}{end_key}"""
'''
Markdown(PROMPT)

8. 下面我们进入重头戏,导入一个提前预训练好的LLM大语言模型“databricks/dolly-v2-3b”。

tokenizer = AutoTokenizer.from_pretrained("databricks/dolly-v2-3b", padding_side="left")tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})model = AutoModelForCausalLM.from_pretrained("databricks/dolly-v2-3b",# use_cache=False,device_map="auto", #"balanced",load_in_8bit=True,
)

9. 对模型训练进行预准备, 处理数据集、优化模型训练(PEFT)效率

model.resize_token_embeddings(len(tokenizer))from functools import partial
from utils.helpers import mlu_preprocess_batchMAX_LENGTH = 256
_preprocessing_function = partial(mlu_preprocess_batch, max_length=MAX_LENGTH, tokenizer=tokenizer)encoded_sagemaker_faqs_dataset = sagemaker_faqs_dataset.map(_preprocessing_function,batched=True,remove_columns=["instruction", "response", "text"],
)processed_dataset = encoded_sagemaker_faqs_dataset.filter(lambda rec: len(rec["input_ids"]) < MAX_LENGTH)split_dataset = processed_dataset.train_test_split(test_size=14, seed=0)
split_dataset

10. 同时我们使用LoRA(Low-Rank Adaptation)模型加速我们的模型微调

from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, TaskTypeMICRO_BATCH_SIZE = 8  
BATCH_SIZE = 64
GRADIENT_ACCUMULATION_STEPS = BATCH_SIZE // MICRO_BATCH_SIZE
LORA_R = 256 # 512
LORA_ALPHA = 512 # 1024
LORA_DROPOUT = 0.05# Define LoRA Config
lora_config = LoraConfig(r=LORA_R,lora_alpha=LORA_ALPHA,lora_dropout=LORA_DROPOUT,bias="none",task_type="CAUSAL_LM"
)model = get_peft_model(model, lora_config)
model.print_trainable_parameters()from utils.helpers import MLUDataCollatorForCompletionOnlyLMdata_collator = MLUDataCollatorForCompletionOnlyLM(tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
)

11. 接下来我们定义模型训练参数并开始训练。其中Batch=1,Step=20000,epoch为10.

EPOCHS = 10
LEARNING_RATE = 1e-4  
MODEL_SAVE_FOLDER_NAME = "dolly-3b-lora"training_args = TrainingArguments(output_dir=MODEL_SAVE_FOLDER_NAME,fp16=True,per_device_train_batch_size=1,per_device_eval_batch_size=1,learning_rate=LEARNING_RATE,num_train_epochs=EPOCHS,logging_strategy="steps",logging_steps=100,evaluation_strategy="steps",eval_steps=100, save_strategy="steps",save_steps=20000,save_total_limit=10,
)trainer = Trainer(model=model,tokenizer=tokenizer,args=training_args,train_dataset=split_dataset['train'],eval_dataset=split_dataset["test"],data_collator=data_collator,
)
model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
trainer.train()

12. 接下来我们将微调后的模型保存在本地

trainer.model.save_pretrained(MODEL_SAVE_FOLDER_NAME)trainer.save_model()trainer.model.config.save_pretrained(MODEL_SAVE_FOLDER_NAME)tokenizer.save_pretrained(MODEL_SAVE_FOLDER_NAME)

13. 接下来,我们将保存到本地的模型进行部署,生成公开访问的API节点Endpoint

对部署所需要的参数进行定义和初始化

import boto3
import json
import sagemaker.djl_inference
from sagemaker.session import Session
from sagemaker import image_uris
from sagemaker import Modelsagemaker_session = Session()
print("sagemaker_session: ", sagemaker_session)aws_role = sagemaker_session.get_caller_identity_arn()
print("aws_role: ", aws_role)aws_region = boto3.Session().region_name
print("aws_region: ", aws_region)image_uri = image_uris.retrieve(framework="djl-deepspeed",version="0.22.1",region=sagemaker_session._region_name)
print("image_uri: ", image_uri)

进行模型部署

model_data="s3://{}/lora_model.tar.gz".format(mybucket)model = Model(image_uri=image_uri,model_data=model_data,predictor_cls=sagemaker.djl_inference.DJLPredictor,role=aws_role)

14.最后我们写入提示词,对大语言模型进行测试, 得到推理

outputs = predictor.predict({"inputs": "What solutions come pre-built with Amazon SageMaker JumpStart?"})from IPython.display import Markdown
Markdown(outputs)

15. 我们下面进入SageMaker Endpoint页面,得到刚部署的模型API端点的URL,通过这种方式我们就可以在应用中调用我们的微调后的大语言模型了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/46336.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【学习笔记】无人机(UAV)在3GPP系统中的增强支持(九)-无人机服务区分离

引言 本文是3GPP TR 22.829 V17.1.0技术报告&#xff0c;专注于无人机&#xff08;UAV&#xff09;在3GPP系统中的增强支持。文章提出了多个无人机应用场景&#xff0c;分析了相应的能力要求&#xff0c;并建议了新的服务级别要求和关键性能指标&#xff08;KPIs&#xff09;。…

Solidworks工程图替换参考零件

1.用solidworks选择工程图文件&#xff0c;点击参考。 2.双击文件名 3.选择新的参考零件&#xff0c;点击确定。

[ruby on rails]部署时候产生ActiveRecord::PreparedStatementCacheExpired错误的原因及解决方法

一、问题&#xff1a; 有时在 Postgres 上部署 Rails 应用程序时&#xff0c;可能会看到 ActiveRecord::PreparedStatementCacheExpired 错误。仅当在部署中运行迁移时才会发生这种情况。发生这种情况是因为 Rails 利用 Postgres 的缓存准备语句(PreparedStatementCache)功能来…

力扣第406场周赛

力扣第406场周赛 100352. 交换后字典序最小的字符串 - 力扣&#xff08;LeetCode&#xff09; 贪心&#xff0c;从 0 0 0开始扫描到 n n n如果有一个可以交换的就立马交换 class Solution { public:string getSmallestString(string s) {for(int i1;i<s.size();i){if(s[i…

【PyQt】

PyQT5线程基础&#xff08;2&#xff09; 线程案例案例一案例二 线程案例 案例一 案例一代码通过线程实现点击按钮向线程传输地址&#xff0c;程序等待20秒后&#xff0c;返回结果。 通过QtDesigner创建如下图所示的界面ui&#xff0c;并用UIC工具转成对应的py文件。 Ui_tes…

C语言之指针的奥秘(三)

一、字符指针变量 在指针的类型中&#xff0c;有字符指针char*&#xff0c;一般使用&#xff1a; #include<stdio.h> int main() {char ch w;char* p &ch;*p w;return 0; } 还有一种方式&#xff1a; #include<stdio.h> int main() {const char* p &qu…

123456

截止2023年10月&#xff0c;目前已公开的官方矢量数据有3个网站&#xff0c;按照公开时间顺序分别是&#xff1a;1&#xff09;全国地理信息资源目录服务系统&#xff1b;2&#xff09;西藏自治区自然资源厅&#xff1b;3&#xff09;福建省标准地图服务。我们将持续更新公开的…

2024较火的软件宣传单页HTML源码

源码介绍 2024较火的软件宣传单页HTML源码&#xff0c;源码由HTMLCSSJS组成&#xff0c;记事本打开源码文件可以进行内容文字之类的修改&#xff0c;双击html文件可以本地运行效果 效果截图 源码获取 2024较火的软件宣传单页HTML源码

自动驾驶可能解决的问题

首先是各种盲区&#xff0c;雷达可能检测到各种东西&#xff0c;而这些是视觉注意不到的 然后是每辆车可以互联互通&#xff0c;整体规划路线

Java二十三种设计模式-单例模式(1/23)

引言 在软件开发中&#xff0c;设计模式是一套被反复使用的、大家公认的、经过分类编目的代码设计经验的总结。单例模式作为其中一种创建型模式&#xff0c;确保一个类只有一个实例&#xff0c;并提供一个全局访问点。本文将深入探讨单例模式的概念、实现方式、使用场景以及潜…

通讯录管理(C++入门练习)

通讯录管理系统 系统需求&#xff1a; 通讯录是一个可以记录亲人、好友信息的工具。本文主要利用C来实现一个通讯录管理系统系统中需要实现的功能如下: 添加联系人:向通讯录中添加新人&#xff0c;信息包括(姓名、性别、年龄、联系电话、家庭住址)最多记录1000人 显示联系…

昇思25天学习打卡营第21天|DCGAN生成漫画头像

DCGAN原理 DCGAN&#xff08;深度卷积对抗生成网络&#xff0c;Deep Convolutional Generative Adversarial Networks&#xff09;是GAN的直接扩展。不同之处在于&#xff0c;DCGAN会分别在判别器和生成器中使用卷积和转置卷积层。 它最早由Radford等人在论文Unsupervised Re…

数据结构历年考研真题对应知识点(哈夫曼树和哈夫曼编码)

目录 5.5.1哈夫曼树和哈夫曼编码 1.哈夫曼树的定义 2.哈夫曼树的构造 【分析哈夫曼树的路径上权值序列的合法性&#xff08;2010&#xff09;】 【哈夫曼树的性质&#xff08;2010、2019&#xff09;】 3.哈夫曼编码 【根据哈夫曼编码对编码序列进行译码&#xff08;201…

【AMBA】AXI总线中的AXLEN、AXSIZE、AXBURST和4K边界

文章目录 AXI传输层级概念AXLEN[7:0]定义突发传输长度AXSIZE[2:0]定义突发传输transfer的位宽AXBURST[1:0]定义突发传输类型4K边界 AXI传输层级概念 在手册的术语表中&#xff0c;与 AXI 传输相关的有三个概念&#xff0c;分别是 transfer(beat)、burst、transaction。用一句话…

树莓派关机

文件 shutdown.sh #!/usr/bin/bash sudo shutdown -r nowpython 文件开头添加 #!/usr/bin/python3

C字符串和内存函数介绍(三)——其他的字符串函数

在#include<string.h>的这个头文件里面&#xff0c;除了前面给大家介绍的两大类——长度固定的字符串函数和长度不固定的字符串函数。还有一些函数以其独特的用途占据一席之地。 今天要给大家介绍的是下面这三个字符串函数&#xff1a;strstr&#xff0c;strtok&#xf…

前端web性能统计

前端web性能统计 1. 背景2. 业界方案2.1 腾讯2.2 蚂蚁金服2.3 字节跳动2.4 美团 3. 相关观念3.1 RAIL模型3.2 性能指标3.3 真实用户监控3.4 performance 4. 性能监控工具介绍5. 推荐采用方案 1. 背景 在如今的数字时代&#xff0c;网站和应用程序的性能对用户体验至关重要。用…

STM32MP135裸机编程:唯一ID(UID)、设备标识号、设备版本

0 资料准备 1.STM32MP13xx参考手册1 唯一ID&#xff08;UID&#xff09;、设备标识号、设备版本 1.1 寄存器说明 &#xff08;1&#xff09;唯一ID 唯一ID可以用于生成USB序列号或者为其它应用所使用&#xff08;例如程序加密&#xff09;。 &#xff08;2&#xff09;设备…

Git代码管理工具 — 3 Git基本操作指令详解

目录 1 获取本地仓库 2 基础操作指令 2.1 基础操作指令框架 2.2 git status查看修改的状态 2.3 git add添加工作区到暂存区 2.4 提交暂存区到本地仓库 2.5 git log查看提交日志 2.6 git reflog查看已经删除的记录 2.7 git reset版本回退 2.8 添加文件至忽略列表 1 获…

0.单片机工作原理

文章目录 最小系统 单片机芯片 时钟电路 复位电路 电源 最小系统 单片机芯片 本次51单片机的芯片为&#xff1a;STC89C52 Flash(闪存)程序存储器&#xff1a;存储程序的空间 SRAM&#xff1a;数据存储器&#xff0c;可用于存放程序执行的中间结果和过程数据 DPTR&#xff1a;…