【DA-CLIP】test.py解读,调用DA-CLIP和IRSDE模型复原计算复原图与GT图SSIM、PSNR、LPIPS

文件路径daclip-uir-main/universal-image-restoration/config/daclip-sde/test.py

代码有部分修改

导包

import argparse
import logging
import os.path
import sys
import time
from collections import OrderedDict
import torchvision.utils as tvutilsimport numpy as np
import torch
from IPython import embed
import lpipsimport options as option
from models import create_modelsys.path.insert(0, "../../")
import open_clip
import utils as util
from data import create_dataloader, create_dataset
from data.util import bgr2ycbcr

注意open_clip使用的是项目里的代码,而非环境里装的那个。data、util、option同样是项目里有的包

声明

#### options
parser = argparse.ArgumentParser()
parser.add_argument("-opt", type=str, default='options/test.yml', help="Path to options YMAL file.")
opt = option.parse(parser.parse_args().opt, is_train=False)opt = option.dict_to_nonedict(opt)

配置文件 

设置配置文件相对地址options/test.yml

在该配置文件中配置GT和LQ图像文件地址

datasets:test1:name: Testmode: LQGTdataroot_GT: C:\Users\86136\Desktop\LQ_test\shadow\GTdataroot_LQ: C:\Users\86136\Desktop\LQ_test\shadow\LQ

设置results_root结果地址,每次计算结束这个地址保存要求记录的计算结果

该目录下Test文件夹将保存一张GT一张LQ一张复原图像  。

不设置也会默认在项目内 daclip-uir-main\results\daclip-sde\universal-ir

#### path
path:pretrain_model_G: E:\daclip\pretrained\universal-ir.pthdaclip: E:\daclip\pretrained\daclip_ViT-B-32.ptresults_root: C:\Users\86136\Desktop\daclip-uir-main\results\daclip-sde\universal-irlog: 

 

#### mkdir and logger
util.mkdirs((pathfor key, path in opt["path"].items()if not key == "experiments_root"and "pretrain_model" not in keyand "resume" not in key)
)# os.system("rm ./result")
# os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result")

 报错执行代码没有删除再创建权限?我把相关os操作注释了,全部保存到result对我影响不大

加载创建数据对

#### Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt["datasets"].items()):test_set = create_dataset(dataset_opt)test_loader = create_dataloader(test_set, dataset_opt)logger.info("Number of test images in [{:s}]: {:d}".format(dataset_opt["name"], len(test_set)))test_loaders.append(test_loader)

 自定义包含复原IR-SDE模型的外层类model,参考app.py

# load pretrained model by default
model = create_model(opt)
device = model.device

 加载DA-CLIP、IR-SDE

# clip_model, _preprocess = clip.load("ViT-B/32", device=device)
if opt['path']['daclip'] is not None:clip_model, preprocess = open_clip.create_model_from_pretrained('daclip_ViT-B-32', pretrained=opt['path']['daclip'])
else:clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')
clip_model = clip_model.to(device)

else是直接使用CLIP的ViT-B-32模型进行测试的代码。与我测DA-CLIP无关。

想使用的话 目测要预先下载对应模型权重并手动修改pretrained为文件地址,否则报错hf无法连接

sde = util.IRSDE(max_sigma=opt["sde"]["max_sigma"], T=opt["sde"]["T"], schedule=opt["sde"]["schedule"], eps=opt["sde"]["eps"], device=device)
sde.set_model(model.model)
lpips_fn = lpips.LPIPS(net='alex').to(device)scale = opt['degradation']['scale']

加载IR-SDE、LPIPS

如果不指定crop_border后续crop_border=scale

处理并计算


for test_loader in test_loaders:test_set_name = test_loader.dataset.opt["name"]  # path opt['']logger.info("\nTesting [{:s}]...".format(test_set_name))test_start_time = time.time()dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name)util.mkdir(dataset_dir)test_results = OrderedDict()test_results["psnr"] = []test_results["ssim"] = []test_results["psnr_y"] = []test_results["ssim_y"] = []test_results["lpips"] = []test_times = []for i, test_data in enumerate(test_loader):single_img_psnr = []single_img_ssim = []single_img_psnr_y = []single_img_ssim_y = []need_GT = False if test_loader.dataset.opt["dataroot_GT"] is None else Trueimg_path = test_data["GT_path"][0] if need_GT else test_data["LQ_path"][0]img_name = os.path.splitext(os.path.basename(img_path))[0]#### input dataset_LQLQ, GT = test_data["LQ"], test_data["GT"]img4clip = test_data["LQ_clip"].to(device)with torch.no_grad(), torch.cuda.amp.autocast():image_context, degra_context = clip_model.encode_image(img4clip, control=True)image_context = image_context.float()degra_context = degra_context.float()noisy_state = sde.noise_state(LQ)model.feed_data(noisy_state, LQ, GT, text_context=degra_context, image_context=image_context)tic = time.time()model.test(sde, save_states=False)toc = time.time()test_times.append(toc - tic)visuals = model.get_current_visuals()SR_img = visuals["Output"]output = util.tensor2img(SR_img.squeeze())  # uint8LQ_ = util.tensor2img(visuals["Input"].squeeze())  # uint8GT_ = util.tensor2img(visuals["GT"].squeeze())  # uint8suffix = opt["suffix"]if suffix:save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png")else:save_img_path = os.path.join(dataset_dir, img_name + ".png")util.save_img(output, save_img_path)# remove it if you only want to save output imagesLQ_img_path = os.path.join(dataset_dir, img_name + "_LQ.png")GT_img_path = os.path.join(dataset_dir, img_name + "_HQ.png")util.save_img(LQ_, LQ_img_path)util.save_img(GT_, GT_img_path)if need_GT:gt_img = GT_ / 255.0sr_img = output / 255.0crop_border = opt["crop_border"] if opt["crop_border"] else scaleif crop_border == 0:cropped_sr_img = sr_imgcropped_gt_img = gt_imgelse:cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border]cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border]psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)lp_score = lpips_fn(GT.to(device) * 2 - 1, SR_img.to(device) * 2 - 1).squeeze().item()test_results["psnr"].append(psnr)test_results["ssim"].append(ssim)test_results["lpips"].append(lp_score)if len(gt_img.shape) == 3:if gt_img.shape[2] == 3:  # RGB imagesr_img_y = bgr2ycbcr(sr_img, only_y=True)gt_img_y = bgr2ycbcr(gt_img, only_y=True)if crop_border == 0:cropped_sr_img_y = sr_img_ycropped_gt_img_y = gt_img_yelse:cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border]cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border]psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255)ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255)test_results["psnr_y"].append(psnr_y)test_results["ssim_y"].append(ssim_y)logger.info("img{:3d}:{:15s} - PSNR: {:.6f} dB; SSIM: {:.6f}; LPIPS: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.".format(i, img_name, psnr, ssim, lp_score, psnr_y, ssim_y))else:logger.info("img:{:15s} - PSNR: {:.6f} dB; SSIM: {:.6f}.".format(img_name, psnr, ssim))test_results["psnr_y"].append(psnr)test_results["ssim_y"].append(ssim)else:logger.info(img_name)ave_lpips = sum(test_results["lpips"]) / len(test_results["lpips"])ave_psnr = sum(test_results["psnr"]) / len(test_results["psnr"])ave_ssim = sum(test_results["ssim"]) / len(test_results["ssim"])logger.info("----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n".format(test_set_name, ave_psnr, ave_ssim))if test_results["psnr_y"] and test_results["ssim_y"]:ave_psnr_y = sum(test_results["psnr_y"]) / len(test_results["psnr_y"])ave_ssim_y = sum(test_results["ssim_y"]) / len(test_results["ssim_y"])logger.info("----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n".format(ave_psnr_y, ave_ssim_y))logger.info("----average LPIPS\t: {:.6f}\n".format(ave_lpips))print(f"average test time: {np.mean(test_times):.4f}")

开头往log记录了相应配置文件内容,不需要可以注释。

遍历测试数据集(test_loaders)计算各种评价指标,如峰值信噪比(PSNR)、结构相似性(SSIM)和感知损失(LPIPS)。

在处理过程中,代码首先会创建一个目录来保存测试结果。

然后,对于每个测试图像,代码会加载对应的图像(如果可用),并使用一个名为clip_model的模型对图像进行编码。

接下来,代码会使用一个名为sde的随机微分方程模型和名为model的深度学习模型来处理带有噪声的图像,并生成复原图像(SR_img)。额可能作者拿了以前做超分的代码没改变量名

在这个过程中,text_contextimage_context被用作模型的输入,

图像都会被保存到之前创建的目录中。

此外,代码还会计算并记录每个图像的PSNR、SSIM和LPIPS分数,并在最后打印出这些分数的平均值。 代码中还包含了一些用于图像处理的实用函数,如util.tensor2img用于将张量转换为图像,util.save_img用于保存图像,以及util.calculate_psnrutil.calculate_ssim用于计算PSNR和SSIM分数。psnr_y和ssim_y 不用可以把相关代码注释。

最后,代码还计算了平均测试时间,并将其打印出来。

结果

log处理的单张图像报错的信息 0是该处理的图像排序序号,即正在处理第0张图

24-04-03 17:28:24.697 - INFO: img  0:_MG_2374_no_shadow - PSNR: 27.779773 dB; SSIM: 0.863140; LPIPS: 0.078669; PSNR_Y: 29.135256 dB; SSIM_Y: 0.869278.

 

可以给复原结果图加个后缀方便区分。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/792421.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数学知识--(质数,约数)

本文用于个人算法竞赛学习,仅供参考 目录 一.质数的判定 二.分解质因数 三.质数筛 1.朴素筛法 2.埃氏筛法 3.线性筛法 四.约数 1.求一个数的所有约数 2.约数个数和约数之和 3.欧几里得算法(辗转相除法)-- 求最大公约数 一.质数的判定 …

新型智慧城市大数据解决方案(附下载)

随着云计算、大数据、移动互联网等技术的发展,由城市运行产生的交通、环境、市政、商业等各领域数据量巨大,这些数据经过合理的分析挖掘可产生大量传统数据不能反映的城市运行信息,已成为智慧城市的重要资产。 在大数据时代,数据信…

理解main方法的语法

由于JVM需要调用类的main()方法,所以该方法的访问权限必须是public,又因为JVM在执行main()方法时不必创建对象,所以该方法必须是static的,该方法接收一个String类型的数组参数,该数组中保存执行Java命令时传递给所运行…

问题解决:写CSDN博文时图片大小不适应,不清晰,没法排版

项目环境: Window10,Edge123.0.2420.65 问题描述: 当我在CSDN写博文的时候,会经常插入一些图片,但有时候我插入的图片太大了,影响了整体排版。 比如我加入了一张图片,就变成了下面这个样子&…

释放 ChatGPT4 的威力

原文:Unleashing the Power of ChatGPT-4: Strategies for Building a Personal Income Stream 译者:飞龙 协议:CC BY-NC-SA 4.0 I. 介绍 在当今快速发展的数字领域中,人工智能(AI)已经成为无数行业的重要…

Kubernetes探索-Deployment面试

1. 简述Deployment的升级策略 在Deployment的定义中,可以通过spec.strategy指定Pod更新的策略,目前支持两种策略:Recreate(重建)和RollingUpdate(滚动更新),默认值为RollingUpdate。…

PEFT-LISA

LISA是LoRA的简化版,但其抓住了LoRA微调的核心,即LoRA侧重更新LLM的底层embedding和顶层head。 根据上述现象,LISA提出两点改进: 始终更新LLM的底层embedding和顶层head随机更新中间层的hidden state 实验结果 显存占用 毕竟模型…

RAMS (Mesoscale Model System) 和 WRF 区别

历史和发展: RAMS:RAMS 最早于1970年代由美国科罗拉多州立大学开发,并在之后几十年不断发展壮大。它是最早用于模拟地区尺度大气动力学、热力学和降水过程的模型之一。WRF:WRF 是由美国国家大气研究中心(NCAR&#xff…

openstack云计算(一)————openstack安装教程,创建空白虚拟机,虚拟机的环境准备

1、创建空白虚拟机 需要注意的步骤会截图一下,其它的基本都是下一步,默认的即可 ----------------------------------------------------------- 2、在所建的空白虚拟机上安装CentOS 7操作系统 (1)、在安装CentOS 7的启动界面中…

RuoYi-Vue若依框架-集成mybatis-plus报错Unknown column ‘search_value‘ in ‘field list‘

报错信息 ### Error querying database. Cause: java.sql.SQLSyntaxErrorException: Unknown column search_value in field list ### The error may exist in com/ruoyi/sales/mapper/ZcSpecificationsMapper.java (best guess) ### The error may involve defaultParameter…

C++之STL的algorithm(6)之排序算法(sort、merge)整理

C之STL的algorithm(6)之排序算法(sort、merge)整理 注:整理一些突然学到的C知识,随时mark一下 例如:忘记的关键字用法,新关键字,新数据结构 C 的排序算法整理 C之STL的al…

Oracle数据库——分组函数四

12.1什么是分组函数 分组函数作用于一组数据,并对一组数据返回一个值。例如求平均数,最大值等等。 1.组函数类型 AVG :求平均数COUNT :COUNT(expr) 返回expr不为空的记录总数。 MAX 求最大值MIN

【简单讲解下WebSocket】

🌈个人主页:程序员不想敲代码啊 🏆CSDN优质创作者,CSDN实力新星,CSDN博客专家 👍点赞⭐评论⭐收藏 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共…

Vue探索之Vue2.x源码分析(一)

一.响应式数据之数组的处理 <template><div><ul><li v-for"(item, index) in items" :key"index">{{ item }}<button click"removeItem(index)">Remove</button></li></ul><input v-model&…

Python卷积网络车牌识别系统(V2.0)

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

基于单片机的无线红外报警系统

**单片机设计介绍&#xff0c;基于单片机的无线红外报警系统 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机的无线红外报警系统是一种结合了单片机控制技术和无线红外传感技术的安防系统。该系统通过无线红外传感器实…

SpringBoot -- 外部化配置

我们如果要对普通程序的jar包更改配置&#xff0c;那么我们需要对jar包解压&#xff0c;并在其中的配置文件中更改配置参数&#xff0c;然后再打包并重新运行。可以看到过程比较繁琐&#xff0c;SpringBoot也注意到了这个问题&#xff0c;其可以通过外部配置文件更新配置。 我…

鸿蒙系统前端:构建智能互联新时代的界面之美

随着华为鸿蒙系统的推出&#xff0c;前端技术也迎来了新的挑战与机遇。鸿蒙系统&#xff0c;作为华为自主研发的分布式操作系统&#xff0c;旨在打通各类智能设备&#xff0c;为用户提供一个无缝的智能互联体验。在这个宏大的愿景下&#xff0c;鸿蒙系统的前端设计显得尤为重要…

Java作业练习_第六周作业多态性(小白学习记录,仅供参考,有错指出)

题目排序&#xff08;点击直达&#xff09; 第一题第二题第三题第四题第五题第六题第七题第八题免责声明 第一题 写出下列程序的运行结果&#xff1a; package com.cxl.ch5.demo5; public class Base {int m0;public int getM(){return m;} } package com.cxl.ch5.demo5;publ…

第18章 JDK8-17新特性

1. Java版本迭代概述 1.1 发布特点&#xff08;小步快跑&#xff0c;快速迭代&#xff09; 发行版本发行时间备注Java 1.01996.01.23Sun公司发布了Java的第一个开发工具包Java 5.02004.09.30①版本号从1.4直接更新至5.0&#xff1b;②平台更名为JavaSE、JavaEE、JavaMEJava 8…