一、概述
1、是什么
Mini-Monkey 论文全称《Mini-Monkey: Multi-Scale Adaptive Cropping for Multimodal Large Language Models》,是一个轻量级MLLM(多模态的视觉-文本模型),基于InternViT、MLP和InternLLM,其实就是全套的Intern VL2。用途论文提到:不仅在各种通用多模态理解任务上展示了领先的性能,而且在文档理解能力上也显示出一致的改进,所以关于多图问对话、视频理解对话 、json格式、代码编写和debug、函数调用论文暂时未提。
2、亮点
论文宣称:
*在2B参数MLLM中取得了最先进的性能。它不仅在各种通用多模态理解任务上展示了领先的性能,而且在文档理解能力上也显示出一致的改进。在OCRBench上,Mini-Monkey达到了802分,超过了8B参数的最先进模型InternVL2-8B。
*此外,我们的模型和训练策略非常高效,只需要8个RTX 3090即可训练。
*创新点:图像的处理方式——多尺度自适应裁剪策略(MSAC)+一种尺度压缩机制(SCM)。Mini-Monkey能够自适应地生成多尺度表示,允许它从不同尺度中选择未被分割的物体,并有效地压缩图像标记。
*开源。
PS
个人还是和当初这个团队的Monkey一样,不建议阅读和复现。理由如下:
*论文的主要创新点MSAC真实代码并没有adaptive,其实就是选择了最优(这里也有疑问)比例+次优比例进行切图,然后再使用SCM就是用LLM的激活(后面细讲)来筛选部分图像切片来推理,增加切片数+参数量然后比原生Intern VL2高一个点,然后还有git用户反馈无法复现,所以大家都懂了。。。
*同样反感的就是公众号的PR稿。
综上,本文就主要结合代码介绍他的创新点的实现方式(个人严重怀疑是PR稿,有点**时间),关于模型结构(参考Intern VL2)、训练方式、训练数据和结果就不展开了,可以参考原文。
二、模型
1、模型结构
主要有VIT、投影层、LLM组成。下面主要介绍VIT对应上图的两个改进点MACS和SCM。
MACS
全称:多尺度自适应裁剪策略,就是在原来大家常用的整图(图里面叫做global layer)+切片图(图里面叫做Detailed Layer)的基础上多了一个图里面的Adaptive layer。源码实现的时候其实就是先运行获得Detailed Layer的切片,然后剔除Detailed Layer的切片对应的最优解和有等比切割的解,选择一个次优比例解,结束。对应的源码dynamic_preprocess2函数如下:Monkey/project/mini_monkey/demo.py at main · Yuliang-Liu/Monkey · GitHub
整体代码很短,可以自己扫一下
import numpy as np
import torch
import torchvision.transforms as T
from decord import VideoReader, cpu
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizerIMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)def build_transform(input_size):MEAN, STD = IMAGENET_MEAN, IMAGENET_STDtransform = T.Compose([T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),T.ToTensor(),T.Normalize(mean=MEAN, std=STD)])return transformdef find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):best_ratio_diff = float('inf')best_ratio = (1, 1)area = width * heightfor ratio in target_ratios:target_aspect_ratio = ratio[0] / ratio[1]ratio_diff = abs(aspect_ratio - target_aspect_ratio)if ratio_diff < best_ratio_diff:best_ratio_diff = ratio_diffbest_ratio = ratioelif ratio_diff == best_ratio_diff:if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:best_ratio = ratioreturn best_ratiodef dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):orig_width, orig_height = image.sizeaspect_ratio = orig_width / orig_height# calculate the existing image aspect ratiotarget_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) ifi * j <= max_num and i * j >= min_num)target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])# find the closest aspect ratio to the targettarget_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, target_ratios, orig_width, orig_height, image_size)# calculate the target width and heighttarget_width = image_size * target_aspect_ratio[0]target_height = image_size * target_aspect_ratio[1]blocks = target_aspect_ratio[0] * target_aspect_ratio[1]# resize the imageresized_img = image.resize((target_width, target_height))processed_images = []for i in range(blocks):box = ((i % (target_width // image_size)) * image_size,(i // (target_width // image_size)) * image_size,((i % (target_width // image_size)) + 1) * image_size,((i // (target_width // image_size)) + 1) * image_size)# split the imagesplit_img = resized_img.crop(box)processed_images.append(split_img)assert len(processed_images) == blocksif use_thumbnail and len(processed_images) != 1:thumbnail_img = image.resize((image_size, image_size))processed_images.append(thumbnail_img)return processed_images, target_aspect_ratiodef dynamic_preprocess2(image, min_num=1, max_num=12, prior_aspect_ratio=None, image_size=448, use_thumbnail=False):orig_width, orig_height = image.sizeaspect_ratio = orig_width / orig_height# calculate the existing image aspect ratiotarget_ratios = set((i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) ifi * j <= max_num and i * j >= min_num)target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])new_target_ratios = []for i in target_ratios:if prior_aspect_ratio[0]%i[0] or prior_aspect_ratio[1]%i[1]:new_target_ratios.append(i)else:continue# find the closest aspect ratio to the targettarget_aspect_ratio = find_closest_aspect_ratio(aspect_ratio, new_target_ratios, orig_width, orig_height, image_size)# calculate the target width and heighttarget_width = image_size * target_aspect_ratio[0]target_height = image_size * target_aspect_ratio[1]blocks = target_aspect_ratio[0] * target_aspect_ratio[1]# resize the imageresized_img = image.resize((target_width, target_height))processed_images = []for i in range(blocks):box = ((i % (target_width // image_size)) * image_size,(i // (target_width // image_size)) * image_size,((i % (target_width // image_size)) + 1) * image_size,((i // (target_width // image_size)) + 1) * image_size)# split the imagesplit_img = resized_img.crop(box)processed_images.append(split_img)assert len(processed_images) == blocksif use_thumbnail and len(processed_images) != 1:thumbnail_img = image.resize((image_size, image_size))processed_images.append(thumbnail_img)return processed_imagesdef load_image(image_file, input_size=448, min_num=1, max_num=12):image = Image.open(image_file).convert('RGB')transform = build_transform(input_size=input_size)images, target_aspect_ratio = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num)pixel_values = [transform(image) for image in images]pixel_values = torch.stack(pixel_values)return pixel_values, target_aspect_ratiodef load_image2(image_file, input_size=448, min_num=1, max_num=12, target_aspect_ratio=None):image = Image.open(image_file).convert('RGB')transform = build_transform(input_size=input_size)images = dynamic_preprocess2(image, image_size=input_size, use_thumbnail=True, min_num=min_num, max_num=max_num, prior_aspect_ratio=target_aspect_ratio)pixel_values = [transform(image) for image in images]pixel_values = torch.stack(pixel_values)return pixel_values# If you want to load a model using multiple GPUs, please refer to the `Multiple GPUs` section.
path = 'minimonkey'
model = AutoModel.from_pretrained(path,torch_dtype=torch.bfloat16,low_cpu_mem_usage=True,trust_remote_code=True).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True, use_fast=False)# set the max number of tiles in `max_num`
pixel_values, target_aspect_ratio = load_image('xxx.jpg', min_num=4, max_num=12)
pixel_values = pixel_values.to(torch.bfloat16).cuda()
pixel_values2 = load_image2('xxx.jpg', min_num=3, max_num=7, target_aspect_ratio=target_aspect_ratio)
pixel_values2 = pixel_values2.to(torch.bfloat16).cuda()
pixel_values = torch.cat([pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0)generation_config = dict(do_sample=False, max_new_tokens=512)question = "Read the all text in the image."
response, history = model.chat(tokenizer, pixel_values, target_aspect_ratio, question, generation_config, history=None, return_history=True)
print(f'User: {question} Assistant: {response}')
SCM
全称:尺度压缩机制
*动机:尽管提出的MSAC显著提高了模型性能,但某些场景可能会对计算有要求,引入了一种称为尺度压缩机制(SCM)的无参数令牌压缩方法,用于减少视觉令牌。
*原理:来自详细层的令牌信息密度较低,相比之下,来自自适应层和全局层的视觉令牌为LLM提供了完整的空间信息,因此关注压缩详细层令牌。
*做法:具体来说,一个训练有素的MLLM的LLM可以有效地根据输入问题选择必要的视觉特征。因此,SCM利用LLM的前两层选择视觉令牌,而不产生任何额外的参数。输入的视觉令牌包括Vd ∈ RL1×C,Va ∈ RL2×C和Vg ∈ RL3×C,文本令牌Tt ∈ RT ×C将被送入LLM的层。Vd代表来自详细层的令牌。Va代表来自自适应层的令牌。Vg代表来自全局层的令牌。值得注意的是,重用了LLM的这一层作为这个LLM的层。LLM的层将输出一个注意力图。选择来自自适应层、全局层和文本令牌的视觉令牌来关注来自详细层的视觉令牌。注意力的计算可以表示如下:其中PE代表位置编码,D表示LLM的维度。Cat()代表序列连接操作。计算注意力机制后,将注意力图的第一个维度Attnw ∈ R(L2+L3+T)×L1平均化,以获得一个权重向量Wa ∈ RL1。然后,根据这个权重向量Wa从详细层中选择前K个视觉特征。这些选定的令牌,连同来自自适应层、全局层的令牌和文本令牌一起输入到LLM中以生成结果。与FastV相比,SCM通过使用高相对信息密度的令牌来压缩低信息密度的令牌,更具针对性。
这里看源码实现其实K选择的是0.5倍的详细层视觉token,再回顾前面分割的超参数,详细层是4-12块,“adaptive”层是3-7块,其实大致可以猜到最终大多数是1:1最终token量级。
源码位于:Monkey/project/mini_monkey/internvl/model/internvl_chat/modeling_minimonkey_chat.py at 83a7899dd57ad9d3abc85c9b26207432be7f5cbb · Yuliang-Liu/Monkey · GitHub
if use_scm:self.language_model.model.img_idx = torch.where(selected == True)self.language_model.model.high_token = target_aspect_ratio[0] * target_aspect_ratio[1] * self.num_image_tokenbatch_size, seq_length = input_embeds.shape[:2]device = input_embeds.deviceposition_ids = torch.arange(0, seq_length, dtype=torch.long, device=device)position_ids = position_ids.unsqueeze(0)new_attention_mask = self.language_model.model._prepare_decoder_attention_mask(attention_mask, (batch_size, seq_length), input_embeds, 0)tmp_layer_outputs = self.language_model.model.layers[0](input_embeds,attention_mask=new_attention_mask,position_ids=position_ids,past_key_value=None,output_attentions=False,use_cache=False,)tmp_layer_outputs2 = self.language_model.model.layers[1](tmp_layer_outputs[0],attention_mask=new_attention_mask,position_ids=position_ids,past_key_value=None,output_attentions=True,use_cache=False,)tmp_attn = tmp_layer_outputs2[1]tmp_attn = tmp_attn[:, :, self.language_model.model.img_idx[0][0] + self.language_model.model.high_token:,self.language_model.model.img_idx[0][0]:self.language_model.model.img_idx[0][0] + self.language_model.model.high_token]tmp_attn = tmp_attn.mean(2)tmp_idx = tmp_attn.mean(1).topk(int(tmp_attn.shape[-1] * 0.5)).indices + self.language_model.model.img_idx[0][0]top_attention_rank_index = tmp_idx.sort().values[0]device = input_embeds.devicetop_attention_rank_index = torch.cat((torch.arange(self.language_model.model.img_idx[0][0], device=device),top_attention_rank_index, torch.arange(self.language_model.model.img_idx[0][0] + self.language_model.model.high_token + 1, input_embeds.shape[1],device=device)))input_embeds = input_embeds[:, top_attention_rank_index]attention_mask = torch.ones((input_embeds.shape[0], input_embeds.shape[1]), dtype=torch.bool, device=device)
2、模型亮点
介绍模型的结构亮点。
PS
自己面向模型的一些理解,主要是一些补充和槽点。
三、数据
1、数据标签
数据的label构成,主要会涉及到loss计算。
2、数据构成
数据集的构成,来源有哪些,这个对于现在的AIGC很重要,可以快速知道训练集、测试集来源,还有就是快速识别一些不客观的对比(模型A在数据集1上训练过所以比模型B在数据集1上好)
3、数据清洗
数据的清洗方式,这个几乎是大模型的命脉,预训练数据的清洗和微调数据的清洗,不过现在多数不开源微调数据的清洗方式。
四、策略
1、训练过程
几个阶段训练、冻结哪个网络模块、训练超参。
2、推理过程
推理的时候是不是有后处理等等
五、结果
1、多维度对比
多个数据集上的对比结果。
2、消融实验
网络、数据、超参等等的消融实验,能够验证想法的有效性,并且同时增加对不同模块其他方法无效的认知。
六、使用方法
一般开源的话,除非确实有bug,这里会写出踩坑日志。
七、待解决
通过论文、代码、询问等我也仍无法理解的,可以大家一起沟通。
八、参考链接
Monkey/project/mini_monkey/demo.py at main · Yuliang-Liu/Monkey · GitHub