首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >使用FFT实现2D卷积

使用FFT实现2D卷积
EN

Stack Overflow用户
提问于 2018-05-22 01:38:04
回答 2查看 8K关注 0票数 3

对于使用大核(滤镜)卷积大图像,TensorFlow.conv2d()的速度非常慢。将1024x1024图像与相同大小的内核进行卷积需要几分钟时间。为了进行比较,cv2.filter2D()会立即返回结果。

我找到tf.fft2()tf.rfft()了。

然而,我并不清楚如何使用这些功能执行简单的图像过滤。

如何使用快速傅立叶变换在TensorFlow中实现快速2D图像滤波?

EN

回答 2

Stack Overflow用户

发布于 2018-05-24 18:13:34

x * y形式的线性离散卷积可以使用卷积定理和离散时间傅立叶变换来计算。如果x * y是一个圆形离散卷积,那么它可以用离散傅立叶变换来计算。

卷积定理状态x * y可以使用傅里叶变换计算如下

哪里

表示傅里叶变换和

傅里叶逆变换。当xy是离散的,并且它们的卷积是线性卷积时,使用DTFT计算如下

如果xy是离散的,并且它们的卷积是循环卷积,则上面的离散傅立叶变换被离散傅立叶变换所代替。注:线性卷积问题可以嵌入到循环卷积问题中。

我对MATLAB比较熟悉,但是通过阅读tf.signal.fft2dtf.signal.ifft2d的TensorFlow文档,下面的解决方案应该可以很容易地转换为TensorFlow,只需替换MATLAB的函数fft2ifft2

在MATLAB (和TensorFlow)中,fft2 (和tf.signal.fft2d)使用快速傅立叶变换算法来计算密度傅立叶变换。如果xy的卷积是循环的,则可以通过以下公式计算

代码语言:javascript
复制
ifft2(fft2(x).*fft2(y))

其中.*表示MATLAB中元素与元素的乘法。然而,如果它是线性的,那么我们将数据零填充到长度2N-1,其中N是一维的长度(在问题中是1024)。在MATLAB中,可以通过以下两种方法之一进行计算。首先,通过

代码语言:javascript
复制
h = ifft2(fft2(x, 2*N-1, 2*N-1).*fft2(y, 2*N-1, 2*N-1));

其中MATLAB通过零填充计算xy2*N-1-point 2D傅立叶变换,然后计算2*N-1-point 2D逆傅立叶变换。这种方法不能在TensorFlow中使用(根据我对文档的理解),因此下一种方法是唯一的选择。在MATLAB和TensorFlow中,可以通过首先扩展xy来计算2*N-1 x 2*N-1的大小,然后计算2*N-1-point二维傅立叶变换和逆傅立叶变换来计算卷积

代码语言:javascript
复制
x_extended = x;
x_extended(2*N-1, 2*N-1) = 0;

y_extended = y;
y_extended(2*N-1, 2*N-1) = 0;

h_extended = ifft2(fft2(x_extended).*fft2(y_extended));

在MATLAB中,hh_extended是完全相等的。无需傅立叶变换即可计算xy的卷积

代码语言:javascript
复制
hC = conv2(x, y);

在MATLAB中实现。

在我笔记本电脑上的MATLAB中,conv2(x, y)需要55秒,而傅立叶变换方法只需要不到0.4秒。

票数 4
EN

Stack Overflow用户

发布于 2019-12-04 03:22:48

这可以以类似于例如实现scipy.signal.fftconvolve的方式来完成。

这里是一个例子,假设我们有一个图像(2维,如果你也有多个通道,你可以使用3d而不是2个函数) (im)和一个滤波器(例如高斯)。

首先,对图像进行傅里叶变换并定义fft_lenghts (如果滤镜的形状不同,则非常有用,在这种情况下,它将被填充为零)。

代码语言:javascript
复制
fft_lenght1 = tf.shape(im)[0]
fft_lenght2 = tf.shape(im)[1]
im_fft = tf.signal.rfft2d(im, fft_length=[fft_lenght1, fft_lenght2])

代码语言:javascript
复制
kernel_fft = tf.signal.rfft2d(kernel, fft_length=[fft_lenght1, fft_lenght2])

最后,取反变换得到卷积后的图像

代码语言:javascript
复制
im_blurred = tf.signal.irfft2d(im_fft * kernel_fft, [fft_lenght1, fft_lenght2])
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50453981

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档