通过类似数据蒸馏或主动学习采样的方法,更加高效地学习良品数据分布

好的,我们先聚焦第一个突破点:

通过类似数据蒸馏或主动学习采样的方法,更加高效地学习良品数据分布。

这里我提供一个完整的代码示例:

Masked图像重建 + 残差热力图

这属于自监督蒸馏方法的一个变体:

  • 使用一个 预训练MAE模型(或轻量ViT)对正常样本进行遮挡重建
  • 用重建图与原图的残差来反映“异常程度”

✅ 示例环境依赖

pip install timm einops torchvision matplotlib

✅ 完整代码(以MVTec中的图像为例)

import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.utils import save_image
from torchvision.datasets.folder import default_loader
from einops import rearrange
import timm
import matplotlib.pyplot as plt
import os
from glob import glob
from PIL import Image
import numpy as np# ---------------------------
# 模型定义:ViT作为Encoder + 简单Decoder
# ---------------------------
class MAE(nn.Module):def __init__(self, encoder_name='vit_base_patch16_224', mask_ratio=0.4):super().__init__()self.encoder = timm.create_model(encoder_name, pretrained=True)self.mask_ratio = mask_ratioself.patch_size = self.encoder.patch_embed.patch_size[0]self.num_patches = self.encoder.patch_embed.num_patchesself.embed_dim = self.encoder.embed_dimself.decoder = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim),nn.GELU(),nn.Linear(self.embed_dim, self.patch_size**2 * 3))def forward(self, x):B, C, H, W = x.shapex_patch = self.encoder.patch_embed(x)  # [B, num_patches, dim]B, N, D = x_patch.shape# 随机遮挡rand_idx = torch.rand(B, N).argsort(dim=1)num_keep = int(N * (1 - self.mask_ratio))keep_idx = rand_idx[:, :num_keep]x_keep = torch.gather(x_patch, 1, keep_idx.unsqueeze(-1).expand(-1, -1, D))x_encoded = self.encoder.blocks(x_keep)x_decoded = self.decoder(x_encoded)# 恢复顺序(只对keep部分重建)output = torch.zeros(B, N, self.patch_size**2 * 3).to(x.device)output.scatter_(1, keep_idx.unsqueeze(-1).expand(-1, -1, self.patch_size**2 * 3), x_decoded)output = rearrange(output, 'b n (p c) -> b c (h p) (w p)', p=self.patch_size, c=3, h=int(H/self.patch_size), w=int(W/self.patch_size))return output# ---------------------------
# 数据加载 + 预处理
# ---------------------------
transform = T.Compose([T.Resize((224, 224)),T.ToTensor(),T.Normalize([0.5]*3, [0.5]*3)
])inv_transform = T.Compose([T.Normalize(mean=[-1]*3, std=[2]*3)
])def load_images(path):files = sorted(glob(os.path.join(path, '*.png')) + glob(os.path.join(path, '*.jpg')))images = []for f in files:img = default_loader(f)images.append(transform(img))return torch.stack(images)# ---------------------------
# 测试图像 → 重建图像 → 残差热图
# ---------------------------
def visualize_anomaly(original, recon, save_path='result.png'):residual = (original - recon).abs().sum(dim=1, keepdim=True)residual = residual / residual.max()fig, axs = plt.subplots(1, 3, figsize=(12, 4))axs[0].imshow(inv_transform(original[0]).permute(1, 2, 0).cpu().numpy())axs[0].set_title('Original')axs[1].imshow(inv_transform(recon[0]).permute(1, 2, 0).cpu().numpy())axs[1].set_title('Reconstruction')axs[2].imshow(residual[0, 0].cpu().numpy(), cmap='hot')axs[2].set_title('Anomaly Map')for ax in axs: ax.axis('off')plt.tight_layout()plt.savefig(save_path)plt.close()# ---------------------------
# 主程序执行
# ---------------------------
if __name__ == '__main__':device = 'cuda' if torch.cuda.is_available() else 'cpu'model = MAE().to(device)model.eval()# 替换为 MVTec / VisA 任一类别路径image_dir = './mvtec/bottle/good/'  # 只加载良品图像images = load_images(image_dir).to(device)with torch.no_grad():for i in range(min(5, len(images))):input_img = images[i:i+1]recon_img = model(input_img)visualize_anomaly(input_img, recon_img, f'output_{i}.png')

✅ 示例输出(保存为output_0.png等):

  • 左:原图
  • 中:重建图(模型“理解的良品”)
  • 右:异常热图(残差)

在正常样本上,残差图应接近0;如果输入的是异常图像,则对应区域将出现高响应。


✅ 可扩展方向

模块可扩展优化
Encoder更换为轻量ViT(如 vit_tiny_patch16_224
Mask策略使用结构化遮挡(如Block Mask)提升重建挑战
异常图像输入异常样本(如MVTec测试集中defect图)验证泛化能力
训练加入重建loss微调,提高良品建模精度

如果你希望我进一步扩展为:

  • 支持少量异常图像的快速修正版本;
  • 或加入主动样本选择机制;

很好,我们继续在上一套基于自监督重建(MAE)方法的基础上,
为其 封装 Gradio Demo,以实现更直观的异常检测体验。


✅ 新增功能目标

  1. 上传任意图片(良品或异常图)
  2. 实时显示:
    • 原图
    • 模型重建图
    • 残差热力图(高响应 = 异常区域)

✅ 完整代码(附Gradio界面)

import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.utils import save_image
from torchvision.datasets.folder import default_loader
from einops import rearrange
import timm
import gradio as gr
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import io# ---------------------------
# 模型定义(同上)
# ---------------------------
class MAE(nn.Module):def __init__(self, encoder_name='vit_base_patch16_224', mask_ratio=0.4):super().__init__()self.encoder = timm.create_model(encoder_name, pretrained=True)self.mask_ratio = mask_ratioself.patch_size = self.encoder.patch_embed.patch_size[0]self.num_patches = self.encoder.patch_embed.num_patchesself.embed_dim = self.encoder.embed_dimself.decoder = nn.Sequential(nn.Linear(self.embed_dim, self.embed_dim),nn.GELU(),nn.Linear(self.embed_dim, self.patch_size**2 * 3))def forward(self, x):B, C, H, W = x.shapex_patch = self.encoder.patch_embed(x)B, N, D = x_patch.shaperand_idx = torch.rand(B, N).argsort(dim=1)num_keep = int(N * (1 - self.mask_ratio))keep_idx = rand_idx[:, :num_keep]x_keep = torch.gather(x_patch, 1, keep_idx.unsqueeze(-1).expand(-1, -1, D))x_encoded = self.encoder.blocks(x_keep)x_decoded = self.decoder(x_encoded)output = torch.zeros(B, N, self.patch_size**2 * 3).to(x.device)output.scatter_(1, keep_idx.unsqueeze(-1).expand(-1, -1, self.patch_size**2 * 3), x_decoded)output = rearrange(output, 'b n (p c) -> b c (h p) (w p)', p=self.patch_size, c=3, h=int(H/self.patch_size), w=int(W/self.patch_size))return output# ---------------------------
# 预处理 & 后处理
# ---------------------------
transform = T.Compose([T.Resize((224, 224)),T.ToTensor(),T.Normalize([0.5]*3, [0.5]*3)
])inv_transform = T.Compose([T.Normalize(mean=[-1]*3, std=[2]*3)
])def tensor_to_pil(t):t = inv_transform(t.squeeze(0)).clamp(0, 1)return T.ToPILImage()(t)def residual_map(orig, recon):residual = (orig - recon).abs().sum(dim=1, keepdim=True)residual = residual / (residual.max() + 1e-8)heat = residual.squeeze().cpu().numpy()fig, ax = plt.subplots()ax.imshow(heat, cmap='hot')ax.axis('off')buf = io.BytesIO()plt.savefig(buf, format='png')plt.close(fig)buf.seek(0)return Image.open(buf)# ---------------------------
# 推理函数
# ---------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MAE().to(device)
model.eval()def infer(img_pil):img_tensor = transform(img_pil).unsqueeze(0).to(device)with torch.no_grad():recon = model(img_tensor)recon_img = tensor_to_pil(recon)input_img = tensor_to_pil(img_tensor)heatmap = residual_map(img_tensor, recon)return input_img, recon_img, heatmap# ---------------------------
# Gradio UI
# ---------------------------
demo = gr.Interface(fn=infer,inputs=gr.Image(type="pil", label="上传图像"),outputs=[gr.Image(type="pil", label="原图"),gr.Image(type="pil", label="重建图"),gr.Image(type="pil", label="残差热图")],title="基于良品数据的异常检测(MAE重建)",description="上传图像,模型将重建正常区域并生成异常残差热力图"
)if __name__ == '__main__':demo.launch()

✅ 使用效果

你可以上传如下类型图像进行实时检测:

  • ✔️ 良品图像:残差图整体应较为平滑,响应值低;
  • 异常图像(如划痕/破损):残差图中异常区域明显发亮(高响应);

✅ 后续扩展建议:

模块可增强
重建网络替换为 DRAEM / Reverse Distillation
异常评分计算全图平均残差 + Otsu二值化分割
多样本比较支持目录上传并批量可视化
迁移微调用少量目标数据 fine-tune 提升领域鲁棒性

需要我下一步为你实现:

  • ✅ 残差异常评分 + 二值掩码输出?
  • ✅ 支持少量异常样本微调功能?
  • ✅ 用 PatchCore / AnomalyCLIP 替换 MAE 结构?

你可以指定下一个要增强的方向,我这边可以直接给出代码。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/79300.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【课题推荐】多速率自适应卡尔曼滤波(MRAKF)用于目标跟踪

多速率自适应卡尔曼滤波(Multi-Rate Adaptive Kalman Filter, MRAKF)是一种针对多传感器异步数据融合的滤波算法,适用于传感器采样率不同、噪声特性时变的目标跟踪场景。本文给出一个多速率自适应卡尔曼滤波框架,以无人机跟踪场景为例,融合IMU和GPS数据 文章目录 背景多速…

软考 系统架构设计师系列知识点之杂项集萃(49)

接前一篇文章:软考 系统架构设计师系列知识点之杂项集萃(48) 第76题 某文件管理系统在磁盘上建立了位视图(bitmap),记录磁盘的使用情况。若磁盘上物理块的编号依次为:0、1、2、……&#xff1b…

HTTP:七.HTTP缓存

HTTP缓存介绍 HTTP缓存是一种通过存储网络资源的副本,以减少对原始服务器请求的技术。当客户端再次请求相同资源时,如果该资源未过期,服务器可以直接从本地缓存中提供响应,而无需再次从原始服务器获取。这大大减少了网络延迟,提高了加载速度,并减轻了服务器的负载。HTTP…

WPF 图标原地旋转

如何使元素原地旋转 - WPF .NET Framework | Microsoft Learn <ButtonRenderTransformOrigin"0.5,0.5"HorizontalAlignment"Left">Hello,World<Button.RenderTransform><RotateTransform x:Name"MyAnimatedTransform" Angle"…

NO.91十六届蓝桥杯备战|图论基础-图的存储和遍历|邻接矩阵|vector|链式前向星(C++)

图的基本概念 图的定义 图G是由顶点集V和边集E组成&#xff0c;记为G (V, E)&#xff0c;其中V(G)表⽰图G中顶点的有限⾮空集&#xff1b;E(G)表⽰图G中顶点之间的关系&#xff08;边&#xff09;集合。若 V { v 1 , v 2 , … , v n } V \left\{ v_{1},v_{2},\dots,v_{n} …

【项目日记(一)】-仿mudou库one thread oneloop式并发服务器实现

1、模型框架 客户端处理思想&#xff1a;事件驱动模式 事件驱动处理模式&#xff1a;谁触发了我就去处理谁。 &#xff08; 如何知道触发了&#xff09;技术支撑点&#xff1a;I/O的多路复用 &#xff08;多路转接技术&#xff09; 1、单Reactor单线程&#xff1a;在单个线程…

Go语言实现OAuth 2.0认证服务器

文章目录 1. 项目概述1.1 OAuth2 流程 2. OAuth 2.0 Storage接口解析2.1 基础方法2.2 客户端管理相关方法2.3 授权码相关方法2.4 访问令牌相关方法2.5 刷新令牌相关方法 2.6 方法调用时序2.7 关键注意点3. MySQL存储实现原理3.1 数据库设计3.2 核心实现 4. OAuth 2.0授权码流程…

结合 Python 与 MySQL 构建你的 GenBI Agent_基于 MCP Server

写在前面 商业智能(BI)正在经历一场由大型语言模型(LLM)驱动的深刻变革。传统的 BI 工具通常需要用户学习复杂的界面或查询语言,而生成式商业智能 (Generative BI, GenBI) 则旨在让用户通过自然语言与数据交互,提出问题,并获得由 AI 生成的数据洞察、可视化建议甚至完整…

Linux中常用命令

目录 1. linux目录结构 2. linux基本命令操作 2.1 目录操作命令 2.2 文件操作命令 2.3 查看登录用户命名 2.4 文件内容查看命令 2.5 系统管理类命令 3. bash通配符 4. 压缩与解压缩命令 4.1 压缩和解压缩 4.2 测试网络连通性命令 ping 4.3 vi编辑器 4.4 管道操作(…

C++ 与 MySQL 数据库优化实战:破解性能瓶颈,提升应用效率

&#x1f9d1; 博主简介&#xff1a;CSDN博客专家、CSDN平台优质创作者&#xff0c;高级开发工程师&#xff0c;数学专业&#xff0c;10年以上C/C, C#, Java等多种编程语言开发经验&#xff0c;拥有高级工程师证书&#xff1b;擅长C/C、C#等开发语言&#xff0c;熟悉Java常用开…

tcp特点+TCP的状态转换图+time_wait详解

tcp特点TCP的状态转换图time wait详解 目录 一、tcp特点解释 1.1 面向连接 1.1.1 连接建立——三次握手 1.1.2 连接释放——四次挥手 1.2 可靠的 1.2.1 应答确认 1.2.2 超时重传 1.2.3 乱序重排 1.2.4 去重 1.2.5 滑动窗口进行流量控制 1.3 流失服务&#xff08;字节…

探秘 Ruby 与 JavaScript:动态语言的多面风采

1 语法特性对比&#xff1a;简洁与灵活 1.1 Ruby 的语法优雅 Ruby 的语法设计旨在让代码读起来像自然语言一样流畅。它拥有简洁而富有表现力的语法结构&#xff0c;例如代码块、符号等。 以下是一个使用 Ruby 进行数组操作的简单示例&#xff1a; # 定义一个数组 numbers [1…

点评项目回顾

表结构 基于Session实现登录流程 发送验证码&#xff1a; 用户在提交手机号后&#xff0c;会校验手机号是否合法&#xff0c;如果不合法&#xff0c;则要求用户重新输入手机号 如果手机号合法&#xff0c;后台此时生成对应的验证码&#xff0c;同时将验证码进行保存&#xf…

OpenShift介绍,跟 Kubernetes ,Docker关系

1. OpenShift 简介 OpenShift是一个开源项目,基于主流的容器技术Docker及容器编排引擎Kubernetes构建。可以基于OpenShift构建属于自己的容器云平台。OpenShift的开源社区版本叫OpenShift Origin,现在叫OKD。 OpenShift 项目主页:https://www.okd.io/。OpenShift GitHub仓库…

Ubuntu服务器性能调优指南:从基础工具到系统稳定性提升

一、性能监控工具的三维应用 1.1 监控矩阵构建 通过组合工具搭建立体监控体系&#xff1a; # 实时进程监控 htop --sort-keyPERCENT_CPU# 存储性能采集 iostat -dx 2# 内存分析组合拳 vmstat -SM 1 | awk NR>2 {print "Active:"$5"MB Swpd:"$3"…

计算机视觉——基于MediaPipe实现人体姿态估计与不良动作检测

概述 正确的身体姿势是个人整体健康的关键。然而&#xff0c;保持正确的身体姿势可能会很困难&#xff0c;因为我们常常会忘记。本博客文章将逐步指导您构建一个解决方案。最近&#xff0c;我们使用 MediaPipe POSE 进行身体姿势检测&#xff0c;效果非常好&#xff01; 一、…

LSTM结合LightGBM高纬时序预测

1. LSTM 时间序列预测 LSTM 是 RNN&#xff08;Recurrent Neural Network&#xff09;的一种变体&#xff0c;它解决了普通 RNN 训练时的梯度消失和梯度爆炸问题&#xff0c;适用于长期依赖的时间序列建模。 LSTM 结构 LSTM 由 输入门&#xff08;Input Gate&#xff09;、遗…

六、adb通过Wifi连接

背景 收集是荣耀X40,数据线原装全新的&#xff0c;USB连上之后&#xff0c;老是断&#xff0c;电脑一直叮咚叮咚的响个不停&#xff0c;试试WIFI 连接是否稳定&#xff0c;需要手机和电脑用相同的WIFI. 连接 1.通过 USB 连接手机和电脑(打开USB调试等这些都略过) adb device…

如何理解前端开发中的“换皮“

"换皮"在前端开发中是一个常见的术语&#xff0c;通常指的是在不改变网站或应用核心功能和结构的情况下&#xff0c;只改变其外观和视觉表现。以下是关于前端"换皮"的详细理解&#xff1a; 基本概念 定义&#xff1a;换皮(Skinning)是指保持应用程序功能不…

从 Vue 到 React:深入理解 useState 的异步更新

目录 从 Vue 到 React&#xff1a;深入理解 useState 的异步更新与函数式写法1. Vue 的响应式回顾&#xff1a;每次赋值立即生效2. React 的状态更新是异步且批量的原因解析 3. 函数式更新&#xff1a;唯一的正确写法4. 对比 Vue vs React 状态更新5. React useState 的核心源码…