Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Sharded embeddings in separate FSDP group #182

Draft
wants to merge 1 commit into
base: gh/awgu/2/base
Choose a base branch
from

Conversation

awgu
Copy link
Contributor

@awgu awgu commented Apr 1, 2024

Stack from ghstack (oldest at bottom):

If we shard the embeddings as a separate FSDP parameter group, then:

  • In forward, we have a separate all-gather for the root first (norm, output projection) followed by an all-gather for the embeddings. This makes the first all-gather smaller and allows overlapping the embedding's pre-forward casts with an all-gather.
  • In forward, the embedding parameters are resharded after their usage near the beginning of forward, before any transformer block forwards.
  • In backward, the embedding parameters are only all-gathered near the end of backward when it is no longer close to peak memory.
  • In backward, the embedding's reduce-scatter and the root's reduce-scatter are still both exposed since the embedding has the last gradient computation.

This saves ~the embedding parameter size from peak memory without any decrease to WPS on the first order. (It introduces extra 2 all-gathers and 1 reduce-scatter, which can be bad for communication latency at large scale.)

For example, for Llama-7B with bf16 mixed precision, we save ~0.84 GiB, and on 8 GPUs, there is no noticeable effect on MFU.

awgu added a commit that referenced this pull request Apr 1, 2024
ghstack-source-id: 8dcb875579c10dde3e5c7bfbed10636bc9ef39f0
Pull Request resolved: #182
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 1, 2024
@awgu awgu changed the title Sharded embeddings in separate FSDP group [RFC] Sharded embeddings in separate FSDP group Apr 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants