使用Colaboratory免费GPU资源微调Llama3-8b

Llama3微调过程

准备工作

Google Colaboratory

Google Colaboratory,也称为 Colab,是一个基于云的平台,允许用户编写和执行 Python 代码。 它为机器学习和数据分析任务提供了便利的环境,并内置了对 TensorFlow 等流行库的支持。

在Google 创建Colab笔记副本

  • 新建->更多->关联更多应用

image-20240605171618341

  • 搜索Google Colaboratory

image-20240605171558617

  • 安装 Colaboratory

image-20240605171745990

  • 再次新建 Google Colaboratory 文件

image-20240605171914405

  • 进入Colab笔记副本

image-20240605173617598

连接到T4 GPU

  • 更改运行时类型

image-20240605174816472

  • 选择T4

image-20240605174731452

连接到谷歌云盘

from google.colab import drive
drive.mount('/content/drive')

运行之后要弹出一个页面进行授权

下载Unsloth

unsloth是一个开源的大模型训练加速项目, 用它来微调 Llama 3、Mistral、Phi-3 和 Gemma 速度提高 2-5 倍,内存减少 80%

目前开源版本的 Unsloth,仅支持单机单卡训练,且仅支持 Llama2、Llama3、Mistral、Gemma、Zephyr、TinyLlama、Phi-3 等模型

%%capture
# Installs Unsloth, Xformers (Flash Attention) and all other packages!
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps xformers trl peft accelerate bitsandbytes

下载预训练模型

from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 
dtype = None 
load_in_4bit = True # 支持的预4位量化模型,可实现4倍更快的下载速度和无OOM。
fourbit_models = ["unsloth/mistral-7b-bnb-4bit","unsloth/mistral-7b-instruct-v0.2-bnb-4bit","unsloth/llama-2-7b-bnb-4bit","unsloth/gemma-7b-bnb-4bit","unsloth/gemma-7b-it-bnb-4bit", "unsloth/gemma-2b-bnb-4bit","unsloth/gemma-2b-it-bnb-4bit","unsloth/llama-3-8b-bnb-4bit", 
] model, tokenizer = FastLanguageModel.from_pretrained(model_name = "unsloth/llama-3-8b-bnb-4bit",max_seq_length = max_seq_length,dtype = dtype,load_in_4bit = load_in_4bit,# token = "hf_...", # 如果使用像meta-llama/Llama-2-7b-hf这样的门控模型,请使用其中一个
)

注意 : 这里运行之前要连接到T4 GPU, 否则会报cuda缺失错误

运行之后:

image-20240605175501190

设置LoRA训练参数

model = FastLanguageModel.get_peft_model(model,r = 16, # 选择任何大于0的数字!建议使用8、16、32、64、128target_modules = ["q_proj", "k_proj", "v_proj", "o_proj","gate_proj", "up_proj", "down_proj",],lora_alpha = 16,lora_dropout = 0, # 支持任何值,但等于0时经过优化bias = "none",    # 支持任何值,但等于"none"时经过优化# [NEW] "unsloth" 使用的VRAM减少30%,适用于2倍更大的批处理大小!use_gradient_checkpointing = "unsloth", # True或"unsloth"适用于非常长的上下文random_state = 3407,use_rslora = False,  # 我们支持排名稳定的LoRAloftq_config = None, # 以及LoftQ
)

什么是 LoRA?

LoRA 的核心思想是通过引入低秩矩阵的变化来代替对原始大矩阵的更新,从而减少训练过程中需要更新的参数数量。具体来说,在模型的某些权重矩阵中引入一个低秩分解(两个小矩阵的乘积),并只训练这些小矩阵,而不是原始的大矩阵。

数据准备

这里采用的是来自yahma的Alpaca数据集,这是原始Alpaca数据集的经过筛选的版本,包含了来自原始数据集中的52000条数据。可以用自己的数据准备替换此代码部分

alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.### Instruction:
{}### Input:
{}### Response:
{}"""EOS_TOKEN = tokenizer.eos_token # 必须添加EOS_TOKEN
def formatting_prompts_func(examples):instructions = examples["instruction"]inputs       = examples["input"]outputs      = examples["output"]texts = []for instruction, input, output in zip(instructions, inputs, outputs):# 必须添加EOS_TOKEN,否则生成将无法停止!text = alpaca_prompt.format(instruction, input, output) + EOS_TOKENtexts.append(text)return { "text" : texts, }
passfrom datasets import load_dataset
dataset = load_dataset("yahma/alpaca-cleaned", split = "train")
dataset = dataset.map(formatting_prompts_func, batched = True,)

运行之后:

image-20240605175810993

训练模型

参数设置

from trl import SFTTrainer  # SFTTrainer:来自 trl(用于大语言模型的低秩适应性训练的库),是一个用于训练模型的类。
from transformers import TrainingArguments # 来自 transformers 库,是一个包含训练参数的类。trainer = SFTTrainer(model = model,  # model:预训练的语言模型。tokenizer = tokenizer, # tokenizer:对应的分词器train_dataset = dataset, # train_dataset:用于训练的文本数据集。dataset_text_field = "text", # dataset_text_field:数据集中包含文本的字段名称。max_seq_length = max_seq_length, # max_seq_length:每个输入序列的最大长度。dataset_num_proc = 2, # 数据处理时使用的进程数量,设置为 2 表示使用两个进程。packing = False, # 是否启用序列打包。打包可以提高短序列训练的效率,这里设置为 False。args = TrainingArguments(per_device_train_batch_size = 2, # 每个设备(如 GPU)上的训练批量大小,设置为 2gradient_accumulation_steps = 4,# 梯度累积步骤数,设置为 4。即每累计 4 个批次的梯度后才进行一次权重更新。warmup_steps = 5, # 学习率预热步骤数,设置为 5max_steps = 60, # 训练的最大步骤数,设置为 60learning_rate = 2e-4, # 初始学习率,设置为 2e-4fp16 = not torch.cuda.is_bf16_supported(), # 是否使用 16 位浮点数进行训练bf16 = torch.cuda.is_bf16_supported(), # 如果 GPU 支持 bf16,则使用 bf16logging_steps = 1, # 日志记录的步数间隔,设置为 1,即每步都记录日志optim = "adamw_8bit", # 优化器,设置为 adamw_8bit,表示使用 8 位精度的 AdamW 优化器,可以减少显存占用weight_decay = 0.01, # 权重衰减系数,设置为 0.01,用于防止过拟合lr_scheduler_type = "linear", # 学习率调度器类型,设置为 linear,表示线性调度seed = 3407, # 随机种子,设置为 3407,用于确保结果的可重复性。output_dir = "outputs", # 输出目录,训练生成的模型和日志会保存在这个目录下),
)

显示当前内存状态

gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. 最大内存 = {max_memory} GB.")
print(f"{start_gpu_memory} GB of 内存剩余。")
  • 输出

image-20240605180217554

开始训练

trainer_stats = trainer.train()

运行之后:

image-20240605182659016

运行模型

# alpaca_prompt = Copied from above
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[alpaca_prompt.format("Continue the fibonnaci sequence.", # instruction"1, 1, 2, 3, 5, 8", # input"", # output - leave this blank for generation!)
], return_tensors = "pt").to("cuda")outputs = model.generate(**inputs, max_new_tokens = 64, use_cache = True)
tokenizer.batch_decode(outputs)

连接对话

# alpaca_prompt = Copied from above
FastLanguageModel.for_inference(model) # Enable native 2x faster inference
inputs = tokenizer(
[alpaca_prompt.format("Continue the fibonnaci sequence.", # instruction"1, 1, 2, 3, 5, 8", # input"", # output - leave this blank for generation!)
], return_tensors = "pt").to("cuda")from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 128)

保存模型

  • 获取当前的保存的地址
import os
current_path = os.getcwd()
model.save_pretrained("lora_model") # 本地保存
print(f"保存地址 {current_path}/lora_model")
  • 将模型保存为预训练的GGUF格式
if True: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")

运行后:

image-20240605191455492

生成的模型移动到谷歌云

import shutilsource_path = '/content/model-unsloth.Q4_K_M.gguf'
destination_path = '/content/drive/MyDrive/'# 移动文件,内容有点大需要点时间
shutil.move(source_path, destination_path)
print("请使用谷歌云MyDrive中下载该内容")

运行后:

image-20240605192100443

查看模型

image-20240605192125976

LM Studio

导入微调后的模型

image-20240606112704431

使用微调后的模型

image-20240606112616928

问题记录

  • GPU 计算能力不够

image-20240605180526918

原因 :

之前安装的 PyTorch 版本与 xFormers 编译时使用的 PyTorch 版本不匹配

解决办法:

  • 重新安装xformers
pip uninstall xformers -y
!pip install --no-deps xformers trl peft accelerate bitsandbytes

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/pingmian/24362.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

1.vue2.x-初识及环境搭建

目录 1.下载nodejs v16.x 2.设置淘宝镜像源 3.安装脚手架 4.创建一个项目 5.项目修改 代码地址:source-code: 源码笔记 1.下载nodejs v16.x 下载地址:Node.js — Download Node.js 2.设置淘宝镜像源 npm config set registry https://registry.…

【c语言】指针就该这么学(3)

🌟🌟作者主页:ephemerals__ 🌟🌟所属专栏:C语言 目录 一、函数指针 1.函数指针变量的创建 2.函数指针变量的使用 二、typedef关键字 三、函数指针数组 1.函数指针数组的概念 2.函数指针数…

从零开始实现自己的串口调试助手(8)-循环发送

循环发送 准备 创建槽函数 设置QSpinBox的最大值 注意: // 我们不能在qt的ui线程中延时,否则将导致页面刷新问题 //QThread::msleep(ui->spinBox->text().toInt());//设置下次发送时间间隔 定时器实现 关联信号与槽: //添加自动换行定…

Pycharm创建Conda虚拟环境时显示CondaHTTPErOT

原因:conda源出问题了,之前可以用,现在报错。 最好的解决方案:找到conda源,换源即可。 步骤: 1.修改 .condarc 文件(文件的位置在:C:\Users\(你的用户名)\.condarc)&a…

Python中的@staticmethod和@classmethod装饰器

名词解释 本文主要介绍静态方法staticmethod和类方法classmethod在类中的应用,在介绍这两个函数装饰器之前,先介绍类中的几个名词,便于后面的理解: 类对象:定义的类就是类对象 类属性:定义在__init__ 外…

基于自动化工具autox.js的抢票(猫眼)

1.看到朋友圈抢周杰伦、林俊杰演唱会票贼难信息,特研究了一段时间,用autox.js写了自动化抢票脚本,购票页面自动点击下单(仅限安卓手机)。 2.脚本运行图 3.前期准备工作 (1)autox.js社区官网&am…

DNF手游攻略:主C职业推荐,云手机强力辅助!

在《地下城与勇士》手游中,你是否厌倦了重复刷图和无休止的手动操作?利用VMOS云手机,你可以一键解决这些烦恼,实现自动打怪、一机多开,让游戏变得更加轻松愉快。下面我们将介绍如何使用VMOS云手机,以及推荐…

MySQL-Explain使用

MySQL-Explain使用 type列 type列 这一列表示关联类型或访问类型,即MySQL决定如何查找表中的行,查找数据行记录的大概范围。 依次从最优到最差分别为:system > const > eq_ref > ref > range > index > ALL 一般来说&…

rk3568 norflash+pcei nvme 配置

文章目录 rk3568 norflashpcei nvme 配置1,添加parameter_nor.txt文件2 修改编译规则3 修改uboot4 修改BoardConfig.mk5 修改kernel pcei配置6 编译7 烧录 rk3568 norflashpcei nvme 配置 1,添加parameter_nor.txt文件 device/rockchip/rk356x/rk3568_…

【学习笔记】Windows GDI绘图(十三)动画播放ImageAnimator(可调速)

文章目录 前言定义方法CanAnimate 是否可动画显示Animate 动画显示多帧图像UpdateFramesStopAnimate终止动画Image.GetFrameCount 获取动画总帧数Image.GetPropertyItem(0x5100) 获取帧延迟 自定义GIF播放(可调速) 前言 在前一篇文章中用到ImageAnimator获取了GIF动画的一些属…

vue3 监听器,组合式API的watch用法

watch函数 在组合式 API 中,我们可以使用 watch 函数在每次响应式状态发生变化时触发回调函数 watch(ref,callback(newValue,oldValue),option:{}) ref:被监听的响应式量,可以是一个 ref (包括计算属性)、一个响应式…

STM32—按键控制LED(定时器)

目录 1 、 电路构成及原理图 2 、编写实现代码 main.c exit.c 3、代码讲解 4、烧录到开发板调试、验证代码 5、检验效果 此笔记基于朗峰 STM32F103 系列全集成开发板的记录。 1 、 电路构成及原理图 EXTI(External interrupt/event controller&#xff…

查询SQL02:寻找用户推荐人

问题描述 找出那些 没有被 id 2 的客户 推荐 的客户的姓名。 以 任意顺序 返回结果表。 结果格式如下所示。 题目分析: 这题主要是要看这null值会不会用,如果说Java玩多了,你去写SQL时就会有问题。在SQL中判断是不是null值用的是is null或…

泛微开发修炼之旅--10基于Ecology实现附件上传,并将上传后的文件id存入表单附件控件中的示例及源码

文章链接:泛微开发修炼之旅--10基于Ecology实现附件上传,并将上传后的文件id存入表单附件控件中的示例及源码

tomcat10部署踩坑记录-公网IP和服务器系统IP搞混

1. 服务器基本条件 使用的阿里云服务器,镜像系统是Ubuntu16.04java version “17.0.11” 2024-04-16 LTS装的是tomcat10.1.24阿里云服务器安全组放行了:8080端口 服务器防火墙关闭: 监听情况和下图一样: tomcat正常启动&#xff…

MySQL进阶——索引使用规则

在上篇文章我们学习了MySQL进阶——索引,这篇文章学习MySQL进阶——索引使用规则。 索引使用规则 在使用索引时,需要遵守一些使用规则,否则索引会部分失效或全部失效。 最左前缀法则 最左前缀法则是查询从索引的最左列开始,并…

Vxe UI vxe-form 实现折叠表单,当表单很多时实现自动收起与展开

Vxe UI vue vxe-form 实现折叠表单&#xff0c;当表单很多时实现自动收起与展开 代码 folding 用于将当前表单项设置为默认隐藏 collapse-node 设置折叠按钮&#xff0c;加上之后会自动在该表单项的右侧显示一个折叠按钮 <template><div><vxe-formtitle-colo…

谷歌上架防关联,打包环境到底是不是关联因素之一?

在Google play上架应用&#xff0c;防关联是开发者们最关注的问题之一&#xff0c;只要开发者账号被谷歌审核系统与其它违规的开发者账号或应用存在关联&#xff0c;就很有可能被封号。 如果账号被封了&#xff0c;通常谷歌的封号通知邮件里只是写了因为关联或高风险、多次违规…

kafka-生产者拦截器(SpringBoot整合Kafka)

文章目录 1、生产者拦截器1.1、创建生产者拦截器1.2、KafkaTemplate配置生产者拦截器1.3、使用Java代码创建主题分区副本1.4、application.yml配置----v1版1.5、屏蔽 kafka debug 日志 logback.xml1.6、引入spring-kafka依赖1.7、控制台日志 1、生产者拦截器 1.1、创建生产者拦…

BeanDefinitionReader接口,Spring加载Bean的过程(非常流畅和容易理解)(Spring源码分析1)

一、前言 前言部分&#xff0c;介绍Spring框架的工作和大致原理&#xff0c;有基础的小伙伴可以跳过。 我们现在最常使用的开发框架SSM&#xff0c;分别是Spring、Spring MVC和Mybatis&#xff0c;其功能已经超出原生Spring非常多&#xff0c;所以想学习Spring原理&#xff0c;…