代码解读 | Hybrid Transformers for Music Source Separation[06]

一、背景

        0、Hybrid Transformer 论文解读

        1、代码复现|Demucs Music Source Separation_demucs架构原理-CSDN博客

        2、Hybrid Transformer 各个模块对应的代码具体在工程的哪个地方

        3、Hybrid Transformer 各个模块的底层到底是个啥(初步感受)?

        4、Hybrid Transformer 各个模块处理后,数据的维度大小是咋变换的?

        5、Hybrid Transformer 拆解STFT模块

        6、Hybrid Transformer 拆解频域编码模块


        从模块上划分,Hybrid Transformer Demucs 共包含 (STFT模块、时域编码模块、频域编码模块、Cross-Domain Transformer Encoder模块、时域解码模块、频域解码模块ISTFT模块)7个模块。已完成解读:STFT模块、频域编码模块(时域编码和频域编码类似,后续不再解读时域编码模块),待解读:Cross-Domain Transformer Encoder模块。

        本篇目标:拆解频域解码模块ISTFT模块的底层。时域解码和频域解码原理类似(后续不再拆解时域解码模块)。

二、频域解码模块


class HDecLayer(nn.Module):def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True,context_freq=True, rewrite=True):"""Same as HEncLayer but for decoder. See `HEncLayer` for documentation."""super().__init__()norm_fn = lambda d: nn.Identity()  # noqaif norm:norm_fn = lambda d: nn.GroupNorm(norm_groups, d)  # noqaif pad:pad = kernel_size // 4else:pad = 0self.pad = padself.last = lastself.freq = freqself.chin = chinself.empty = emptyself.stride = strideself.kernel_size = kernel_sizeself.norm = normself.context_freq = context_freqklass = nn.Conv1dklass_tr = nn.ConvTranspose1dif freq:kernel_size = [kernel_size, 1]stride = [stride, 1]klass = nn.Conv2dklass_tr = nn.ConvTranspose2dself.conv_tr = klass_tr(chin, chout, kernel_size, stride)self.norm2 = norm_fn(chout)if self.empty:returnself.rewrite = Noneif rewrite:if context_freq:self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)else:self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,[0, context])self.norm1 = norm_fn(2 * chin)self.dconv = Noneif dconv:self.dconv = DConv(chin, **dconv_kw)def forward(self, x, skip, length):if self.freq and x.dim() == 3:B, C, T = x.shapex = x.view(B, self.chin, -1, T)if not self.empty:x = x + skipif self.rewrite:y = F.glu(self.norm1(self.rewrite(x)), dim=1)else:y = xif self.dconv:if self.freq:B, C, Fr, T = y.shapey = y.permute(0, 2, 1, 3).reshape(-1, C, T)y = self.dconv(y)if self.freq:y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)else:y = xassert skip is Nonez = self.norm2(self.conv_tr(y))print('self.pad,self.last:', self.pad,self.last)if self.freq:if self.pad:z = z[..., self.pad:-self.pad, :]else:z = z[..., self.pad:self.pad + length]assert z.shape[-1] == length, (z.shape[-1], length)if not self.last:z = F.gelu(z)return z, y

        频域解码模块的核心代码如上所示。在上一篇频域编码模块的基础上,继续贴出完善之后的频域编解码模块全景图。

编码层:Conv2d+Norm1+GELU,  Norm1:Identity()

解码层:(Conv2d+Norm1+GLU)+(ConvTranspose2d+Norm2+倒数第二个维度裁剪+GELU),    Norm1\Norm2:Identity()

残差连接:(Conv1d+GroupNorm+GELU +Conv1d+GroupNorm+GLU+LayerScale())+(Conv2d+Norm2+GLU),Norm2:Identity() ,备注:Identity可以理解成直通

#频域编码层1-4的Conv2d分别是:
Conv2d(4, 48, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(48, 96, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(96, 192, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))
Conv2d(192, 384, kernel_size=(8, 1), stride=(4, 1), padding=(2, 0))#频域解码层4-1的Conv2d和ConvTranspose2d
Conv2d(384, 768, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ConvTranspose2d(384, 192, kernel_size=(8, 1), stride=(4, 1)) 
Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ConvTranspose2d(192, 96, kernel_size=(8, 1), stride=(4, 1))
Conv2d(96, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
ConvTranspose2d(96, 48, kernel_size=(8, 1), stride=(4, 1))
Conv2d(48, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 
ConvTranspose2d(48, 16, kernel_size=(8, 1), stride=(4, 1))

        残差连接模块如下所示。

#残差连接1
DConv((layers): ModuleList((0): Sequential((0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 6, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 96, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(48, 6, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 6, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(6, 96, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 96, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(48, 96, kernel_size=(1, 1), stride=(1, 1))#残差连接2
DConv((layers): ModuleList((0): Sequential((0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 12, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 192, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(96, 12, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 12, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(12, 192, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 192, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(96, 192, kernel_size=(1, 1), stride=(1, 1))#残差连接3
DConv((layers): ModuleList((0): Sequential((0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 24, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 384, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(192, 24, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 24, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(24, 384, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 384, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(192, 384, kernel_size=(1, 1), stride=(1, 1))#残差连接4
DConv((layers): ModuleList((0): Sequential((0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(1,))(1): GroupNorm(1, 48, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 768, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale())(1): Sequential((0): Conv1d(384, 48, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))(1): GroupNorm(1, 48, eps=1e-05, affine=True)(2): GELU(approximate=none)(3): Conv1d(48, 768, kernel_size=(1,), stride=(1,))(4): GroupNorm(1, 768, eps=1e-05, affine=True)(5): GLU(dim=1)(6): LayerScale()))
)
Conv2d(384, 768, kernel_size=(1, 1), stride=(1, 1))

三、ISTFT模块

        ISTFT模块的核心代码如下所示。

import torch as th
def ispectro(z, hop_length=None, length=None, pad=0):*other, freqs, frames = z.shapen_fft = 2 * freqs - 2z = z.view(-1, freqs, frames)win_length = n_fft // (1 + pad)is_mps = z.device.type == 'mps'if is_mps:z = z.cpu()x = th.istft(z,n_fft,hop_length,window=th.hann_window(win_length).to(z.real),win_length=win_length,normalized=True,length=length,center=True)_, length = x.shapereturn x.view(*other, length)

        其中,torch.istft【逆短时傅里叶变换(Inverse Short Time Fourier Transform,ISTFT)】,该函数期望是torch.stft函数的逆过程。它具有相同的参数(加上一个可选参数length),并且应该返回原始信号的最小二乘估计。算法将根据NOLA条件(非零重叠)进行检查。

#### torch.istft接口参数####
input (Tensor): 输入张量,期望是`torch.stft`的输出,可以是复数形式(`channel`, `fft_size`, `n_frame`),或者是实数形式(`channel`, `fft_size`, `n_frame`, 2),其中`channel`维度是可选的。

       deprecated:: 1.8.0
            实数输入已废弃,请使用`stft(..., return_complex=True)`返回的复数输入代替。
n_fft (int): 傅里叶变换的大小。
hop_length (Optional[int]): 相邻滑动窗口帧之间的距离。(默认:`n_fft // 4`)
win_length (Optional[int]): 窗口帧和STFT滤波器的大小。(默认:`n_fft`)
window (Optional[torch.Tensor]): 可选的窗函数。(默认:`torch.ones(win_length)`)
center (bool): 指示输入是否在两边进行了填充,使得第`t`帧位于时间`t × hop_length`处居中。(默认:`True`)
normalized (bool): 指示STFT是否被标准化。(默认:`False`)
onesided (Optional[bool]): 指示STFT是否为单边谱。(默认:如果输入尺寸中的`n_fft != fft_size`则为`True`)
length (Optional[int]): 修剪信号的长度,即原始信号的长度。(默认:整个信号)
return_complex (Optional[bool]):指示输出是否应为复数,或者输入是否应假定源自实信号和窗函数。注意,这与`onesided=True`不兼容。(默认:`False`)

        频域解码模块和ISTFT模块解读完毕。还剩一个Cross-Domain Transformer Encoder模块没有解读。后面又来新的活了,希望能把demucs落地~。


        感谢阅读,最近开始写公众号(分享好用的AI工具),欢迎大家一起见证我的成长(桂圆学AI)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/diannao/27243.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

八爪鱼现金流-022-mybatis插件加密和国密SM4算法

背景: 用户的金额数据,不希望被别人看到。 业务场景分析: 用户在页面上添加金额数据 -----> 服务器内存(加密、解密) -----> 存储数据库 调研及结果: 使用mybatis的拦截器插件,进行数…

win11电脑桌面倒计时提醒怎么设置?

在日常工作中,我们经常需要处理大量的工作任务,而且很多任务都有时间限制。如果将这些任务记录在桌面上,并设置倒计时提醒,无疑会大大提高我们的工作效率。想象一下,在繁忙的工作间隙,你只需一瞥桌面&#…

文件简单做二维码的方法,几步就能够完成操作

怎样用二维码来查看文件内容?随着网络的快速发展,通过二维码来查看文件是现在很常用的一种形式,能够更快让其他人获取文件内容,从而提升传播的速度和效率。比如用这种方式来下发通知文件、分享学习资料、浏览海报图片、传递个人简…

ESP32基础应用之esp32连接腾讯云并使用微信小程序控制的智能灯

文章目录 1. 项目简介1.1 功能接收1.2 使用资源1.3 测试平台 2 腾讯云物联网开发平台3 esp32设备开发3.1 准备参考例程3.2 vscode平台创建测试工程3.3 修改工程 问题总结使用PowerShell命令行终端生成的二维码不能用 1. 项目简介 1.1 功能接收 实现腾讯云创建项目与设备&…

防止暴力破解,教你如何在登录失败后实施10分钟账户锁定策略!

最近,在服务器上发现了异常的登录尝试。尽管您的团队已经采取了强密码策略和其他安全措施来加固服务器,但恶意程序仍然通过暴力破解的方式试图多次尝试猜测正确的凭据以获取访问权限。为了增强系统的安全性,特别是防止此类暴力破解攻击&#…

文章分享 | Ribo-seq与RNA-seq联合分析揭示uAUG-ds翻译调控机制

技术简介 RNA-seq主要从转录组水平分析基因的表达调控机制,检测用于核糖体翻译的RNA序列及二级结构。Ribo-seq主要用于检测核糖体翻译的起始位置、翻译富集区和翻译终止位置。RNA-seq与Ribo-seq联合分析可以准确检测mRNA上游5’UTR区的uORFs翻译调控结构&#xff0…

SSM小区疫情防控系统-计算机毕业设计源码03748

摘 要 随着社会的发展,社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。 小区疫情防控系统,主要的模块包括查看首页、轮播图(轮播图管理)、社区公告管理(社区公告&#…

Opengauss开源4年了,都谁在向其贡献代码?

2020 年 6 月 30 日,华为将Opengauss正式开源,截止目前已经过去4年时间,社区力量对这款数据库产品都起到了哪些作用,谁的代码贡献更大一些? 根据社区官网信息统计,截止目前(2024年6月12日&…

2024年护网行动全国各地面试题汇总(5)作者:————LJS

2024护网蓝队面试题第一大题 目录 一. 目前有防火墙,全流量检测,态势感知,IDS,waf,web服务器等设备,如何搭建一个安全的内网环境,请给出大概拓扑结构 (适用于中高级) 1.1…

查看电子磁盘ssd空间信息并释放zfs空间@FreeBSD

发现问题 在某宝买了一块32G的ssd电子盘,但是在FreeBSD里面使用df看到的空间较少,只有15G,一度怀疑是发错货了。不过自己清楚的记得swap分区还分了4G,这样铁定是大于16G的,应该是32G没错。但是少掉的那部分空间跑哪里…

安装前端依赖node-sass报错

文章目录 问题1:node-sass报错问题2:node-gyp报错问题3:node-sass再次报错问题4:node-sass三次报错 问题1:node-sass报错 问题描述:经常会碰到一个新的项目安装依赖时,会报node-sass版本的问题…

揭秘裂变客户背后的心理学:如何触动用户分享欲望?

在当今的社交媒体时代,裂变客户——即用户主动分享并推广某一产品或服务,已成为企业营销的重要策略。那么,如何触动用户的分享欲望呢?这背后其实隐藏着深刻的心理学原理。本文将以looka这个知名的国外设计工具为例,为s…

# RocketMQ 实战:模拟电商网站场景综合案例(五)

RocketMQ 实战&#xff1a;模拟电商网站场景综合案例&#xff08;五&#xff09; 一、mybatis 逆向工程使用 4、逆向工程 生成 的 .xml 配置文件。 4.1、生成的 TradeCouponMapper.xml 文件。 <?xml version"1.0" encoding"UTF-8" ?> <!DOC…

Spring Cloud Stream 消息驱动基础入门与实践总结

Spring Cloud Stream是用于构建与共享消息传递系统连接的高度可伸缩的事件驱动微服务框架&#xff0c;该框架提供了一个灵活的编程模型&#xff0c;它建立在已经建立和熟悉的Spring熟语和最佳实践上&#xff0c;包括支持持久化的发布/订阅、消费组以及消息分区这三个核心概念。…

激活和禁用Hierarchy面板上的物体

1、准备工作&#xff1a; (1) 在HIerarchy上添加待隐藏/显示的物体&#xff0c;名字自取。如&#xff1a;endImage (2) 在Inspector面板&#xff0c;该物体的名称前取消勾选&#xff08;隐藏&#xff09; (3) 在HIerarchy上添加按钮&#xff0c;名字自取。如&#xff1a;tip…

前端开发之TCP与UDP认识

上一篇&#x1f449;: 前端开发之性能优化 TCP与UDP 三次握手 1. 初始状态&#xff1a; 客户端开始时处于CLOSED状态&#xff0c;表明没有活动的连接。服务器监听特定端口&#xff0c;处于LISTEN状态&#xff0c;等待连接请求。 2. 第一次握手&#xff08;SYN_SENT状态&am…

sklearn(Scikit-learn)入门学习教程

sklearn&#xff08;Scikit-learn&#xff09;是一个功能强大的Python机器学习库&#xff0c;它提供了丰富的工具和方法&#xff0c;用于数据挖掘、数据分析和预测建模。以下是一个关于sklearn的清晰教程&#xff0c;涵盖了其主要特点和功能&#xff1a; 1. sklearn简介 定义…

FPGA “+:”、“-:“语法

“:”变量[起始地址 : 数据位宽] <–等价于–> 变量[(起始地址数据位宽-1)&#xff1a;起始地址] data[0 : 8] <–等价于–> data[7:0] data[15 : 2] <–等价于–> data[16:15] “-:”变量[结束地址 -: 数据位宽] <–等价于–> 变量[结束地址&#xf…

【机器学习300问】117、序列模型中的符号表示方法?以命名实体识别(NER)任务为例。

在序列模型中&#xff0c;特别是在命名实体识别(NER)任务中&#xff0c;我们通常会用一系列符号来表示输入序列、目标标签以及模型的结构和操作。本文列出一些常见的符号表示方法&#xff0c;结合NER任务进行解释。 一、什么是命名实体识别任务&#xff1f; &#xff08;1&am…

mysql8.0 sql_mode与ONLY_FULL_GROUP_BY报错

如果你的项目出现如下类似的错误 ### Error querying database. Cause: java.sql.SQLSyntaxErrorException: Expression #2 of SELECT list is not in GROUP BY clause and contains nonaggregated column 字段名 which is not functionally dependent on columns in GROUP BY…