论文解读:DiAD之SG网络

目录

  • 一、SG网络功能介绍
  • 二、SG网络代码实现

一、SG网络功能介绍

DiAD论文最主要的创新点就是使用SG网络解决多类别异常检测中的语义信息丢失问题,那么它是怎么实现的保留原始图像语义信息的同时重建异常区域?

与稳定扩散去噪网络的连接: SG网络被设计为与稳定扩散(Stable Diffusion, SD)去噪网络相连接。SD去噪网络本身具有强大的图像生成能力,但可能无法在多类异常检测任务中保持图像的语义信息一致性。SG网络通过引入语义引导机制,使得在重构异常区域时能够参考并保留原始图像的语义上下文。整个框架图中,SG网络与去噪网络的连接如下图所示。
在这里插入图片描述在这里插入图片描述
这是论文给出的最终输出,我认为图中圈出来的地方有问题,应该改为SG网络的编码器才对。

语义一致性保持: SG网络在重构过程中,通过在不同尺度下处理噪声,并利用空间感知特征融合(Spatial-aware Feature Fusion, SFF)块融合特征,确保重建过程中保留语义信息。这样,即使在重构异常区域时,也能使修复后的区域与原始图像的语义上下文保持一致。
多尺度特征融合: SFF块将高尺度的语义信息集成到低尺度中,使得在保留原始正常样本信息的同时,能够处理大规模异常区域的重建。这种机制有助于在处理需要广泛重构的区域时,最大化重构的准确性,同时保持图像的语义一致性。从下图中可以看到,特征融合模块还是很好理解的。

在这里插入图片描述

与预训练特征提取器的结合: SG网络还与特征空间中的预训练特征提取器相结合。预训练特征提取器能够处理输入图像和重建图像,并在不同尺度上提取特征。通过比较这些特征,系统能够生成异常图(anomaly maps),这些图显示了图像中可能存在的异常区域,并给出了异常得分或置信度。这一步骤进一步验证了SG网络在保留语义信息方面的有效性。
避免类别错误: 相比于传统的扩散模型(如DDPM),SG网络通过引入类别条件解决了在多类异常检测任务中可能出现的类别错误问题。LDM虽然通过交叉注意力引入了条件约束,但在随机高斯噪声下去噪时仍可能丢失语义信息。SG网络则通过其语义引导机制,有效地避免了这一问题。

二、SG网络代码实现

这部分代码大概有300行

class SemanticGuidedNetwork(nn.Module):def __init__(self,image_size,in_channels,model_channels,hint_channels,num_res_blocks,attention_resolutions,dropout=0,channel_mult=(1, 2, 4, 8),conv_resample=True,dims=2,use_checkpoint=False,use_fp16=False,num_heads=-1,num_head_channels=-1,num_heads_upsample=-1,use_scale_shift_norm=False,resblock_updown=False,use_new_attention_order=False,use_spatial_transformer=False,  # custom transformer supporttransformer_depth=1,  # custom transformer supportcontext_dim=None,  # custom transformer supportn_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq modellegacy=True,disable_self_attentions=None,num_attention_blocks=None,disable_middle_self_attn=False,use_linear_in_transformer=False,):super().__init__()if use_spatial_transformer:assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'if context_dim is not None:assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'from omegaconf.listconfig import ListConfigif type(context_dim) == ListConfig:context_dim = list(context_dim)if num_heads_upsample == -1:num_heads_upsample = num_headsif num_heads == -1:assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'if num_head_channels == -1:assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'self.dims = dimsself.image_size = image_sizeself.in_channels = in_channelsself.model_channels = model_channelsif isinstance(num_res_blocks, int):self.num_res_blocks = len(channel_mult) * [num_res_blocks]else:if len(num_res_blocks) != len(channel_mult):raise ValueError("provide num_res_blocks either as an int (globally constant) or ""as a list/tuple (per-level) with the same length as channel_mult")self.num_res_blocks = num_res_blocksif disable_self_attentions is not None:# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or notassert len(disable_self_attentions) == len(channel_mult)if num_attention_blocks is not None:assert len(num_attention_blocks) == len(self.num_res_blocks)assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "f"This option has LESS priority than attention_resolutions {attention_resolutions}, "f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "f"attention will still not be set.")self.attention_resolutions = attention_resolutionsself.dropout = dropoutself.channel_mult = channel_multself.conv_resample = conv_resampleself.use_checkpoint = use_checkpointself.dtype = th.float16 if use_fp16 else th.float32self.num_heads = num_headsself.num_head_channels = num_head_channelsself.num_heads_upsample = num_heads_upsampleself.predict_codebook_ids = n_embed is not Nonetime_embed_dim = model_channels * 4self.time_embed = nn.Sequential(linear(model_channels, time_embed_dim),nn.SiLU(),linear(time_embed_dim, time_embed_dim),)self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))])self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])self.input_hint_block = TimestepEmbedSequential(conv_nd(dims, hint_channels, 16, 3, padding=1),nn.SiLU(),conv_nd(dims, 16, 16, 3, padding=1),nn.SiLU(),conv_nd(dims, 16, 32, 3, padding=1, stride=2),nn.SiLU(),conv_nd(dims, 32, 32, 3, padding=1),nn.SiLU(),conv_nd(dims, 32, 96, 3, padding=1, stride=2),nn.SiLU(),conv_nd(dims, 96, 96, 3, padding=1),nn.SiLU(),conv_nd(dims, 96, 256, 3, padding=1, stride=2),nn.SiLU(),zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)))self._feature_size = model_channelsinput_block_chans = [model_channels]ch = model_channelsds = 1for level, mult in enumerate(channel_mult):for nr in range(self.num_res_blocks[level]):layers = [ResBlock(ch,time_embed_dim,dropout,out_channels=mult * model_channels,dims=dims,use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm,)]ch = mult * model_channelsif ds in attention_resolutions:if num_head_channels == -1:dim_head = ch // num_headselse:num_heads = ch // num_head_channelsdim_head = num_head_channelsif legacy:# num_heads = 1dim_head = ch // num_heads if use_spatial_transformer else num_head_channelsif exists(disable_self_attentions):disabled_sa = disable_self_attentions[level]else:disabled_sa = Falseif not exists(num_attention_blocks) or nr < num_attention_blocks[level]:layers.append(AttentionBlock(ch,use_checkpoint=use_checkpoint,num_heads=num_heads,num_head_channels=dim_head,use_new_attention_order=use_new_attention_order,) if not use_spatial_transformer else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,use_checkpoint=use_checkpoint))self.input_blocks.append(TimestepEmbedSequential(*layers))self.zero_convs.append(self.make_zero_conv(ch))self._feature_size += chinput_block_chans.append(ch)if level != len(channel_mult) - 1:out_ch = chself.input_blocks.append(TimestepEmbedSequential(ResBlock(ch,time_embed_dim,dropout,out_channels=out_ch,dims=dims,use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm,down=True,)if resblock_updownelse Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)))ch = out_chinput_block_chans.append(ch)self.zero_convs.append(self.make_zero_conv(ch))ds *= 2self._feature_size += chif num_head_channels == -1:dim_head = ch // num_headselse:num_heads = ch // num_head_channelsdim_head = num_head_channelsif legacy:# num_heads = 1dim_head = ch // num_heads if use_spatial_transformer else num_head_channelsself.middle_block = TimestepEmbedSequential(ResBlock(ch,time_embed_dim,dropout,dims=dims,use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm,),AttentionBlock(ch,use_checkpoint=use_checkpoint,num_heads=num_heads,num_head_channels=dim_head,use_new_attention_order=use_new_attention_order,) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attnch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,use_checkpoint=use_checkpoint),ResBlock(ch,time_embed_dim,dropout,dims=dims,use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm,),)self.middle_block_out = self.make_zero_conv(ch)self._feature_size += ch#SFF Blockself.down11 = nn.Sequential(zero_module(nn.Conv2d(640, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down12 = nn.Sequential(zero_module(nn.Conv2d(640, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down13 = nn.Sequential(zero_module(nn.Conv2d(640, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down21 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down22 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down23 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down31 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down32 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down33 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.silu = nn.SiLU()def make_zero_conv(self, channels):return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))def forward(self, x, hint, timesteps, context, **kwargs):t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)emb = self.time_embed(t_emb)guided_hint = self.input_hint_block(hint, emb, context)outs = []h = x.type(self.dtype)for module, zero_conv in zip(self.input_blocks, self.zero_convs):if guided_hint is not None:h = module(h, emb, context)h += guided_hintguided_hint = Noneelse:h = module(h, emb, context)outs.append(zero_conv(h, emb, context))#SFF Block Implementationouts[9] = self.silu(outs[9]+self.down11(outs[6])+self.down21(outs[7])+self.down31(outs[8]))outs[10] = self.silu(outs[10]+self.down12(outs[6])+self.down22(outs[7])+self.down32(outs[8]))outs[11] = self.silu(outs[11]+self.down13(outs[6])+self.down23(outs[7])+self.down33(outs[8]))h = self.middle_block(h, emb, context)outs.append(self.middle_block_out(h, emb, context))return outs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/51550.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

昇思25天学习打卡营第3天|基础知识-数据集Dataset

目录 环境 环境 导包 数据集加载 数据集迭代 数据集常用操作 shuffle map batch 自定义数据集 可随机访问数据集 可迭代数据集 生成器 MindSpore提供基于Pipeline的数据引擎&#xff0c;通过数据集&#xff08;Dataset&#xff09;和数据变换&#xff08;Transfor…

skynet 入门篇

文章目录 概述1.实现了actor模型2.实现了服务器的基础组件 环境准备centosubuntumac编译安装 ActorActor模型定义组成 Actor调度工作线程流程工作线程权重工作线程执行规则 小结 概述 skynet 是一个轻量级服务器框架&#xff0c;而不仅仅用于游戏&#xff1b; 轻量级有以下几…

C语言百分号打印器

目录 开头程序程序的流程图程序输入与输出的效果例1输入输出 例2输入输出 例3输入输出 结尾 开头 大家好&#xff0c;我叫这是我58。今天&#xff0c;我们来看一下我用C语言编译的百分号打印器和与之相关的一些东西。 程序 #define _CRT_SECURE_NO_WARNINGS 1 #include <…

【RabbitMQ】MQ相关概念

一、MQ的基本概念 定义&#xff1a;MQ全称为Message Queue&#xff0c;是一种提供消息队列服务的中间件&#xff0c;也称为消息中间件。它允许应用程序通过读写队列中的消息来进行通信&#xff0c;而无需建立直接的连接。作用&#xff1a;主要用于分布式系统之间的通信&#x…

CANoe在使用时碰到的一些很少见的Bug

CANoe作为一款成熟且稳定的总线仿真与测试工具&#xff0c;深受汽车工程师们的喜爱。CANoe虽然稳定&#xff0c;但作为一个软件来说&#xff0c;在使用中总会出现一些或大或小的Bug。最近全球范围内的大规模蓝屏事件&#xff0c;是由某个安全软件引起的。而很多CANoe使用者最近…

【中项】系统集成项目管理工程师-第7章 软硬件系统集成-7.2基础设施集成

前言&#xff1a;系统集成项目管理工程师专业&#xff0c;现分享一些教材知识点。觉得文章还不错的喜欢点赞收藏的同时帮忙点点关注。 软考同样是国家人社部和工信部组织的国家级考试&#xff0c;全称为“全国计算机与软件专业技术资格&#xff08;水平&#xff09;考试”&…

【React】详解classnames工具:优化类名控制的全面指南

文章目录 一、classnames的基本用法1. 什么是classnames&#xff1f;2. 安装classnames3. 导入classnames4. classnames的基本示例 二、classnames的高级用法1. 动态类名2. 传递数组3. 结合字符串和对象4. 结合数组和对象 三、实际应用案例1. 根据状态切换类名2. 条件渲染和类名…

Kafka消息队列

目录 什么是消息队列 高可用性 高扩展性 高可用性 持久化和过期策略 consumer group 分组消费 ZooKeeper 什么是消息队列 普通版消息队列 说白了就是一个队列,生产者生产多少,放在消息队列中存储,而消费者想要多少拿多少,按序列号消费 缓存信息 生产者与消费者解耦…

VulnHub靶机入门篇--Kioptrix4

1.环境配置 下载地址&#xff1a; https://download.vulnhub.com/kioptrix/Kioptrix4_vmware.rar 下载完解压之后是一个vdmk文件&#xff0c;我们需要先创建一个新的虚拟机&#xff0c;将vdmk文件导入就行了 先移除原先硬盘&#xff0c;然后再进行添加&#xff0c;网络连接为…

EV代码签名证书具体申请流程

EV&#xff08;扩展验证&#xff09;代码签名证书是一种用于对代码进行数字签名的安全证书&#xff0c;它可以帮助用户验证软件发布者的身份&#xff0c;并确保软件未被篡改。对于Windows硬件开发者来说&#xff0c;这种证书尤其重要&#xff0c;因为它可以用来注册Windows硬件…

【Golang 面试 - 基础题】每日 5 题(八)

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/UWz06 &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会分享 Golang 面试中常见的面试题给大家~ ❤️如果有收获的话&#xff0c;欢迎点赞&#x1f44d;收藏…

六、2 写PWM代码(函数介绍、呼吸灯代码)

目录 一、1、步骤 2、函数介绍 3、外设引脚和GPIO引脚的复用关系&#xff08;引脚定义表&#xff09; 二、1、呼吸灯 步骤 &#xff08;1&#xff09;初始化通道 1&#xff09;输出比较模式 2&#xff09;输出比较极性 &#xff08;2&#xff09;配置GPIO &#xff08…

肆[4],VisionMaster全局触发测试说明

1&#xff0c;环境 VisionMaster4.3 2&#xff0c;实现功能 2.1&#xff0c;全局触发进行流程控制执行。 2.2&#xff0c;取像完成&#xff0c;立即运动到下一个位置&#xff0c;同步进行图片处理。 2.3&#xff0c;发送结果的同时&#xff0c;还需要显示图像处理的痕迹。 …

H616设计时候存在的问题

1.存在大量孤铜的问题&#xff1a; 这种情况是绝对不允许的&#xff0c;但是GBA焊盘打大量的过孔会出现很多这样的孤铜&#xff1a; 解决办法&#xff1a; 像这种出现大量重复焊盘的&#xff0c;用导线连接起来&#xff0c;之后铺铜形成铜皮&#xff0c;再在这个小铜皮上面打…

全网首创!基于GaitSet的一种多人步态识别方法公示

有源代码V细聊&#xff0c;可商用/私用/毕设等&#xff1a;NzqDssm16 &#x1f349;1 绪论 经过相关研究确认&#xff0c;步态识别是足以达到应用级别的生物识别技术&#xff0c;在现代社会中自始至终都存在着广泛的应用前景。之所以迟迟没有普及&#xff0c;主要是实…

【Oracle 进阶之路】Oracle 简介

一、简述 Oracle Database&#xff0c;又名Oracle RDBMS&#xff0c;或简称Oracle。是甲骨文公司的一款关系数据库管理系统。它是在数据库领域一直处于领先地位的产品。可以说Oracle数据库系统是世界上流行的关系数据库管理系统&#xff0c;系统可移植性好、使用方便、功能强&…

华为ensp中链路聚合两种(lacp-static)模式配置方法

作者主页&#xff1a;点击&#xff01; ENSP专栏&#xff1a;点击&#xff01; 创作时间&#xff1a;2024年4月26日11点54分 链路聚合&#xff08;Link Aggregation&#xff09;&#xff0c;又称为端口聚合&#xff08;Port Trunking&#xff09;&#xff0c;是一种将多条物理…

【编程工具使用技巧】VS如何显示行号

&#x1f493; 博客主页&#xff1a;倔强的石头的CSDN主页 &#x1f4dd;Gitee主页&#xff1a;倔强的石头的gitee主页 ⏩ 文章专栏&#xff1a;《编程工具与技巧探索》 期待您的关注 目录 引言 一、VS编译器行号显示的基本步骤 1.打开VS与项目 2.进入选项设置 3.找到并…

【Linux】远程连接Linux虚拟机(MobaXterm)

【Linux】远程连接Linux虚拟机&#xff08;MobaXterm&#xff09; 零、原因 有时候我们在虚拟机中操作Linux不太方便&#xff0c;比如不能复制粘贴&#xff0c;不能传文件等等&#xff0c;我们在主机上使用远程连接软件远程连接Linux虚拟机后可以解决上面的问题。 壹、软件下…

成为git砖家(5): 理解 HEAD

文章目录 1. git rev-parse 命令2. 什么是 HEAD2.1 创建分支当并未切换&#xff0c; HEAD 不变2.2 切换分支&#xff0c;HEAD 改变2.3 再次切换分支&#xff0c; HEAD 再次改变 3. detached HEAD4. HEAD 表示分支、表示 detached HEAD 有什么区别&#xff1f;区别相同点 5. HEA…