[CVPR 2023]PyramidFlow-训练并推理-附bug调试

CVPR2023-PyramidFlow-zero shot异常检测网络 代码调试记录

  • 一.论文以及开源代码
  • 二.前期代码准备
  • 三.环境配置
  • 四.bug调试
    • num_samples should be a positive integer value, but got num_samples=0
    • AttributeError: Can't pickle local object 'fix_randseed.<locals>.seed_worker'
  • 五.数据集准备
  • 六.训练
  • 七.推理

一.论文以及开源代码

PyramidFlow一篇2023年发表于CVPR的关于无监督异常检测算法的论文,由浙江大学出品,下面附上论文和代码链接:
论文链接:PyramidFlow论文
代码链接:PyramidFlow源代码

二.前期代码准备

首先,我们需要把我在一中提到的代码先git clone到我们的项目路径中,这是我们接下去的训练代码,当然其中也包括了验证和测试(推理过程也包含在内部了,需要自己写一小部分)。然后我们还需要去作者的官网git clone一份名为autoFlow的项目代码,这里面包含了训练代码中将会调用的一些函数,十分重要:
进入训练代码的链接后,点击作者头像,如图所示
进入训练代码的链接后,点击作者头像,如图所示。然后我们便进入了作者的github主页,点击主页下方的这个链接:
在这里插入图片描述
就可以跳转到这个页面:
在这里插入图片描述
红框中包含了两个链接,其中一个是我们在第一步就已经clone好的训练代码,不用管他了,现在我们点击蓝色框中的链接:
在这里插入图片描述
点击code然后复制链接,然后打开git工具使用git clone命令行即可:
在这里插入图片描述
此时,两个项目都已经拷贝下来了,我这里选择将autoflow这个文件夹直接复制到了PyramidFlow里面,这样方便PyramidFlow中代码的调用:
在这里插入图片描述

三.环境配置

PyramidFlow的环境,作者已经在Readme中给出,按照里面的版本pip install即可,如果下载速度过慢,可以设置默认源为清华源,可以大大方便我们配置环境。

四.bug调试

这里默认大家的环境已经按照要求配置好了。

num_samples should be a positive integer value, but got num_samples=0

这时候我们直接运行PyramidFlow中的train.py时一般会报这个错误:
在这里插入图片描述
这个问题的原因主要是torch库中的DataLoader函数加载数据时,如果已经设置了batch_size,就不需要设置shuffle=True来打乱数据了,此时需要把shuffle设置为False,DataLoader的具体参数可以参考如下:

DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=1, persistent_workers=True, pin_memory=True, drop_last=True, **loader_dict)

参数解释:
dataset:包含所有数据的数据集,加载的数据集(Dataset对象)

batch_size :每个batch包含的数据数量

Shuffle : 是否打乱数据位置。

sampler : 自定义从数据集中采样的策略,如果制定了采样策略,shuffle则必须为False.

Batch_sampler:和sampler一样,但是每次返回一组的索引,和batch_size, shuffle, sampler, drop_last 互斥。

num_workers : 使用线程的数量,当为0时数据直接加载到主程序,默认为0。

collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认的拼接方式即可

pin_memory:s 是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些

drop_last: dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃

AttributeError: Can’t pickle local object ‘fix_randseed..seed_worker’

打开util.py:
在这里插入图片描述
我将这段代码中的seed_worker直接独立出来:
在这里插入图片描述
然后在train.py代码中,我们需要创建一个变量接收fix_randseed的返回值,然后将这个返回值作为seed_worker的参数传入:
在这里插入图片描述

五.数据集准备

用到的是Mvtec数据集,放在项目文件夹的同一级路径下,改名为如下所示:
在这里插入图片描述

六.训练

作者在源代码中是在训练代码的最后一npz的形式保存了模型的权重,由于我对这个npz了解甚少,并且我平时推理常用的都是pt,onnx或者tensorRT的engine等,因此,我在训练代码的最后加了一句torch.save()来将模型以pt的方式保存,见110行代码:

import torch
import torch.nn as nn
from torch.utils.data import DataLoaderimport numpy as np
import time, argparse
from sklearn.metrics import roc_auc_scorefrom model import PyramidFlow
from util import MVTecAD, BatchDiffLoss
from util import fix_randseed, compute_pro_score_fast, getLogger, seed_worker
import cv2def train(logger, save_name, cls_name, datapath, resnetX, num_layer, vn_dims, \ksize, channel, num_stack, device, batch_size, save_memory, ):# save configsave_dict = {'cls_name': cls_name, 'resnetX': resnetX, 'num_layer': num_layer, 'vn_dims': vn_dims,\'ksize': ksize, 'channel': channel, 'num_stack': num_stack, 'batch_size': batch_size}#我的改动seed_ = fix_randseed(seed=0)loader_dict = seed_worker(seed_)# model flow = PyramidFlow(resnetX, channel, num_layer, numStack, ksize, vn_dims, saveMem).to(device)x_size = 256 if resnetX==0 else 1024optimizer = torch.optim.Adam(flow.parameters(), lr=2e-4, eps=1e-04, weight_decay=1e-5, betas=(0.5, 0.9)) # using cs-flow optimizerLoss = BatchDiffLoss(batch_size, p=2)# datasettrain_dataset = MVTecAD(cls_name, mode='train', x_size=x_size, y_size=256, datapath=datapath)val_dataset = MVTecAD(cls_name, mode='val', x_size=x_size, y_size=256, datapath=datapath)test_dataset = MVTecAD(cls_name, mode='test', x_size=x_size, y_size=256, datapath=datapath)train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, persistent_workers=True, pin_memory=True, drop_last=True, **loader_dict)val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=1, persistent_workers=True, pin_memory=True, drop_last=False, **loader_dict)test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=1, persistent_workers=True, pin_memory=True, **loader_dict)# training & evaluationpixel_auroc_lst = [0]pixel_pro_lst = [0]image_auroc_lst = [0]losses_lst = [0]t0 = time.time()for epoch in range(15):# trainflow.train()losses = []for train_dict in train_loader:image, labels = train_dict['images'].to(device), train_dict['labels'].to(device)optimizer.zero_grad()pyramid2= flow(image)diffes = Loss(pyramid2)diff_pixel = flow.pyramid.compose_pyramid(diffes).mean(1)  loss = torch.fft.fft2(diff_pixel).abs().mean() # Fourier lossloss.backward()nn.utils.clip_grad_norm_(flow.parameters(), max_norm=1e0) # Avoiding numerical explosionsoptimizer.step()losses.append(loss.item())mean_loss = np.mean(losses)logger.info(f'Epoch: {epoch}, mean_loss: {mean_loss:.4f}, time: {time.time()-t0:.1f}s')losses_lst.append(mean_loss)# val for template flow.eval()feat_sum, cnt = [0 for _ in range(num_layer)], 0for val_dict in val_loader:image = val_dict['images'].to(device)with torch.no_grad():pyramid2= flow(image) cnt += 1feat_sum = [p0+p for p0, p in zip(feat_sum, pyramid2)]feat_mean = [p/cnt for p in feat_sum]# testflow.eval()diff_list, labels_list = [], []for test_dict in test_loader:image, labels = test_dict['images'].to(device), test_dict['labels']with torch.no_grad():pyramid2 = flow(image) pyramid_diff = [(feat2 - template).abs() for feat2, template in zip(pyramid2, feat_mean)]diff = flow.pyramid.compose_pyramid(pyramid_diff).mean(1, keepdim=True)# b,1,h,wdiff_list.append(diff.cpu())labels_list.append(labels.cpu()==1)# b,1,h,wlabels_all = torch.concat(labels_list, dim=0)# b1hw amaps = torch.concat(diff_list, dim=0)# b1hw amaps, labels_all = amaps[:, 0], labels_all[:, 0] # both b,h,wpixel_auroc = roc_auc_score(labels_all.flatten(), amaps.flatten()) # pixel scoreimage_auroc = roc_auc_score(labels_all.amax((-1,-2)), amaps.amax((-1,-2))) # image scorepixel_pro = compute_pro_score_fast(amaps, labels_all) # pro scorelogger.info(f'   TEST Pixel-AUROC: {pixel_auroc}, time: {time.time()-t0:.1f}s')logger.info(f'   TEST Image-AUROC: {image_auroc}, time: {time.time()-t0:.1f}s')logger.info(f'   TEST Pixel-PRO: {pixel_pro}, time: {time.time()-t0:.1f}s')if pixel_auroc > np.max(pixel_auroc_lst):save_dict['state_dict_pixel'] = {k: v.cpu() for k, v in flow.state_dict().items()} # save ckptif pixel_pro > np.max(pixel_pro_lst):save_dict['state_dict_pro'] = {k: v.cpu() for k, v in flow.state_dict().items()} # save ckptpixel_auroc_lst.append(pixel_auroc)pixel_pro_lst.append(pixel_pro)image_auroc_lst.append(image_auroc)del amaps, labels_all, diff_list, labels_listsave_dict['pixel_auroc_lst'] = pixel_auroc_lstsave_dict['image_auroc_lst'] = image_auroc_lstsave_dict['pixel_pro_lst']   = pixel_pro_lstsave_dict['losses_lst'] = losses_lsttorch.save(flow, "best.pt")np.savez(f'saveDir/{save_name}.npz', **save_dict) # save allif __name__ == '__main__':parser = argparse.ArgumentParser(description='Training on MVTecAD')parser.add_argument('--cls', type=str, default='bottle', choices=\['tile', 'leather', 'hazelnut', 'toothbrush', 'wood', 'bottle', 'cable', \'capsule', 'pill', 'transistor', 'carpet', 'zipper', 'grid', 'screw', 'metal_nut'])parser.add_argument('--datapath', type=str, default='../mvtec_anomaly_detection')# hyper-parameters of architectureparser.add_argument('--encoder', type=str, default='resnet18', choices=['none', 'resnet18', 'resnet34'])parser.add_argument('--numLayer', type=str, default='auto', choices=['auto', '2', '4', '8'])parser.add_argument('--volumeNorm', type=str, default='auto', choices=['auto', 'CVN', 'SVN'])# non-key parameters of architectureparser.add_argument('--kernelSize', type=int, default=7, choices=[3, 5, 7, 9, 11])parser.add_argument('--numChannel', type=int, default=16)parser.add_argument('--numStack', type=int, default=4)# other parametersparser.add_argument('--gpu', type=int, default=0)parser.add_argument('--batchSize', type=int, default=2)parser.add_argument('--saveMemory', type=bool, default=True) args = parser.parse_args()cls_name = args.clsresnetX = 0 if args.encoder=='none' else int(args.encoder[6:])if args.volumeNorm == 'auto':vn_dims = (0, 2, 3) if cls_name in ['carpet', 'grid', 'bottle', 'transistor'] else (0, 1)elif args.volumeNorm == 'CVN':vn_dims = (0, 1)elif args.volumeNorm == 'SVN':vn_dims = (0, 2, 3)if args.numLayer == 'auto':num_layer = 4if cls_name in ['metal_nut', 'carpet', 'transistor']:num_layer = 8elif cls_name in ['screw',]:num_layer = 2else:num_layer = int(args.numLayer)ksize = args.kernelSizenumChannel = args.numChannelnumStack = args.numStackgpu_id = args.gpubatchSize = args.batchSizesaveMem = args.saveMemorydatapath = args.datapathlogger, save_name = getLogger(f'./saveDir')logger.info(f'========== Config ==========')logger.info(f'> Class: {cls_name}')logger.info(f'> MVTecAD dataset root: {datapath}')logger.info(f'> Encoder: {args.encoder}')logger.info(f"> Volume Normalization: {'CVN' if len(vn_dims)==2 else 'SVN'}")logger.info(f'> Num of Pyramid Layer: {num_layer}')logger.info(f'> Conv Kernel Size in NF: {ksize}')logger.info(f'> Num of Channels in NF: {numChannel}')logger.info(f'> Num of Stack Block: {numStack}')logger.info(f'> Batch Size: {batchSize}')logger.info(f'> GPU device: cuda:{gpu_id}')logger.info(f'> Save Training Memory: {saveMem}')logger.info(f'============================')train(logger, save_name, cls_name, datapath,\resnetX, num_layer, vn_dims, \ksize=ksize, channel=numChannel, num_stack=numStack, \device=f'cuda:{gpu_id}', batch_size=batchSize, save_memory=saveMem)

到这里,训练完之后,我们就得到了模型的权重pt文件,为我们后面的推理做准备。

七.推理

未完待续

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/56571.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

springboot使用properties

一、方式1&#xff1a; 1.1.配置类&#xff1a; package cn.zyq.stater.config;import cn.zyq.stater.bean.User4; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework…

解决:Appium Inspector刷新页面一直加载转圈

目录 问题&#xff1a;Appium Inspector刷新页面一直加载转圈 解决办法&#xff1a; 1.进入设置页面-电池-后台耗电管理 2.找到下面3个应用&#xff0c;修改为允许后台高耗电 问题&#xff1a;Appium Inspector刷新页面一直加载转圈 1、手机进行操作后&#xff0c;Appium I…

Go 语言的实战案例 SOCKS5 代理 | 青训营

Powered by:NEFU AB-IN 文章目录 Go 语言的实战案例 SOCKS5 代理 | 青训营 引入TCP echo serverauth 认证请求阶段relay阶段 Go 语言的实战案例 SOCKS5 代理 | 青训营 GO语言工程实践课后作业&#xff1a;实现思路、代码以及路径记录 引入 代理是指在计算机网络中&#xff…

Cpp学习——编译链接

目录 ​编辑 一&#xff0c;两种环境 二&#xff0c;编译环境下四个部分的 1.预处理 2.编译 3.汇编 4.链接 三&#xff0c;执行环境 一&#xff0c;两种环境 在程序运行时会有两种环境。第一种便是编译环境&#xff0c;第二种则是执行环境。如下图&#xff1a; 在程序运…

5G NR:协议 - PDCCH信道

1、基本概念 不同于LTE中的控制信道包括PCFICH、PHICH和PDCCH&#xff0c;在5G NR中&#xff0c;控制信道仅包括PDCCH&#xff08;Physical Downlink Control Channel&#xff09;&#xff0c;负责物理层各种关键控制信息的传递&#xff0c;PDCCH中传递的下行控制信息&#xff…

【LeetCode】面试题总结 消失的数字 最小k个数

1.消失的数字 两种思路 1.先升序排序&#xff0c;再遍历并且让后一项与前一项比较 2.转化为数学问题求等差数列前n项和 &#xff08;n的大小为数组的长度&#xff09;&#xff0c;将根据公式求得的应有的和数与数组中实际的和作差 import java.util.*; class Solution {public …

代码随想录算法训练营第四十六天 | 139.单词拆分

代码随想录算法训练营第四十六天 | 139.单词拆分 139.单词拆分 139.单词拆分 题目链接 视频讲解 给你一个字符串 s 和一个字符串列表 wordDict 作为字典。请你判断是否可以利用字典中出现的单词拼接出 s 注意&#xff1a;不要求字典中出现的单词全部都使用&#xff0c;并且字典…

【LeetCode】227. 基本计算器 II

227. 基本计算器 II&#xff08;中等&#xff09; 方法&#xff1a;双栈解法 思路 我们可以使用两个栈 nums 和 ops 。 nums &#xff1a; 存放所有的数字ops &#xff1a;存放所有的数字以外的操作 然后从前往后做&#xff0c;对遍历到的字符做分情况讨论&#xff1a; 空格 …

安全测试-django防御安全策略

django安全性 django针对安全方面有一些处理&#xff0c;学习如何进行处理设置&#xff0c;也有利于学习安全测试知识。 CSRF 跨站点请求伪造&#xff08;Cross-Site Request Forgery&#xff0c;CSRF&#xff09;是一种网络攻击方式&#xff0c;攻击者欺骗用户在自己访问的网…

Python 包管理(pip、conda)基本使用指南

Python 包管理 概述 介绍 Python 有丰富的开源的第三方库和包&#xff0c;可以帮助完成各种任务&#xff0c;扩展 Python 的功能&#xff0c;例如 NumPy 用于科学计算&#xff0c;Pandas 用于数据处理&#xff0c;Matplotlib 用于绘图等。在开始编写 Pytlhon 程序之前&#…

【力扣】2813 子序列最大优雅度

class Solution//诡异的数据结构维护反悔贪心 { public:long long findMaximumElegance(vector<vector<int>>& items, int k){sort(items.begin(), items.end(), [](const auto &a, const auto &b){return a[0] > b[0];});//奇妙的排序方法long lon…

K8S最新版本集群部署(v1.28) + 容器引擎Docker部署(上)

温故知新 &#x1f4da;第一章 前言&#x1f4d7;背景&#x1f4d7;目的&#x1f4d7;总体方向 &#x1f4da;第二章 基本环境信息&#x1f4d7;机器信息&#x1f4d7;软件信息&#x1f4d7;部署用户kubernetes &#x1f4da;第三章 Kubernetes各组件部署&#x1f4d7;安装kube…

Linux(实操篇一)

Linux实操篇 Linux(实操篇一)1. 常用基本命令1.1 帮助命令1.1.1 man获得帮助信息1.1.2 help获得shell内置命令的帮助信息1.1.3 常用快捷键 1.2 文件目录类1.2.1 pwd显示当前 工作目录的绝对路径1.2.2 ls列出目录的内容1.2.3 cd切换目录1.2.4 mkdir创建一个新的目录1.2.5 rmdir删…

Linux环境搭建SVN服务器并实现公网访问 - cpolar端口映射

文章目录 前言1. Ubuntu安装SVN服务2. 修改配置文件2.1 修改svnserve.conf文件2.2 修改passwd文件2.3 修改authz文件 3. 启动svn服务4. 内网穿透4.1 安装cpolar内网穿透4.2 创建隧道映射本地端口 5. 测试公网访问6. 配置固定公网TCP端口地址6.1 保留一个固定的公网TCP端口地址6…

【HashMap】key和value能否为null

【HashMap】key和value能否为null 【一】HashMap【二】HashTable【三】ConcurrentHashMap【四】测试代码【五】底层代码分析 【一】HashMap &#xff08;1&#xff09;结论&#xff1a;HashMap对象的key、value值均可为null HashMap 的 key 和 value 都可以为 null 值。在 Jav…

Ubuntu20.04下安装搜狗输入法Linux版

Ubuntu20.04下安装搜狗输入法Linux版 参考搜狗输入法的官网安装指南&#xff1b; 第一步&#xff1a;打开搜狗输入法官网&#xff1b; https://shurufa.sogou.com/ 点击X86_64后将会自动跳转到搜狗输入法的安装指南中&#xff1b; 安装指南 Ubuntu搜狗输入法安装指南 搜狗…

Linux的Man Page知识记录

Man&#xff08;short for manual&#xff09; Page是Unix和Linux操作系统中的一个重要文档&#xff0c;提供命令、函数、系统调用等的详细介绍和使用说明。它是以纯文本的形式出现&#xff0c;通常在终端&#xff08;terminal&#xff09;中使用man命令访问。Man Page按照章节…

elementui的el-tabs标签页样式修改

一、官网样式&#xff1a; 二、修改样式 1.去掉下划线 效果&#xff1a; 代码: /* 去掉tabs标签栏下的下划线 */ ::v-deep .el-tabs__nav-wrap::after {position: static !important;/* background-color: #fff; */ } 2.改变下划线颜色 效果&#xff1a; 代码&#xff1a;…

Docker网络-探索容器网络如何相互通信

当今世界&#xff0c;企业热衷于容器化&#xff0c;这需要强大的网络技能来正确配置容器架构&#xff0c;因此引入了 Docker Networking 的概念。Docker 是一种容器化平台&#xff0c;允许您在独立、轻量级的容器中运行应用程序和服务。Docker 提供了一套强大的网络功能&#x…

QNAP(威联通)NAS外远程访问指南,免费内网穿透工具的应用和配置指导——“cpolar内网穿透”

文章目录 前言1. 威联通安装cpolar内网穿透2. 内网穿透2.1 创建隧道2.2 测试公网远程访问 3. 配置固定二级子域名3.1 保留二级子域名3.2 配置二级子域名 4. 使用固定二级子域名远程访问 前言 购入威联通NAS后&#xff0c;很多用户对于如何在外在公网环境下的远程访问威联通NAS…