Idefics2-8B多模态大模型微调指南

我们生活在大模型 (LLM) 时代,几乎每周你都会听到一种新的语言模型问世。从谷歌的 Gemini 和 Gemma 模型到 Meta 最新的 Llama 3 和微软的微型 Phi-3 模型,这些行业巨头之间正在进行激烈的在线竞争,以争夺头把交椅。

在这一连串的活动中,最引人注目的是这些科技巨头愿意向开发者社区开放其中一些语言模型。

开放模型有什么好处?

向开发者社区开放模型带来了几个好处,包括开发人员可以针对特定用例微调这些模型来解决有趣的问题。

如果你现在是一个大模型 (LLM)粉丝,我相信你已经尝试微调至少一个开放模型来探索它们的能力。在讨论所有这些流行的模型时,很少能找到同时开放和多模式的模型。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割 

我最近探索到的其中一个隐藏宝藏是 Hugging Face 🤗 构建的 Idefics2-8B 视觉语言模型。它是开放的,支持多模态,接受图像和文本输入序列。Idefics2 模型可以回答有关图像的问题、描述视觉内容、从多幅图像创建故事等等。

我一直在寻找详细介绍在自定义数据集上微调视觉语言模型的步骤的文章和教程。虽然大多数文章都涵盖了微调过程(在已有的数据集上),但它们往往忽略了数据准备的关键步骤。

本博客将为你带来这一点:微调 Idefics2 模型的综合指南,你不仅可以学习如何微调视觉语言模型,还可以从头开始准备自己的自定义数据集。你可以在此 GitHub 存储库中找到完整的 Colab 笔记本和原始数据集。

但在开始之前,让我们快速了解视觉语言模型的高级架构,以及它们与标准 LLM 有何不同。

视觉语言模型是从图像和文本中学习的多模态模型,从图像和文本输入生成文本输出。它们在零样本能力、泛化以及图像识别、问答和文档理解等各种任务方面表现出色。

视觉语言模型

这些模型还可以捕获空间属性并输出特定主题的边界框或分割蒙版。在此处详细了解 VLM 及其架构。

现在你已经对 VLM 有了基本的了解,是时候展示一下了。让我们深入了解代码。

1、数据准备

与 LLM 不同,VLM 的数据集格式略有不同,因为除了标准文本数据外还引入了图像和视频。

今天,我们将在文档图像上微调 Idefics2 模型,以进行视觉问答。我们的训练数据来自 DocVQA 数据集的子采样版本,并进行了轻微修改以从头开始重新创建整个数据集。

克隆以下存储库:

!git clone https://github.com/NSTiwari/Fine-tune-IDEFICS-Vision-Language-Model

在存储库中,你将找到一个数据集文件夹,其中包含一个图像子文件夹。此子文件夹包含训练和测试图像集。确切地说,有 1000 张图像用于训练,200 张图像用于测试。

以下是数据的主要内容:

DocVQA 图像数据集样本示例

这些图像是各种文档(如新闻文章、电子邮件、发票、报告、广告等)的黑白扫描件。

除了图像之外,存储库中还有一个 qa_text.csv 文件,其中包含有关所有图像的详细信息。由于我们正在处理问答任务,因此 CSV 文件包含每个图像的问题或查询及其对应的答案。

DocVQA CSV数据集样本示例

现在我们已经将图像及其相关内容分别存储起来,让我们将它们合并以创建一个用于训练模型的格式化数据集。

安装以下库:

!pip install -q git+https://github.com/huggingface/transformers.gi 
!pip install -q accelerate datasets peft bitsandbytes
from datasets import Dataset, DatasetDict, Image
import pandas as pd
import os# Define train and test size.
TRAIN_SAMPLES = 1000
TEST_SAMPLES = 200
TEST_SIZE = 0.166 ## Define the directory containing the images.
train_images_directory = '/content/Fine-tune-IDEFICS-Vision-Language-Model/dataset/images/train/'
test_images_directory = '/content/Fine-tune-IDEFICS-Vision-Language-Model/dataset/images/test/'# Read the CSV Q&A text.
qa_text = pd.read_csv('/content/Fine-tune-IDEFICS-Vision-Language-Model/dataset/qa_text.csv')# Create a list of image paths.
train_image_paths = [os.path.join(train_images_directory, f'train_{i}.jpg') for i in range(TRAIN_SAMPLES)]
test_image_paths = [os.path.join(test_images_directory, f'test_{i}.jpg') for i in range(TEST_SAMPLES)]
image_paths = train_image_paths + test_image_paths# Create a list of other columns such as id, query, and answer.
ids = ids = qa_text['id'].tolist()
queries = qa_text['query'].tolist()
answers = qa_text['answers'].tolist()# Create the dataset dictionary.
dataset_dict = {'id': ids,'image': image_paths,'query': queries,'answers': answers
}# Create the dataset.
dataset = Dataset.from_dict(dataset_dict)# Cast the 'image' column to Image type.
dataset = dataset.cast_column("image", Image())# Split the dataset into train and test.
split_dataset = dataset.train_test_split(test_size=TEST_SIZE, shuffle=False)# Push the dataset on Hugging Face Hub.
split_dataset.push_to_hub("NSTiwari/DocumentIDEFICS_QA")

上述脚本将图像与文本查询和答案结合起来,创建了一个统一的数据集,随后将其上传到 Hugging Face 🤗。你可以通过此链接访问数据集。

Hugging Face 的 DocumentIDEFICS_QA 数据集

恭喜您创建了自定义数据集。这确实很简单,不是吗?如果你希望为其他用例创建数据集,只需按照与之前相同的步骤操作即可。

2、加载数据集

现在我们已经准备好数据集,让我们继续加载它。

from datasets import load_datasettrain_dataset = load_dataset("NSTiwari/DocumentIDEFICS_QA", split="train")
eval_dataset = load_dataset("NSTiwari/DocumentIDEFICS_QA", split="test")

检查训练数据:

print(train_dataset[0])
train_dataset[0]['image']

这是训练数据集中记录的显示方式。图像作为 JpegImageFile 嵌入到数据集中:

{'id': 'train_0','image': <PIL.JpegImagePlugin.JpegImageFile image mode=L size=1695x2025>,'query': 'what is the date mentioned in this letter?','answers': "['1/8/93']"
}

train_0.jpg

3、配置 LoRA 适配器

训练或微调语言模型本身就是一项艰巨的任务,除非你拥有高计算能力的 GPU,否则几乎不可能完成。

随着参数高效微调 (PEFT) 的引入,你无需再担心训练整个语言模型。

借助 LoRA 和 QLoRA 等技术,你可以通过显著减少可训练参数的数量来有效地微调这些大型模型。这不仅可以加速微调过程,还可以节省内存使用量。

import torch
from peft import LoraConfig
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics2ForConditionalGenerationDEVICE = "cuda:0"
USE_LORA = False
USE_QLORA = Trueprocessor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b",do_image_splitting=False
)
if USE_QLORA or USE_LORA:lora_config = LoraConfig(r=8,lora_alpha=8,lora_dropout=0.1,target_modules='.*(text_model|modality_projection|perceiver_resampler).*(down_proj|gate_proj|up_proj|k_proj|q_proj|v_proj|o_proj).*$',use_dora=False if USE_QLORA else True,init_lora_weights="gaussian")if USE_QLORA:bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.float16)model = Idefics2ForConditionalGeneration.from_pretrained("HuggingFaceM4/idefics2-8b",torch_dtype=torch.float16,quantization_config=bnb_config if USE_QLORA else None,)model.add_adapter(lora_config)model.enable_adapters()
else:model = Idefics2ForConditionalGeneration.from_pretrained("HuggingFaceM4/idefics2-8b",torch_dtype=torch.float16,_attn_implementation="flash_attention_2", # Need GPUs like A100 or H100.).to(DEVICE)

上述代码块为 Idefics2–8b 模型配置了 LoRA 适配器。

QLoRA 是 Quantized LoRA 的缩写,是 LoRA 的增强版本。顾名思义,QLoRA 量化了权重参数的精度,将 32 位参数压缩为 4 位格式。

LLM 内存需求的大幅减少使得微调变得容易,在硬件资源有限的情况下尤其有用。

4、创建数据整理器

数据整理器是使用数据集元素列表作为输入来形成批处理的对象。这些元素与 train_dataset 或 eval_dataset 的元素类型相同。

import randomclass MyDataCollator:def __init__(self, processor):self.processor = processorself.image_token_id = processor.tokenizer.additional_special_tokens_ids[processor.tokenizer.additional_special_tokens.index("<image>")]def __call__(self, examples):texts = []images = []for example in examples:image = example["image"]question = example["query"]['en']answer = random.choice(example["answers"])messages = [{"role": "user","content": [{"type": "text", "text": "Answer briefly."},{"type": "image"},{"type": "text", "text": question}]},{"role": "assistant","content": [{"type": "text", "text": answer}]}]text = processor.apply_chat_template(messages, add_generation_prompt=False)texts.append(text.strip())images.append([image])batch = processor(text=texts, images=images, return_tensors="pt", padding=True)labels = batch["input_ids"].clone()labels[labels == processor.tokenizer.pad_token_id] = self.image_token_idbatch["labels"] = labelsreturn batchdata_collator = MyDataCollator(processor)

5、设置训练参数

配置超参数来训练模型:

from transformers import TrainingArguments, Trainertraining_args = TrainingArguments(output_dir = "IDEFICS_DocVQA",learning_rate = 2e-4,fp16 = True,per_device_train_batch_size = 2,per_device_eval_batch_size = 2,gradient_accumulation_steps = 8,dataloader_pin_memory = False,save_total_limit = 3,evaluation_strategy ="steps",save_strategy = "steps",eval_steps = 10,save_steps = 25,max_steps = 25,logging_steps = 5,remove_unused_columns = False,push_to_hub=False,label_names = ["labels"],load_best_model_at_end = False,report_to = "none",optim = "paged_adamw_8bit",
)
trainer = Trainer(model = model,args = training_args,data_collator = data_collator,train_dataset = train_dataset,eval_dataset = eval_dataset
)

6、开始训练

现在可以启动训练了:

trainer.train()

我使用 Google Colab 上的 T4 GPU 花了大约一个小时对模型进行了 25 步微调。使用的训练步骤和示例越多,结果就越好。考虑到 T4 的局限性,这是我训练模型的最佳方法。

训练结果

7、评估模型

现在,是时候在测试示例上评估模型了。

test_example = eval_dataset[0]
test_example["image"]

test_0.jpg

model.eval()image = test_example["image"]
query = test_example["query"]messages = [{"role": "user","content": [{"type": "text", "text": "Answer briefly."},{"type": "image"},{"type": "text", "text": query}]}
]text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=[text.strip()], images=[image], return_tensors="pt", padding=True)
generated_ids = model.generate(**inputs, max_new_tokens=64)
generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].size(1):], skip_special_tokens=True)
print(generated_texts)Question: What the location address of NSDA?
Answer: [‘1128 SIXTEENTH ST., N. W., WASHINGTON, D. C. 20036’, ‘1128 sixteenth st., N. W., washington, D. C. 20036’]

我要求微调后的模型看图回答如下问题:

What the location address of NSDA?

模型的回答如下:

[‘1128 SIXTEENTH ST., N. W., WASHINGTON, D. C. 20036’, ‘1128 sixteenth st., N. W., washington, D. C. 20036’]

这太棒了——该模型已经为测试图像中提出的问题提供了准确的答案。

8、将微调模型推送到 Hugging Face

如果你希望将来无需重新训练即可访问该模型,最好将其推送到 HF 上。

# Login to your HF account.
from huggingface_hub import notebook_login
notebook_login()from huggingface_hub import whoami
from pathlib import Path# Output directory.
output_dir = "IDEFICS_DocVQA"
repo_name = "IDEFICS2-DocVQA-fine-tuned"
username = whoami(token=Path("/root/.cache/huggingface/"))["name"]
repo_id = f"{username}/{repo_name}"
from huggingface_hub import upload_folder, create_reporepo_id = create_repo(repo_id, exist_ok=True).repo_idupload_folder(repo_id=repo_id,folder_path=output_dir,commit_message="Pushed the IDEFICS2 fine-tuned model.",ignore_patterns=["step_*", "epoch_*"],
)

就这些了。恭喜你成功在自定义数据集上微调了 Idefics2–8B 模型。我希望你学到了从头开始创建图像文本数据集和微调视觉语言模型的宝贵见解。


原文链接:Idefics2-8B微调简明教程 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/21465.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java 18新特性深度解析:提升开发效率与性能的革新工具

在Java的世界中&#xff0c;每一次更新都带来新的惊喜和挑战。Java 18作为长期支持版本&#xff0c;不仅延续了Java语言的稳定性和可靠性&#xff0c;还引入了一系列令人兴奋的新特性&#xff0c;旨在进一步提升开发者的生产力和应用程序的性能。本文将深入探讨Java 18中的关键…

AtCoder Regular Contest 179 (ABC题)视频讲解

A - Partition Problem Statement You are given integers N N N and K K K. The cumulative sums of an integer sequence X ( X 1 , X 2 , … , X N ) X(X_1,X_2,\dots ,X_N) X(X1​,X2​,…,XN​) of length N N N is defined as a sequence Y ( Y 0 , Y 1 , … , …

交互设计如何助力传统技艺在当代复兴?

背景介绍 榫卯是中国传统木工中一种独特的接合技术&#xff0c;它通过构件间的凸凹部分相互配合来实现两个或多个构件的紧密结合。这种结构方式不依赖于钉子或其他金属连接件&#xff0c;而是利用木材自身的特性&#xff0c;通过精巧的设计和工艺&#xff0c;实现构件间的稳定…

GEE数据集:美国植被干旱响应指数 (Vegetation Drought Response Index,VegDRI)数据集

植被干旱响应指数 (VegDRI) 简介 植被干旱响应指数&#xff08;VegDRI&#xff09;是一个每周一次的地理空间模型&#xff0c;用于描述干旱对美国本土植被造成的压力。VegDRI干旱监测工具是由美国地质调查局EROS中心、内布拉斯加大学国家干旱缓解中心&#xff08;NDMC&#…

计算机网络学习实践:配置主机通过DHCP获取IP并通过域名访问web服务器

计算机网络学习实践&#xff1a;配置主机通过DHCP获取IP并通过域名访问web服务器 点一点就能配置&#xff0c;不需要输入命令 1.实验准备 实验环境&#xff1a;思科的模拟器 实验设备&#xff1a; 3个服务器&#xff0c;1个二层交换机&#xff08;不是三层的&#xff09;&a…

一个弹出的虚假安全警告去除

虚假的安全警告 poratus.azurewebsites.net Pornographic spyware detected! Remove viruses with Avira Antivirus 通过 Microsoft Edge GPT-4 (OpenAI) 这个提示可能是一个虚假的安全警告&#xff0c;被称为“恐吓软件”&#xff08;scareware&#xff09;&#xff0c;旨在…

名下企业查询,清晰明了;在线操作,方便快捷

在现代社会&#xff0c;越来越多的人开始涉足创业和投资&#xff0c;拥有自己的企业成为一种时尚。然而&#xff0c;随之而来的是繁琐的企业注册流程和复杂的信息查询。为了解决这个问题&#xff0c;挖数据平台推出了一项名下企业查询接口&#xff0c;提供了一种方便快捷的方式…

计算机网络介绍

计算机网络介绍 概述网络概述相关硬件 链路层VLAN概念VLAN 特点VLAN 的划分帧格式端口类型原理 STP概念特点原理 Smart Link概念特点组网 网络层ARP概念原理 IP概念版本IP 地址 IPv4IP 地址数据报格式 IPv6特点IP 地址数据报格式 ICMP概念分类报文格式 VRRP概念原理报文格式 OS…

片上电控系统集成技术

一、背景 片上电机控制系统集成技术&#xff08;On-Chip Motor Control System Integration&#xff09;是一种先进的电子工程技术&#xff0c;它主要聚焦于将复杂的电机控制算法和硬件组件整合到单一集成电路&#xff08;IC&#xff09;中&#xff0c;以便于高效、精确地管理…

计算机毕业设计 | 基于Koa+vue的高校宿舍管理系统宿舍可视化系统

项目介绍 项目背景 随着科技的发展&#xff0c;智能化管理越来越重要。大学生在宿舍的时间超过了1/3&#xff0c;因此良好的宿舍管理对学生的生活和学习极为关键。学生宿舍管理系统能够合理安排新生分配宿舍&#xff0c;不浪费公共资源&#xff0c;减轻学校管理压力&#xff…

关于工作组

什么是局域网&#xff08;内网&#xff09; 我们常说的内网指的就是局域网&#xff0c;局域网&#xff08;Local Area Network&#xff0c;简称LAN&#xff09;是指在相对较小的地理范围内&#xff0c;如一个办公室、学校、住宅区或建筑群内部&#xff0c;通过通信设备&#xf…

上位机图像处理和嵌入式模块部署(f407 mcu中tf卡读写和fatfs挂载)

【 声明&#xff1a;版权所有&#xff0c;欢迎转载&#xff0c;请勿用于商业用途。 联系信箱&#xff1a;feixiaoxing 163.com】 很早之前&#xff0c;个人对tf卡并不是很重视&#xff0c;觉得它就是一个存储工具而已。后来在移植v3s芯片的时候&#xff0c;才发现很多的soc其实…

如何监控慢 SQL?

引言&#xff1a;在开发和维护数据库驱动的应用程序时&#xff0c;监控慢 SQL 查询是确保系统性能和稳定性的关键一环。慢 SQL 查询可能会导致系统性能下降、资源浪费和用户体验差等问题。因此&#xff0c;及时监控和优化慢 SQL 查询对于保障系统的正常运行和用户满意度至关重要…

k8s 1.28.x 配置nfs

1.安装nfs&#xff0c;在每个节点上安装 yum install -y nfs-utils 2.创建共享目录(主节点上操作) mkdir -p /opt/nfs/k8s 3.编写NFS的共享配置 /opt/nfs/k8s *(rw,no_root_squash) #*代表对所有IP都开放此目录&#xff0c;rw是读写 4.启动nfs systemctl enable nfs-ser…

动态代理(黑马笔记)

一、BigStar 大明星类 package com.itheima.mydynamicproxy1; public class BigStar implements Star {//实现接口要重写里边的抽象方法private String name;public BigStar() {}public BigStar(String name) {this.name name;}//唱歌Override //表示重写接口中的方法public…

Java | Leetcode Java题解之第127题单词接龙

题目&#xff1a; 题解&#xff1a; class Solution {Map<String, Integer> wordId new HashMap<String, Integer>();List<List<Integer>> edge new ArrayList<List<Integer>>();int nodeNum 0;public int ladderLength(String beginW…

算法-找出N个数组的共同元素

一、代码与执行结果 财经新闻是大众了解金融事件的重要渠道&#xff0c;现有N位编辑&#xff0c;分别对K篇新闻进行专业的编辑与排版。需要您找出被这N位编辑共同编辑过的新闻&#xff0c;并根据这些新闻ID升序排列返回一个数组。 import random# 查找编辑共同处理的新闻id def…

RunLoop小白入门

核心概念 什么是 RunLoop ? RunLoop 是 iOS 和 macOS 应用程序框架中的一个核心概念&#xff0c;用于管理线程的事件处理。它可以看作是一个循环&#xff0c;用于持续接收和处理各种事件&#xff0c;如用户输入、定时器、网络事件等。RunLoop 在保持应用程序响应用户交互和系…

系统与软件工程软件测试过程

系统与软件工程 软件测试 测试过程 &#xff1b;对应的国标是GB/T 38634.4 2020 &#xff0c;该标准的范围规定适应用于治理、管理和实施任何组织,项目或较小规模测试活动的软件测试的测试过程,定义了软件测试通用过程,给出了描述过程的支持信息图表。 一 术语和定义 1.1实测…

力扣173题:二叉搜索树迭代器(含模拟面试)

❤️❤️❤️ 欢迎来到我的博客。希望您能在这里找到既有价值又有趣的内容&#xff0c;和我一起探索、学习和成长。欢迎评论区畅所欲言、享受知识的乐趣&#xff01; 推荐&#xff1a;数据分析螺丝钉的首页 关注微信公众号 数据分析螺丝钉 免费领取价值万元的python/java/商业…