【CVPR2022】MatteFormer: Transformer-Based Image Matting via Prior-Tokens-CSDN博客文章浏览阅读1.2k次。【CVPR2022】 MatteFormer: Transformer-Based Image Matting via Prior-Tokens_matteformerhttps://blog.csdn.net/Thinkobj/article/details/128209388本文核心2点:1.提出了PA-WSA(Prior-Attentive Window self-attention),2.通过trimap生成Prior-tokens。目前基于transformer的matting方法,大多就是把swin transformer修修改改,加一些模块,当然也很难验证其所加模块的有效性,也许是transformer本身的能力带来的算法增益。
composite 1k上的测评结果:
代码:
train_image_file = ImageFileTrain(alpha_dir=config.data.train_alpha,fg_dir=config.data.train_fg,bg_dir=config.data.train_bg)
test_image_file = ImageFileTest(alpha_dir=config.data.test_alpha,merged_dir=config.data.test_merged,trimap_dir=config.data.test.trimap)
train_dataset = DataGenerator(train_image_file, phase='train')
test_dataset = DataGenerator(test_image_file, phase='val')train_dataloader = DataLoader(train_dataset,...)
train_dataloader = Predetcher(train_dataloader)trainer = Trainer(train_dataloader,test_dataloader,...)
- build_model()->
- G = network.get_generator()->
-- generator = Generator_MatteFormer(is_train)
-- encoder = MatteFormer(embed_dim=96,...)
--- patch_embed = PatchEmbed()
--- pos_drop = nn.Dropout()
--- layer = BasicLayer(dim,depth,...)
---- blocks = nn.ModuleList([PASTBlock() for i in range(depth)])
----- attn = PAWSA() norm2 = norm_layer() mlp = MLP()
-- decoder = decoders.__dict__['res_shortcut_decoder']()
- G_optimizer = torch.optim.Adam(G.parameters(),lr...)
- build_lr_scheduler()
trainer.train()
- image, alpha, trimap = image_dict['image'], image_dict['alpha'], image_dict['trimap']
- pred = self.G(image, trimap)
-- inp = torch.cat((image, trimap), axis=1)
-- x = self.encoder(inp, trimap)]
---
-- embedding = x[-1]
-- outs = self.decoder(embedding, x[:-1])
--- x = self.patch_embed(x)
--- trimapmask = F.interpolate(trimapmask, scale_factor=1/4, mode='nearest')
--- # get outs[1] outs.append(self.shortcut[1](F.upsample_bilinear(x, scale_factor=2.0)))
--- # dropout
--- x = x.flatten(2).transpose(1, 2)
--- x = self.pos_drop(x)
--- # get outs[2~5]
--- for i in range(self.num_layers):layer = self.layers[i]trimapmask_ = F.interpolate(trimapmask, scale_factor=1/(pow(2, i)), mode='nearest')area_fg = trimapmask_[:, 0, :, :].unsqueeze(-1) # background areaarea_bg = trimapmask_[:, 2, :, :].unsqueeze(-1) # foreground areaarea_uk = trimapmask_[:, 1, :, :].unsqueeze(-1) # unknown areax_out, H, W, x, Wh, Ww = layer(x, Wh, Ww, area_fg=area_fg, area_bg=area_bg, area_uk=area_uk)if i in self.out_indices:norm_layer = getattr(self, f'norm{i}')x_out = norm_layer(x_out)out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()out = self.shortcut[i+2](out)outs.append(out)return tuple(outs)