一、背景
接着上一篇代码解读 | Hybrid Transformers for Music Source Separation[02]文章,继续对Hybrid Transformer Demucs 代码进行解读。
解读目标:明确数据从进入算法,在算法内部,以及在算法输出 这三个阶段中 数据的大小是如何变换的。例如:算法输入数据大小为[BatchSize,Channels,Length],算法内部的数据大小为[BatchSize,Channels,Freqency,Time],算法输出的数据大小为[BatchSize,Channels,Length]。
二、解读
在htdemucs.py文件中编写测试代码print,把每个模块输出的数据大小打印出来。控制台打印的结果如下所示。例如:时域[1, 2, 258602] 可以理解成[BatchSize,Channels,Time];频域[1, 4, 2048, 253]可以理解成[BatchSize,Channels,Frequency,Time]。
再具体一些,不管是时域还是频域第一个维度数字1表示批大小(Batch_Size),第二个数字表示通道数(Channels),频域的第三个数字可以理解成频域维度属性(对应模型图中的freq),频域第四个数字理解成时间维度属性(对应模型图中的time steps)。时域的第三个数字可以理解成时间维度属性(对应模型图中的time steps)。根据这个先知条件,我们再和算法模型图中的Cin、Cout、xxx freq对比便一目了然了。
算法输入 torch.Size([1, 2, 258602])
频域输入(STFT输出) torch.Size([1, 4, 2048, 253])
频域输入(归一化) torch.Size([1, 4, 2048, 253])
时域输入(归一化) torch.Size([1, 2, 258602])
时域第1个编码层输出torch.Size([1, 48, 64651])
频域第1个编码层输出torch.Size([1, 48, 512, 253])
时域第2个编码层输出torch.Size([1, 96, 16163])
频域第2个编码层输出torch.Size([1, 96, 128, 253])
时域第3个编码层输出torch.Size([1, 192, 4041])
频域第3个编码层输出torch.Size([1, 192, 32, 253])
时域第4个编码层输出torch.Size([1, 384, 1011])
频域第4个编码层输出torch.Size([1, 384, 8, 253])
crosstransformer输出频域:torch.Size([1, 384, 8, 253]),时域:torch.Size([1, 384, 1011])
频域第1个解码层输出torch.Size([1, 192, 32, 253])
时域第1个解码层输出torch.Size([1, 192, 4041])
频域第2个解码层输出torch.Size([1, 96, 128, 253])
时域第2个解码层输出torch.Size([1, 96, 16163])
频域第3个解码层输出torch.Size([1, 48, 512, 253])
时域第3个解码层输出torch.Size([1, 48, 64651])
频域第4个解码层输出torch.Size([1, 16, 2048, 253])
时域第4个解码层输出torch.Size([1, 8, 258602])
频域输出(STFT输出) torch.Size([1, 4, 2, 258602])
时域输出(归一化) torch.Size([1, 4, 2, 258602])
算法输出(时域输出+频域输出) torch.Size([1, 4, 2, 258602])
总结:打印出每个环节的数据大小,这样就能和算法模型图中的参数对应上。至于各个模块更具体的细节还需仔细阅读源码进行理解。
感谢阅读,最近开始写公众号(分享好用的AI工具),欢迎大家一起见证我的成长(桂圆学AI)