代码解读 | Hybrid Transformers for Music Source Separation[06]

一、背景

        0、Hybrid Transformer 论文解读

        1、代码复现|Demucs Music Source Separation_demucs架构原理-CSDN博客

        2、Hybrid Transformer 各个模块对应的代码具体在工程的哪个地方

        3、Hybrid Transformer 各个模块的底层到底是个啥(初步感受)?

        4、Hybrid Transformer 各个模块处理后,数据的维度大小是咋变换的?

        5、Hybrid Transformer 拆解STFT模块

        6、Hybrid Transformer 拆解频域编码模块


        从模块上划分,Hybrid Transformer Demucs 共包含 (STFT模块、时域编码模块、频域编码模块、Cross-Domain Transformer Encoder模块、时域解码模块、频域解码模块ISTFT模块)7个模块。已完成解读:STFT模块、频域编码模块(时域编码和频域编码类似,后续不再解读时域编码模块),待解读:Cross-Domain Transformer Encoder模块。

        本篇目标:拆解频域解码模块ISTFT模块的底层。时域解码和频域解码原理类似(后续不再拆解时域解码模块)。

二、频域解码模块


class HDecLayer(nn.Module):def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,context_freq=True, rewrite=True):"""Same as HEncLayer but for decoder. See `HEncLayer` for documentation."""super().__init__()norm_fn = lambda d: nn.Identity()  # noqaif norm:norm_fn = lambda d: nn.GroupNorm(norm_groups, d)  # noqaif pad:pad = kernel_size // 4else:pad = 0self.pad = padself.last = lastself.freq = freqself.chin = chinself.empty = emptyself.stride = strideself.kernel_size = kernel_sizeself.norm = normself.context_freq = context_freqklass = nn.Conv1dklass_tr = nn.ConvTranspose1dif freq:kernel_size = [kernel_size, 1]stride = [stride, 1]klass = nn.Conv2dklass_tr = nn.ConvTranspose2dself.conv_tr = klass_tr(chin, chout, kernel_size, stride)self.norm2 = norm_fn(chout)if self.empty:returnself.rewrite = Noneif rewrite:if context_freq:self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)else:self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,[0, context])self.norm1 = norm_fn(2 * chin)self.dconv = Noneif dconv:self.dconv = DConv(chin, **dconv_kw)def forward(self, x, skip, length):if self.freq and x.dim() == 3:B, C, T = x.shapex = x.view(B, self.chin, -1, T)if not self.empty:x = x + skipif self.rewrite:y = F.glu(self.norm1(self.rewrite(x)), dim=1)else:y = xif self.dconv:if self.freq:B, C, Fr, T = y.shapey = y.permute(0, 2, 1, 3).reshape(-1, C, T)y = self.dconv(y)if self.freq:y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)else:y = xassert skip is Nonez = self.norm2(self.conv_tr(y))print('self.pad,self.last:', self.pad,self.last)if self.freq:if self.pad:z = z[..., self.pad:-self.pad, :]else:z = z[..., self.pad:self.pad + length]assert z.shape[-1] == length, (z.shape[-1], length)if not self.last:z = F.gelu(z)return z, y

        频域解码模块的核心代码如上所示。在上一篇频域编码模块的基础上,继续贴出完善之后的频域编解码模块全景图。

编码层:Conv2d+Norm1+GELU,  Norm1:Identity()

解码层:(Conv2d+Norm1+GLU)+(ConvTranspose2d+Norm2+倒数第二个维度裁剪+GELU),    Norm1\Norm2:Identity()

残差连接:(Conv1d+GroupNorm+GELU +Conv1d+GroupNorm+GLU+LayerScale())+(Conv2d+Norm2+GLU),Norm2:Identity() ,备注:Identity可以理解成直通

#频域编码层1-4的Conv2d分别是:
Conv2d(4, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(48, 96, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(96, 192, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(192, 384, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))#频域解码层4-1的Conv2d和ConvTranspose2d
Conv2d(384, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ConvTranspose2d(384, 192, kernel_size=(8, 1), stride=(4, 1)) 
Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ConvTranspose2d(192, 96, kernel_size=(8, 1), stride=(4, 1))
Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ConvTranspose2d(96, 48, kernel_size=(8, 1), stride=(4, 1))
Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
ConvTranspose2d(48, 16, kernel_size=(8, 1), stride=(4, 1))

        残差连接模块如下所示。

#残差连接1
DConv((layers): ModuleList((0): Sequential((0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 6, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 96, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 6, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 96, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))#残差连接2
DConv((layers): ModuleList((0): Sequential((0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 12, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 192, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 12, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 192, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1))#残差连接3
DConv((layers): ModuleList((0): Sequential((0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 24, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 384, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 24, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 384, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))#残差连接4
DConv((layers): ModuleList((0): Sequential((0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 48, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 768, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 48, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 768, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))

三、ISTFT模块

        ISTFT模块的核心代码如下所示。

import torch as th
def ispectro(z, hop_length=None, length=None, pad=0):*other, freqs, frames = z.shapen_fft = 2 * freqs - 2z = z.view(-1, freqs, frames)win_length = n_fft // (1 + pad)is_mps = z.device.type == 'mps'if is_mps:z = z.cpu()x = th.istft(z,n_fft,hop_length,window=th.hann_window(win_length).to(z.real),win_length=win_length,normalized=True,length=length,center=True)_, length = x.shapereturn x.view(*other, length)

        其中,torch.istft【逆短时傅里叶变换(Inverse Short Time Fourier Transform,ISTFT)】,该函数期望是torch.stft函数的逆过程。它具有相同的参数(加上一个可选参数length),并且应该返回原始信号的最小二乘估计。算法将根据NOLA条件(非零重叠)进行检查。

#### torch.istft接口参数####
input (Tensor): 输入张量,期望是`torch.stft`的输出,可以是复数形式(`channel`, `fft_size`, `n_frame`),或者是实数形式(`channel`, `fft_size`, `n_frame`, 2),其中`channel`维度是可选的。

       deprecated:: 1.8.0
            实数输入已废弃,请使用`stft(..., return_complex=True)`返回的复数输入代替。
n_fft (int): 傅里叶变换的大小。
hop_length (Optional[int]): 相邻滑动窗口帧之间的距离。(默认:`n_fft // 4`)
win_length (Optional[int]): 窗口帧和STFT滤波器的大小。(默认:`n_fft`)
window (Optional[torch.Tensor]): 可选的窗函数。(默认:`torch.ones(win_length)`)
center (bool): 指示输入是否在两边进行了填充,使得第`t`帧位于时间`t × hop_length`处居中。(默认:`True`)
normalized (bool): 指示STFT是否被标准化。(默认:`False`)
onesided (Optional[bool]): 指示STFT是否为单边谱。(默认:如果输入尺寸中的`n_fft != fft_size`则为`True`)
length (Optional[int]): 修剪信号的长度,即原始信号的长度。(默认:整个信号)
return_complex (Optional[bool]):指示输出是否应为复数,或者输入是否应假定源自实信号和窗函数。注意,这与`onesided=True`不兼容。(默认:`False`)

        频域解码模块和ISTFT模块解读完毕。还剩一个Cross-Domain Transformer Encoder模块没有解读。后面又来新的活了,希望能把demucs落地~。


        感谢阅读,最近开始写公众号(分享好用的AI工具),欢迎大家一起见证我的成长(桂圆学AI)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/27243.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

八爪鱼现金流-022-mybatis插件加密和国密SM4算法

背景: 用户的金额数据,不希望被别人看到。 业务场景分析: 用户在页面上添加金额数据 -----> 服务器内存(加密、解密) -----> 存储数据库 调研及结果: 使用mybatis的拦截器插件,进行数…

win11电脑桌面倒计时提醒怎么设置?

在日常工作中,我们经常需要处理大量的工作任务,而且很多任务都有时间限制。如果将这些任务记录在桌面上,并设置倒计时提醒,无疑会大大提高我们的工作效率。想象一下,在繁忙的工作间隙,你只需一瞥桌面&#…

文件简单做二维码的方法,几步就能够完成操作

怎样用二维码来查看文件内容?随着网络的快速发展,通过二维码来查看文件是现在很常用的一种形式,能够更快让其他人获取文件内容,从而提升传播的速度和效率。比如用这种方式来下发通知文件、分享学习资料、浏览海报图片、传递个人简…

ESP32基础应用之esp32连接腾讯云并使用微信小程序控制的智能灯

文章目录 1. 项目简介1.1 功能接收1.2 使用资源1.3 测试平台 2 腾讯云物联网开发平台3 esp32设备开发3.1 准备参考例程3.2 vscode平台创建测试工程3.3 修改工程 问题总结使用PowerShell命令行终端生成的二维码不能用 1. 项目简介 1.1 功能接收 实现腾讯云创建项目与设备&…

防止暴力破解,教你如何在登录失败后实施10分钟账户锁定策略!

最近,在服务器上发现了异常的登录尝试。尽管您的团队已经采取了强密码策略和其他安全措施来加固服务器,但恶意程序仍然通过暴力破解的方式试图多次尝试猜测正确的凭据以获取访问权限。为了增强系统的安全性,特别是防止此类暴力破解攻击&#…

文章分享 | Ribo-seq与RNA-seq联合分析揭示uAUG-ds翻译调控机制

技术简介 RNA-seq主要从转录组水平分析基因的表达调控机制,检测用于核糖体翻译的RNA序列及二级结构。Ribo-seq主要用于检测核糖体翻译的起始位置、翻译富集区和翻译终止位置。RNA-seq与Ribo-seq联合分析可以准确检测mRNA上游5’UTR区的uORFs翻译调控结构&#xff0…

SSM小区疫情防控系统-计算机毕业设计源码03748

摘 要 随着社会的发展,社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 小区疫情防控系统,主要的模块包括查看首页、轮播图(轮播图管理)、社区公告管理(社区公告&#…

Opengauss开源4年了,都谁在向其贡献代码?

2020 年 6 月 30 日,华为将Opengauss正式开源,截止目前已经过去4年时间,社区力量对这款数据库产品都起到了哪些作用,谁的代码贡献更大一些? 根据社区官网信息统计,截止目前(2024年6月12日&…

查看电子磁盘ssd空间信息并释放zfs空间@FreeBSD

发现问题 在某宝买了一块32G的ssd电子盘,但是在FreeBSD里面使用df看到的空间较少,只有15G,一度怀疑是发错货了。不过自己清楚的记得swap分区还分了4G,这样铁定是大于16G的,应该是32G没错。但是少掉的那部分空间跑哪里…

安装前端依赖node-sass报错

文章目录 问题1:node-sass报错问题2:node-gyp报错问题3:node-sass再次报错问题4:node-sass三次报错 问题1:node-sass报错 问题描述:经常会碰到一个新的项目安装依赖时,会报node-sass版本的问题…

揭秘裂变客户背后的心理学:如何触动用户分享欲望?

在当今的社交媒体时代,裂变客户——即用户主动分享并推广某一产品或服务,已成为企业营销的重要策略。那么,如何触动用户的分享欲望呢?这背后其实隐藏着深刻的心理学原理。本文将以looka这个知名的国外设计工具为例,为s…

Spring Cloud Stream 消息驱动基础入门与实践总结

Spring Cloud Stream是用于构建与共享消息传递系统连接的高度可伸缩的事件驱动微服务框架,该框架提供了一个灵活的编程模型,它建立在已经建立和熟悉的Spring熟语和最佳实践上,包括支持持久化的发布/订阅、消费组以及消息分区这三个核心概念。…

激活和禁用Hierarchy面板上的物体

1、准备工作: (1) 在HIerarchy上添加待隐藏/显示的物体,名字自取。如:endImage (2) 在Inspector面板,该物体的名称前取消勾选(隐藏) (3) 在HIerarchy上添加按钮,名字自取。如:tip…

【机器学习300问】117、序列模型中的符号表示方法?以命名实体识别(NER)任务为例。

在序列模型中,特别是在命名实体识别(NER)任务中,我们通常会用一系列符号来表示输入序列、目标标签以及模型的结构和操作。本文列出一些常见的符号表示方法,结合NER任务进行解释。 一、什么是命名实体识别任务? (1&am…

YUV格式与RGB格式详解

图像处理 文章目录 图像处理前言YUV 格式YUV 采样 前言 像素格式描述了像素数据存储所用的格式,定义了像素在内存中的编码方式。RGB 和 YUV 为两种经常使用的像素格式。/ 1024 / 1024 2.63 MB 存储空间。 RGB 和 RGBA 格式 RGB 图像具有三个通道 R、G、B&#xff…

HyperBDR新版本上线,自动化容灾兼容再升级!

本次HyperBDR v5.5.0版本新增完成HCS(Huawei Cloud Stack)8.3.x和HCSO(Huawei Cloud Stack Online)自动化对接,另外还突破性完成了Oracle云(块存储模式)的自动化对接。 HyperBDR,云原生业务级别容灾工具。支…

确定性网络_v0

目录 一、背景二、技术参考文献 一、背景 确定性网络(Deterministic Networking)是提供确定性服务质量的网络技术,是在以太网的基础上为多种业务提供端到端确定性服务质量保障的一种新技术。通过对网络数据转发行为的控制,将时延…

【渗透测试】|dvwa命令注入乱码问题

法一: 解决方法如下: 1、按住winr,在运行框中输入cmd弹出命令行,在命令行中输入“control intl.cpl” 2、这个命令是使用control命令行工具来打开"区域和语言设置"对话框 3、选中对话框中的管理选项卡 4、可以看到这里…

linux 安装 Nginx 并部署 vue 项目

1、安装 yum install nginx2、使用 nginx 命令 查看nginx状态 systemctl status nginx 启动服务 systemctl start nginx停止服务 systemctl stop nginx重启服务 systemctl restart nginx修改配置后重载 systemctl reload nginx 加入开机自启动 systemctl enable ngin…

企业应该先上ERP系统还是先实施MES管理系统

在当今日益激烈的市场竞争中,企业信息化已成为提升竞争力的关键。ERP系统与MES管理系统作为企业信息化建设的两大核心系统,各自扮演着不可或缺的角色。然而,在资源有限的情况下,企业往往需要在两者之间做出选择。本文将深入探讨ER…