论文解读:DiAD之SG网络

目录

  • 一、SG网络功能介绍
  • 二、SG网络代码实现

一、SG网络功能介绍

DiAD论文最主要的创新点就是使用SG网络解决多类别异常检测中的语义信息丢失问题,那么它是怎么实现的保留原始图像语义信息的同时重建异常区域?

与稳定扩散去噪网络的连接: SG网络被设计为与稳定扩散(Stable Diffusion, SD)去噪网络相连接。SD去噪网络本身具有强大的图像生成能力,但可能无法在多类异常检测任务中保持图像的语义信息一致性。SG网络通过引入语义引导机制,使得在重构异常区域时能够参考并保留原始图像的语义上下文。整个框架图中,SG网络与去噪网络的连接如下图所示。
在这里插入图片描述在这里插入图片描述
这是论文给出的最终输出,我认为图中圈出来的地方有问题,应该改为SG网络的编码器才对。

语义一致性保持: SG网络在重构过程中,通过在不同尺度下处理噪声,并利用空间感知特征融合(Spatial-aware Feature Fusion, SFF)块融合特征,确保重建过程中保留语义信息。这样,即使在重构异常区域时,也能使修复后的区域与原始图像的语义上下文保持一致。
多尺度特征融合: SFF块将高尺度的语义信息集成到低尺度中,使得在保留原始正常样本信息的同时,能够处理大规模异常区域的重建。这种机制有助于在处理需要广泛重构的区域时,最大化重构的准确性,同时保持图像的语义一致性。从下图中可以看到,特征融合模块还是很好理解的。

在这里插入图片描述

与预训练特征提取器的结合: SG网络还与特征空间中的预训练特征提取器相结合。预训练特征提取器能够处理输入图像和重建图像,并在不同尺度上提取特征。通过比较这些特征,系统能够生成异常图(anomaly maps),这些图显示了图像中可能存在的异常区域,并给出了异常得分或置信度。这一步骤进一步验证了SG网络在保留语义信息方面的有效性。
避免类别错误: 相比于传统的扩散模型(如DDPM),SG网络通过引入类别条件解决了在多类异常检测任务中可能出现的类别错误问题。LDM虽然通过交叉注意力引入了条件约束,但在随机高斯噪声下去噪时仍可能丢失语义信息。SG网络则通过其语义引导机制,有效地避免了这一问题。

二、SG网络代码实现

这部分代码大概有300行

class SemanticGuidedNetwork(nn.Module):def __init__(self,image_size,in_channels,model_channels,hint_channels,num_res_blocks,attention_resolutions,dropout=0,channel_mult=(1, 2, 4, 8),conv_resample=True,dims=2,use_checkpoint=False,use_fp16=False,num_heads=-1,num_head_channels=-1,num_heads_upsample=-1,use_scale_shift_norm=False,resblock_updown=False,use_new_attention_order=False,use_spatial_transformer=False,  # custom transformer supporttransformer_depth=1,  # custom transformer supportcontext_dim=None,  # custom transformer supportn_embed=None,  # custom support for prediction of discrete ids into codebook of first stage vq modellegacy=True,disable_self_attentions=None,num_attention_blocks=None,disable_middle_self_attn=False,use_linear_in_transformer=False,):super().__init__()if use_spatial_transformer:assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'if context_dim is not None:assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'from omegaconf.listconfig import ListConfigif type(context_dim) == ListConfig:context_dim = list(context_dim)if num_heads_upsample == -1:num_heads_upsample = num_headsif num_heads == -1:assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'if num_head_channels == -1:assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'self.dims = dimsself.image_size = image_sizeself.in_channels = in_channelsself.model_channels = model_channelsif isinstance(num_res_blocks, int):self.num_res_blocks = len(channel_mult) * [num_res_blocks]else:if len(num_res_blocks) != len(channel_mult):raise ValueError("provide num_res_blocks either as an int (globally constant) or ""as a list/tuple (per-level) with the same length as channel_mult")self.num_res_blocks = num_res_blocksif disable_self_attentions is not None:# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or notassert len(disable_self_attentions) == len(channel_mult)if num_attention_blocks is not None:assert len(num_attention_blocks) == len(self.num_res_blocks)assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "f"This option has LESS priority than attention_resolutions {attention_resolutions}, "f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "f"attention will still not be set.")self.attention_resolutions = attention_resolutionsself.dropout = dropoutself.channel_mult = channel_multself.conv_resample = conv_resampleself.use_checkpoint = use_checkpointself.dtype = th.float16 if use_fp16 else th.float32self.num_heads = num_headsself.num_head_channels = num_head_channelsself.num_heads_upsample = num_heads_upsampleself.predict_codebook_ids = n_embed is not Nonetime_embed_dim = model_channels * 4self.time_embed = nn.Sequential(linear(model_channels, time_embed_dim),nn.SiLU(),linear(time_embed_dim, time_embed_dim),)self.input_blocks = nn.ModuleList([TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))])self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])self.input_hint_block = TimestepEmbedSequential(conv_nd(dims, hint_channels, 16, 3, padding=1),nn.SiLU(),conv_nd(dims, 16, 16, 3, padding=1),nn.SiLU(),conv_nd(dims, 16, 32, 3, padding=1, stride=2),nn.SiLU(),conv_nd(dims, 32, 32, 3, padding=1),nn.SiLU(),conv_nd(dims, 32, 96, 3, padding=1, stride=2),nn.SiLU(),conv_nd(dims, 96, 96, 3, padding=1),nn.SiLU(),conv_nd(dims, 96, 256, 3, padding=1, stride=2),nn.SiLU(),zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)))self._feature_size = model_channelsinput_block_chans = [model_channels]ch = model_channelsds = 1for level, mult in enumerate(channel_mult):for nr in range(self.num_res_blocks[level]):layers = [ResBlock(ch,time_embed_dim,dropout,out_channels=mult * model_channels,dims=dims,use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm,)]ch = mult * model_channelsif ds in attention_resolutions:if num_head_channels == -1:dim_head = ch // num_headselse:num_heads = ch // num_head_channelsdim_head = num_head_channelsif legacy:# num_heads = 1dim_head = ch // num_heads if use_spatial_transformer else num_head_channelsif exists(disable_self_attentions):disabled_sa = disable_self_attentions[level]else:disabled_sa = Falseif not exists(num_attention_blocks) or nr < num_attention_blocks[level]:layers.append(AttentionBlock(ch,use_checkpoint=use_checkpoint,num_heads=num_heads,num_head_channels=dim_head,use_new_attention_order=use_new_attention_order,) if not use_spatial_transformer else SpatialTransformer(ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,use_checkpoint=use_checkpoint))self.input_blocks.append(TimestepEmbedSequential(*layers))self.zero_convs.append(self.make_zero_conv(ch))self._feature_size += chinput_block_chans.append(ch)if level != len(channel_mult) - 1:out_ch = chself.input_blocks.append(TimestepEmbedSequential(ResBlock(ch,time_embed_dim,dropout,out_channels=out_ch,dims=dims,use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm,down=True,)if resblock_updownelse Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)))ch = out_chinput_block_chans.append(ch)self.zero_convs.append(self.make_zero_conv(ch))ds *= 2self._feature_size += chif num_head_channels == -1:dim_head = ch // num_headselse:num_heads = ch // num_head_channelsdim_head = num_head_channelsif legacy:# num_heads = 1dim_head = ch // num_heads if use_spatial_transformer else num_head_channelsself.middle_block = TimestepEmbedSequential(ResBlock(ch,time_embed_dim,dropout,dims=dims,use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm,),AttentionBlock(ch,use_checkpoint=use_checkpoint,num_heads=num_heads,num_head_channels=dim_head,use_new_attention_order=use_new_attention_order,) if not use_spatial_transformer else SpatialTransformer(  # always uses a self-attnch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,use_checkpoint=use_checkpoint),ResBlock(ch,time_embed_dim,dropout,dims=dims,use_checkpoint=use_checkpoint,use_scale_shift_norm=use_scale_shift_norm,),)self.middle_block_out = self.make_zero_conv(ch)self._feature_size += ch#SFF Blockself.down11 = nn.Sequential(zero_module(nn.Conv2d(640, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down12 = nn.Sequential(zero_module(nn.Conv2d(640, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down13 = nn.Sequential(zero_module(nn.Conv2d(640, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down21 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down22 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down23 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down31 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down32 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.down33 = nn.Sequential(zero_module(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1, bias=False)),nn.InstanceNorm2d(1280),nn.SiLU(),)self.silu = nn.SiLU()def make_zero_conv(self, channels):return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))def forward(self, x, hint, timesteps, context, **kwargs):t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)emb = self.time_embed(t_emb)guided_hint = self.input_hint_block(hint, emb, context)outs = []h = x.type(self.dtype)for module, zero_conv in zip(self.input_blocks, self.zero_convs):if guided_hint is not None:h = module(h, emb, context)h += guided_hintguided_hint = Noneelse:h = module(h, emb, context)outs.append(zero_conv(h, emb, context))#SFF Block Implementationouts[9] = self.silu(outs[9]+self.down11(outs[6])+self.down21(outs[7])+self.down31(outs[8]))outs[10] = self.silu(outs[10]+self.down12(outs[6])+self.down22(outs[7])+self.down32(outs[8]))outs[11] = self.silu(outs[11]+self.down13(outs[6])+self.down23(outs[7])+self.down33(outs[8]))h = self.middle_block(h, emb, context)outs.append(self.middle_block_out(h, emb, context))return outs

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/51550.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

从零开始:MT4软件下载与初步使用教程

MetaTrader 4&#xff08;简称MT4&#xff09;是一款由MetaQuotes开发并广泛使用的在线金融衍生品交易终端。它不仅在外汇市场占据重要地位&#xff0c;还支持各种股指和大宗商品的差价合约&#xff08;CFD&#xff09;产品交易。本文将详细指导您如何从零开始下载并初步使用MT…

昇思25天学习打卡营第3天|基础知识-数据集Dataset

目录 环境 环境 导包 数据集加载 数据集迭代 数据集常用操作 shuffle map batch 自定义数据集 可随机访问数据集 可迭代数据集 生成器 MindSpore提供基于Pipeline的数据引擎&#xff0c;通过数据集&#xff08;Dataset&#xff09;和数据变换&#xff08;Transfor…

skynet 入门篇

文章目录 概述1.实现了actor模型2.实现了服务器的基础组件 环境准备centosubuntumac编译安装 ActorActor模型定义组成 Actor调度工作线程流程工作线程权重工作线程执行规则 小结 概述 skynet 是一个轻量级服务器框架&#xff0c;而不仅仅用于游戏&#xff1b; 轻量级有以下几…

C语言百分号打印器

目录 开头程序程序的流程图程序输入与输出的效果例1输入输出 例2输入输出 例3输入输出 结尾 开头 大家好&#xff0c;我叫这是我58。今天&#xff0c;我们来看一下我用C语言编译的百分号打印器和与之相关的一些东西。 程序 #define _CRT_SECURE_NO_WARNINGS 1 #include <…

Redis数据结构之跳跃表(SkipList)

Redis是一个开源的、使用ANSI C语言编写、支持网络、可基于内存亦可持久化的日志型、Key-Value数据库&#xff0c;并提供多种语言的API。Redis凭借其高性能、高可用性、丰富的数据结构以及简洁的API而备受青睐。其中&#xff0c;跳跃表&#xff08;SkipList&#xff09;作为Red…

【RabbitMQ】MQ相关概念

一、MQ的基本概念 定义&#xff1a;MQ全称为Message Queue&#xff0c;是一种提供消息队列服务的中间件&#xff0c;也称为消息中间件。它允许应用程序通过读写队列中的消息来进行通信&#xff0c;而无需建立直接的连接。作用&#xff1a;主要用于分布式系统之间的通信&#x…

SQL Server数据库管理(五)从权限管理到数据恢复的全面指南

文章目录 SQL Server数据库管理&#xff1a;从权限管理到数据恢复的全面指南引言第一章&#xff1a;权限管理1.1 SQL Server的安全机制1.2 身份验证模式1.3 登录权限设置1.3.1 创建登录账户1.3.2 服务器级别权限 1.4 数据库级别权限1.5 对象级别权限1.6 实验&#xff1a;权限设…

CANoe在使用时碰到的一些很少见的Bug

CANoe作为一款成熟且稳定的总线仿真与测试工具&#xff0c;深受汽车工程师们的喜爱。CANoe虽然稳定&#xff0c;但作为一个软件来说&#xff0c;在使用中总会出现一些或大或小的Bug。最近全球范围内的大规模蓝屏事件&#xff0c;是由某个安全软件引起的。而很多CANoe使用者最近…

Java spring security 自定义登录逻辑实现

介绍 在使用框架自带的Security的登录认证时&#xff0c;默认只能使用用户名去查询&#xff0c;如果有业务需要其他字段也需要进行查询&#xff0c;只能采用根据用户名去找到对应的数据。 自定义鉴权接口CustomUsernamePasswordAuthenticationToken /*** author wuzhenyong* C…

【中项】系统集成项目管理工程师-第7章 软硬件系统集成-7.2基础设施集成

前言&#xff1a;系统集成项目管理工程师专业&#xff0c;现分享一些教材知识点。觉得文章还不错的喜欢点赞收藏的同时帮忙点点关注。 软考同样是国家人社部和工信部组织的国家级考试&#xff0c;全称为“全国计算机与软件专业技术资格&#xff08;水平&#xff09;考试”&…

【React】详解classnames工具:优化类名控制的全面指南

文章目录 一、classnames的基本用法1. 什么是classnames&#xff1f;2. 安装classnames3. 导入classnames4. classnames的基本示例 二、classnames的高级用法1. 动态类名2. 传递数组3. 结合字符串和对象4. 结合数组和对象 三、实际应用案例1. 根据状态切换类名2. 条件渲染和类名…

Kafka消息队列

目录 什么是消息队列 高可用性 高扩展性 高可用性 持久化和过期策略 consumer group 分组消费 ZooKeeper 什么是消息队列 普通版消息队列 说白了就是一个队列,生产者生产多少,放在消息队列中存储,而消费者想要多少拿多少,按序列号消费 缓存信息 生产者与消费者解耦…

VulnHub靶机入门篇--Kioptrix4

1.环境配置 下载地址&#xff1a; https://download.vulnhub.com/kioptrix/Kioptrix4_vmware.rar 下载完解压之后是一个vdmk文件&#xff0c;我们需要先创建一个新的虚拟机&#xff0c;将vdmk文件导入就行了 先移除原先硬盘&#xff0c;然后再进行添加&#xff0c;网络连接为…

EV代码签名证书具体申请流程

EV&#xff08;扩展验证&#xff09;代码签名证书是一种用于对代码进行数字签名的安全证书&#xff0c;它可以帮助用户验证软件发布者的身份&#xff0c;并确保软件未被篡改。对于Windows硬件开发者来说&#xff0c;这种证书尤其重要&#xff0c;因为它可以用来注册Windows硬件…

【Golang 面试 - 基础题】每日 5 题(八)

✍个人博客&#xff1a;Pandaconda-CSDN博客 &#x1f4e3;专栏地址&#xff1a;http://t.csdnimg.cn/UWz06 &#x1f4da;专栏简介&#xff1a;在这个专栏中&#xff0c;我将会分享 Golang 面试中常见的面试题给大家~ ❤️如果有收获的话&#xff0c;欢迎点赞&#x1f44d;收藏…

【DP】01背包

算法-01背包 前置知识 DP 思路 01背包一般分为两种&#xff0c;不妨叫做价值01背包和判断01背包。 价值01背包 01背包问题是这样的一类问题&#xff1a;给定一个背包的容量 m m m 和 n n n 个物品&#xff0c;每个物品有重量 w w w 和价值 v v v&#xff0c;求不超过背…

unity 导出 资源 -- 的 Player Settings... Inspector ----配置文件

--------------------------- unity 导出 资源 -- 的 Player Settings... Inspector ----配置文件名称--------OK 配置 文件位置&#xff1a;E:\BL\client\ProjectSettings\ProjectSettings.asset 具体操作&#xff1a; 复制一个备份配置 ------.unity--File--Build-Setting…

六、2 写PWM代码(函数介绍、呼吸灯代码)

目录 一、1、步骤 2、函数介绍 3、外设引脚和GPIO引脚的复用关系&#xff08;引脚定义表&#xff09; 二、1、呼吸灯 步骤 &#xff08;1&#xff09;初始化通道 1&#xff09;输出比较模式 2&#xff09;输出比较极性 &#xff08;2&#xff09;配置GPIO &#xff08…

Zabbix 部署 - docker

考虑方便移植&#xff0c;多环境部署&#xff0c;整体采用 docker-compose 方式部署 docker-compose 总共4个服务&#xff0c;数据库 后台服务 前端服务 Agent version: 3.7 services:zabbix-mysql:container_name: zabbix-mysqlimage: mysql:5.7.40restart: alwaysenviro…

肆[4],VisionMaster全局触发测试说明

1&#xff0c;环境 VisionMaster4.3 2&#xff0c;实现功能 2.1&#xff0c;全局触发进行流程控制执行。 2.2&#xff0c;取像完成&#xff0c;立即运动到下一个位置&#xff0c;同步进行图片处理。 2.3&#xff0c;发送结果的同时&#xff0c;还需要显示图像处理的痕迹。 …