Rich Human Feedback for Text-to-Image Generation 读论文笔记
- 摘要
- 方法细节
- 收集数据的过程
- 人类反馈确认
- 数据集
- VQA
- 使用方法
- 数据分析
- 分数统计
- 评价一致性(pair alignment)
- 实验
- 模型
- 模型架构
- 模型变体
- 模型其他优化
- 实验
- Metrics
- Score
- Heatmap
- Misalignment
- 量化结果
- Score
- Heatmap
- Misalignment
- 定性分析
- 从反馈中学习
摘要
Motivation:探索如何优化如Stable Diffusion T2I生成模型的优化问题,因为这些模型都会有诸如伪影,与文字描述不匹配和美学质量低等问题。本文参考大语言模型强化学习的方式,训练奖励模型来改进生成模型。
Contribusion:在收集的数据集(RichHF18K)收集feedback,通过选择高质量的训练数据和改进来生成模型,或者使用预测的heatmap来创建掩码,修复有问题的区域。
- Rich Human Feedback dataset
- 一个多模态Transformer模型对生成的图像进行丰富的反馈预测
- improve method:方式:1. 标记有问题的图像区域 2.标记文本描述不匹配的prompt(被误报或漏报)3. 使用预测的分数来帮助微调图像生成模型
方法细节
收集数据的过程
RichHF-18K数据集:
每个图片包含的标注和分数:
- 图像高度的1 / 20为半径标记伪影和错位标注(两个heatmap,implausibility and misalignment heatmap)。
- 没有对齐的关键词的标注
- 四个细粒度的分数(合理性、一致性、美观性、总体评分)
人类反馈确认
每个图像-文本对由三个标注员进行注释,所以对于分数直接做平均,文字对齐标注采取多数原则,点标注使用每个点区域的平均值(每个点被转换为热图上的一个磁盘区域,然后计算三个热图之间的平均热图)
数据集
在Pick-a-Pic dataset数据集选取的子集。选取的部分是照片等级的图像。为了平衡类别,使用PaLI visual question answering (VQA) model从Pick - a - Pic数据样本中提取一些基本特征。
VQA
是一种用于能够结合大语言模型和图像理解技术的多模态模型。
使用方法
输入问题:
- 图像有真实感吗
- 那个类别最能描述图像?在"人"、“动物”、“物”、“室内场景”、"室外场景"中任选其一
18K的数据集,16K作为训练集,1K作为验证,1K作为测试。
数据分析
分数统计
s − s min s max − s min = s − 1 5 − 1 \frac{s - s_{\text{min}}}{s_{\text{max}} - s_{\text{min}}} = \frac{s - 1}{5 - 1} smax−smins−smin=5−1s−1
得到的分布如下:
基本符合高斯分布
评价一致性(pair alignment)
maxdiff = max ( scores ) − min ( scores ) \text{maxdiff} = \max(\text{scores}) - \min(\text{scores}) maxdiff=max(scores)−min(scores)
实验
模型
模型架构
这个架构中有两个计算流,分别关注视觉和文本的部分,使用的架构分别是VIT和T5X。
文本信息通过对齐程度和heatmap传递给图像token,视觉信息传递给文本token用于视觉感知。使用WebLi预训练模型。
- 生成的图像输入ViT,然后在输出的地方成为高级表征,text则是嵌入成dense向量。
- 将两种token经过T5X的自注意力级联编码
- 编码后使用三种预测器来预测不同的输出。
type | operate |
---|---|
heatmap | 输入:图像token 经过卷积反卷积和sigmoid 输出:不可信和heatmap |
score | 输入:feature map 经过卷积,线性和sigmoid 输出:细粒度scores |
misalignment | 输入:原始caption,target:修改的caption 使用T5X的解码器,不对齐的用后缀0表示,e.g.:如果生成的图像中包含黑猫,且黄色单词与图像不对齐,则为黄色0猫。 |
模型变体
- Multi-head 每个评分,heat map和misalignment有一个头对应,共七个
- 对每个预测类型使用单个头,即总共3个头,分别用于热图、得分和misalignment。
在实验中,第二种方法具体操作是:增加能够让模型判断输出类型的prompt,比如如 “implausibility heatmap”,这样能够明确任务类型。通过将这种prompt与相应的任务进行结合,单个热图(得分)头就可以预测不同的热图(得分)。能够在有些任务中得到比第一种更好的结果。
模型其他优化
损失函数是热图MSE损失、评分MSE和序列CE的加权组合。
实验
针对三种标注和打分的方法:
Metrics
Score
使用的系数:Pearson线性相关系数(PLCC)和斯皮尔曼等级相关系数(SRCC)。PLCC测量预测和真实分数之间的线性相关性,表明预测以线性方式近似实际分数的程度。SRCC测量预测和实际分数之间的关系可以使用单调函数来描述,重点是排名顺序而不是确切值。
Heatmap
标准的显着性热图评估指标,如归一化扫描路径显着性(NSS),Kullback-Leibler发散(KLD)
Misalignment
Token-level precision, recall, 和 F1-score.精度测量预测的未对齐关键字的准确性(即,正确的预测关键字的比例),查全率测量完整性(即,被正确预测的实际未对齐关键字的比例),而F1-score通过计算它们的调和平均值来提供精确度和召回率之间的平衡。
量化结果
Score
Heatmap
Misalignment
表1和表3中变体都超过了ResNet50,表2中多头版本不如resnet50,但是三头版本优于resnet50。
作者在这里预测的原因是:可能在多头版本中,所有7个预测任务都使用相同的prompt(相对于3头版本),因此所有任务的特征图和文本标记都是相同的。在这些任务之间找到一个好的折衷可能并不容易,因此一些任务如伪影/不可信热图的性能会变得更差。
注意到misalignment heat map预测通常比伪影heatmap预测的结果更差,这可能是因为错配区域的定义较少,因此注释可能更嘈杂。
定性分析
从反馈中学习
研究从这些反馈中能不能学到知识用于改善图像生成。
使用基于遮蔽变换器架构的Muse模型作为改进的目标。
首先,我们使用预训练的Muse模型为12,564个prompt(通过PaLM 2生成的提示集)生成了八张图像。我们为每张图像预测RAHF分数,如果每个提示生成的图像中最高分超过一个固定阈值,它将被选为我们微调数据集的一部分。然后,Muse模型与这个数据集一起进行微调。前后对比:
量化Muse微调的收益:作者使用100个新提示生成图像,并请6名注释者进行两张图像的并排比较,这两张图像分别来自原始的Muse和微调后的Muse。注释者在不知道哪个模型用于生成图像A/B的情况下,从五种可能的反应中选择(图像A明显/稍微好于图像B,大致相同,图像B稍微/明显好于图像A)。表5的结果显示,与原始Muse相比,经过RAHF可信度分数微调的Muse具有显著更少的人工痕迹/不可信之处。
展示了一个使用RAHF审美分数作为分类器指导对潜在扩散模型的示例
对于每张图像,首先预测不可信度heatmap,然后通过处理heatmap(使用阈值和扩张)创建一个掩码。在掩码区域内应用Muse修复,生成与文本提示相匹配的新图像。生成多张图像,最终图像由我们的RAHF预测的最高可信度分数选择。
总结来说就是使用训练的模型来判断生成模型中不合理的地方,并使用掩码模型做遮蔽处理,好让模型重新生成有问题的部位,类似图像编辑的内容。