视觉语言模型解释 |
文章目录
- 一. 视觉语言模型解析
- 1.什么是视觉语言模型?
- 2. 开源视觉语言模型概览
- 3. 如何找到合适的视觉语言模型
- MMMU
- MMBench
- 4. 技术细节
- 5.使用变压器 (transformers) 运用视觉语言模型
- 6. 使用 TRL 微调视觉语言模型
- 二. 参考文章
一. 视觉语言模型解析
视觉语言模型是一类能够同时从图像和文本中学习,以处理从视觉问题回答到图像描述等多种任务的模型。本文将深入探讨视觉语言模型的核心组成部分,介绍其工作原理,如何选取适合的模型,以及如何利用 trl 的新版本进行便捷的微调。
1.什么是视觉语言模型?
视觉语言模型是指能够从图像和文本中学习的多模态模型。这类模型属于生成模型,能够接收图像和文本输入,并产生文本输出。大型视觉语言模型具备优秀的零样本能力,能够广泛适应多种图像类型,如文档、网页等,并且表现出良好的泛化性。应用场景包括图像聊天、图像识别指导、视觉问答、文档理解和图像描述等。部分视觉语言模型还能识别图像中的空间属性,例如,在被要求检测或分割特定对象时,能够输出边界框或分割蒙版,或定位不同实体并回答关于它们的相对或绝对位置的问题。当前大型视觉语言模型在训练数据、图像编码方式上具有多样性,因此它们的能力也各不相同。
2. 开源视觉语言模型概览
Hugging Face Hub 提供了众多开源视觉语言模型。以下是一些显著的模型:
- 包括基础模型和为聊天应用而微调的模型,均可用于对话模式。
- 部分模型具备“定位(grounding)”功能,有助于减少幻觉现象。
- 除非另有说明,所有模型均采用英语进行训练。
3. 如何找到合适的视觉语言模型
选择适合特定用例的模型有多种方法:
Vision Arena 是一个基于模型输出的匿名投票排行榜,不断更新。在这个平台上,用户提交图像及提示,系统从两个不同模型中生成输出,用户则根据偏好选择输出,从而构建基于用户偏好的排行榜。
Open VLM排行榜 根据各种视觉语言模型的表现和平均得分进行排名。用户还可以根据模型大小、是否为开源以及不同的性能指标来筛选模型。
VLMEvalKit 是一个评估工具包,可以在视觉语言模型上运行基准测试,支持 Open VLM 排行榜。另一个评估套件是 LMMS-Eval,它提供了一个标准的命令行界面,用于使用托管在 Hugging Face Hub 上的数据集评估选定的 Hugging Face 模型,示例如下:
accelerate launch --num_processes=8 -m lmms_eval --model llava --model_args pretrained=\"liuhaotian/llava-v1.5-7b\" --tasks mme,mmbench_en --batch_size 1 --log_samples --log_samples_suffix llava_v1.5_mme_mmbenchen --output_path ./logs/
如果您希望探索更多模型,可以在 Hub 上
浏览执行image-text-to-text
任务的模型。
您可能会在排行榜上看到不同的视觉语言模型评估基准。下面是其中一些基准的介绍:
MMMU
全面的多学科多模态理解与推理基准,用于评估专家级通用人工智能 (AGI) (MMMU) 是评估视觉语言模型最全面的基准之一。它包含 11.5K 个多模态挑战,涉及艺术、工程等多个学科,需要大学级别的知识和推理能力。
MMBench
MMBench 是一个包含 3000 个单项选择题的评估基准,涵盖 20 种不同的技能,如 OCR、对象定位等。该论文还引入了一种名为 CircularEval 的评估策略,通过不同组合混洗答案选项,并期望模型在每次都能给出正确答案。此外,还有其他更具体的跨不同领域的基准,如 MathVista(视觉数学推理)、AI2D(图表理解)、ScienceQA(科学问答)和 OCRBench(文档理解)。
4. 技术细节
预训练视觉语言模型有多种方法。核心技术是统一图像和文本的表示,并将其输入文本解码器进行生成。常见的模型通常包括 图像编码器、用于对齐图像和文本表征的嵌入投影器(通常是一个密集型神经网络)及文本解码器。
例如,LLaVA 模型包括一个 CLIP 图像编码器、多模态投影器和一个 Vicuna 文本解码器。开发者将图像及其标题的数据集输入到 GPT-4 中,生成相关的问题。在此过程中,图像编码器和文本解码器被冻结,只对多模态投影器进行训练,以通过比较模型输出和实际标题来对齐图像和文本特征。预训练完成后,保持图像编码器冻结,解冻文本解码器,并继续训练投影器和解码器。这种预训练及微调方法是训练视觉语言模型的常见方式。
典型视觉语言模型的结构
投影和文本嵌入的连接
KOSMOS-2 采用了全面的端到端训练方法,与 LLAVA 式的预训练相比,在计算上更为昂贵。开发者后续还进行了语言指令的微调以优化模型。另一个例子是 Fuyu-8B,该模型不使用图像编码器,而是直接将图像块输入到投影层,然后通过自回归解码器进行处理。通常情况下,您可以使用现有的视觉语言模型,或根据自己的需求对模型进行微调。
5.使用变压器 (transformers) 运用视觉语言模型
您可以使用以下代码使用 Llava 模型进行推断:
首先初始化模型和处理器。
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torchdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf",torch_dtype=torch.float16,low_cpu_mem_usage=True
)
model.to(device)
现在我们将图像和文本提示传递给处理器,然后将处理过的输入传递给generate
。请注意,每个模型使用自己的提示模板,正确使用以确保性能。
from PIL import Image
import requestsurl = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"inputs = processor(prompt, image, return_tensors="pt").to(device)
output = model.generate(**inputs, max_new_tokens=100)
调用 decode 方法来解码输出 token。
print(processor.decode(output[0], skip_special_tokens=True))
6. 使用 TRL 微调视觉语言模型
我们高兴地宣布,TRL 的 SFTTrainer
现在支持视觉语言模型的实验性训练!我们提供了一个示例,展示如何使用 llava-instruct 数据集对 Llava 1.5 VLM 进行 SFT,该数据集包含 260k 张图像对话对。这些数据集包含的是格式化为消息序列的用户与助手之间的互动,例如,每个对话都配对一张图像,用户会就此图像提问。
要使用这项实验性训练支持,请安装 TRL 的最新版本,使用命令 pip install -U trl
。完整的示例脚本可在此处查看。
from trl.commands.cli_utils import SftScriptArguments, TrlParserparser = TrlParser((SftScriptArguments, TrainingArguments))
args, training_args = parser.parse_args_and_config()
我们现在初始化用于指令性微调的聊天模板。
LLAVA_CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""
我们将初始化我们的模型和 tokenizer。
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration
import torchmodel_id = "llava-hf/llava-1.5-7b-hf"
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer = tokenizermodel = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16)
创建一个数据整合器来结合文本和图像对。
class LLavaDataCollator:def __init__(this, processor):this.processor = processordef __call__(this, examples):texts = []images = []for example in examples:messages = example["messages"]text = this.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)texts.append(text)images.append(example["images"][0])batch = this.processor(texts, images, return_tensors="pt", padding=True)labels = batch["input_ids"].clone()if this.processor.tokenizer.pad_token_id is not None:labels[labels == this.processor.tokenizer.pad_token_id] = -100batch["labels"] = labelsreturn batchdata_collator = LLavaDataCollator(processor)
加载我们的数据集。
from datasets import load_datasetraw_datasets = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft")
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["test"]
初始化 SFTTrainer,传入模型、数据集分割、PEFT 配置和数据整合器,并调用 train()
。要将我们的最终检查点推送到 Hub,请调用 push_to_hub()
。
from trl import SFTTrainertrainer = SFTTrainer(model=model,args=training_args,train_dataset=train_dataset,eval_dataset=eval_dataset,dataset_text_field="text", # need a dummy fieldtokenizer=tokenizer,data_collator=data_collator,dataset_kwargs={"skip_prepare_dataset": True},
)trainer.train()
保存模型并将其推送到 Hugging Face Hub。
trainer.save_model(training_args.output_dir)
trainer.push_to_hub()
您可以在下方的 VLM 游乐场中直接体验我们新训练的模型 ⬇️
致谢
我们要感谢 Pedro Cuenca, Lewis Tunstall, Kashif Rasul 和 Omar Sanseviero 对这篇博客文章的审阅和建议。
二. 参考文章
Vision Language Models Explained:https://huggingface.co/blog/vlms