-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] support saving and loading of sharded checkpoints #7830
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Always delightful to deal with the from_pretrained
code ;)
I don't really have any bigger comments, as this should hopefully work well since it's based on the transformers implementation. Only some smaller comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your work @sayakpaul ! Left a suggestion (not a blocker, we can do it afterwards if needed) ! No major comments since @BenjaminBossan did a very thorough review already !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot @sayakpaul for the integration and iterating over it! Current code looks good to me :) I'd rather have another pair of eyes reviewing it, given it's fairly easy to miss something when iterating/reviewing several times on the same code.
Thanks again!
Yeah. @yiyixuxu would be the final approver here :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the PR!!
I left some comments and questions :)
src/diffusers/configuration_utils.py
Outdated
@@ -349,7 +349,7 @@ def load_config( | |||
local_files_only = kwargs.pop("local_files_only", False) | |||
revision = kwargs.pop("revision", None) | |||
_ = kwargs.pop("mirror", None) | |||
subfolder = kwargs.pop("subfolder", None) | |||
subfolder = kwargs.pop("subfolder", None) or "" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why don't we handle it where it fails then
subfolder, |
we would only need to change one place, no?
@yiyixuxu do the recent changes work for you? (I have run the tests) |
What does this PR do?
Follow-up of #6396.
This PR adds support for saving a big model's state dict into multiple shards for efficient portability and loading. Adds support for loading the sharded checkpoints, too.
This is much akin to handling big models like T5XXL.
Also, added a nice test to ensure the models that have
_no_split_modules
specified can be sharded and loaded back to perform inference ensuring numerical assertions.Here's a real use-case. Consider this
Transformer2DModel
checkpoint: https://huggingface.co/sayakpaul/actual_bigger_transformer/.It was serialized like so:
As we can see from the Hub repo that its state dict is sharded. To perform with the model, all we have to do is this:
I haven't purposefully haven't added documentation because all of this will become useful once we use this in the context of a full-fledged pipeline execution (up next) :)