多模态大语言模型(MLLM)-Blip3/xGen-MM

论文链接:https://www.arxiv.org/abs/2408.08872
代码链接:https://github.com/salesforce/LAVIS/tree/xgen-mm

本次解读xGen-MM (BLIP-3): A Family of Open Large Multimodal Models
可以看作是
[1] Blip: Bootstrapping language-image pre-training for unified vision-language understanding and generation
[2] BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models
的后继版本

前言

在这里插入图片描述
没看到Blip和Blip2的一作Junnan Li,不知道为啥不参与Blip3
整体pipeline服从工业界的一贯做法,加数据,加显卡,模型、训练方式简单,疯狂scale up

创新点

  • 开源模型在模型权重、训练数据、训练方法上做的不好
  • Blip2用的数据不够多、质量不够高;Blip2用的Q-Former、训练Loss不方便scale up;Blip2仅支持单图输入,不支持多图输入
  • Blip3收集超大规模数据集,并且用相对简单的训练方式,实现多图、文本的交互。
  • 开放两个数据集:BLIP3-OCR-200M(大规模OCR标注数据集),BLIP3-GROUNDING-50M(大规模visual grounding数据集)

具体细节

模型结构

在这里插入图片描述
整体结构非常简单

  • 图像经过ViT得到patch embedding,再经过token sampler得到vision token。(先经过Token Sampler,得到视觉embedding,而后经过VL connector,得到vision token)
  • 文本通过tokenizer获得text token
  • 文本、图像输入均送到LLM中,并且仅对本文加next prediction loss
  • 注意:ViT参数冻结,其他参数可训练
  • 注意:支持图像和文本交替输入,支持多图,任意分辨率图像
  • ViT:所用模型有DFN、SigLIP,在不同任务上,效果不同,如下:
    在这里插入图片描述
  • LLM:所用模型为phi3-mini
  • 模型结构代码见https://github.com/salesforce/LAVIS/blob/xgen-mm/open_flamingo/src/factory.py
  • token Sampler代码见https://github.com/salesforce/LAVIS/blob/xgen-mm/open_flamingo/src/vlm.py
  • VL connector代码见https://github.com/salesforce/LAVIS/blob/xgen-mm/open_flamingo/src/helpers.py

Token Sampler

详见博客https://blog.csdn.net/weixin_40779727/article/details/142019977,就不赘述了

VL Connector

整体结构如下:

class PerceiverAttention(nn.Module):def __init__(self, *, dim, dim_head=64, heads=8):super().__init__()self.scale = dim_head**-0.5self.heads = headsinner_dim = dim_head * headsself.norm_media = nn.LayerNorm(dim)self.norm_latents = nn.LayerNorm(dim)self.to_q = nn.Linear(dim, inner_dim, bias=False)self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)self.to_out = nn.Linear(inner_dim, dim, bias=False)def forward(self, x, latents, vision_attn_masks=None):"""Args:x (torch.Tensor): image featuresshape (b, T, n1, D)latent (torch.Tensor): latent featuresshape (b, T, n2, D)"""x = self.norm_media(x)latents = self.norm_latents(latents)h = self.headsq = self.to_q(latents)kv_input = torch.cat((x, latents), dim=-2) # TODO: Change the shape of vision attention mask according to this.if vision_attn_masks is not None:vision_attn_masks = torch.cat((vision_attn_masks, torch.ones((latents.shape[0], latents.shape[-2]), dtype=latents.dtype, device=latents.device)),dim=-1)k, v = self.to_kv(kv_input).chunk(2, dim=-1)q, k, v = rearrange_many((q, k, v), "b t n (h d) -> b h t n d", h=h)q = q * self.scale# attentionsim = einsum("... i d, ... j d  -> ... i j", q, k)# Apply vision attention mask here.# Reference: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attentionif vision_attn_masks is not None:attn_bias = torch.zeros((q.size(0), 1, 1, q.size(-2), k.size(-2)), dtype=q.dtype, device=q.device)vision_attn_masks = repeat(vision_attn_masks, 'b n -> b 1 1 l n', l=q.size(-2))attn_bias.masked_fill_(vision_attn_masks.logical_not(), float("-inf"))sim += attn_biassim = sim - sim.amax(dim=-1, keepdim=True).detach()attn = sim.softmax(dim=-1)out = einsum("... i j, ... j d -> ... i d", attn, v)out = rearrange(out, "b h t n d -> b t n (h d)", h=h)return self.to_out(out)class PerceiverResampler(VisionTokenizer):def __init__(self,*,dim,dim_inner=None,depth=6,dim_head=96,heads=16,num_latents=128,max_num_media=None,max_num_frames=None,ff_mult=4,):"""Perceiver module which takes in image features and outputs image tokens.Args:dim (int): dimension of the incoming image featuresdim_inner (int, optional): final dimension to project the incoming image features to;also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.depth (int, optional): number of layers. Defaults to 6.dim_head (int, optional): dimension of each head. Defaults to 64.heads (int, optional): number of heads. Defaults to 8.num_latents (int, optional): number of latent tokens to use in the Perceiver;also corresponds to number of tokens per sequence to output. Defaults to 64.max_num_media (int, optional): maximum number of media per sequence to input into the Perceiverand keep positional embeddings for. If None, no positional embeddings are used.max_num_frames (int, optional): maximum number of frames to input into the Perceiverand keep positional embeddings for. If None, no positional embeddings are used.ff_mult (int, optional): dimension multiplier for the feedforward network. Defaults to 4."""if dim_inner is not None:projection = nn.Linear(dim, dim_inner)else:projection = Nonedim_inner = dimsuper().__init__(dim_media=dim, num_tokens_per_media=num_latents)self.projection = projectionself.latents = nn.Parameter(torch.randn(num_latents, dim))# positional embeddingsself.frame_embs = (nn.Parameter(torch.randn(max_num_frames, dim))if exists(max_num_frames)else None)self.media_time_embs = (nn.Parameter(torch.randn(max_num_media, 1, dim))if exists(max_num_media)else None)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),FeedForward(dim=dim, mult=ff_mult),]))self.norm = nn.LayerNorm(dim)def forward(self, x, vision_attn_masks):"""Args:x (torch.Tensor): image featuresshape (b, T, F, v, D)vision_attn_masks (torch.Tensor): attention masks for padded visiont tokens (i.e., x)shape (b, v)Returns:shape (b, T, n, D) where n is self.num_latents"""b, T, F, v = x.shape[:4]# frame and media time embeddingsif exists(self.frame_embs):frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)x = x + frame_embsx = rearrange(x, "b T F v d -> b T (F v) d")  # flatten the frame and spatial dimensionsif exists(self.media_time_embs):x = x + self.media_time_embs[:T]# blockslatents = self.latentslatents = repeat(latents, "n d -> b T n d", b=b, T=T)for attn, ff in self.layers:latents = attn(x, latents, vision_attn_masks) + latentslatents = ff(latents) + latentsif exists(self.projection):return self.projection(self.norm(latents)) else:return self.norm(latents)

训练及数据

预训练
  • 训练数据
    在这里插入图片描述
    用了0.1T token的多模态数据训练,和一些知名的MLLM相比,例如Qwen2VL 0.6T,还是不太够
  • 训练方式:针对文本的next token prediction方式训练,图像输入为384x384
有监督微调(SFT)
  • 训练数据:从不同领域(multi-modal conversation、 image captioning、chart/document understanding、science、math),收集一堆开源数据。从中采样1百万,包括图文指令+文本指令数据。
    训练1epoch
  • 训练方式:针对文本的next token prediction方式训练
交互式多图有监督微调(Interleaved Multi-Image Supervised Fine-tuning)
  • 训练数据:首先,收集多图指令微调数据(MANTIS和Mmdu)。为避免模型过拟合到多图数据,选择上一阶段的单图指令微调数据子集,与收集的多图指令微调数据合并,构成新的训练集合。
  • 训练方式:针对文本的next token prediction方式训练
后训练(Post-training)
DPO提升Truthfulness
part1
  • 训练数据:利用开源的VLFeedback数据集。VLFeedback数据集构造方式:输入指令,让多个VLM模型做生成,随后GPT4-v从helpfulness, visual faithfulness, ethics三个维度对生成结果打分。分值高的输出作为preferred responses,分值低的输出作为dispreferred responses。BLIP3进一步过滤掉一部分样本,最终得到62.6K数据。
  • 训练方式:DPO为训练目标,用LoRA微调LLM 2.5%参数,总共训练1 epoch
part2
  • 训练数据:根据该工作,生成一组额外responses。该responses能够捕捉LLM的内在幻觉,作为额外dispreferred responses,采用DPO训练。
  • 训练方式:同part1,再次训练1 epoch
Safety微调(Safety Fine-tuning)提升Harmlessness
  • 训练数据:用2k的VLGua数据集+随机5K SFT数据集。VLGuard包括两个部分:
    这段话可以翻译为:
    (1) 恶心图配上安全指示及安全回应
    (2) 安全图配上安全回应及不安全回应
  • 训练方式:用上述7k数据,训练目标为next token prediction,用LoRA微调LLM 2.5%参数,总共训练1 epoch

实验效果

预训练

对比类似于预训练任务的VQA、Captioning任务,效果在使用小参数量LLM的MLLM里,效果不错。
在这里插入图片描述

有监督微调(SFT)

在这里插入图片描述

交互式多图有监督微调(Interleaved Multi-Image Supervised Fine-tuning)

在这里插入图片描述

后训练(Post-training)

在这里插入图片描述

消融实验

预训练
预训练数据量

在这里插入图片描述

预训练数据配比

在这里插入图片描述

视觉backbone

在这里插入图片描述

有监督微调(SFT)
视觉Token Sampler对比

在这里插入图片描述
base resolution:直接把图片resize到目标大小
anyres-fixed-sampling (ntok=128):把所有图像patch的表征concat起来,经过perceiver resampler,得到128个vision token
anyres-fixed-sampling (ntok=256):把所有图像patch的表征concat起来,经过perceiver resampler,得到256个vision token
anyres-patch-sampling:本文采用的方法

Instruction-Aware Vision Token Sampling.

在这里插入图片描述
XGen-MM:输入图像,获取vision token
XGen-MM(instruction-aware):同时输入图像+指令,获取vision token

Quality of the Text-only Instruction Data.

在这里插入图片描述仅利用文本指令数据,训练SFT模型,对比效果


https://blog.csdn.net/weixin_40779727/article/details/142019977

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/882121.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

TCP/IP 寻址

TCP/IP 寻址 概述 TCP/IP(传输控制协议/互联网协议)是一组用于数据网络的通信协议。它们定义了数据如何在网络上从一个设备传输到另一个设备。在TCP/IP网络中,每个设备都有一个唯一的地址,称为IP地址,用于标识网络上的设备。本文将深入探讨TCP/IP寻址的概念、类型、分配…

Lua脚本的原子性

Lua脚本之所以被认为是原子性的,主要源于Redis的内部实现机制和Lua脚本的执行方式。以下是对Lua脚本原子性的详细解释: 一、Redis的单线程模型 Redis是一个基于内存、可基于Key-Value等多种数据结构的存储系统,它使用单线程模型来处理客户端的请求。这意味着在任何给定的时…

Vue3中防止按钮重复点击的方式

本文列两种方式&#xff0c;推荐第一种&#xff0c;经过长时间测试第二种防止的还是会漏&#xff0c;这里也列一下 ①使用定时器&#xff08;推荐&#xff09; 判断3秒钟之内方法只能执行一次 <el-button click"handleClick" type"primary" :loading…

二叉树算法之二叉树遍历(前序、中序、后序、层次遍历)

二叉树遍历是指按照某种顺序访问二叉树的所有节点。常见的二叉树遍历方式包括前序遍历&#xff08;Preorder Traversal&#xff09;、中序遍历&#xff08;Inorder Traversal&#xff09;、后序遍历&#xff08;Postorder Traversal&#xff09;和层次遍历&#xff08;Level-or…

stm32 bootloader写法

bootloader写法&#xff1a; 假设app的起始地址&#xff1a;0x08020000&#xff0c;则bootloader的范围是0x0800,0000~0x0801,FFFF。 #define APP_ADDR 0x08020000 // 应用程序首地址定义 typedef void (*APP_FUNC)(void); // 函数指针类型定义 /*main函数中调用rum_app&#x…

Luogu P1528 切蛋糕 || SCOI2005 栅栏

假设最多能满足 x x x个人&#xff0c;那么这 x x x个人一定可以是按照每个人吃蛋糕的需求将他们从小到大排序后的前 x x x个人。&#xff08;有两个人他们吃蛋糕的需求分别为 x 1 x_1 x1​和 x 2 x_2 x2​&#xff0c;且 x 1 < x 2 x_1<x_2 x1​<x2​&#xff0c;如果…

【从零开始的LeetCode-算法】504. 七进制数

给定一个整数 num&#xff0c;将其转化为 7 进制&#xff0c;并以字符串形式输出。 示例 1: 输入: num 100 输出: "202"示例 2: 输入: num -7 输出: "-10"提示&#xff1a; -107 < num < 107 我的解答 class Solution {public String convertT…

大数据存储计算平台EasyMR:大数据集群动态扩缩容,快速提升集群服务能力

在当今的数据驱动时代&#xff0c;组织面临着数据量的爆炸性增长。为了有效管理和存储这些数据&#xff0c;许多组织依赖于 Hadoop 这样的分布式存储系统。Hadoop 集群通过在多个节点上存储数据的冗余副本&#xff0c;提供了高可靠性和可扩展性。然而&#xff0c;随着数据量的不…

【MySQL】提高篇—复杂查询:多表连接(INNER JOIN、LEFT JOIN、RIGHT JOIN、FULL JOIN)

在关系数据库中&#xff0c;数据通常分散在多个表中。为了获取相关联的数据&#xff0c;需要使用表连接&#xff08;JOIN&#xff09;操作。 表连接允许我们在一个查询中结合多个表的数据&#xff0c;这在实际应用中非常重要。 例如&#xff0c;在一个电商系统中&#xff0c;…

如何更改MySQL的root密码

前言 在管理数据库时&#xff0c;有时可能会忘记MySQL的root用户密码或需要更改默认设置的密码。以下是在Windows环境下更改MySQL root密码的详细步骤。请注意&#xff0c;这些步骤适用于MySQL 5.7及以上版本&#xff1b;对于其他版本&#xff0c;请参考相应版本的文档。 准备…

倪师学习笔记-天纪-易经入门

1、易经-易的意思 变易、变化简易、简单 2、神、象 神&#xff1a;永恒象&#xff1a;一直变&#xff0c;不会重复&#xff0c;但对应的神一样 3、机锋 突发的、短暂的具有预兆功能的事情 4、先天八卦&#xff08;伏羲八卦&#xff09; 特点&#xff1a;为体&#xff0c…

ChatGPT国内中文版镜像网站整理合集(2024/10/06)

一、GPT中文镜像站 ① yixiaai.com 支持GPT4、4o以及o1&#xff0c;支持MJ绘画 ② chat.lify.vip 支持通用全模型&#xff0c;支持文件读取、插件、绘画、AIPPT ③ AI Chat 支持GPT3.5/4&#xff0c;4o以及MJ绘画 1. 什么是镜像站 镜像站&#xff08;Mirror Site&#xff…

Spring Boot在线考试系统:JavaWeb技术的应用案例

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

几个常值问题导致formality失败debug方法

在做fomality的时候&#xff0c;如果出现寄存器unmatch问题&#xff0c;通常是由于常值推导不一致&#xff0c;导致寄存器被优化。 几个基本定位方法如下&#xff1a; 1.verify r:wrok/top/xyz_reg[0] -constant0 此命令可查看ref中的寄存器是否是常值。 2.guide guide_re…

英飞达医学影像存档与通信系统 WebUserLogin.asmx 信息泄露漏洞复现

0x01 产品简介 英飞达医学影像存档与通信系统 Picture Archiving and Communication System,它是应用在医院影像科室的系统,主要的任务就是把日常产生的各种医学影像(包括核磁,CT,超声,各种X光机,各种红外仪、显微仪等设备产生的图像)通过各种接口(模拟,DICOM,网络…

概率 随机变量以及分布

一、基础定义及分类 1、随机变量 随机变量是一个从样本空间&#xff08;所有可能结果的集合&#xff09;到实数集的函数。&#xff08;随机变量的值可以是离散的&#xff0c;也可以是连续的。 &#xff09; 事件可以定义为随机变量取特定值的集合。 2、离散型随机变量 随机变…

npm-run-all 使用实践

参考: npm-run-all 背景 在前端开发中&#xff0c;你是否存在以下烦恼: 写 package.json 的 scripts 命令时&#xff0c;命令太过冗长&#xff0c;例如编译命令 build 需要执行清理 clean, 编译css build:css, 编译js build:js, 编译html build:html 命令&#xff0c;则 bui…

OpenCV高级图形用户界面(17)设置一个已经创建的滚动条的最小值函数setTrackbarMin()的使用

操作系统&#xff1a;ubuntu22.04 OpenCV版本&#xff1a;OpenCV4.9 IDE:Visual Studio Code 编程语言&#xff1a;C11 算法描述 cv::setTrackbarMin 这个函数的作用就是设置指定窗口中轨迹条的最小位置。这使得开发者能够在程序运行时动态地调整轨迹条的范围&#xff0c;而不…

Leetcode—1242. 多线程网页爬虫【中等】Plus(多线程)

2024每日刷题&#xff08;187&#xff09; Leetcode—1242. 多线程网页爬虫 实现代码 /*** // This is the HtmlParsers API interface.* // You should not implement it, or speculate about its implementation* class HtmlParser {* public:* vector<string>…

Go程序的一生——Go如何跑起来的?

引入编译链接概述 编译过程 词法分析语法分析语义分析中间代码生成目标代码生成与优化链接过程Go 程序启动GoRoot 和 GoPathGo 命令详解 go buildgo installgo run总结参考资料 引入 我们从一个 Hello World 的例子开始&#xff1a; package mainimport "fmt"func…