Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support fft based convolution #811

Open
adonath opened this issue Mar 8, 2024 · 7 comments
Open

[Feature] Support fft based convolution #811

adonath opened this issue Mar 8, 2024 · 7 comments
Labels
enhancement New feature or request performance

Comments

@adonath
Copy link

adonath commented Mar 8, 2024

It would be nice to have FFT based convolution supported in mlx. FFT bases convolution shows much better performance for large images / arrays and kernels. The FFT building blocks are already supported in mlx, so it is mostly a matter of combining them to a convolution operation.

@adonath adonath changed the title Feature request: support fft based convolution [Feature] Support fft based convolution Mar 8, 2024
@sebblanchet
Copy link

@awni

wouldn't mind looking into implementing this, what do you think ?

@awni
Copy link
Member

awni commented Mar 9, 2024

One challenge here is that FFT is not yet supported on the GPU (in Metal). So you could use it but on the CPU it would almost certainly be much slower than our GPU convolution.

Also I think FFT-based convolution is more of an implementation detail. If there are some sizes that are slow for you, please share any benchmarks. We can then figure out the best way to make them faster (which may or may not require an FFT-based convolution).

@awni awni closed this as completed Mar 10, 2024
@adonath
Copy link
Author

adonath commented Mar 11, 2024

Thanks @awni and @sebblanchet! I did a quick implementation of a FFT based convolution in MLX:

def _centered(arr, newshape):
    newshape = mx.array(newshape)
    currshape = mx.array(arr.shape)

    startind = (currshape - newshape) // 2
    endind = startind + newshape
    myslice = [slice(startind[k].item(), endind[k].item()) for k in range(len(endind))]
    return arr[tuple(myslice)]


def convolve_fft(image, kernel, stream):
    """Convolve FFT for torch tensors"""
    image_2d, kernel_2d = image[0, 0], kernel[0, 0]

    shape = [image_2d.shape[i] + kernel_2d.shape[i] - 1 for i in range(image_2d.ndim)]

    image_ft = mx.fft.rfft2(image, s=shape, stream=stream)
    kernel_ft = mx.fft.rfft2(kernel, s=shape, stream=stream)
    result = mx.fft.irfft2(image_ft * kernel_ft, s=shape, stream=stream)
    return _centered(result, image.shape)

I also did a simple benchmark. It uses a random image of size 1024x1024 and varying kernel sizes. It compares mx.conv2d on the GPU and CPU respectively, the FFT based algorithm from above and for comparison Scipy's FFT convolution implementation. The result is the following:

mlx-conv-mini-benchmark

I think it follows exactly the expectation:

  • gpu is faster than cpu for native convolution
  • native convolution is faster for small kernel sizes
  • for large kernel sizes FFT clearly wins, because of the Nlog(N) scaling instead of N^2
  • Scipy's FFT convolve is faster, than mine (not too surprising)
  • the transition point where the cpu FFT becomes faster than native GPU is at a kernel size of >20-30.

In general I think it is still worth to have an FFT based convolution. For NNs with small kernels, there is no point. But there are many scientific applications that rely on large kernels (think of cross-correlations, convolution with pathological point spread functions, etc.)

I think it is worth re-opening.

@awni
Copy link
Member

awni commented Mar 12, 2024

Ok sounds good! Thanks for the benchmarks, that's really interesting!

@awni awni added enhancement New feature or request performance labels Mar 12, 2024
@awni
Copy link
Member

awni commented Mar 12, 2024

One option is to update the CPU convolution to dispatch to an FFT implementation when the input sizes make sense. We would want to benchmark it in a few settings to be sure it's a strict improvement.

@adonath
Copy link
Author

adonath commented Mar 12, 2024

Thanks for re-opening @awni!

One option is to update the CPU convolution to dispatch to an FFT implementation when the input sizes make sense. We would want to benchmark it in a few settings to be sure it's a strict improvement.

This is what Scipy has too, see https://github.com/scipy/scipy/blob/v1.12.0/scipy/signal/_signaltools.py#L1161 There is the option to measure or to actually compute the flops. Measuring only makes sense for repeated convolutions, but gives probably the most accurate results for arbitrary architectures. Looking at the Scipy code, it seems that computing the flops is maybe too complex. Or is there a general way to predict flops for mlx operations? (would be nice to have...)

In general the performance of MLX operations is probably much more predictable across the more homogeneous M architectures. So there could be a third option by just parametrizing the scaling laws based on empirical benchmarks or something similar...

@adonath
Copy link
Author

adonath commented Mar 12, 2024

Here is the gist with the code for the benchmark: https://gist.github.com/adonath/3f16b30498c60f25cf1349792c15283c

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request performance
Projects
None yet
Development

No branches or pull requests

3 participants