Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal shaders for memory efficient self attention on large sequences #964

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

bpkeene
Copy link
Contributor

@bpkeene bpkeene commented Apr 6, 2024

Implements metal shaders for:

o = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale, mask=None)

Supports fp16, fp32 dtypes; flexible hidden dimension, currently templated for 64 and 128.

Causal masking for prompt encoding not yet implemented, this shader is focused at present on full self-attention.

Context

We're continuing the work started with #735 and following up with memory savings for the self-attention use case common in diffusion transformer workflows such as Stable Diffusion 3. In particular, we observe memory savings for the 8-b use case of > 5GB for float32. Current shaders are implemented using Steel-like GEMM primitives in MLX style, with potential for performance tuning to improve latency while retaining the memory savings.

Supported

  • Supports mx.float16 and mx.float32 dtypes
  • Supports head_dim=64,128 (covers most 7b+ LLMs)
  • MHA supported (no MQA, no GQA)
  • Unsupported use cases still go through MLX primitives under the hood.
  • No backward pass implementation (inference-only kernel)

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@bpkeene
Copy link
Contributor Author

bpkeene commented Apr 6, 2024

Marking as draft, currently working through some numerical issues via separate workflow, and will add CPU side bindings + dispatch, test & docs - sharing a current status and will update this PR

@bpkeene bpkeene marked this pull request as draft April 6, 2024 04:01
@bpkeene bpkeene force-pushed the user/bkeene/fast_sdpa_self_attention branch 2 times, most recently from 54d1412 to beccbf5 Compare April 6, 2024 04:07
@bpkeene bpkeene marked this pull request as ready for review April 12, 2024 06:46
@bpkeene bpkeene marked this pull request as draft May 3, 2024 17:19
1 /* self attention: unused */,
alpha};

set_array_buffer(compute_encoder, q, 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changed to compute_encoder.set_input_array (for inputs) and compute_encoder.set_output_array for outputs of the shader.

constexpr const uint rows_per_tgroup = 16;
const int tgp_y_indices = ((int(qseq) - 1) / rows_per_tgroup) + 1;

auto compute_encoder = d.get_command_encoder(s.index);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changes to auto& compute_encoder ...

@bpkeene bpkeene force-pushed the user/bkeene/fast_sdpa_self_attention branch from f9ede9b to 9da07d0 Compare May 16, 2024 03:46
@bpkeene bpkeene marked this pull request as ready for review May 16, 2024 03:47
@bpkeene bpkeene changed the title Metal shaders for efficient self attention on large sequences Metal shaders for memory efficient self attention on large sequences May 21, 2024
@bpkeene
Copy link
Contributor Author

bpkeene commented May 21, 2024

Hi folks,
Latency - Self Attention SDPA
Memory - Self Attention SDPA

Attaching some graphs for measured latency on M3 Max and some estimated memory savings per attention block (empirically observed at several data points, graph here obtained via formulas)

@bpkeene
Copy link
Contributor Author

bpkeene commented May 21, 2024

Some room for improvement on larger sequences re: latency, with a divergence after ~2300 sequence length, though the memory savings exceeds 1GB ~2k, and is approaching 5GB at 4250 sequence length (SD3 8B use case).

All measurements were with batch size 2, heads = 38, hidden dim = 64, and float32 on M3 Max / 48GB.

mlx/fast.cpp Show resolved Hide resolved
mlx/fast.cpp Outdated Show resolved Hide resolved
mlx/fast.cpp Outdated Show resolved Hide resolved
@awni
Copy link
Member

awni commented May 22, 2024

@bpkeene left a few minor comments. Could you address? Once updated we can run the tests and get this merged.

@bpkeene
Copy link
Contributor Author

bpkeene commented May 23, 2024

Updated with the requested changes, thank you for the prompt review!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants