视频去噪网络BSVD的实现

前些天写了视频去噪网络BSVD论文的理解,详情请点击这里,这两个星期动手实践了一下,本篇就来记录一下这个模型的实现。

这个网络的独特之处在于,它的训练和推理在实现上有所差别。在训练阶段,其使用了TSM(Time Shift Module)结构,而在推理时则使用了BBB(Bidirectional Buffer Block)结构。训练时,网络是一个MIMO(多输入多输出)形式,而在推理时,则将其设计成了单输入、单输出的流式形式。推理时,由于网络中存在16个双向buffer,即BBB,因此,前16帧会输出空数据,16帧之后开始正常输出去噪视频帧,到视频序列结束后,还会继续输出16帧的去噪视频帧,也就是,流式推理整体存在16帧的延迟。这在一些对实时性要求不太高的应用中可以推广,但对于实时性要求严格,并且存储资源有限的应用中,就无法有效应用了。

下面,我们就通过对官方代码的理解,来聊一聊BSVD的实现。

官方代码地址:GitHub - ChenyangQiQi/BSVD: [ACM MM 2022] Real-time Streaming Video Denoising with Bidirectional Buffers

BSVD网络采用了两个UNet级联的方式。

1. 训练阶段的网络实现

在训练阶段,网络的实现如下:

class WNet(nn.Module):def __init__(self, chns=[32, 64, 128], mid_ch=3, shift_input=False, stage_num=2, in_ch=4, out_ch=3, norm='bn', act='relu', interm_ch=30, blind=False):# def __init__(self, chns=[32, 64, 128], mid_ch=3, shift_input=False, stage_num=2, in_ch=4, out_ch=3, norm='bn', act='relu', blind=False):super(WNet, self).__init__()self.stage_num = stage_numself.nets_list = nn.ModuleList()for i in np.arange(stage_num):if i == 0:stage_in_ch = in_chelse:stage_in_ch = mid_chif i == (stage_num-1):stage_out_ch = out_chelse:stage_out_ch = mid_ch# self.nets_list.append(DenBlock(chns=chns, out_ch=stage_out_ch, in_ch=stage_in_ch, shift_input=shift_input, norm=norm, act=act, interm_ch=interm_ch))if i == 0:self.nets_list.append(DenBlock(chns=chns, out_ch=stage_out_ch, in_ch=stage_in_ch, shift_input=shift_input, norm=norm, act=act, blind=blind, interm_ch=interm_ch))else:self.nets_list.append(DenBlock(chns=chns, out_ch=stage_out_ch,in_ch=stage_in_ch, shift_input=shift_input, norm=norm, act=act, interm_ch=interm_ch))# self.temp2 = DenBlock(chns=chns, in_ch=mid_ch, shift_input=shift_input)# Init weightsself.reset_params()@staticmethoddef weight_init(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, nonlinearity='relu')def reset_params(self):for _, m in enumerate(self.modules()):self.weight_init(m)def forward(self, x, debug=False):# if debug: x_in = x# x = self.temp1(x)for i in np.arange(self.stage_num):if debug: x_temp1 = xx = self.nets_list[i](x)# if debug: x_temp2 = xreturn x

网络由两个DenBlock组成,每个DenBlock是一个UNet结构:


class DenBlock(nn.Module):""" Definition of the denosing block of FastDVDnet.Inputs of constructor:num_input_frames: int. number of input framesInputs of forward():xn: input frames of dim [N, C, H, W], (C=3 RGB)noise_map: array with noise map of dim [N, 1, H, W]"""def __init__(self, chns=[32, 64, 128], out_ch=3, in_ch=4, shift_input=False, norm='bn', bias=True,  act='relu', interm_ch=30, blind=False):# def __init__(self, chns=[32, 64, 128], out_ch=3, in_ch=4, shift_input=False, norm='bn', bias=True,  act='relu', blind=False):super(DenBlock, self).__init__()self.chs_lyr0, self.chs_lyr1, self.chs_lyr2 = chns# if stage2: in_ch=3if shift_input:self.inc = CvBlock(in_ch=in_ch, out_ch=self.chs_lyr0, norm=norm, bias=bias, act=act)else:self.inc = InputCvBlock(num_in_frames=1, out_ch=self.chs_lyr0, in_ch=in_ch, norm=norm, bias=bias, act=act, interm_ch=interm_ch, blind=blind)# num_in_frames=1, out_ch=self.chs_lyr0, in_ch=in_ch, norm=norm, bias=bias, act=act, blind=blind)self.downc0 = DownBlock(in_ch=self.chs_lyr0, out_ch=self.chs_lyr1, norm=norm, bias=bias, act=act)self.downc1 = DownBlock(in_ch=self.chs_lyr1, out_ch=self.chs_lyr2, norm=norm, bias=bias, act=act)self.upc2 = UpBlock(in_ch=self.chs_lyr2, out_ch=self.chs_lyr1, norm=norm, bias=bias,    act=act)self.upc1 = UpBlock(in_ch=self.chs_lyr1, out_ch=self.chs_lyr0, norm=norm, bias=bias,    act=act)self.outc = OutputCvBlock(in_ch=self.chs_lyr0, out_ch=out_ch, norm=norm, bias=bias,     act=act)self.reset_params()@staticmethoddef weight_init(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, nonlinearity='relu')def reset_params(self):for _, m in enumerate(self.modules()):self.weight_init(m)def forward(self, in1):'''Args:inX: Tensor, [N, C, H, W] in the [0., 1.] rangenoise_map: Tensor [N, 1, H, W] in the [0., 1.] range'''# Input convolution blockx0 = self.inc(in1)# Downsamplingx1 = self.downc0(x0)x2 = self.downc1(x1)# Upsamplingx2 = self.upc2(x2)x1 = self.upc1(x1+x2)# Estimationx = self.outc(x0+x1)# Residualx[:, :3, :, :] = in1[:, :3, :, :] - x[:, :3, :, :]return x

这段代码与论文中的UNet结构相对应(见下图),包含一个输入层,两个下采样层,两个上采样层,一个输出层。

输入层没什么特别可说的,主要是两个Conv2d=>BN=>ReLU的组合;输出层也是常规实现,Con2d=>BN=>ReLU=>Con2d,需要注意的是,作者在实现过程中,BN层是没有使用的,是透传通过。

需要花心思理解的是下采样层和上采样层的实现,因为这两个模块在训练和推理过程中,是有所不同的。

两个模块的初始实现很简单,定义如下:

class DownBlock(nn.Module):'''Downscale + (Conv2d => BN => ReLU)*2'''def __init__(self, in_ch, out_ch, norm='bn', bias=True, act='relu'):super(DownBlock, self).__init__()norm_fn = get_norm_function(norm)act_fn = get_act_function(act)self.convblock = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3,padding=1, stride=2, bias=bias),norm_fn(out_ch),act_fn(inplace=True),CvBlock(out_ch, out_ch, norm=norm, bias=bias, act=act))def forward(self, x):return self.convblock(x)class UpBlock(nn.Module):'''(Conv2d => BN => ReLU)*2 + Upscale'''def __init__(self, in_ch, out_ch, norm='bn', bias=True, act='relu'):super(UpBlock, self).__init__()# norm_fn = get_norm_function(norm)self.convblock = nn.Sequential(CvBlock(in_ch, in_ch, norm=norm, bias=bias, act=act),nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1, bias=bias),nn.PixelShuffle(2))return self.convblock(x)

关键在于两者共同调用的子模块CvBlock的实现,在定义时,CvBlock被常规定义为:

class CvBlock(nn.Module):'''(Conv2d => BN => ReLU) x 2'''def __init__(self, in_ch, out_ch, norm='bn', bias=True, act='relu'):super(CvBlock, self).__init__()norm_fn = get_norm_function(norm)act_fn = get_act_function(act)self.c1 = nn.Conv2d(in_ch, out_ch, kernel_size=3,padding=1, bias=bias)self.b1 = norm_fn(out_ch)self.relu1 = act_fn(inplace=True)self.c2 = nn.Conv2d(out_ch, out_ch, kernel_size=3,padding=1, bias=bias)self.b2 = norm_fn(out_ch)self.relu2 = act_fn(inplace=True)def forward(self, x):x = self.c1(x)x = self.b1(x)x = self.relu1(x)x = self.c2(x)x = self.b2(x)x = self.relu2(x)return x

但接下来,上述定义中的c1和c2则被替换成了TSM实现:

其中,shift模块的核心实现代码如下,对输入的channels分别向左和向右移动了一定单位(fold)。

def shift(x, n_segment, shift_type, fold_div=3, stride=1, inplace=False):nt, c, h, w = x.size()n_batch = nt // n_segmentx = x.view(n_batch, n_segment, c, h, w)fold = c // fold_div # 32/8 = 4if inplace:# Due to some out of order error when performing parallel computing. # May need to write a CUDA kernel.print("WARNING: use inplace shift. it has bugs")raise NotImplementedError  else:out = torch.zeros_like(x)if not 'toFutureOnly' in shift_type:out[:, :-stride, :fold] = x[:, stride:, :fold]  # backward (left shift)out[:, stride:, fold: 2 * fold] = x[:, :-stride, fold: 2 * fold]  # forward (right shift)else:out[:, stride:, : 2 * fold] = x[:, :-stride, : 2 * fold] # right shift onlyout[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shiftreturn out.view(nt, c, h, w)

2. 推理阶段的网络实现

在推理阶段,网络实现就显得复杂一些了。大致的网络结构没变,但由于内部的TSM替换成了BBB, 因此没办法严格进行整体网络的加载,只能每一层单独加载训练出来的state_dict。并且,网络推理变成了流式推理,整个网络的定义显得比较凌乱,结构如下:

class BSVD(nn.Module):"""Bidirection-buffer based framework with pipeline-style inference"""def __init__(self, chns=[32, 64, 128], mid_ch=3, shift_input=False, in_ch=4, out_ch=3, norm='bn', act='relu', interm_ch=30, blind=False, pretrain_ckpt='./experiments/pretrained_ckpt/bsvd-64.pth'):super(BSVD, self).__init__()self.temp1 = DenBlock(chns=chns, out_ch=mid_ch, in_ch=in_ch,  shift_input=shift_input, norm=norm, act=act, blind=blind, interm_ch=interm_ch)self.temp2 = DenBlock(chns=chns, out_ch=out_ch, in_ch=mid_ch, shift_input=shift_input, norm=norm, act=act, blind=blind, interm_ch=interm_ch)self.shift_num = self.count_shift()# Init weightsself.reset_params()if pretrain_ckpt is not None:self.load(pretrain_ckpt)def reset(self):self.temp1.reset()self.temp2.reset()def load(self, path):ckpt = torch.load(path)print("load from %s"%path)ckpt_state = ckpt['params']# split the dict hereif 'module' in list(ckpt_state.keys())[0]:base_name = 'module.base_model.'else:base_name = 'base_model.'ckpt_state_1 = extract_dict(ckpt_state, string_name=base_name+'nets_list.0.')ckpt_state_2 = extract_dict(ckpt_state, string_name=base_name+'nets_list.1.')self.temp1.load_from(ckpt_state_1)self.temp2.load_from(ckpt_state_2)@staticmethoddef weight_init(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, nonlinearity='relu')def reset_params(self):for _, m in enumerate(self.modules()):self.weight_init(m)def feedin_one_element(self, x):x   = self.temp1(x)x   = self.temp2(x)return xdef forward(self, input, noise_map=None):# N, F, C, H, W -> (N*F, C, H, W)if noise_map != None:input = torch.cat([input, noise_map], dim=2)N, F, C, H, W = input.shapeinput = input.reshape(N*F, C, H, W)base_out = self.streaming_forward(input)NF, C, H, W = base_out.shapebase_out = base_out.reshape(N, F, C, H, W)return base_outdef streaming_forward(self, input_seq):"""pipeline-style inferenceArgs:Noisy video streamReturns:Denoised video stream"""out_seq = []if isinstance(input_seq, torch.Tensor):n,c,h,w = input_seq.shapeinput_seq = [input_seq[i:i+1, ...] for i in np.arange(n)]assert type(input_seq) == list, "convert the input into a sequence"_,c,h,w = input_seq[0].shapewith torch.no_grad():for i, x in enumerate(input_seq):x_cuda = x.cuda()x_cuda = self.feedin_one_element(x_cuda)# if x_cuda is not None: x_cuda = x_cuda.cpu()if isinstance(x_cuda, torch.Tensor):out_seq.append(x_cuda)else:out_seq.append(x_cuda)end_out = self.feedin_one_element(None)out_seq.append(end_out)# end stagewhile 1:end_out = self.feedin_one_element(None)if len(out_seq) == (self.shift_num+len(input_seq)):breakout_seq.append(end_out)# number of temporal shift is 2, last element is 0# TODO fix init and end framesout_seq_clip = out_seq[self.shift_num:]self.reset()return torch.cat(out_seq_clip, dim=0)def count_shift(self):count = 0for name, module in self.named_modules():# print(type(module))if "BiBufferConv" in str(type(module)):count+=1return count

两个UNet的定义(DenBlock)大体上没发生变化,但下采样模块和上采样模块的定义发生了改变。

下采样层如下,原来带有TSM的CvBlock换成了MemCvBlock:

上采样模块也类似:

 

而MemCvBlock则调用了BBB模块,BBB模块的实现如下,这是整个算法的核心:

class BiBufferConv(nn.Module):def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,bias=True) -> None:super(BiBufferConv, self).__init__()self.op = ShiftConv(in_channels,out_channels,kernel_size,stride,padding,bias)self.out_channels = out_channelsself.left_fold_2fold = None# self.zero_tensor = Noneself.center = Nonedef reset(self):self.left_fold_2fold = Noneself.center = Nonedef forward(self, input_right, verbose=False):fold_div = 8if input_right is not None:self.n, self.c, self.h, self.w = input_right.size()self.fold = self.c//fold_div# Case1: In the start or end stage, the memory is emptyif self.center is None:self.center = input_right# if verbose:if input_right is not None:if self.left_fold_2fold is None:# In the start stage, the memory and left tensor is emptyself.left_fold_2fold = torch.zeros((self.n, self.fold, self.h, self.w), device=torch.device('cuda'))if verbose: print("%f+none+%f = none"%(torch.mean(self.left_fold_2fold), torch.mean(input_right)))else:# in the end stage, both feed in and memory are emptyif verbose: print("%f+none+none = none"%(torch.mean(self.left_fold_2fold)))# print("self.center is None")return None# Case2: Center is not None, but input_right is Noneelif input_right is None:# In the last procesing stage, center is 0output =  self.op(self.left_fold_2fold, self.center, torch.zeros((self.n, self.fold, self.h, self.w), device=torch.device('cuda')))if verbose: print("%f+%f+none = %f"%(torch.mean(self.left_fold_2fold), torch.mean(self.center), torch.mean(output)))else:output =  self.op(self.left_fold_2fold, self.center, input_right)if verbose: print("%f+%f+%f = %f"%(torch.mean(self.left_fold_2fold), torch.mean(self.center), torch.mean(input_right), torch.mean(output)))# if output == 57:# a = 1self.left_fold_2fold = self.center[:, self.fold:2*self.fold, :, :]self.center = input_rightreturn output

这样,通过BBB模块,就实现了16个双向Buffer的填充、更新和清空。

限于篇幅,先梳理出个大体的思路,实际上还有很多细节需要特别关注,留待下一篇来写吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/116068.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于斑马优化的BP神经网络(分类应用) - 附代码

基于斑马优化的BP神经网络(分类应用) - 附代码 文章目录 基于斑马优化的BP神经网络(分类应用) - 附代码1.鸢尾花iris数据介绍2.数据集整理3.斑马优化BP神经网络3.1 BP神经网络参数设置3.2 斑马算法应用 4.测试结果:5.M…

【数据仓库-零】数据仓库知识体系 ing

文章目录 一. 数仓基本概念二. 离线数仓建设方法论三. etl流程四. 数仓规范建设指南四. 数据仓库架构五. 数据可视化 通过熟悉构建数仓整体的过程,可以系统的了解 数仓构建理论:能够站在全局角度看数仓的运行架构,数仓执行流程。了解到构建数…

车道线检测laneatt 学习笔记

目录 图片检测可视化 图片检测可视化 import logging import argparse import os import timeimport cv2 import numpy as np import torchfrom lib.config import Config from lib.runner import Runner from lib.experiment import Experimentdef parse_args():parser = ar…

虚拟机安装centos系统后配置桥接网络

一.桥接网络和nat网络的区别 桥接模式 通过使用物理机网卡 具有单独ip,但是需要手动配置。 在bridged模式下,VMWare虚拟出来的操作系统就像是局域网中的一台独立的主机,它可以访问网内任何一台机器。主机网卡和虚拟网卡的IP地址处于同一个网段&#xff…

App爬虫之强大的Airtest的操作总结

App爬虫之强大的Airtest的操作总结 App爬虫之强大的Airtest的操作总结 # Python使用该框架需要安装的依赖库 pip install airtest pip install poco pip install pocouifrom airtest.core.api import * from airtest.cli.parser import cli_setup from poco.drivers.android.…

Mybatis的SqlRunner执行流程

Mybatis的SqlRunner执行流程 SqlRunner exec new SqlRunner(connection); Map<String, Object> row exec.selectOne("SELECT * FROM PRODUCT WHERE PRODUCTID ?", "FI-SW-01");connection.close();assertEquals("FI-SW-01", row.ge…

【QT开发(10)】QT 进程

文章目录 1.1 运行一个新进程1.2 QProcess 还可以对一些信号进行关联2 进程间通信2.1 使用共享内存实现进程通信2.2 演示 代码仓库参考 1.1 运行一个新进程 使用类 QProcess&#xff0c;允许将一个进程堪称一个顺序IO设备。 在Qt中&#xff0c;QProcess类是用于启动外部进程的…

大模型与知识图谱如何相互助力

目前各行各业在数字化、智能化发展的大势所趋下&#xff0c;信息新技术不断涌现&#xff0c;也在加快深入融合到传统实体行业应用中&#xff0c;比如知识图谱、人工智能、数字孪生等等&#xff0c;特别是基于人工智能的大模型在去年底被chatgpt的带领下涌现出一波又一波的浪潮&…

驱动开发1 概念、内核模块编程、内核消息打印函数printk函数的使用、内核模块传参、内核导出符号

1 驱动相关概念 2 内核模块编程 内核模块编写实例代码注释 #include <linux/init.h> #include <linux/module.h>//入口函数&#xff0c;安装内核模块时执行 static int __init mycdev_init(void) {//static 修饰当前函数只能在本文件使用//int 函数的返回值类型&a…

【Leetcode】【中等】1726.同积元组

力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能&#xff0c;轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/tuple-with-same-product/ 给你…

适用于 Mac 电脑的 10 款最佳数据恢复工具集

无论是个人照片还是重要的商业文档&#xff0c;对于那些依赖计算机获取重要文件的人来说&#xff0c;数据丢失都是一场噩梦。 值得庆幸的是&#xff0c;Mac用户可以使用各种数据恢复工具&#xff0c;可以帮助您恢复丢失或意外删除的文件。 在本文中&#xff0c;我们将采用适用于…

Arrays 中的 asList()方法

public static <T> List<T> asList&#xff08; T . . . a &#xff09;{ return new ArrayList<>&#xff08;a&#xff09;&#xff1b; } 返回由指定数组支持的固定大小的 list集合。对数组所做的更改将在返回的 l…

【USRP】通信之有线通信

有线通信&#xff1a; 有线通信是指使用物理线路或媒体&#xff08;例如&#xff0c;铜线、同轴电缆、光纤&#xff09;进行数据、声音和视频传输的通信方式。由于它依赖于实体传输媒介&#xff0c;有线通信通常具有较高的稳定性和可靠性&#xff0c;并能支持长距离的高带宽通…

input框输入中文时,输入未完成触发事件。Vue中文输入法不触发input事件?

前言 在做搜索输入框时&#xff0c;产品期待实时搜索&#xff0c;就是边输入边搜索&#xff0c;然而对于中文输入法出现的效果&#xff0c;不同的产品可能有不同的意见&#xff0c;有的觉得输入未完成也应该触发搜索。但有的却认为应该在中文输入完成后再触发搜索。我发现在vu…

Docker Swarm 集群搭建

Docker Swarm Mode Docker Swarm 集群搭建 Docker Swarm 节点维护 Docker Service 创建 1.准备主机 搭建一个 docker swarm 集群&#xff0c;包含 5 个 swarm 节点。这 5 个 swarm 节点的 IP 与暂 时的角色分配如下&#xff08;注意&#xff0c;搭建完成后会切换角色&#xff…

23年上半年上午题复习

敏捷方法 耦合 软件维护 消息 面向对象测试 面向对象设计原则 包图 原型模式 数据库三级模型 数据库函数依赖 哈夫曼树 左0右1 折半查找 画一个折半查找树&#xff0c;这个树只会往一个方向查找&#xff0c;一个节点不会同时出现左右子树&#xff0c;较小的作为左子树&#…

git将当前分支A强制推送远程分支pro上

前言 开发中基于线上分支pro创建了A分支&#xff0c;开发完成之后。又基于线上分支pro创建了B分支&#xff0c;都以此合并到测试分支&#xff0c;两个分支更改中都动用部分共同的文件&#xff0c;这就导致后续开发合并代码越来越乱&#xff0c;这时你想把本地开发的分支强推到…

数据库设计阶段-架构真题(五十七)

下面关于联合需求计划JRP叙述&#xff0c;不正确的是&#xff08;&#xff09;。 JRP是一种相对成本较高但十分有效的需求获取方法在讨论期间尽量避免使用专业术语JRP的主要目的是对需求进行分析和验证在JRP实施之前&#xff0c;应制定详细的议程&#xff0c;并严格遵照议程进…

力扣每日一题57:插入区间

题目描述&#xff1a; 给你一个 无重叠的 &#xff0c;按照区间起始端点排序的区间列表。 在列表中插入一个新的区间&#xff0c;你需要确保列表中的区间仍然有序且不重叠&#xff08;如果有必要的话&#xff0c;可以合并区间&#xff09;。 示例 1&#xff1a; 输入&#x…

通义大模型使用指南之通义千问

一、注册 我们可以打开以下网站&#xff0c;用手机号注册一个账号即可。 通义大模型 (aliyun.com) 二、使用介绍 如图&#xff0c;我们可以看到有三个大项功能&#xff0c;通义千问、通义万相、通义听悟。下来我们体验一下通义千问的功能。 1、通义千问 通义千问主要有两个功能…