好的,我们先聚焦第一个突破点:
通过类似数据蒸馏或主动学习采样的方法,更加高效地学习良品数据分布。
这里我提供一个完整的代码示例:
✅ Masked图像重建 + 残差热力图
这属于自监督蒸馏方法的一个变体:
- 使用一个 预训练MAE模型(或轻量ViT)对正常样本进行遮挡重建
- 用重建图与原图的残差来反映“异常程度”
✅ 示例环境依赖
pip install timm einops torchvision matplotlib
✅ 完整代码(以MVTec中的图像为例)
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.utils import save_image
from torchvision.datasets.folder import default_loader
from einops import rearrange
import timm
import matplotlib.pyplot as plt
import os
from glob import glob
from PIL import Image
import numpy as np# ---------------------------
# 模型定义:ViT作为Encoder + 简单Decoder
# ---------------------------
class MAE(nn.Module):def __init__(self, encoder_name='vit_base_patch16_224', mask_ratio=0.4):super().__init__()self.encoder = timm.create_model(encoder_name, pretrained=True)self.mask_ratio = mask_ratioself.patch_size = self.encoder.patch_embed.patch_size[0]self.num_patches = self.encoder.patch_embed.num_patchesself.embed_dim = self.encoder.embed_dimself.decoder = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim),nn.GELU(),nn.Linear(self.embed_dim, self.patch_size**2 * 3))def forward(self, x):B, C, H, W = x.shapex_patch = self.encoder.patch_embed(x) # [B, num_patches, dim]B, N, D = x_patch.shape# 随机遮挡rand_idx = torch.rand(B, N).argsort(dim=1)num_keep = int(N * (1 - self.mask_ratio))keep_idx = rand_idx[:, :num_keep]x_keep = torch.gather(x_patch, 1, keep_idx.unsqueeze(-1).expand(-1, -1, D))x_encoded = self.encoder.blocks(x_keep)x_decoded = self.decoder(x_encoded)# 恢复顺序(只对keep部分重建)output = torch.zeros(B, N, self.patch_size**2 * 3).to(x.device)output.scatter_(1, keep_idx.unsqueeze(-1).expand(-1, -1, self.patch_size**2 * 3), x_decoded)output = rearrange(output, 'b n (p c) -> b c (h p) (w p)', p=self.patch_size, c=3, h=int(H/self.patch_size), w=int(W/self.patch_size))return output# ---------------------------
# 数据加载 + 预处理
# ---------------------------
transform = T.Compose([T.Resize((224, 224)),T.ToTensor(),T.Normalize([0.5]*3, [0.5]*3)
])inv_transform = T.Compose([T.Normalize(mean=[-1]*3, std=[2]*3)
])def load_images(path):files = sorted(glob(os.path.join(path, '*.png')) + glob(os.path.join(path, '*.jpg')))images = []for f in files:img = default_loader(f)images.append(transform(img))return torch.stack(images)# ---------------------------
# 测试图像 → 重建图像 → 残差热图
# ---------------------------
def visualize_anomaly(original, recon, save_path='result.png'):residual = (original - recon).abs().sum(dim=1, keepdim=True)residual = residual / residual.max()fig, axs = plt.subplots(1, 3, figsize=(12, 4))axs[0].imshow(inv_transform(original[0]).permute(1, 2, 0).cpu().numpy())axs[0].set_title('Original')axs[1].imshow(inv_transform(recon[0]).permute(1, 2, 0).cpu().numpy())axs[1].set_title('Reconstruction')axs[2].imshow(residual[0, 0].cpu().numpy(), cmap='hot')axs[2].set_title('Anomaly Map')for ax in axs: ax.axis('off')plt.tight_layout()plt.savefig(save_path)plt.close()# ---------------------------
# 主程序执行
# ---------------------------
if __name__ == '__main__':device = 'cuda' if torch.cuda.is_available() else 'cpu'model = MAE().to(device)model.eval()# 替换为 MVTec / VisA 任一类别路径image_dir = './mvtec/bottle/good/' # 只加载良品图像images = load_images(image_dir).to(device)with torch.no_grad():for i in range(min(5, len(images))):input_img = images[i:i+1]recon_img = model(input_img)visualize_anomaly(input_img, recon_img, f'output_{i}.png')
✅ 示例输出(保存为output_0.png
等):
- 左:原图
- 中:重建图(模型“理解的良品”)
- 右:异常热图(残差)
在正常样本上,残差图应接近0;如果输入的是异常图像,则对应区域将出现高响应。
✅ 可扩展方向
模块 | 可扩展优化 |
---|---|
Encoder | 更换为轻量ViT(如 vit_tiny_patch16_224 ) |
Mask策略 | 使用结构化遮挡(如Block Mask)提升重建挑战 |
异常图像 | 输入异常样本(如MVTec测试集中defect图)验证泛化能力 |
训练 | 加入重建loss微调,提高良品建模精度 |
如果你希望我进一步扩展为:
- 支持少量异常图像的快速修正版本;
- 或加入主动样本选择机制;
很好,我们继续在上一套基于自监督重建(MAE)方法的基础上,
为其 封装 Gradio Demo,以实现更直观的异常检测体验。
✅ 新增功能目标
- 上传任意图片(良品或异常图)
- 实时显示:
- 原图
- 模型重建图
- 残差热力图(高响应 = 异常区域)
✅ 完整代码(附Gradio界面)
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.utils import save_image
from torchvision.datasets.folder import default_loader
from einops import rearrange
import timm
import gradio as gr
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io# ---------------------------
# 模型定义(同上)
# ---------------------------
class MAE(nn.Module):def __init__(self, encoder_name='vit_base_patch16_224', mask_ratio=0.4):super().__init__()self.encoder = timm.create_model(encoder_name, pretrained=True)self.mask_ratio = mask_ratioself.patch_size = self.encoder.patch_embed.patch_size[0]self.num_patches = self.encoder.patch_embed.num_patchesself.embed_dim = self.encoder.embed_dimself.decoder = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim),nn.GELU(),nn.Linear(self.embed_dim, self.patch_size**2 * 3))def forward(self, x):B, C, H, W = x.shapex_patch = self.encoder.patch_embed(x)B, N, D = x_patch.shaperand_idx = torch.rand(B, N).argsort(dim=1)num_keep = int(N * (1 - self.mask_ratio))keep_idx = rand_idx[:, :num_keep]x_keep = torch.gather(x_patch, 1, keep_idx.unsqueeze(-1).expand(-1, -1, D))x_encoded = self.encoder.blocks(x_keep)x_decoded = self.decoder(x_encoded)output = torch.zeros(B, N, self.patch_size**2 * 3).to(x.device)output.scatter_(1, keep_idx.unsqueeze(-1).expand(-1, -1, self.patch_size**2 * 3), x_decoded)output = rearrange(output, 'b n (p c) -> b c (h p) (w p)', p=self.patch_size, c=3, h=int(H/self.patch_size), w=int(W/self.patch_size))return output# ---------------------------
# 预处理 & 后处理
# ---------------------------
transform = T.Compose([T.Resize((224, 224)),T.ToTensor(),T.Normalize([0.5]*3, [0.5]*3)
])inv_transform = T.Compose([T.Normalize(mean=[-1]*3, std=[2]*3)
])def tensor_to_pil(t):t = inv_transform(t.squeeze(0)).clamp(0, 1)return T.ToPILImage()(t)def residual_map(orig, recon):residual = (orig - recon).abs().sum(dim=1, keepdim=True)residual = residual / (residual.max() + 1e-8)heat = residual.squeeze().cpu().numpy()fig, ax = plt.subplots()ax.imshow(heat, cmap='hot')ax.axis('off')buf = io.BytesIO()plt.savefig(buf, format='png')plt.close(fig)buf.seek(0)return Image.open(buf)# ---------------------------
# 推理函数
# ---------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MAE().to(device)
model.eval()def infer(img_pil):img_tensor = transform(img_pil).unsqueeze(0).to(device)with torch.no_grad():recon = model(img_tensor)recon_img = tensor_to_pil(recon)input_img = tensor_to_pil(img_tensor)heatmap = residual_map(img_tensor, recon)return input_img, recon_img, heatmap# ---------------------------
# Gradio UI
# ---------------------------
demo = gr.Interface(fn=infer,inputs=gr.Image(type="pil", label="上传图像"),outputs=[gr.Image(type="pil", label="原图"),gr.Image(type="pil", label="重建图"),gr.Image(type="pil", label="残差热图")],title="基于良品数据的异常检测(MAE重建)",description="上传图像,模型将重建正常区域并生成异常残差热力图"
)if __name__ == '__main__':demo.launch()
✅ 使用效果
你可以上传如下类型图像进行实时检测:
- ✔️ 良品图像:残差图整体应较为平滑,响应值低;
- ❌ 异常图像(如划痕/破损):残差图中异常区域明显发亮(高响应);
✅ 后续扩展建议:
模块 | 可增强 |
---|---|
重建网络 | 替换为 DRAEM / Reverse Distillation |
异常评分 | 计算全图平均残差 + Otsu二值化分割 |
多样本比较 | 支持目录上传并批量可视化 |
迁移微调 | 用少量目标数据 fine-tune 提升领域鲁棒性 |
需要我下一步为你实现:
- ✅ 残差异常评分 + 二值掩码输出?
- ✅ 支持少量异常样本微调功能?
- ✅ 用 PatchCore / AnomalyCLIP 替换 MAE 结构?
你可以指定下一个要增强的方向,我这边可以直接给出代码。