SINet(CVPR2020)复现及问题记录

1. 创建虚拟环境

我创建的是 python 3.8 版本

conda create -n SINet python=3.8

然后进入虚拟环境

2. 克隆项目代码

git clone https://github.com/DengPingFan/SINet.git

3. 安装依赖

我安装的是 pytorch==1.11.0版本, 通过 conda 安装 pytorch,torchvision cudatoolkit 的命令

conda install pytorch==1.11.0 torchvision==0.12.0 cudatoolkit=11.3 -c pytorch

然后还需要安装 scipy 和 opencv-python, imageio

pip install scipy opencv-python imageio

4. 需要修改的代码

GPU参数设置

由于我是单卡训练,将 MyTrain.py 中原先默认的 default=1 修改成 default=0

  parser.add_argument('--gpu', type=int, default=0,

MyTrain.py 和 Src/utils/trainer.py

MyTrain.py 和 Src/utils/trainer.py 中都引入了apex的amp,
from apex import amp
但是目前已经不支持该API了,这里使用 torch.cuda 的amp来实现,
from torch.cuda.amp 所以修改后的

MyTrain.py

import torch
import argparse
from Src.SINet import SINet_ResNet50
from Src.utils.Dataloader import get_loader
from Src.utils.trainer import trainer, adjust_lr
from torch.cuda.amp import GradScalerif __name__ == "__main__":parser = argparse.ArgumentParser()parser.add_argument('--epoch', type=int, default=40,help='epoch number, default=30')parser.add_argument('--lr', type=float, default=1e-4,help='init learning rate, try `lr=1e-4`')parser.add_argument('--batchsize', type=int, default=36,help='training batch size (Note: ~500MB per img in GPU)')parser.add_argument('--trainsize', type=int, default=352,help='the size of training image, try small resolutions for speed (like 256)')parser.add_argument('--clip', type=float, default=0.5,help='gradient clipping margin')parser.add_argument('--decay_rate', type=float, default=0.1,help='decay rate of learning rate per decay step')parser.add_argument('--decay_epoch', type=int, default=30,help='every N epochs decay lr')parser.add_argument('--gpu', type=int, default=0,help='choose which gpu you use')parser.add_argument('--save_epoch', type=int, default=10,help='every N epochs save your trained snapshot')parser.add_argument('--save_model', type=str, default='./Snapshot/2020-CVPR-SINet/')parser.add_argument('--train_img_dir', type=str, default='./Dataset/TrainDataset/Image/')parser.add_argument('--train_gt_dir', type=str, default='./Dataset/TrainDataset/GT/')opt = parser.parse_args()num_gpus = torch.cuda.device_count()if opt.gpu >= num_gpus:raise ValueError(f"GPU device number is invalid. This system has {num_gpus} GPUs, but gpu {opt.gpu} was requested.")torch.cuda.set_device(opt.gpu)# TIPS: you also can use deeper network for better performance like channel=64model_SINet = SINet_ResNet50(channel=32).cuda()print('-' * 30, model_SINet, '-' * 30)optimizer = torch.optim.Adam(model_SINet.parameters(), opt.lr)LogitsBCE = torch.nn.BCEWithLogitsLoss()scaler = GradScaler()train_loader = get_loader(opt.train_img_dir, opt.train_gt_dir, batchsize=opt.batchsize,trainsize=opt.trainsize, num_workers=12)total_step = len(train_loader)print('-' * 30, "\n[Training Dataset INFO]\nimg_dir: {}\ngt_dir: {}\nLearning Rate: {}\nBatch Size: {}\n""Training Save: {}\ntotal_num: {}\n".format(opt.train_img_dir, opt.train_gt_dir, opt.lr,opt.batchsize, opt.save_model, total_step), '-' * 30)for epoch_iter in range(1, opt.epoch):adjust_lr(optimizer, epoch_iter, opt.decay_rate, opt.decay_epoch)trainer(train_loader=train_loader, model=model_SINet,optimizer=optimizer, epoch=epoch_iter,opt=opt, loss_func=LogitsBCE, total_step=total_step)torch.save(model_SINet.state_dict(), os.path.join(opt.save_model, 'SINet_Final.pth'))print("\n[Congratulations! Training Done]")

修改后的 Src/utils/trainer.py

import torch
from torch.autograd import Variable
from datetime import datetime
import os
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocastdef eval_mae(y_pred, y):"""evaluate MAE (for test or validation phase):param y_pred::param y::return: Mean Absolute Error"""return torch.abs(y_pred - y).mean()def numpy2tensor(numpy):"""convert numpy_array in cpu to tensor in gpu:param numpy::return: torch.from_numpy(numpy).cuda()"""return torch.from_numpy(numpy).cuda()def clip_gradient(optimizer, grad_clip):"""recalibrate the misdirection in the training:param optimizer::param grad_clip::return:"""for group in optimizer.param_groups:for param in group['params']:if param.grad is not None:param.grad.data.clamp_(-grad_clip, grad_clip)def adjust_lr(optimizer, epoch, decay_rate=0.1, decay_epoch=30):decay = decay_rate ** (epoch // decay_epoch)for param_group in optimizer.param_groups:param_group['lr'] *= decaydef trainer(train_loader, model, optimizer, epoch, opt, loss_func, total_step):"""Training iteration:param train_loader::param model::param optimizer::param epoch::param opt::param loss_func::param total_step::return:"""model.train()scaler = GradScaler()for step, data_pack in enumerate(train_loader):optimizer.zero_grad()images, gts = data_packimages = Variable(images).cuda()gts = Variable(gts).cuda()with autocast():cam_sm, cam_im = model(images)loss_sm = loss_func(cam_sm, gts)loss_im = loss_func(cam_im, gts)loss_total = loss_sm + loss_imscaler.scale(loss_total).backward()# clip_gradient(optimizer, opt.clip)scaler.step(optimizer)scaler.update()if step % 10 == 0 or step == total_step:print('[{}] => [Epoch Num: {:03d}/{:03d}] => [Global Step: {:04d}/{:04d}] => [Loss_s: {:.4f} Loss_i: {:0.4f}]'.format(datetime.now(), epoch, opt.epoch, step, total_step, loss_sm.data, loss_im.data))save_path = opt.save_modelos.makedirs(save_path, exist_ok=True)if (epoch+1) % opt.save_epoch == 0:torch.save(model.state_dict(), save_path + 'SINet_%d.pth' % (epoch+1))

MyTest.py

因为目前已经不支持 misc 进行图像文件写存了,所以使用 imageio 保存预测的图像
将 6 行的

from scipy import misc

修改成:

import imageio

将 49 行的

misc.imsave(save_path+name, cam)

修改成:

imageio.imsave(save_path + name, cam)

然后,原始代码只在 COD10K 测试集上进行了测试,如果需要在多个测试集进行测试,修改

for dataset in ['COD10K'']:

在这个 list 中添加数据集所在文件夹名称就 ok

for dataset in ['COD10K', 'CAMO', 'CHAMELEON', 'NC4K']:

此外,nn.functional.upsample 已被弃用,应使用 nn.functional.interpolate。需要

  1. F.upsample 替换为 F.interpolate
  2. 将浮点数数组转换为 uint8 格式,以便可以保存为 PNG 文件。具体来说,通过 (cam * 255).astype(np.uint8) 将浮点数数组转换为 uint8 格式。

完整的 MyTest.py

import torch
import torch.nn.functional as F
import numpy as np
import os
import argparse
import imageio
from Src.SINet import SINet_ResNet50
from Src.utils.Dataloader import test_dataset
from Src.utils.trainer import eval_mae, numpy2tensorparser = argparse.ArgumentParser()
parser.add_argument('--testsize', type=int, default=352, help='the snapshot input size')
parser.add_argument('--model_path', type=str,default='./Snapshot/2020-CVPR-SINet/SINet_40.pth')
parser.add_argument('--test_save', type=str,default='./Result/2020-CVPR-SINet-New/')
opt = parser.parse_args()model = SINet_ResNet50().cuda()
model.load_state_dict(torch.load(opt.model_path))
model.eval()for dataset in ['COD10K', 'CAMO', 'CHAMELEON', 'NC4K']:save_path = opt.test_save + dataset + '/'os.makedirs(save_path, exist_ok=True)# NOTES:#  if you plan to inference on your customized dataset without grouth-truth,#  you just modify the params (i.e., `image_root=your_test_img_path` and `gt_root=your_test_img_path`)#  with the same filepath. We recover the original size according to the shape of grouth-truth, and thus,#  the grouth-truth map is unnecessary actually.test_loader = test_dataset(image_root='./Dataset/TestDataset/{}/Image/'.format(dataset),gt_root='./Dataset/TestDataset/{}/GT/'.format(dataset),testsize=opt.testsize)img_count = 1for iteration in range(test_loader.size):# load dataimage, gt, name = test_loader.load_data()gt = np.asarray(gt, np.float32)gt /= (gt.max() + 1e-8)image = image.cuda()# inference_, cam = model(image)# reshape and squeezecam = F.interpolate(cam, size=gt.shape, mode='bilinear', align_corners=True)cam = cam.sigmoid().data.cpu().numpy().squeeze()# normalizecam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)# convert to uint8cam = (cam * 255).astype(np.uint8)imageio.imsave(save_path + name, cam)# evaluatemae = eval_mae(numpy2tensor(cam), numpy2tensor(gt))# coarse scoreprint('[Eval-Test] Dataset: {}, Image: {} ({}/{}), MAE: {}'.format(dataset, name, img_count,test_loader.size, mae))img_count += 1print("\n[Congratulations! Testing Done]")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/841003.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

IT学习笔记--Docker

Docker基本概念: Docker是一个开源的应用容器引擎,基于 Go 语言 并遵从 Apache2.0 协议开源。可以让开发者打包他们的应用以及依赖包到一个轻量级、可移植的容器中,然后发布到任何流行的 Linux 机器上,也可以实现虚拟化。容器是完…

代码随想录35期Day50-Java

Day50题目 LeetCode309买卖股票有冷冻期 核心思想:因为有冷冻期,和之前的题目的状态就有所不同,状态的确定是很难的 0 持有股票的状态 1 没有股票的状态 2 当天卖出股票 3 前一天卖出股票,今天就是冷冻期状态 class Solution {public int maxProfit(int[] prices) {// 四个状…

程序员做推广?我劝你别干

关注卢松松,会经常给你分享一些我的经验和观点。 这是卢松松会员专区,一位会员朋友的咨询,如果你也有自研产品,但不知道如何推广,一定要阅读本文!强烈建议收藏关注,因为你关注的人,决定你看到的…

AWS容器之Fargate

AWS Fargate是亚马逊提供的一种容器管理服务,它允许开发人员在AWS云中轻松运行容器化应用程序,而无需管理底层的服务器基础架构。Fargate可以自动管理容器的部署、扩展和负载平衡,并提供了与ECS和EKS等AWS容器服务集成的能力。适用于容器的无…

技术周总结 2024.05.20~05.26 (Java架构师 数据库理论 MyBatis)

文章目录 一、 问题01 在数据库理论的阿姆斯特朗公理中的自反性规则指什么自反性(Reflexivity)详细解释自反性规则的形式化描述自反性规则的具体含义示例 总结 二、问题02: 数据库理论中的笛卡尔联结和自然联结区别笛卡尔联结(Cartesian Join…

【机器学习300问】98、卷积神经网络中的卷积核到底有什么用?以边缘检测为例说明其意义。

卷积核是用于从输入数据中提取特征的关键工具。卷积核的设计直接关系到网络能够识别和学习的特征类型。本文让我以边缘检测为例,带大家深入理解卷积核的作用。 一、卷积核的作用 卷积核,又称为过滤器,本质上是一个小的矩阵,其元素…

微信小程序毕业设计-智慧旅游平台系统项目开发实战(附源码+演示视频+LW)

大家好!我是程序猿老A,感谢您阅读本文,欢迎一键三连哦。 💞当前专栏:微信小程序毕业设计 精彩专栏推荐👇🏻👇🏻👇🏻 🎀 Python毕业设计…

【算法】二分算法——山脉数组的峰顶索引

该题用二分算法解“山脉数组的峰顶索引”,有需要借鉴即可。 目录 1.题目2.总结 1.题目 题目链接:LINK 暴力求解很简单,这里不再提及。 这个可以根据峰顶值分为两部分,因而具有“二段性”,可以用二分算法&#xff0c…

Java抽象类

明确设计思想 子类越来越具体,父类需要越来越通用 父类和子类保证能够共享特征 父类的设计有时非常抽象,以至于它没有具体的实例 抽象类和抽象方法 abstract关键字修饰一个类,这个类叫做抽象类 abstract关键字修饰一个方法,…

git命令新建远程仓库

今天记录一下使用git命令新建远程分支的操作,因为公司的代码管理仓库界面没找到新建分支的操作界面,无奈只能通过git命令来新建分支。 1、新建本地分支 首先,你的至少应该已经有了一个master分支,然后你再master分支下面执行下面…

默认路由实现两个网段互通实验

默认路由实现两个网段互通实验 **默认路由:**是一种特殊的静态路由,当路由表中与数据包目的地址没有匹配的表项时,数据包将根据默认路由条目进行转发。默认路由在某些时候是非常有效的,例如在末梢网络中,默认路由可以…

React中的使用ref 操作Dom

跟vue 中的类似 也有ref 操作dom 由于 React 会自动更新 DOM 以匹配渲染输出,因此组件通常不需要操作 DOM。但是,有时可能需要访问由 React 管理的 DOM 元素——例如聚焦节点、滚动到此节点,以及测量它的尺寸和位置。React 没有内置的方法来执…

Postgresql源码(133)优化器动态规划生成连接路径的实例分析

物理算子的生成分为两步,基表的扫描路径生成set_base_rel_pathlists;连接路径生成(make_rel_from_joinlist动态规划)。本篇简单分析实现。看过代码会发现,“基表的扫描路径生成”其实就是作为连接路径生成dp计算的第一…

【Mac】MWeb Pro(好用的markdown编辑器) v4.5.9中文版安装教程

软件介绍 MWeb Pro for Mac是一款Mac上的Markdown编辑器软件,它支持实时预览,语法高亮,自动保存和备份等功能,并且有多种主题和样式可供选择。此外,MWeb还支持多种导出格式,包括HTML、PDF、Word、ePub等&a…

栈和队列的经典例题,LeetCode 括号匹配问题;栈实现队列;队列实现栈;队列带环问题

1.前序 又有很久没有更新文章了&#xff0c;这次带你们手撕几道基础题&#xff1b;真的就和康纳吃饭一样简单&#xff01;&#xff01;&#xff01; 如果还不会队列和栈的可以去看看之前写的博客&#xff1b; 栈的实现 队列概念以及实现 <- 快速传送 目录 1.前序 …

HTML中 video标签样式铺满全屏

video标签默认不是铺满的&#xff0c;即使手动设置宽高100%也不会生效&#xff0c;所以当需要video铺满div时&#xff0c;需要加上一个css样式 <videocontrolsstyle"width: 100%; height: 100%; object-fit: fill"autoplay:src"item.video" ></v…

自定义全局变量3

变量删除 语法 unset var_name演示 自定义常量 介绍 就是变量设置值以后不可以修改的变量叫常量, 也叫只读变量 语法 readonly var_name演示 自定义全局变量 父子Shell环境介绍 例如: 有2个Shell脚本文件 A.sh 和 B.sh 如果 在A.sh脚本文件中执行了B.sh脚本文件, 那么A.…

【Web】CISCN 2024初赛 题解(全)

目录 Simple_php easycms easycms_revenge ezjava mossfern sanic Simple_php 用php -r进行php代码执行 因为ban了引号&#xff0c;考虑hex2bin&#xff0c;将数字转为字符串 php -r eval(hex2bin(16进制)); 注意下面这段报错&#xff0c;因为加不了引号&#xff0c;开…

链表-设计LRU缓存结构

题目描述&#xff1a; 代码实现&#xff1a;这里记录了根据LRU算法原理最直接理解的代码实现。 import java.util.*;//存储输入内容&#xff0c;记录访问权值 class CounterInfo {int key;int value;int times;//代表key对应的权值&#xff0c;值越小优先级越高public Counter…

【第2章】SpringBoot配置文件

文章目录 前言一、编写配置信息1. properties2. yml 二、获取配置信息1.直接获取2.配置类形式 总结 前言 SpringBoot工程创建后&#xff0c;会为我们提供一个默认的配置文件(application.properties)&#xff0c;配置文件主要用于那些可能发生变化且经常改变的属性值。 一、编…