复数卷积 tensorflow_PyTorch 中的傅里叶卷积

df690e7c1ed3d872b1e278353de8788a.png

欢迎关注 “小白玩转Python”,发现更多 “有趣”

注意: 在这个 Github repo 中提供了1D、2D 和3D Fourier 卷积的完整方法。我还提供了 PyTorch 模块,可以方便地将傅里叶卷积添加到可训练模型中。链接如下:

https://github.com/fkodom/fft-conv-pytorch

卷积

卷积在数据分析中无处不在。几十年来,它们一直被用于信号和图像处理。最近,它们成为现代神经网络的重要组成部分。如果你处理数据的话,你可能会遇到错综复杂的问题。

数学上,卷积表示为:

1e998bbad6b372b7db8326d7ab13778e.png

尽管离散卷积在计算应用程序中更为常见,但在本文的大部分内容中我将使用连续形式,因为使用连续变量来证明卷积定理(下面讨论)要容易得多。之后,我们将回到离散情况,并使用傅立叶变换在 PyTorch 中实现它。离散卷积可以看作是连续卷积的近似,其中连续函数离散在规则网格上。因此,我们不会为这个离散的案例重新证明卷积定理。

卷积定理

从数学上来说,卷积定理可以这样描述:

b38c5f2734ec3445b26f54911eb62f7f.png

其中的连续傅里叶变换是(达到正常化常数) :

78b9368ff02cb4b8956504f91af8bad5.png

换句话说,位置空间中的卷积等价于频率空间中的直乘。这个想法是相当不直观的,但是对于连续的情况来说,证明卷积定理是惊人的容易。要做到这一点,首先要写出等式的左边。

dc9f36dae65d15e91b213d9b48518a7c.png

现在切换积分的顺序,替换变量(x = y + z) ,并分离两个被积函数。

cce93edc4fbaad12c8e77c1deefa8564.png

我们为什么要关心这一切?

因为快速傅里叶变换的算法复杂度低于卷积。直接卷积运算具有复杂度 O(n^2) ,因为在 f 中,我们传递 g 中的每个元素,所以可以在 O(nlogn)时间内计算出快速傅立叶变换。当输入数组很大时,它们比卷积要快得多。在这些情况下,我们可以使用卷积定理计算频率空间中的卷积,然后执行逆傅里叶变换回到位置空间。

当输入较小时(例如3x3卷积内核) ,直接卷积仍然更快。在机器学习应用程序中,使用小内核更为常见,因此像 PyTorch 和 Tensorflow 这样的深度学习库只提供直接卷积的实现。但是在现实世界中有很多使用大内核的用例,其中傅立叶卷积算法更有效。

PyTorch 实现

现在,我将演示如何在 PyTorch 中实现傅里叶卷积函数。它应该模仿 torch.nn.functional.convNd 的功能,并利用 fft,而不需要用户做任何额外的工作。因此,它应该接受三个 Tensors (signal、kernel 和可选 bias)和应用于输入的 padding。从概念上讲,这个函数的内部工作原理是:

def fft_conv(    signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: int = 0,) -> Tensor:    # 1. Pad the input signal & kernel tensors    # 2. Compute FFT for both signal & kernel    # 3. Multiply the transformed Tensors together    # 4. Compute inverse FFT    # 5. Add bias and return

让我们按照上面显示的操作顺序逐步构建 FFT 卷积。对于这个例子,我将构建一个一维傅里叶卷积,但是将其扩展到二维和三维卷积是很简单的。

1. 填充输入数组

我们需要确保 signal 和 kernel 在填充之后有相同的大小。应用初始填充 signal,然后调整 kernel 的填充以匹配。

# 1. Pad the input signal & kernel tensorssignal = f.pad(signal, [padding, padding])kernel_padding = [0, signal.size(-1) - kernel.size(-1)]padded_kernel = f.pad(kernel, kernel_padding)

注意,我只在一边填充 kernel。我们希望原始内核位于填充数组的左侧,这样它就可以与 signal 数组的开始对齐。

2. 计算傅立叶变换

这非常简单,因为 n 维 fft 已经在 PyTorch 中实现了。我们简单地使用内置函数,并计算沿每个张量的最后一个维数的 FFT。

# 2. Perform fourier convolutionsignal_fr = rfftn(signal, dim=-1)kernel_fr = rfftn(padded_kernel, dim=-1)

3. 变换张量相乘

令人惊讶的是,这是我们功能中最复杂的部分。这有两个原因。(1) PyTorch 卷积运行于多维张量上,因此我们的 signal 和 kernel 张量实际上是三维的。从 PyTorch 文档中的这个方程式,我们可以看到矩阵乘法是在前两个维度上运行的(不包括偏差项) :

18d8d249dc7e6964dd9946a005871c6c.png

我们将需要包括这个矩阵乘法,以及对转换后的维度的直接乘法。

PyTorch 实际上实现了互相关/值方法而不是卷积方法。(TensorFlow 和其他深度学习库也是如此。)互相关与卷积密切相关,但有一个重要的标志变化:

e1c7a8a3c1611b076a548be19e24c486.png

与卷积相比,这有效地逆转了核的方向(g)。我们不是手动翻转内核,而是在傅里叶空间中利用内核的共轭复数来纠正这个问题。由于我们不需要创建一个全新的 Tensor,所以这样做的速度明显更快,内存效率也更高。(本文末尾的附录中简要说明了这种方法的工作原理。)

# 3. Multiply the transformed matricesdef complex_matmul(a: Tensor, b: Tensor) -> Tensor:    """Multiplies two complex-valued tensors."""    # Scalar matrix multiplication of two tensors, over only the first two dimensions.    # Dimensions 3 and higher will have the same shape after multiplication.    scalar_matmul = partial(torch.einsum, "ab..., cb... -> ac...")    # Compute the real and imaginary parts independently, then manually insert them    # into the output Tensor.  This is fairly hacky but necessary for PyTorch 1.7.0,    # because Autograd is not enabled for complex matrix operations yet.  Not exactly    # idiomatic PyTorch code, but it should work for all future versions (>= 1.7.0).    real = scalar_matmul(a.real, b.real) - scalar_matmul(a.imag, b.imag)    imag = scalar_matmul(a.imag, b.real) + scalar_matmul(a.real, b.imag)    c = torch.zeros(real.shape, dtype=torch.complex64)    c.real, c.imag = real, imag    return c# Conjugate the kernel for cross-correlationkernel_fr.imag *= -1output_fr = complex_matmul(signal_fr, kernel_fr)

PyTorch 1.7改进了对复数的支持,但是在 autograd 中还不支持对复数张量的许多操作。现在,我们必须编写我们自己的复杂 matmul 方法作为一个补丁。虽然不是很理想,但是它确实有效,并且在未来的版本中不会出现问题。

4. 计算逆变换

使用 torch.irfftn 可以直接计算逆变换,然后裁剪出额外的数组填充。

# 4. Compute inverse FFT, and remove extra padded valuesoutput = irfftn(output_fr, dim=-1)output = output[:, :, :signal.size(-1) - kernel.size(-1) + 1]

5. 添加偏执项并返回

添加偏差项也很容易。请记住,对于输出阵列中的每个通道,偏置项都有一个元素,并相应地调整其形状。

# 5. Optionally, add a bias term before returning.if bias is not None:    output += bias.view(1, -1, 1)

将上述代码整合在一起

为了完整起见,让我们将所有这些代码片段编译成一个内聚函数。

def fft_conv_1d(    signal: Tensor, kernel: Tensor, bias: Tensor = None, padding: int = 0,) -> Tensor:    """    Args:        signal: (Tensor) Input tensor to be convolved with the kernel.        kernel: (Tensor) Convolution kernel.        bias: (Optional, Tensor) Bias tensor to add to the output.        padding: (int) Number of zero samples to pad the input on the last dimension.    Returns:        (Tensor) Convolved tensor    """    # 1. Pad the input signal & kernel tensors    signal = f.pad(signal, [padding, padding])    kernel_padding = [0, signal.size(-1) - kernel.size(-1)]    padded_kernel = f.pad(kernel, kernel_padding)    # 2. Perform fourier convolution    signal_fr = rfftn(signal, dim=-1)    kernel_fr = rfftn(padded_kernel, dim=-1)    # 3. Multiply the transformed matrices    kernel_fr.imag *= -1    output_fr = complex_matmul(signal_fr, kernel_fr)    # 4. Compute inverse FFT, and remove extra padded values    output = irfftn(output_fr, dim=-1)    output = output[:, :, :signal.size(-1) - kernel.size(-1) + 1]    # 5. Optionally, add a bias term before returning.    if bias is not None:        output += bias.view(1, -1, 1)    return output

直接卷积测试

最后,我们将使用 torch.nn.functional.conv1d 来确认这在数值上等同于直接一维卷积。我们为所有输入构造随机张量,并测量输出值的相对差异。

import torchimport torch.nn.functional as ftorch.manual_seed(1234)kernel = torch.randn(2, 3, 1025)signal = torch.randn(3, 3, 4096)bias = torch.randn(2)y0 = f.conv1d(signal, kernel, bias=bias, padding=512)y1 = fft_conv_1d(signal, kernel, bias=bias, padding=512)abs_error = torch.abs(y0 - y1)print(f'\nAbs Error Mean: {abs_error.mean():.3E}')print(f'Abs Error Std Dev: {abs_error.std():.3E}')# Abs Error Mean: 1.272E-05

考虑到我们使用的是32位精度,每个元素相差大约1e-5ー相当精确!让我们也执行一个快速的基准来测量每个方法的速度:

from timeit import timeitdirect_time = timeit(    "f.conv1d(signal, kernel, bias=bias, padding=512)",     globals=locals(),     number=100) / 100fourier_time = timeit(    "fft_conv_1d(signal, kernel, bias=bias, padding=512)",     globals=locals(),     number=100) / 100print(f"Direct time: {direct_time:.3E} s")print(f"Fourier time: {fourier_time:.3E} s")# Direct time: 1.523E-02 s# Fourier time: 1.149E-03 s

测量的基准将随着您使用的机器而发生显著的变化。(我正在用一台非常旧的 Macbook Pro 进行测试。)对于1025的内核,傅里叶卷积似乎要快10倍以上。

总结

我希望这已经提供了一个彻底的介绍傅里叶卷积。我认为这是一个非常酷的技巧,在现实世界中有很多应用程序可以使用它。我也喜欢数学,所以看到编程和纯数学的结合是很有趣的。欢迎和鼓励所有的评论和建设性的批评,如果你喜欢这篇文章,请鼓掌!

附录:

卷积 vs. 互相关

在本文的前面,我们通过在傅里叶空间中取得内核的互相关共轭复数来实现。这实际上颠倒了 kernel 的方向,现在我想演示一下为什么会这样。首先,记住卷积和互相关的公式:

096ded9048a935d22ba354c883a2bc93.png

然后,让我们来看看 g(x) 的傅里叶变换:

8c6e8de6fd71eeded7bd3429f5a55bc0.png

注意,g(x)是实值的,所以它不受共轭复数变化的影响。然后,更改变量(y =-x)并简化表达式。

579a64590355910c0c27b80cfd7036a4.png

·  END  ·

HAPPY LIFE

43ba687a02759329878d5c17ff508d52.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/500703.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

python3 枚举_Python3枚举类如何处理重复名称?

筛选重复的名称,相信对于各位小伙伴不是什么难事。那么,大家发现有重复的名称会如何进行解决呢?作为一位python的粉丝,小编优先选择的事这类的方法。在处理重复名称方面,小编选择的是用python3里的枚举法进行操作。没听…

python迷宫算法及实现_Python迷宫递归算法

所以我盯着这个有一段时间了,我不知道怎么才能回到这个迷宫的正确路径。在2代表墙MAZE [[2,2,2,2,1,2],[2,2,1,2,1,2],[2,2,1,2,1,2],[2,1,1,1,1,2],[2,1,2,2,2,2],[2,1,2,2,2,2]]START_ROW 5START_COL 1END_ROW 0END_COL 4was_here [[False]*6 for i in rang…

lucene索引MySQL原因_我如何在数据库中存储Lucene索引?

这是我的示例代码:MysqlDataSource dataSource new MysqlDataSource();dataSource.setUser("root");dataSource.setPassword("ncl");dataSource.setDatabaseName("userdb");dataSource.setEmulateLocators(true); //This is importa…

python time智能等待_Python Selenium智能等待

前言在使用python selenium进行自动化测试实践的过程中,经常会遇到元素定位不到,弹出框定位不到等等各种定位不到的情况,在大多数的情况下,无非是以下两种情况:1、有frame存在,定位前,未switch到…

python存储对象的数组_Python:在2d数组中存储对象并调用其方法

我正在尝试制作一个象棋应用程序。代码如下:#file containing pieces classesclass Piece(object):name "piece"value 0grid_name "____"class Pawn(Piece):# Rules for pawns.#If first move, then can move forward two spacesname "…

python正则表达式匹配括号并删除_如何使用正则表达式删除括号内的文本?

在括号匹配子串的图案不具有其他(和)字符之间(如(xyz 123)在Text (abc(xyz 123))是\([^()]*\)详细资料:\(-圆括号(请注意,在POSIX BRE中(应使用,请参见sed下面的示例)[^()]*-除否定的字符类别 / POSIX括号表达式中定义的字符以外的零个或多个…

现代软件工程 (备份)

自我介绍一下, 我叫邹欣, 是微软亚洲研究院 创新工程中心 首席研发主管 (Principal Development Manager). 我和同事们一起把研究成果转化为商业软件产品和服务。近期主要专注于垂直搜索,企业搜索,软件开发工具和数字娱乐等领域。 在工作之余, 我也写…

java钱_在Java中如何表示钱Money?

为什么不使用float / double?使用java时会遇到money类型的选择问题,首先想到的是float / double。如果只是简单的货币计算,很难发现用float会有问题。出现问题的原因是使用float / double(已经相应的包装类Float / Double)会出现舍入误差(rou…

期望

把原来一些关于软件工程,教学,和程序设计相关的博客搬过来。 [http://yishan.cc/blogs/xin/archive/2009/04/12/1058.aspx] 学校的期望 我在BBS 看到有人感慨说- 有家长让小孩在大学里专心学习,不要想别的。等到一毕业,就希望小孩…

java 数值变量_Java 中数值变量赋值问题

写了一段判断数值相等判断方法的程序://在-128~127 之外的数Integer i1 200;Integer i2 200;System.out.println("i1i2: "(i1i2));//false// 在-128~127 之内的数Integer i3 100;Integer i4 100;System.out.println("i3i4: "(i3i4));//true…

软件教育随想

[由于工作的关系,我在过去的几年中访问了十二三所软件学院/计算机学院,和不少老师,学生座谈过。我在研究院里也碰到了不少各个学校来的学生,谈得多了,就有下面的随想。] 想来的人来不了 学校里都是按照高考/考研的成绩…

java获取当前电脑的ip_使用Java获取当前计算机的IP地址

问题我正在尝试开发一个系统,其中有不同的节点在不同的系统上或在同一系统上的不同端口上运行。现在所有节点都创建一个Socket,其目标IP作为称为自举节点的特殊节点的IP。然后节点创建自己的ServerSocket并开始侦听连接。引导节点维护一个节点列表&#…

两千块钱带来的 希望

几年以前,我参加过一个全国 “软件学院” 的评审,得到两千块现金和一些希望。我后来把钱和希望都还给同学们了,现在说明一下。 [这是个人回忆,不代表任何组织,也不确保所有信息的完全准确] 我先…

java 内部变量_java 中的内置数据类型

1, 基本数据类型Java是强类型语言, 对于每一种数据都定义了类型,基本数据类型分为数值型,字符型,布尔型。数值型又分为了整型和浮点型。整型又分为byte, int, short long.浮点型又分为了float 和double.字符型是char 类型&#x…

DG导入mysql依赖包_MySql导入导出数据库(含远程导入导出)

1、先运行cmd,cd 到mysql安装目录中的bin文件夹2、mysqldump -u root -p 数据库名 > 导出文件名.sql其他情况下:1.导出整个数据库mysqldump -u 用户名 -p 数据库名 > 导出的文件名mysqldump -u wcnc -p smgp_apps_wcnc > wcnc.sql2.导出一个表m…

java线程的优点_Java使用多线程的优势

Java使用多线程的优势如果使用得当,线程可以有效地降低程序的开发和维护等成本,同时提升复杂应用程序的性能。那么Java使用多线程的优势具体有哪些呢,一起来了解一下!1、发挥多处理器的强大能力现在,多处理器系统正日益盛行&#…

开发软件不是闭卷考试

有人问我这个问题: “你正在做一个项目,这个项目有一项关键的feature需要实现,这个feature有一定的技术难度,你调试了很久,都没找到实现的途径,这时你已经在这个feature上花了很多时间了,而且无…

go语言mysql删除记录_MySQL数据库删除操作-Go语言中文社区

删除数据库DROP DATABASE [IF EXISTS] 数据库名;例如:删除school数据库IF EXISTS 为可选,判断是否存在,如果不存在则会抛出异常删除数据表DROP TABLE [IF EXISTS] 表名;例如:删除student表注意:删除具有主外键关系的表…

java csv下载_java 生成csv文件,弹出下载对话框。。。

1.最直接最简单的,方式是把文件地址直接放到html页面的一个链接中。这样做的缺点是把文件在服务器上的路径暴露了,并且还无法对文件下载进行其它的控制(如权限)。这个就不写示例了。2.在服务器端把文件转换成输出流,写…

中科大在50年代的教学理念

给中科大的学生上课, 得了解这个学校的教学理念,我从网上找到了一篇: 中国科学技术大学里的基础课中国科学技术大学力学和力学工程系系主任 钱学森中国科学技术大学是为我国培养尖端科学研究技术干部的,因此学生必需在学校里打下将来作研究工作的基础。…