如何使用C#实现Padim算法的训练和推理

目录

说明

项目背景

算法实现

预处理模块——图像预处理

主要模块——训练:Resnet层信息提取

主要模块——信息处理,计算Anomaly Map

主要模块——评估

主要模块——评估:门限值的确定

主要模块——推理

写在最后

项目下载链接


说明

作者:来瓶霸王防脱发

项目地址:

https://github.com/IntptrMax/PadimSharp

原文地址:

https://blog.csdn.net/qq_30270773/article/details/143029865

项目背景

缺陷检测(Anomaly Detection)算法是一个区分正常类别与异常类别的二分类问题,但在工业场景中大多数数据都为良品,不良数据难以获取,更难枚举,所以训练一个全监督的模型是不切实际的。因此,异常检测模型通常以单类别学习的模式。Padim算法是一种十分优秀的缺陷检测算法,直接上图可以看一下这个算法的效果。

良品图片

图片

不良品图片

图片

检测效果

图片

C#是一种十分受欢迎的编程语言,这种编程语言在工业场景下使用也是十分广泛的。在一些AI领域,会在Python下将模型转化为onnx形式,通过onnxruntime加载使用,进行推理。但是在onnx形式下进行训练十分困难。很多C#开发者不太熟悉Python环境,或者某些条件下希望在纯粹的C#环境下进行深度学习的训练和使用。这个还是有一定的困难的。

目前搜索了Github和CSDN排名靠前的几十条数据,还没有Padim算法在除Python平台下的训练+推理的相关项目或资料。本文就是在C#平台实现了Padim的训练+推理过程,应该在相关领域也算是独一份了。

算法实现

Padim算法的“训练”过程其实并没有涉及到真正的训练,而是使用Resnet18算法提取关键信息加以处理,在推理时再次使用,因此“训练”过程速度非常快,这也是这个算法的优点之一。Padim算法的具体实现还请参考相关论文:PaDiM: a Patch Distribution Modeling Framework for Anomaly Detection and Localization

https://arxiv.org/abs/2011.08785

如果论文看起来困难,还有一些大佬对该算法在Python平台下的解读,也可以参考:PaDiM 原理与代码解析

https://blog.csdn.net/ooooocj/article/details/127601035

预处理模块——图像预处理

图像预处理使用的方法比较常规,使用了缩放等方式,此处并没有使用LetterBox,也可以达到预期效果:

var transformers = torchvision.transforms.Compose([torchvision.transforms.Resize(resizeHeight,resizeWidth),
torchvision.transforms.CenterCrop(cropHeight,cropWidth),
torchvision.transforms.Normalize(means, stdevs)]);

主要模块——训练:Resnet层信息提取

使用Resnet模型进行推理,并提取Layer1、Layer2、Layer3层的信息,并进行了拼接(EmbeddingConcat)。注意:这里提取时使用了钩子,钩子在使用时会有资源释放,因此这里使用了比较迂回的方式记录结果

实现代码如下:

public List<(string, Tensor)> Forward(Tensor input)
{List<(string, Tensor)> outputs = new List<(string, Tensor)>();List<TempTensor> tempTensors = new List<TempTensor>();foreach (var named_module in model.named_children()){string name = named_module.name;if (name == "layer1" || name == "layer2" || name == "layer3"){((Sequential)named_module.module).register_forward_hook((Module, input, output) =>{tempTensors.Add(new TempTensor{Data = output.data<float>().ToArray(),Name = name,Shape = output.shape,});return null;});}}model.forward(input);var layer1output = tempTensors.Find(a => a.Name == "layer1");var layer2output = tempTensors.Find(a => a.Name == "layer2");var layer3output = tempTensors.Find(a => a.Name == "layer3");Tensor l1 = torch.tensor(layer1output.Data, layer1output.Shape, device: input.device);Tensor l2 = torch.tensor(layer2output.Data, layer2output.Shape, device: input.device);Tensor l3 = torch.tensor(layer3output.Data, layer3output.Shape, device: input.device);outputs.Add(new("layer1", l1));outputs.Add(new("layer2", l2));outputs.Add(new("layer3", l3));GC.Collect();return outputs;
}
private Tensor EmbeddingConcat(Tensor[] features)
{var embeddings = features[0];for (int i = 1; i < features.Length; i++){var layerEmbedding = features[i];layerEmbedding = torch.nn.functional.interpolate(layerEmbedding, size: [embeddings.shape[2], embeddings.shape[2]], mode: InterpolationMode.Nearest);embeddings = torch.cat([embeddings, layerEmbedding], 1);}return embeddings;
}

主要模块——信息处理,计算Anomaly Map

这一块主要对信息进行处理,获取矩阵的mean和cov(协方差矩阵),代码如下:

public Tensor ComputeAnomalyMapInternal(Tensor embedding, Tensor mean, Tensor covariance)
{var scoreMap = ComputeDistance(embedding, mean, covariance);var upSampledScoreMap = UpSample(scoreMap);var smoothedAnomalyMap = SmoothAnomalyMap(upSampledScoreMap);return smoothedAnomalyMap;
}public Tensor ComputeAnomalyMap(List<(string, Tensor)> outputs, Tensor mean, Tensor covariance, Tensor idx)
{Tensor embedding = GetEmbedding(outputs);var embeddingVectors = torch.index_select(embedding, 1, idx);return ComputeAnomalyMapInternal(embeddingVectors, mean, covariance);
}

主要模块——评估

与训练过程开始部分相似,也是获取图像的Embeddings,然后利用之前获取的Cov和mean计算马氏距离,以此评估图像的异常情况。马氏距离的计算方法如下:

private Tensor ComputeDistance(Tensor embedding, Tensor mean, Tensor covariance)
{long batch = embedding.shape[0];long channel = embedding.shape[1];long height = embedding.shape[2];long width = embedding.shape[3];Tensor inv_covariance = covariance.permute(2, 0, 1).inverse();var embedding_reshaped = embedding.reshape(batch, channel, height * width);var delta = (embedding_reshaped - mean).permute(2, 0, 1);var distances = (torch.matmul(delta, inv_covariance) * delta).sum(2).permute(1, 0);distances = distances.reshape(batch, 1, height, width);distances = distances.clamp(0).sqrt();return distances;
}

主要模块——评估:门限值的确定

这里需要确定图像的评估门限和像素值的评估门限。如果在评估时有负向样本,这个值会更准确,如果只有正向样本也是可以的。在Python下有个precision_recall_curve包,可以计算相关参数,但是在C#下时没有的,因此在此处仍旧只能造轮子,代码如下:

private (float[] precisions, float[] recalls, float[] thresholds) _precision_recall_curve_compute_single_class(Tensor yTrue, Tensor yScores, int pos_label = 1)
{var (fps, tps, thresholds) = BinaryClfCurve(yScores, yTrue, pos_label);var precision = tps / (tps + fps);var recall = tps / tps[-1];var lastInd = torch.where(tps == tps[-1])[0][0].ToInt32();int[] sl = new int[lastInd + 1];for (int i = 0; i < sl.Length; i++){sl[i] = i;}var reversedPrecision = precision[sl].flip(0);var reversedRecall = recall[sl].flip(0);var reversedThresholds = thresholds[sl].flip(0);precision = torch.cat(new Tensor[] { reversedPrecision, torch.ones(1, dtype: precision.dtype, device: precision.device) });recall = torch.cat(new Tensor[] { reversedRecall, torch.zeros(1, dtype: recall.dtype, device: recall.device) });return (precision.data<float>().ToArray(), recall.data<float>().ToArray(), reversedThresholds.data<float>().ToArray());
}private (Tensor fps, Tensor tps, Tensor thresholds) BinaryClfCurve(Tensor preds, Tensor target, int posLabel = 1)
{using (torch.no_grad()){if (preds.ndim > target.ndim){preds = preds[TensorIndex.Ellipsis, 0];}var descScoreIndices = torch.argsort(preds, descending: true);preds = preds[descScoreIndices];target = target[descScoreIndices];Tensor weight = torch.tensor(1.0f);var distinctValueIndices = torch.nonzero(preds[1..] - preds[..^1]).squeeze();var thresholdIdxs = torch.cat(new Tensor[] { distinctValueIndices, torch.tensor(new long[] { target.shape[0] - 1 }, device: preds.device) });target = (target == posLabel).to_type(ScalarType.Int64);var tps = torch.cumsum(target * weight, dim: 0)[thresholdIdxs];Tensor fps = 1 + thresholdIdxs - tps;return (fps, tps, preds[thresholdIdxs]);}
}

主要模块——推理

这个过程与上面过程也十分相似,正向计算出图像的Anomaly Map后,取出这个张量中最大的值,与图像的门限值进行比较,即可评估图像是否是良品。然后对这个张量中每个元素与像素门限值做对比,即可得到按像素的异常区域,以便绘制Mask和热力图。

Tensor orgImg = tensors["orgImage"].clone().to(device);
Tensor t = anomaly_map > pixel_threshold;
anomaly_map = (anomaly_map * t).squeeze(0);
anomaly_map = torchvision.transforms.functional.resize(anomaly_map, (int)orgImg.size(2), (int)orgImg.size(1));
Tensor heatmapNormalized = (anomaly_map - anomaly_map.min()) / (anomaly_map.max() - anomaly_map.min());
Tensor coloredHeatmap = torch.zeros([3, (int)orgImg.size(2), (int)orgImg.size(1)],device:anomaly_map.device);coloredHeatmap[0] = heatmapNormalized.squeeze(0);float alpha = 0.3f;
Tensor blendedImage = (1 - alpha) * (orgImg / 255.0f) + alpha * coloredHeatmap;
var imageTensor = blendedImage.clamp(0, 1).mul(255).to(ScalarType.Byte);torchvision.io.write_jpeg(imageTensor.cpu(), "result.jpg");

写在最后

使用C#开发深度学习项目,尤其是训练的项目,是一个十分困难的过程。或者说除了Python平台,训练都十分困难。C#进行深度学习训练这个方向在国内基本很少有人开展,所以能查得到的资料很少。本人十分喜爱C#这门语言,又十分喜爱深度学习,因此仅半年一直在这方面努力。遇到了很多困难,也收获了很多。

这条路走的不容易,希望能有更多人能加入进来,一起开发,一起学习。

我在Github上已经将完整的代码发布了,项目地址为:

https://github.com/IntptrMax/PadimSharp

,期待你能在Github上送我一颗小星星。在我的Github里还GGMLSharp这个项目,这个项目也是C#平台下深度学习的开发包,希望能得到你的支持。

项目下载链接

https://download.csdn.net/download/qq_30270773/89897710

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/55508.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

进入 Searing-66 火焰星球:第一周游戏指南

Alpha 第四季已开启&#xff0c;穿越火焰星球 Searing-66&#xff0c;带你开启火热征程。准备好勇闯炙热的沙漠&#xff0c;那里有无情的高温和无情的挑战在等待着你。从高风险的烹饪对决到炙热的冒险&#xff0c;Searing-66 将把你的耐力推向极限。带上充足的水&#xff0c;天…

Java线程的状态以及转换条件,与操作系统线程状态的区别?

先看图增加点记忆。 Java线程状态&#xff1a; 线程状态转换图&#xff1a; 背景知识 JAVA的线程模型与操作系统线程的对应关系是1:1的&#xff0c;线程的调度权是由操作系统控制的。 为什么java的线程状态与操作系统不一致&#xff1f; JVM线程状态&#xff1a;RUNNAB…

【含开题报告+文档+PPT+源码】基于SSM的景行天下旅游网站的设计与实现

开题报告 随着互联网的快速发展&#xff0c;旅游业也逐渐进入了数字化时代。作为一个旅游目的地&#xff0c;云浮市意识到了互联网在促进旅游业发展方面的巨大潜力。为了更好地推广云浮的旅游资源&#xff0c;提高旅游服务质量&#xff0c;云浮市决定开发一个专门的旅游网站。…

【红日安全】vulnstack (一)

&#x1f3d8;️个人主页&#xff1a; 点燃银河尽头的篝火(●’◡’●) 如果文章有帮到你的话记得点赞&#x1f44d;收藏&#x1f497;支持一下哦 【红日安全】vulnstack &#xff08;一&#xff09; 靶场搭建靶场渗透明确目标信息收集phpadmin后台getshell 靶场搭建 靶场下载…

IP报文格式、IPv6概述

IPv4报文格式 IPv4报文首部长度至少为20字节(没有可选字段和填充的情况下)&#xff0c;下面来逐一介绍首部各个字段的含义 Version版本&#xff1a;表示采用哪一种具体的IP协议&#xff0c;对于IPv4来说该字段就填充4以表示&#xff0c;如果是IPv6就填充6IHL首部长度&#xff…

网络资源模板--Android Studio 实现简易计算器App

目录 一、项目演示 二、项目测试环境 三、项目详情 四、完整的项目源码 一、项目演示 网络资源模板--基于Android studio 实现的简易计算器 二、项目测试环境 三、项目详情 动态绑定按钮&#xff1a; 使用循环遍历 buttons 数组&#xff0c;根据动态生成的按钮 ID (btn_0, …

SQL进阶技巧:如何删除第N次连续出现NULL值所存在的行?

目录 0 场景描述 1 数据准备 2 问题分析 问题拓展:如何删除第2次、第3次、第N次连续出现NULL值所在的行? 3 小结 0 场景描述 有下面的场景: 我们希望删除某id中连续存在NULL值的所有行,但是保留第一次出现不为NULL值的以下所有存在NULL值的行。具体如下图所示: 如…

iframe的使用详解

目录 一、基本概念和语法 二、优点 1.内容整合与复用&#xff1a; 2.独立的浏览环境&#xff1a; 3.跨域数据展示&#xff1a; 三、缺点 1.可访问性问题&#xff1a; 2.性能问题&#xff1a; 3.安全风险&#xff1a; 四、替代方案 1.使用JavaScript框架进行组件化开…

Unity开发Hololens项目

Unity打包Hololens设备 目录Visual Studio2019 / Visual Studio2022 远端部署设置Visual Studio2019 / Visual Studio2022 USB部署设置Hololens设备如何查找自身IPHololens设备门户Unity工程内的打包设置 目录 记录下自己做MR相关&#xff1a;Unity和HoloLens设备的历程。 Vi…

大规模多传感器滑坡检测数据集,利用landsat,哨兵2,planet,无人机图像等多种传感器采集数据共2w余副图像,mask准确标注滑坡位置

大规模多传感器滑坡检测数据集&#xff0c;利用landsat&#xff0c;哨兵2&#xff0c;planet&#xff0c;无人机图像等多种传感器采集数据共2w余副图像&#xff0c;mask准确标注滑坡位置 大规模多传感器滑坡检测数据集介绍 数据集概述 名称&#xff1a;大规模多传感器滑坡检测…

Python | Leetcode Python题解之第491题非递减子序列

题目&#xff1a; 题解&#xff1a; class Solution:def findSubsequences(self, nums: List[int]) -> List[List[int]]:def dfs(i, tmp):if i len(nums):if len(tmp) > 2:res.append(tmp[:]) # 拷贝&#xff0c;tmp[:]而非tmpreturn# 选 nums[i]if not tmp or nu…

2d 数字人实时语音聊天对话使用案例;支持asr、llm、tts实时语音交互

参考: https://github.com/lyz1810/live2dSpeek 下载live2dSpeek项目 ## 下载live2dSpeek git clone https://github.com/lyz1810/live2dSpeek cd live2dSpeek-main ## 运行live2dSpeek npm install -g http-server http-server .更改新的index.html页面 index.html

【SQL Server】数据库在新建查询后闪退——解决方案:以管理员的身份运行

我的SQLServer2022之前都是可以用的&#xff0c;隔了好久没有使用&#xff0c;今天要用到去写一些SQL 语句 结果在点击新建查询后闪退了&#xff0c; 经过查询后&#xff0c;解决方案&#xff1a; 以管理员的身份运行后点击新建查询&#xff0c;发现正常了 总结&#xff1a;以…

记一次库版本升级引起程序自动停止

记一次库版本升级引起程序自动停止 最近我们的应用升级了jedis 版本,版本从 2.10.2 升级 到3.8.0。发现我们的任务应用启动后立马自动关闭了。 这就奇怪了&#xff0c;为什么升级个版本&#xff0c;会导致程序启动后自动关闭呢。带着这个疑问我们看下代码。 表现如下&#x…

C语言_指针_进阶

引言&#xff1a;在前面的c语言_指针初阶上&#xff0c;我们了解了简单的指针类型以及使用&#xff0c;下面我们将进入更深层次的指针学习&#xff0c;对指针的理解会有一个极大的提升。从此以后&#xff0c;指针将不再是难点&#xff0c;而是学习底层语言的一把利器。 本章重点…

vr体验馆计时收银软件试用版下载 佳易王VR游戏厅计时计费管理系统使用操作教程

一、前言 【软件试用版资源文件下载可以点击文章最后卡片了解】 vr体验馆计时收银软件试用版下载 佳易王VR游戏厅计时计费管理系统使用操作教程 VR体验馆计时计费软件是专门为VR体验馆设计的管理工具&#xff0c;旨在提高服务效率和客户的满意度。软件能够记录客户使用设备的…

vue组件调用生命周期

《vue基础学习-组件》提到组件传递数据方式&#xff1a; 1. props/$emit 父传子&#xff1a;子组件通过 props 显式声明 自定义 属性&#xff0c;接收父组件的传值。子传父&#xff1a;子组件通过 $emit() 显式声明 自定义 事件&#xff0c;父组件调用自定义事件接收子组件返…

Docker-compose提示specified IP address..configured subnets问题以及Docker容器相关操作记录保存

一、Docker-compose提示user specified IP address is supported only when connecting to networks with user configured subnets 在网上下载的一些docker-compose.yml在执行的时碰到过多次如下报错&#xff1a; ERROR: for 5307e2acb....user specified IP address is supp…

2024.10.17 软考学习笔记

刷题网站&#xff1a; 软考中级软件设计师在线试题、软考解析及答案-51CTO题库-软考在线做题备考工具

vue2项目 实现上边两个下拉框,下边一个输入框 输入框内显示的值为[“第一个下拉框选中值“ -- “第二个下拉框选中的值“]

效果: 思路: 采用vue中 [computed:] 派生属性的方式实现联动效果,上边两个切换时,下边的跟随变动 demo代码: <template><div><!-- 第一个下拉框 --><select v-model"firstValue"><option v-for"option in options" :key&q…