视频去噪网络BSVD的实现

前些天写了视频去噪网络BSVD论文的理解,详情请点击这里,这两个星期动手实践了一下,本篇就来记录一下这个模型的实现。

这个网络的独特之处在于,它的训练和推理在实现上有所差别。在训练阶段,其使用了TSM(Time Shift Module)结构,而在推理时则使用了BBB(Bidirectional Buffer Block)结构。训练时,网络是一个MIMO(多输入多输出)形式,而在推理时,则将其设计成了单输入、单输出的流式形式。推理时,由于网络中存在16个双向buffer,即BBB,因此,前16帧会输出空数据,16帧之后开始正常输出去噪视频帧,到视频序列结束后,还会继续输出16帧的去噪视频帧,也就是,流式推理整体存在16帧的延迟。这在一些对实时性要求不太高的应用中可以推广,但对于实时性要求严格,并且存储资源有限的应用中,就无法有效应用了。

下面,我们就通过对官方代码的理解,来聊一聊BSVD的实现。

官方代码地址:GitHub - ChenyangQiQi/BSVD: [ACM MM 2022] Real-time Streaming Video Denoising with Bidirectional Buffers

BSVD网络采用了两个UNet级联的方式。

1. 训练阶段的网络实现

在训练阶段,网络的实现如下:

class WNet(nn.Module):def __init__(self, chns=[32, 64, 128], mid_ch=3, shift_input=False, stage_num=2, in_ch=4, out_ch=3, norm='bn', act='relu', interm_ch=30, blind=False):# def __init__(self, chns=[32, 64, 128], mid_ch=3, shift_input=False, stage_num=2, in_ch=4, out_ch=3, norm='bn', act='relu', blind=False):super(WNet, self).__init__()self.stage_num = stage_numself.nets_list = nn.ModuleList()for i in np.arange(stage_num):if i == 0:stage_in_ch = in_chelse:stage_in_ch = mid_chif i == (stage_num-1):stage_out_ch = out_chelse:stage_out_ch = mid_ch# self.nets_list.append(DenBlock(chns=chns, out_ch=stage_out_ch, in_ch=stage_in_ch, shift_input=shift_input, norm=norm, act=act, interm_ch=interm_ch))if i == 0:self.nets_list.append(DenBlock(chns=chns, out_ch=stage_out_ch, in_ch=stage_in_ch, shift_input=shift_input, norm=norm, act=act, blind=blind, interm_ch=interm_ch))else:self.nets_list.append(DenBlock(chns=chns, out_ch=stage_out_ch,in_ch=stage_in_ch, shift_input=shift_input, norm=norm, act=act, interm_ch=interm_ch))# self.temp2 = DenBlock(chns=chns, in_ch=mid_ch, shift_input=shift_input)# Init weightsself.reset_params()@staticmethoddef weight_init(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, nonlinearity='relu')def reset_params(self):for _, m in enumerate(self.modules()):self.weight_init(m)def forward(self, x, debug=False):# if debug: x_in = x# x = self.temp1(x)for i in np.arange(self.stage_num):if debug: x_temp1 = xx = self.nets_list[i](x)# if debug: x_temp2 = xreturn x

网络由两个DenBlock组成,每个DenBlock是一个UNet结构:


class DenBlock(nn.Module):""" Definition of the denosing block of FastDVDnet.Inputs of constructor:num_input_frames: int. number of input framesInputs of forward():xn: input frames of dim [N, C, H, W], (C=3 RGB)noise_map: array with noise map of dim [N, 1, H, W]"""def __init__(self, chns=[32, 64, 128], out_ch=3, in_ch=4, shift_input=False, norm='bn', bias=True,  act='relu', interm_ch=30, blind=False):# def __init__(self, chns=[32, 64, 128], out_ch=3, in_ch=4, shift_input=False, norm='bn', bias=True,  act='relu', blind=False):super(DenBlock, self).__init__()self.chs_lyr0, self.chs_lyr1, self.chs_lyr2 = chns# if stage2: in_ch=3if shift_input:self.inc = CvBlock(in_ch=in_ch, out_ch=self.chs_lyr0, norm=norm, bias=bias, act=act)else:self.inc = InputCvBlock(num_in_frames=1, out_ch=self.chs_lyr0, in_ch=in_ch, norm=norm, bias=bias, act=act, interm_ch=interm_ch, blind=blind)# num_in_frames=1, out_ch=self.chs_lyr0, in_ch=in_ch, norm=norm, bias=bias, act=act, blind=blind)self.downc0 = DownBlock(in_ch=self.chs_lyr0, out_ch=self.chs_lyr1, norm=norm, bias=bias, act=act)self.downc1 = DownBlock(in_ch=self.chs_lyr1, out_ch=self.chs_lyr2, norm=norm, bias=bias, act=act)self.upc2 = UpBlock(in_ch=self.chs_lyr2, out_ch=self.chs_lyr1, norm=norm, bias=bias,    act=act)self.upc1 = UpBlock(in_ch=self.chs_lyr1, out_ch=self.chs_lyr0, norm=norm, bias=bias,    act=act)self.outc = OutputCvBlock(in_ch=self.chs_lyr0, out_ch=out_ch, norm=norm, bias=bias,     act=act)self.reset_params()@staticmethoddef weight_init(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, nonlinearity='relu')def reset_params(self):for _, m in enumerate(self.modules()):self.weight_init(m)def forward(self, in1):'''Args:inX: Tensor, [N, C, H, W] in the [0., 1.] rangenoise_map: Tensor [N, 1, H, W] in the [0., 1.] range'''# Input convolution blockx0 = self.inc(in1)# Downsamplingx1 = self.downc0(x0)x2 = self.downc1(x1)# Upsamplingx2 = self.upc2(x2)x1 = self.upc1(x1+x2)# Estimationx = self.outc(x0+x1)# Residualx[:, :3, :, :] = in1[:, :3, :, :] - x[:, :3, :, :]return x

这段代码与论文中的UNet结构相对应(见下图),包含一个输入层,两个下采样层,两个上采样层,一个输出层。

输入层没什么特别可说的,主要是两个Conv2d=>BN=>ReLU的组合;输出层也是常规实现,Con2d=>BN=>ReLU=>Con2d,需要注意的是,作者在实现过程中,BN层是没有使用的,是透传通过。

需要花心思理解的是下采样层和上采样层的实现,因为这两个模块在训练和推理过程中,是有所不同的。

两个模块的初始实现很简单,定义如下:

class DownBlock(nn.Module):'''Downscale + (Conv2d => BN => ReLU)*2'''def __init__(self, in_ch, out_ch, norm='bn', bias=True, act='relu'):super(DownBlock, self).__init__()norm_fn = get_norm_function(norm)act_fn = get_act_function(act)self.convblock = nn.Sequential(nn.Conv2d(in_ch, out_ch, kernel_size=3,padding=1, stride=2, bias=bias),norm_fn(out_ch),act_fn(inplace=True),CvBlock(out_ch, out_ch, norm=norm, bias=bias, act=act))def forward(self, x):return self.convblock(x)class UpBlock(nn.Module):'''(Conv2d => BN => ReLU)*2 + Upscale'''def __init__(self, in_ch, out_ch, norm='bn', bias=True, act='relu'):super(UpBlock, self).__init__()# norm_fn = get_norm_function(norm)self.convblock = nn.Sequential(CvBlock(in_ch, in_ch, norm=norm, bias=bias, act=act),nn.Conv2d(in_ch, out_ch*4, kernel_size=3, padding=1, bias=bias),nn.PixelShuffle(2))return self.convblock(x)

关键在于两者共同调用的子模块CvBlock的实现,在定义时,CvBlock被常规定义为:

class CvBlock(nn.Module):'''(Conv2d => BN => ReLU) x 2'''def __init__(self, in_ch, out_ch, norm='bn', bias=True, act='relu'):super(CvBlock, self).__init__()norm_fn = get_norm_function(norm)act_fn = get_act_function(act)self.c1 = nn.Conv2d(in_ch, out_ch, kernel_size=3,padding=1, bias=bias)self.b1 = norm_fn(out_ch)self.relu1 = act_fn(inplace=True)self.c2 = nn.Conv2d(out_ch, out_ch, kernel_size=3,padding=1, bias=bias)self.b2 = norm_fn(out_ch)self.relu2 = act_fn(inplace=True)def forward(self, x):x = self.c1(x)x = self.b1(x)x = self.relu1(x)x = self.c2(x)x = self.b2(x)x = self.relu2(x)return x

但接下来,上述定义中的c1和c2则被替换成了TSM实现:

其中,shift模块的核心实现代码如下,对输入的channels分别向左和向右移动了一定单位(fold)。

def shift(x, n_segment, shift_type, fold_div=3, stride=1, inplace=False):nt, c, h, w = x.size()n_batch = nt // n_segmentx = x.view(n_batch, n_segment, c, h, w)fold = c // fold_div # 32/8 = 4if inplace:# Due to some out of order error when performing parallel computing. # May need to write a CUDA kernel.print("WARNING: use inplace shift. it has bugs")raise NotImplementedError  else:out = torch.zeros_like(x)if not 'toFutureOnly' in shift_type:out[:, :-stride, :fold] = x[:, stride:, :fold]  # backward (left shift)out[:, stride:, fold: 2 * fold] = x[:, :-stride, fold: 2 * fold]  # forward (right shift)else:out[:, stride:, : 2 * fold] = x[:, :-stride, : 2 * fold] # right shift onlyout[:, :, 2 * fold:] = x[:, :, 2 * fold:]  # not shiftreturn out.view(nt, c, h, w)

2. 推理阶段的网络实现

在推理阶段,网络实现就显得复杂一些了。大致的网络结构没变,但由于内部的TSM替换成了BBB, 因此没办法严格进行整体网络的加载,只能每一层单独加载训练出来的state_dict。并且,网络推理变成了流式推理,整个网络的定义显得比较凌乱,结构如下:

class BSVD(nn.Module):"""Bidirection-buffer based framework with pipeline-style inference"""def __init__(self, chns=[32, 64, 128], mid_ch=3, shift_input=False, in_ch=4, out_ch=3, norm='bn', act='relu', interm_ch=30, blind=False, pretrain_ckpt='./experiments/pretrained_ckpt/bsvd-64.pth'):super(BSVD, self).__init__()self.temp1 = DenBlock(chns=chns, out_ch=mid_ch, in_ch=in_ch,  shift_input=shift_input, norm=norm, act=act, blind=blind, interm_ch=interm_ch)self.temp2 = DenBlock(chns=chns, out_ch=out_ch, in_ch=mid_ch, shift_input=shift_input, norm=norm, act=act, blind=blind, interm_ch=interm_ch)self.shift_num = self.count_shift()# Init weightsself.reset_params()if pretrain_ckpt is not None:self.load(pretrain_ckpt)def reset(self):self.temp1.reset()self.temp2.reset()def load(self, path):ckpt = torch.load(path)print("load from %s"%path)ckpt_state = ckpt['params']# split the dict hereif 'module' in list(ckpt_state.keys())[0]:base_name = 'module.base_model.'else:base_name = 'base_model.'ckpt_state_1 = extract_dict(ckpt_state, string_name=base_name+'nets_list.0.')ckpt_state_2 = extract_dict(ckpt_state, string_name=base_name+'nets_list.1.')self.temp1.load_from(ckpt_state_1)self.temp2.load_from(ckpt_state_2)@staticmethoddef weight_init(m):if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, nonlinearity='relu')def reset_params(self):for _, m in enumerate(self.modules()):self.weight_init(m)def feedin_one_element(self, x):x   = self.temp1(x)x   = self.temp2(x)return xdef forward(self, input, noise_map=None):# N, F, C, H, W -> (N*F, C, H, W)if noise_map != None:input = torch.cat([input, noise_map], dim=2)N, F, C, H, W = input.shapeinput = input.reshape(N*F, C, H, W)base_out = self.streaming_forward(input)NF, C, H, W = base_out.shapebase_out = base_out.reshape(N, F, C, H, W)return base_outdef streaming_forward(self, input_seq):"""pipeline-style inferenceArgs:Noisy video streamReturns:Denoised video stream"""out_seq = []if isinstance(input_seq, torch.Tensor):n,c,h,w = input_seq.shapeinput_seq = [input_seq[i:i+1, ...] for i in np.arange(n)]assert type(input_seq) == list, "convert the input into a sequence"_,c,h,w = input_seq[0].shapewith torch.no_grad():for i, x in enumerate(input_seq):x_cuda = x.cuda()x_cuda = self.feedin_one_element(x_cuda)# if x_cuda is not None: x_cuda = x_cuda.cpu()if isinstance(x_cuda, torch.Tensor):out_seq.append(x_cuda)else:out_seq.append(x_cuda)end_out = self.feedin_one_element(None)out_seq.append(end_out)# end stagewhile 1:end_out = self.feedin_one_element(None)if len(out_seq) == (self.shift_num+len(input_seq)):breakout_seq.append(end_out)# number of temporal shift is 2, last element is 0# TODO fix init and end framesout_seq_clip = out_seq[self.shift_num:]self.reset()return torch.cat(out_seq_clip, dim=0)def count_shift(self):count = 0for name, module in self.named_modules():# print(type(module))if "BiBufferConv" in str(type(module)):count+=1return count

两个UNet的定义(DenBlock)大体上没发生变化,但下采样模块和上采样模块的定义发生了改变。

下采样层如下,原来带有TSM的CvBlock换成了MemCvBlock:

上采样模块也类似:

 

而MemCvBlock则调用了BBB模块,BBB模块的实现如下,这是整个算法的核心:

class BiBufferConv(nn.Module):def __init__(self,in_channels,out_channels,kernel_size,stride=1,padding=0,bias=True) -> None:super(BiBufferConv, self).__init__()self.op = ShiftConv(in_channels,out_channels,kernel_size,stride,padding,bias)self.out_channels = out_channelsself.left_fold_2fold = None# self.zero_tensor = Noneself.center = Nonedef reset(self):self.left_fold_2fold = Noneself.center = Nonedef forward(self, input_right, verbose=False):fold_div = 8if input_right is not None:self.n, self.c, self.h, self.w = input_right.size()self.fold = self.c//fold_div# Case1: In the start or end stage, the memory is emptyif self.center is None:self.center = input_right# if verbose:if input_right is not None:if self.left_fold_2fold is None:# In the start stage, the memory and left tensor is emptyself.left_fold_2fold = torch.zeros((self.n, self.fold, self.h, self.w), device=torch.device('cuda'))if verbose: print("%f+none+%f = none"%(torch.mean(self.left_fold_2fold), torch.mean(input_right)))else:# in the end stage, both feed in and memory are emptyif verbose: print("%f+none+none = none"%(torch.mean(self.left_fold_2fold)))# print("self.center is None")return None# Case2: Center is not None, but input_right is Noneelif input_right is None:# In the last procesing stage, center is 0output =  self.op(self.left_fold_2fold, self.center, torch.zeros((self.n, self.fold, self.h, self.w), device=torch.device('cuda')))if verbose: print("%f+%f+none = %f"%(torch.mean(self.left_fold_2fold), torch.mean(self.center), torch.mean(output)))else:output =  self.op(self.left_fold_2fold, self.center, input_right)if verbose: print("%f+%f+%f = %f"%(torch.mean(self.left_fold_2fold), torch.mean(self.center), torch.mean(input_right), torch.mean(output)))# if output == 57:# a = 1self.left_fold_2fold = self.center[:, self.fold:2*self.fold, :, :]self.center = input_rightreturn output

这样,通过BBB模块,就实现了16个双向Buffer的填充、更新和清空。

限于篇幅,先梳理出个大体的思路,实际上还有很多细节需要特别关注,留待下一篇来写吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/116068.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于斑马优化的BP神经网络(分类应用) - 附代码

基于斑马优化的BP神经网络(分类应用) - 附代码 文章目录 基于斑马优化的BP神经网络(分类应用) - 附代码1.鸢尾花iris数据介绍2.数据集整理3.斑马优化BP神经网络3.1 BP神经网络参数设置3.2 斑马算法应用 4.测试结果:5.M…

虚拟机安装centos系统后配置桥接网络

一.桥接网络和nat网络的区别 桥接模式 通过使用物理机网卡 具有单独ip,但是需要手动配置。 在bridged模式下,VMWare虚拟出来的操作系统就像是局域网中的一台独立的主机,它可以访问网内任何一台机器。主机网卡和虚拟网卡的IP地址处于同一个网段&#xff…

Mybatis的SqlRunner执行流程

Mybatis的SqlRunner执行流程 SqlRunner exec new SqlRunner(connection); Map<String, Object> row exec.selectOne("SELECT * FROM PRODUCT WHERE PRODUCTID ?", "FI-SW-01");connection.close();assertEquals("FI-SW-01", row.ge…

【QT开发(10)】QT 进程

文章目录 1.1 运行一个新进程1.2 QProcess 还可以对一些信号进行关联2 进程间通信2.1 使用共享内存实现进程通信2.2 演示 代码仓库参考 1.1 运行一个新进程 使用类 QProcess&#xff0c;允许将一个进程堪称一个顺序IO设备。 在Qt中&#xff0c;QProcess类是用于启动外部进程的…

大模型与知识图谱如何相互助力

目前各行各业在数字化、智能化发展的大势所趋下&#xff0c;信息新技术不断涌现&#xff0c;也在加快深入融合到传统实体行业应用中&#xff0c;比如知识图谱、人工智能、数字孪生等等&#xff0c;特别是基于人工智能的大模型在去年底被chatgpt的带领下涌现出一波又一波的浪潮&…

驱动开发1 概念、内核模块编程、内核消息打印函数printk函数的使用、内核模块传参、内核导出符号

1 驱动相关概念 2 内核模块编程 内核模块编写实例代码注释 #include <linux/init.h> #include <linux/module.h>//入口函数&#xff0c;安装内核模块时执行 static int __init mycdev_init(void) {//static 修饰当前函数只能在本文件使用//int 函数的返回值类型&a…

【Leetcode】【中等】1726.同积元组

力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能&#xff0c;轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/tuple-with-same-product/ 给你…

适用于 Mac 电脑的 10 款最佳数据恢复工具集

无论是个人照片还是重要的商业文档&#xff0c;对于那些依赖计算机获取重要文件的人来说&#xff0c;数据丢失都是一场噩梦。 值得庆幸的是&#xff0c;Mac用户可以使用各种数据恢复工具&#xff0c;可以帮助您恢复丢失或意外删除的文件。 在本文中&#xff0c;我们将采用适用于…

input框输入中文时,输入未完成触发事件。Vue中文输入法不触发input事件?

前言 在做搜索输入框时&#xff0c;产品期待实时搜索&#xff0c;就是边输入边搜索&#xff0c;然而对于中文输入法出现的效果&#xff0c;不同的产品可能有不同的意见&#xff0c;有的觉得输入未完成也应该触发搜索。但有的却认为应该在中文输入完成后再触发搜索。我发现在vu…

Docker Swarm 集群搭建

Docker Swarm Mode Docker Swarm 集群搭建 Docker Swarm 节点维护 Docker Service 创建 1.准备主机 搭建一个 docker swarm 集群&#xff0c;包含 5 个 swarm 节点。这 5 个 swarm 节点的 IP 与暂 时的角色分配如下&#xff08;注意&#xff0c;搭建完成后会切换角色&#xff…

23年上半年上午题复习

敏捷方法 耦合 软件维护 消息 面向对象测试 面向对象设计原则 包图 原型模式 数据库三级模型 数据库函数依赖 哈夫曼树 左0右1 折半查找 画一个折半查找树&#xff0c;这个树只会往一个方向查找&#xff0c;一个节点不会同时出现左右子树&#xff0c;较小的作为左子树&#…

通义大模型使用指南之通义千问

一、注册 我们可以打开以下网站&#xff0c;用手机号注册一个账号即可。 通义大模型 (aliyun.com) 二、使用介绍 如图&#xff0c;我们可以看到有三个大项功能&#xff0c;通义千问、通义万相、通义听悟。下来我们体验一下通义千问的功能。 1、通义千问 通义千问主要有两个功能…

如何使用VSCode将iPad Pro转化为功能强大的开发工具?

文章目录 前言1. 本地环境配置2. 内网穿透2.1 安装cpolar内网穿透(支持一键自动安装脚本)2.2 创建HTTP隧道 3. 测试远程访问4. 配置固定二级子域名4.1 保留二级子域名4.2 配置二级子域名 5. 测试使用固定二级子域名远程访问6. iPad通过软件远程vscode6.1 创建TCP隧道 7. ipad远…

# 开发趋势 Java Lambda 表达式 第三篇

开发趋势 Java Lambda 表达式 第三篇 一&#xff0c;Lambda 整合集合常规操作 List Java Lambda 表达式可以与List集合和常规操作进行整合&#xff0c;以提供一种更简洁、更可读的代码编写方式。以下是几个示例&#xff1a; 集合遍历操作&#xff1a; List<String> n…

使用IO流完成项目实战水果库存系统

以下内容本人都是在 Maven 工程下总结的 需求介绍显示主菜单让程序无线运行下去加载数据显示库存列表根据名称查找特定库存记录添加库存记录查看_下架_退出功能实现持久化数据 package com.csdn.fruit.pojo; import lombok.AllArgsConstructor; import lombok.Data; import lom…

Http长连接同一个socket多个请求和响应如何保证一一对应?

HTTP/2引入二进制数据帧和流的概念&#xff0c;其中帧对数据进行顺序标识&#xff0c;如下图所示&#xff0c;这样浏览器收到数据之后&#xff0c;就可以按照序列对数据进行合并&#xff0c;而不会出现合并后数据错乱的情况。同样是因为有了序列&#xff0c;服务器就可以并行的…

使用Packstack安装器安装一体化OpenStack云平台

【实训目的】 初步掌握OpenStack快捷安装的方法。掌握OpenStack图形界面的基本操作。 【实训准备】 &#xff08;1&#xff09;准备一台能够安装OpenStack的实验用计算机&#xff0c;建议使用VMware虚拟机。 &#xff08;2&#xff09;该计算机应安装CentOS 7&#xff0c;建…

【微服务】Feign 整合 Sentinel,深入探索 Sentinel 的隔离和熔断降级规则,以及授权规则和自定义异常返回结果

文章目录 前言一、Feign 整合 Sentinel1.1 实现步骤1.2 FallbackFactory 示例 二、Sentinel 实现隔离2.1 隔离的实现方法2.2 Sentinel 实现线程隔离示例 三、熔断降级规则3.1 熔断降级原理及其流程3.2 熔断策略 —— 慢调用3.3 熔断策略 —— 异常比例和异常数 四、授权规则4.1…

数据结构中的七大排序(Java实现)

目录 一、直接插入排序 二、希尔排序 三、直接选择排序 四、堆排序 五、冒泡排序 六、快速排序 七、归并排序 一、直接插入排序 思想&#xff1a; 定义i下标之前的元素全部已经有序&#xff0c;遍历一遍要排序的数组&#xff0c;把i下标前的元素全部进行排序&#xff0…

ArGIS Engine专题(15)之GP模型在地图服务与地图服务之间实现叠置分析

前一篇文章已经介绍过导入要素范围与地图服务的叠加分析,相当于单要素与多要素之间的分析,这篇文章介绍地图服务与地图服务之间的叠加分析,即是多要素有多要素之间的相交分析,功能基本类似。 一、结果预览 二、需求简介 以下是一些常见的业务场景: (1)空间规划和土地…