引言
人工智能研究的前沿领域见证了显著的交叉融合。将计算机视觉和自然语言处理的领域融合,问题随之而来:AI能否直接从其视觉表现,即从原始像素中辨识和理解语言?在这篇博客中,我试图探究AI从图像中直接理解自然语言的能力。
代码仓库的链接在这里https://github.com/filipbasara0/visual-language-processing。请记住,该项目仍在进行中,README文件和代码仍需一些调整。
问题陈述
从视觉线索中解读语言不仅仅是一个光学字符识别(OCR)任务;它关乎理解上下文、语义甚至掩盖的信息。用通俗的话说,我正试图弄清楚AI模型如何能像我们一样,从原始图像像素中理解文本。为此,我采用了自监督的掩码语言模型(MLM)范式。
图1 - 用于训练的小型、中型和大型图像示例。标记的遮蔽方式类似于BERT的掩码语言模型(MLM)任务。文本少于50个标记的视为小型,少于100个的视为中型,其他的为大型。这样做是为了简化数据生成。这也是出于课程学习的考虑。该数据集包含多种字体大小和字体类型。
为了提供一些背景 — 最初,我在没有MLM部分的情况下构建了这个任务。目标是从图像中直接重构文本,就像光学字符识别(OCR)那样,但是端到端实现。这个方法效果不错,但模型未能理解文本。
为了解决这个问题,我在图像上遮蔽了一些词,留下未遮蔽的目标文本;图1中可以看到几个示例。这大大提高了理解文本的能力,但仍有很大的提升空间和更多任务表述方式。其中一种可能是让模型重构图像中被移除的像素。
从高层次来看,架构相当简单。
图1 - 用于训练的小型、中型和大型图像示例。标记的遮蔽方式类似于BERT的掩码语言模型(MLM)任务。文本少于50个标记的视为小型,少于100个的视为中型,其他的为大型。这样做是为了简化数据生成。这也是出于课程学习的考虑。该数据集包含多种字体大小和字体类型。
为了提供一些背景 — 最初,我在没有MLM部分的情况下构建了这个任务。目标是从图像中直接重构文本,就像光学字符识别(OCR)那样,但是端到端实现。这个方法效果不错,但模型未能理解文本。
为了解决这个问题,我在图像上遮蔽了一些词,留下未遮蔽的目标文本;图1中可以看到几个示例。这大大提高了理解文本的能力,但仍有很大的提升空间和更多任务表述方式。其中一种可能是让模型重构图像中被移除的像素。
从高层次来看,架构相当简单。
图2 - 用于训练的架构概览。
对于下游任务,如分类,变换器解码器和线性层会被移除。
卷积神经网络(CNN)用于捕获文本特征,对于收敛非常重要。编码器应该学习理解上下文,然后由解码器使用这些上下文输出重构的标记或预测的掩码标记。线性层将解码器的输出映射到我们词汇表中的一个标记。
有可能可以移除CNN层(我们可以使用类似于视觉变换器(ViT)的补丁),但到目前为止,我还没有在没有它的情况下取得好的结果。
架构
这一节稍微技术性一些,如果您只对结果感兴趣,可以跳过!
项目的核心依赖于混合卷积神经网络(CNNs)和变换器模型。让我们更深入地了解架构!
使用CNNs进行特征提取 - 这一层对于收敛非常重要,但仅增加其复杂性并没有放大结果。这表明潜在的饱和或需要更复杂的架构调整。
- ResNet特征提取器 - 我们的模型采用了一个基于ResNet的CNN,针对当前任务进行了稍许调整。该模块负责将原始图像转换为平整的特征映射集。这些特征映射中的每一个都捕获了图像中存在的复杂模式,为进一步处理做好准备。
import torch.nn as nn
from torchvision.models import resnet50class ResNetFeatureExtractor(nn.Module):def __init__(self, feature_map_size, out_features_size):super(ResNetFeatureExtractor, self).__init__()self.feature_extractor = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])self.input_proj = nn.Conv2d(feature_map_size,out_features_size,kernel_size=1)def forward(self, x):x = self.feature_extractor(x)x = self.input_proj(x)x = x.view(x.size(0), x.size(1), -1).permute(0, 2, 1)return x
SinePositionalEncoding: 继ResNet模块之后,SinePositionalEncoding接管。这一层至关重要,因为它为我们的特征映射注入了位置信息。与纯自然语言处理任务中的序列不同,图像本身并没有固有的序列,而这种位置编码为我们的模型提供了空间洞察。以下是用于位置编码的代码。一个潜在的改进可能是使用二维位置编码。
class SinePositionalEncoding(nn.Module):def __init__(self, d_model, max_len=1000):super(SinePositionalEncoding, self).__init__()pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)self.register_buffer("pe", pe)def forward(self, x):return x + self.pe[:, :x.size(1)].requires_grad_(False)
变换器和视觉语言处理(VLP)模型 - 选择变换器是因为其通用性和可扩展性。从CNN输出的信息被转发到变换器编码器,然后其输出传递给变换器解码器。这种编码器-解码器过程使我们能够逐标记处理图像数据,包括在某些情况下预测掩码标记。
- 变换器配置:通过model_config_factory函数,我们的架构保持高度可配置。它允许轻松指定模型参数,如model_dim、num_heads等。这确保我可以根据我们数据的需求和以往实验的洞察快速调整和扩展我们的模型。我遵循传统的变换器实现;完整代码可以在这里找到。以下是最佳性能模型所使用的配置:
...,
"encoder_decoder_lg": {"model_dim": 768,"ff_dim": 4096,"num_heads": 16,"num_layers": 12,"feature_map_size": 2048, # from the resnet model"dec_div": 2
}, ...
视觉语言处理(VLP)类:这是核心部分。一旦使用特征提取器提取了图像特征,这些特征就被传入我们的变换器进行处理。变换器的结构可以根据选择而变化 - 可以是编码器、解码器,或两者的组合。
class VLP(nn.Module):def __init__(self, model_dim, num_layers, ff_dim, num_heads,feature_map_size, vocab_size, dropout, transformer_type,dec_div):super(VLP, self).__init__()self.feature_extractor = nn.Sequential(ResNetFeatureExtractor(feature_map_size=feature_map_size,out_features_size=model_dim),SinePositionalEncoding(model_dim))self.transformer = get_transformer(num_layers=num_layers,model_dim=model_dim,ff_dim=ff_dim,num_heads=num_heads,vocab_size=vocab_size,dropout=dropout,transformer_type=transformer_type,dec_div=dec_div)def get_image_features(self, images):return self.feature_extractor(images)def forward(self, images, tgt=None, tgt_mask=None):image_features = self.get_image_features(images)return self.transformer(image_features, tgt, tgt_mask=tgt_mask)
使用VLPForTextMLM处理文本:建立在视觉语言处理(VLP)模型之上,VLPForTextMLM类专为我们的掩码语言模型(MLM)任务定制。它增加了一个额外的线性层,将变换器的输出映射到我们期望的类别数量(标记)。
class VLPForTextMLM(nn.Module):def __init__(self,model_dim,num_layers,num_heads,ff_dim,feature_map_size,num_classes,dec_div=2,dropout=0.0):super(VLPForTextMLM, self).__init__()self.vlp = VLP(num_layers=num_layers,model_dim=model_dim,ff_dim=ff_dim,num_heads=num_heads,feature_map_size=feature_map_size,vocab_size=None,dropout=dropout,transformer_type="encoder",dec_div=dec_div)self.out = nn.Linear(model_dim, num_classes)def forward(self, images):out = self.vlp(images)return self.out(out)
设计上的不对称性 - 编码器和解码器在大小上并不对称。这是一个有意的设计选择,编码器配备了更多的参数。目的是让编码器存储尽可能多的语言知识。这个选择在MNLI数据集结果上产生了积极的反馈。推理以自回归方式进行,使用贪婪解码。
数据和训练
我们的掩码语言模型(MLM)任务的训练数据由Wikipedia和Bookcorpus数据集的子集(大约30-50%,因硬件和时间限制)组成,文本经过筛选以保持少于144个标记的长度。我花费了最多的时间在创建数据集上 - 这方面还有许多改进的空间。该模型经过大约150万次迭代训练,批处理大小为16。encoder_decoder_lg变体的模型有1.29亿参数。
选择这种标记长度是在计算效率和足够的上下文信息之间的平衡。此外,创建包含大量文本的图像更加困难。
训练循环相当简单:
1. 通过VLPForTextMLM将图像进行前向传递。
2. 根据预测标记和实际标记之间的差异计算损失。
3. 反向传播误差以调整模型权重。
4. 迭代此过程多次,直至收敛或模型性能达到平稳。
我使用了交叉熵损失、AdamW优化器、OneCycle余弦退火学习率策略和混合精度进行训练。
结果和观察
为了测试学习表示的质量,我在MNLI文本蕴涵下游任务上训练了VLP模型。给定一个前提句和一个假设句,任务是预测前提是否蕴涵假设(蕴涵)、与假设相矛盾(矛盾)或两者皆非(中立)。这是一个分类任务,因此解码器层被移除。
图3 - 左侧是输入模型的图像,右侧是在MNLI数据集推理时的注意力图。注意力图是使用综合梯度方法提取的。
在MNLI数据集上实现了0.73的F1分数,模型显示出潜力。然而,它并非没有挑战。模型对CNN层的强依赖以及在增加CNN大小后结果停滞不前,表明潜在的瓶颈。
图4 - 来自MNLI数据集的验证图像上的额外示例。这两个例子都被正确分类。
为了分析学习到的嵌入的质量,我在imdb情感分类任务(正面/负面)上进行了线性探测,设置如下:
- 不限制标记数量 — 在这个设置中,VLP模型达到了0.7的F1分数,相较于BERT的0.8。这是有道理的,因为VLP模型没有在超过144个标记的文本上进行训练。
- 将标记数量限制为144 — 在这个设置中,VLP表现得更好,F1分数达到0.78,相比之下BERT为0.82。
总的来说,线性探测显示出有希望的结果,但在表示学习方法上仍有很大的改进空间。此外,探索VLP模型在更大上下文尺寸下的扩展性也将是一个有趣的课题。
当前状态和未来方向
该项目目前状态在文本识别方面表现良好。它熟练地理解了语言结构,如动词使用、标点符号和共指。然而,在需要记忆的任务中,例如“法国的首都是[MASK]”(不会产生“巴黎”)时,它表现不佳。该项目正适合进行探索,计划在像SQUAD和Ontonotes这样的数据集上进行训练。
解决性能停滞的一些显而易见的方法可能是:
- 扩展和改进训练数据集,
- 重新审视架构并引入替代技术。
我非常期待与同行研究人员和爱好者交流,欢迎您的见解和贡献!如果您有任何想法或问题,请随时与我联系!感谢您花时间阅读博客文章,希望您玩得开心!