基于lora技术对Gemma(2B)大模型的微调实践

一、概述

本文主要基于Lora技术,在Google colab上用A100对Gemma 2B大模型进行了指令微调,第一次指令微调是采用databricks-dolly-15k 作为数据集,取得了不错的微调效果,能准确用英文回答问题,但databricks-dolly-15k 毕竟是英文数据集,微调后的模型对中文的理解并不好。为了使模型对中文有更好的理解,笔者采用COIG-CQIA数据集对模型进行了指令微调,并展示了微调前后的效果对比。

《两个数据集说明》

databricks-dolly-15k 是一个开源数据集,其中包含数千名 Databricks 员工在 InstructGPT 论文中概述的几个行为类别中生成的指令跟踪记录,包括头脑风暴、分类、封闭式 QA、生成、信息提取、开放式 QA 和摘要。

图:databricks-dolly-15k Dataset card

图:databricks-dolly-15k 具体内容

COIG-CQIA全称为Chinese Open Instruction Generalist - Quality is All You Need, 是一个开源的高质量指令微调数据集,旨在为中文NLP社区提供高质量且符合人类交互行为的指令微调数据。COIG-CQIA以中文互联网获取到的问答及文章作为原始数据,经过深度清洗、重构及人工审核构建而成。

图:COIG-CQIA-full Dataset card

图:COIG-CQIA-full 具体内容

二、前置条件

获得模型访问权,选择Colab运行时,配置训练环境。

先在Kaggle上注册,然后获得Gemma 2B 的访问权;

然后在Google colab 配置环境,主要是GPU的选择,免费的是T4,建议采用付费的A100(为了节省时间,微调训练耗时T4需要30分钟左右,A100只需要2分钟左右)

最后 在Kaggle 上的account上生成令牌文件(主要是usename 和 API Key),并将令牌文件配置到colab环境。

三、微调步骤

因为直接使用databricks-dolly-15k进行微调时,可以基于原有代码进行快速验证,为了采用COIG-CQIA-full数据集进行微调,最直接的想法就是把代码中databricks-dolly-15k部分替换为COIG-CQIA-full,然后对代码进行稍微修改,但这一步花费了笔者非常多的时间,因为无论如何修改,总会报错。最总采用了以下办法:

1、先通过代码上传COIG-CQIA-full代码,

2、然后将其转换为和databricks-dolly-15k同格式的内容,并将新内容命名为databricks-dolly-15k-fb,

3、然后下载到本地检测下内容是否正确。在正确的前提下,取databricks-dolly-15k-fb内容的前1600行,保存到另外一个文件databricks-dolly-15k-fb1,

4、最后使用databricks-dolly-15k-fb1进行微调。

通过此方法,可以基本不修改原有微调语料处理代码的前提下,完成微调训练。

图:COIG-CQIA-full 格式转化后内容

本次选择1600行,主要是为了减少上传的时间,当然也可以更少,openai建议的50行

Example count recommendations计数建议示例

To fine-tune a model, you are required to provide at least 10 examples. We typically see clear improvements from fine-tuning on 50 to 100 training examples with gpt-3.5-turbo but the right number varies greatly based on the exact use case.
要微调模型,您需要提供至少 10 个示例。我们通常会看到对 50 到 100 个训练示例进行微调的明显改进, gpt-3.5-turbo 但正确的数量会根据确切的用例而有很大差异。

We recommend starting with 50 well-crafted demonstrations and seeing if the model shows signs of improvement after fine-tuning. In some cases that may be sufficient, but even if the model is not yet production quality, clear improvements are a good sign that providing more data will continue to improve the model. No improvement suggests that you may need to rethink how to set up the task for the model or restructure the data before scaling beyond a limited example set.
我们建议从 50 个精心制作的演示开始,看看模型在微调后是否显示出改进的迹象。在某些情况下,这可能就足够了,但即使模型尚未达到生产质量,明显的改进也是一个好兆头,表明提供更多数据将继续改进模型。没有改进表明,在扩展到有限的示例集之前,您可能需要重新考虑如何为模型设置任务或重组数据。

此外数据格式的检测也非常关键,openai官网有专门的格式检查代码。

Check data formatting检查数据格式

Once you have compiled a dataset and before you create a fine-tuning job, it is important to check the data formatting. To do this, we created a simple Python script which you can use to find potential errors, review token counts, and estimate the cost of a fine-tuning job.
编译数据集后,在创建微调作业之前,检查数据格式非常重要。为此,我们创建了一个简单的 Python 脚本,您可以使用它来查找潜在错误、查看令牌计数以及估计微调作业的成本。

四、微调前后效果展示

基于databricks-dolly-15k微调的效果不做展示,主要展示下基于COIG-CQIA-full微调后的效果展示。

4.1微调前的表现

图:微调前问答表现

图:微调后问答表现

再看几个微调后的表现

图:微调后问答表现2

4.2微调前后训练参数对比

采用Lora微调,训练参数量由25亿降低到130万左右;

图:训练参数对比

4.3用A100微调训练的关键片段

训练耗时104s,80ms每步,这里如果采用T4,需要训练接近30分钟。

图:采用A100进行微调

五、关键源码

一、格式转换代码

import json
from google.colab import files  # 如果你是在Google Colab环境中运行,需要导入该模块进行文件上传# 提示上传文件
print("请上传 COIG-CQIA-full.jsonl 文件")
uploaded = files.upload()# 获取上传文件名
uploaded_filename = list(uploaded.keys())[0]# 读取上传的COIG-CQIA-full.jsonl文件
with open(uploaded_filename, "r", encoding="utf-8") as f:coig_data = f.readlines()# 转换格式
converted_data = []
for line in coig_data:coig_entry = json.loads(line.strip())converted_entry = {"instruction": coig_entry["instruction"],"context": coig_entry["input"],"response": coig_entry["output"],"category": "open_qa" if coig_entry["task_type"]["minor"] == ["问答"] else "classification"}converted_data.append(converted_entry)# 将转换后的数据写入databricks-dolly-15k-fb.jsonl文件
with open("databricks-dolly-15k-fb.jsonl", "w", encoding="utf-8") as f:for entry in converted_data:f.write(json.dumps(entry, ensure_ascii=False) + "\n")print("转换完成,结果已保存到 databricks-dolly-15k-fb.jsonl 文件中")

二、数据处理代码

import json
data = []
with open("databricks-dolly-15k-fb1.jsonl") as file:for line in file:features = json.loads(line)# Filter out examples with context, to keep it simple.if features["context"]:continue# Format the entire example as a single string.template = "Instruction:\n{instruction}\n\nResponse:\n{response}"data.append(template.format(**features))# Only use 1000 training examples, to keep it fast.
data = data[:1000]

三、Lora微调代码

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(learning_rate=5e-5,weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])gemma_lm.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),optimizer=optimizer,weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

参考文档:

1、https://ai.google.dev/gemma/docs/lora_tuning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/806511.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

智慧公厕中的大数据、云计算和物联网技术引领未来公厕管理革命

现代社会对于公共卫生和环境保护的要求越来越高,智慧公厕作为城市基础设施建设的重要组成部分,正引领着公厕管理的革命。随着科技的不断进步,大数据、云计算和物联网技术的应用为智慧公厕带来了全新的可能性,(ZonTree中…

Spring Boot统一功能处理之拦截器

本篇主要介绍Spring Boot的统一功能处理中的拦截器。 目录 一、拦截器的基本使用 二、拦截器实操 三、浅尝源码 初始化DispatcherServerlet 处理请求(doDispatch) 四、适配器模式 一、拦截器的基本使用 在一般的学校或者社区门口,通常会安排几个…

selenium添加代理(有账号密码)

以下为各种尝试的记录,正确实现可直接参考最后一条! 1,导入Proxy库来添加capabilities属性:可以访问网站,但ip还是本机ip from selenium import webdriver from selenium.webdriver.chrome.options import Options f…

【TensorRT】TensorRT C# API 项目更新 (1):支持动态Bath输入模型推理(下篇)

4. 接口应用 关于该项目的调用方式在上一篇文章中已经进行了详细介绍,具体使用可以参考《最新发布!TensorRT C# API :基于C#与TensorRT部署深度学习模型》,下面结合Yolov8-cls模型详细介绍一下更新的接口使用方法。 4.1 创建并配…

Java零基础入门-Java反射机制

一、概述 我们都听说过java有个反射机制,通过反射机制我们可以更深入的控制程序的运行过程。例如,在程序进入到运行期间,由用户输入一个类名,然后我们可以动态获取到该类拥有的所有类结构、属性名和方法,甚至还可以任意…

Java快速入门系列-9(Spring框架与Spring Boot —— 深度探索及实践指南)

第九章:Spring框架与Spring Boot —— 深度探索及实践指南 9.1 Spring框架概述9.2 Spring IoC容器9.3 Spring AOP9.4 Spring MVC9.5 Spring Data JPA/Hibernate9.6 Spring Boot快速入门与核心特性9.7 Spring Boot的自动配置与启动流程详解9.8 创建RESTful服务与数据库交互实践…

专为苹果系统设计的精美可视化图表 | 开源日报 No.219

danielgindi/Charts Stars: 27.3k License: Apache-2.0 Charts 是为 iOS/tvOS/OSX 提供美观图表的开源项目,是跨平台 MPAndroidChart 在苹果设备上的实现。该项目提供了以下主要功能和优势: 支持 iOS、tvOS 和 macOS 平台使用 Swift 编写,可…

Ceph学习 -6.Nautilus版本集群部署

文章目录 1.集群部署1.1 环境概述1.1.1 基础知识1.1.2 环境规划1.1.3 小结 1.2 准备工作1.2.1 基本环境1.2.2 软件安装1.2.3 小结 1.3 Ceph部署1.3.1 集群创建1.3.2 部署Mon1.3.3 小结 1.4 Ceph部署21.4.1 Mon认证1.4.2 Mgr环境1.4.3 小结 1.5 OSD环境1.5.1 基本环境1.5.2 OSD实…

数据结构-移除元素(简单)

题目描述 给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素,并返回移除后数组的新长度。 不要使用额外的数组空间,你必须仅使用 O(1) 额外空间并 原地 修改输入数组。 元素的顺序可以改变。你不需要考虑数组中超出…

可视化大屏的应用(11):智慧运维领域的得力助手

一、什么是智慧运维 智慧运维(Smart Operations and Maintenance,简称智慧运维)是一种利用先进的信息技术和数据分析手段,对设备、设施或系统进行监测、分析和优化管理的运维方式。它通过实时监测数据、智能分析和预测&#xff0…

Redis中的集群(五)

集群 在集群中执行命令 MOVED错误。 当节点发现键所在的槽并非由自己负责处理的时候&#xff0c;节点就会向客户端返回一个MOVED错误&#xff0c;指引客户端转向至正在负责槽的节点&#xff0c;MOVED错误的格式为: MOVED <slot> <ip>:<port>其中slot为键…

centos 7.9 nginx本地化安装,把镜像改成阿里云

1.把centos7.9系统切换到阿里云的镜像源 1.1.先备份 mv /etc/yum.repos.d/CentOS-Base.repo /etc/yum.repos.d/CentOS-Base.repo.backup1.2.下载新的CentOS-Base.repo配置文件 wget -O /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun.com/repo/Centos-7.repo特别…

盲人独立出行的新里程:“盲人软件”赋能无障碍生活

作为一名资深记者&#xff0c;我始终致力于探索并分享那些以科技之力提升特殊群体生活质量的故事。最近&#xff0c;一款名为蝙蝠避障的盲人软件进入了我的视野&#xff0c;其强大的避障导航功能正悄然改变着视障人士的出行方式&#xff0c;赋予他们前所未有的独立生活能力。 …

【多线程】线程池Future和FutureTask

【多线程】线程池Future和FutureTask 【一】Future概述【1】Future的出现原因【2】Future结构图 【二】Future详解【1】Future接口源码【2】Future的5个方法【3】ThreadPoolExecutor提供了三个方法&#xff0c;来获取返回值&#xff08;1&#xff09;submit(Runnable r)&#x…

【ARM 裸机】汇编 led 驱动之原理分析

1、我们为什么要学习汇编&#xff1f;&#xff1f;&#xff1f; 之前我们或许接触过 STM32 以及其他的 32 位的 MCU ,都是基于 C 语言环境进行编程的&#xff0c;都没怎么注意汇编&#xff0c;是因为 ST 公司早已将启动文件写好了&#xff0c;新建一个 STM32 工程的时候&#…

回归预测 | Matlab实现SSA-GRNN麻雀算法优化广义回归神经网络多变量回归预测(含优化前后预测可视化)

回归预测 | Matlab实现SSA-GRNN麻雀算法优化广义回归神经网络多变量回归预测(含优化前后预测可视化) 目录 回归预测 | Matlab实现SSA-GRNN麻雀算法优化广义回归神经网络多变量回归预测(含优化前后预测可视化)预测效果基本介绍程序设计参考资料预测效果

图像生成:Pytorch实现一个简单的对抗生成网络模型

图像生成&#xff1a;Pytorch实现一个简单的对抗生成网络模型 前言相关介绍具体步骤准备并读取数据集定义生成器定义判别器定义损失函数定义优化器开始训练完整代码 训练生成的图片 前言 由于本人水平有限&#xff0c;难免出现错漏&#xff0c;敬请批评改正。更多精彩内容&…

实战Java高并发程序设计课

课程介绍 实战Java高并发程序设计课是一门针对Java开发者的培训课程&#xff0c;重点关注如何设计和优化高并发的程序。学员将学习到并发编程的基本概念、线程池的使用、锁机制、并发集合等技术&#xff0c;并通过实际案例进行实践操作。这门课程旨在帮助开发者掌握并发编程的…

最祥解决python 将Dataframe格式数据上传数据库所碰到的问题

碰到的问题 上传Datafrane格式的数据到数据库 会碰见很多错误 举几个很普遍遇到的问题(主要以SqlServer举例) 这里解释下 将截断字符串或二进制数据 这个是字符长度超过数据库设置的长度 然后还有字符转int失败 或者字符串转换日期/或时间失败 这个是碰到的需要解决的最多的问…

Java面试题戏剧

目录 第一幕 、第一场&#xff09;某大厦楼下大门前第二场&#xff09;电梯中第三场&#xff09;走廊中 第二幕、第一场&#xff09;公司前台第二场&#xff09;公司卫生间 第三幕、第一场&#xff09;一场异常面试 第四幕 、第一场&#xff09;大厦楼下门口第二场&#xff09;…