-
Notifications
You must be signed in to change notification settings - Fork 856
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metal shaders for memory efficient self attention on large sequences #964
base: main
Are you sure you want to change the base?
Metal shaders for memory efficient self attention on large sequences #964
Conversation
Marking as draft, currently working through some numerical issues via separate workflow, and will add CPU side bindings + dispatch, test & docs - sharing a current status and will update this PR |
54d1412
to
beccbf5
Compare
1 /* self attention: unused */, | ||
alpha}; | ||
|
||
set_array_buffer(compute_encoder, q, 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changed to compute_encoder.set_input_array
(for inputs) and compute_encoder.set_output_array
for outputs of the shader.
constexpr const uint rows_per_tgroup = 16; | ||
const int tgp_y_indices = ((int(qseq) - 1) / rows_per_tgroup) + 1; | ||
|
||
auto compute_encoder = d.get_command_encoder(s.index); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This changes to auto& compute_encoder ...
Updated fast attention: GEMM-ified with Steel primitives Uses flash attention 1 for scale correction
f9ede9b
to
9da07d0
Compare
Some room for improvement on larger sequences re: latency, with a divergence after ~2300 sequence length, though the memory savings exceeds 1GB ~2k, and is approaching 5GB at 4250 sequence length (SD3 8B use case). All measurements were with batch size 2, heads = 38, hidden dim = 64, and float32 on M3 Max / 48GB. |
@bpkeene left a few minor comments. Could you address? Once updated we can run the tests and get this merged. |
Updated with the requested changes, thank you for the prompt review! |
Implements metal shaders for:
o = mx.fast.scaled_dot_product_attention(queries, keys, values, scale=scale, mask=None)
Supports fp16, fp32 dtypes; flexible hidden dimension, currently templated for 64 and 128.
Causal masking for prompt encoding not yet implemented, this shader is focused at present on full self-attention.
Context
We're continuing the work started with #735 and following up with memory savings for the self-attention use case common in diffusion transformer workflows such as Stable Diffusion 3. In particular, we observe memory savings for the 8-b use case of > 5GB for float32. Current shaders are implemented using Steel-like GEMM primitives in MLX style, with potential for performance tuning to improve latency while retaining the memory savings.
Supported
mx.float16
andmx.float32
dtypesChecklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes