Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run sdpa with dtensor #180

Open
wants to merge 6 commits into
base: gh/tianyu-l/7/base
Choose a base branch
from
Open

Conversation

tianyu-l
Copy link
Contributor

@tianyu-l tianyu-l commented Mar 30, 2024

Stack from ghstack (oldest at bottom):

This PR gets rid of the manual adjustment of num of heads in attention layers, by using dtensor outputs of wq, wk, wv, so that the SDPA is aware of the distributedness.

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Mar 30, 2024
ghstack-source-id: 33d3d0b6a19c747269aab1a95589bb61bf9c1f51
Pull Request resolved: #180
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Mar 30, 2024
@tianyu-l tianyu-l mentioned this pull request Mar 30, 2024
@tianyu-l tianyu-l requested a review from wanchaol March 30, 2024 00:25
This PR gets rid of the manual adjustment of num of heads in attention layers, by using dtensor outputs of `wq`, `wk`, `wv`, so that the SDPA is aware of the distributedness.

[ghstack-poisoned]
tianyu-l added a commit that referenced this pull request Mar 30, 2024
ghstack-source-id: 43941c1ca0dfc7a04589a7513a110b877c217917
Pull Request resolved: #180
"attention.wq": col_parallel_strategy(),
"attention.wk": col_parallel_strategy(),
"attention.wv": col_parallel_strategy(),
"attention.wq": col_parallel_strategy(use_local_output=False),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 I thought we need to replicate the freq_cis but here it seems we don't need to?

@wconstab
Copy link
Contributor

wconstab commented Apr 30, 2024

just curious, is this gonna land soon or does it have some risk or unfinished business?

also looks like this could use a rebase. i got a little confused applying it on my branch bc some of the sharding config seems changed (attention.wo and attention_norm)

@tianyu-l
Copy link
Contributor Author

just curious, is this gonna land soon or does it have some risk or unfinished business?

also looks like this could use a rebase. i got a little confused applying it on my branch bc some of the sharding config seems changed (attention.wo and attention_norm)

It hasn't been landed because there is a very strange bug (#267) associated with (but seemingly not caused by) multiplication using DTensor. It would be triggered in the rotary embedding computation if this PR is landed. I will work on the bug soon since it will also benefit PP (iiuc). @wconstab

@wconstab
Copy link
Contributor

It would be triggered in the rotary embedding computation if this PR is landed

oh, is this related to dispatching for complex numbers by any chance?

@tianyu-l
Copy link
Contributor Author

oh, is this related to dispatching for complex numbers by any chance?

@wconstab Possibly, we don't know. The aten.mul op returns bad results with inputs being raw torch.Tensor (desugared from DTensor), and this bug is only present in the backward pass. Do you know who I should ask for help from?

[ghstack-poisoned]
wconstab pushed a commit that referenced this pull request May 1, 2024
ghstack-source-id: 58ba72163a4b03d77f4b2ba7c97cef7e7e8b3096
Pull Request resolved: #180
[ghstack-poisoned]
wconstab pushed a commit that referenced this pull request May 2, 2024
ghstack-source-id: a18a3cb1ba48fb751f437a5ee44f186ff9a26e9a
Pull Request resolved: #180
[ghstack-poisoned]
wconstab pushed a commit that referenced this pull request May 2, 2024
ghstack-source-id: b8b2b58ffc72fcb8bfc88f4ba2a3455e3cc92c0a
Pull Request resolved: #180
[ghstack-poisoned]
wconstab pushed a commit that referenced this pull request May 2, 2024
ghstack-source-id: 55bb9e1ba289c212f4af58e19d9bede2ad0246a8
Pull Request resolved: #180
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants