Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] support saving and loading of sharded checkpoints #7830

Open
wants to merge 60 commits into
base: main
Choose a base branch
from

Conversation

sayakpaul
Copy link
Member

What does this PR do?

Follow-up of #6396.

This PR adds support for saving a big model's state dict into multiple shards for efficient portability and loading. Adds support for loading the sharded checkpoints, too.

This is much akin to handling big models like T5XXL.

Also, added a nice test to ensure the models that have _no_split_modules specified can be sharded and loaded back to perform inference ensuring numerical assertions.

Here's a real use-case. Consider this Transformer2DModel checkpoint: https://huggingface.co/sayakpaul/actual_bigger_transformer/.

It was serialized like so:

from diffusers import Transformer2DModel
from accelerate.utils import compute_module_sizes, shard_checkpoint
from accelerate import init_empty_weights
import torch.nn as nn

def bytes_to_giga_bytes(bytes):
    return f"{(bytes / 1024 / 1024 / 1024):.3f}"

with init_empty_weights():
    pixart_transformer = Transformer2DModel.from_config("PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="transformer")
    bigger_transformer = Transformer2DModel.from_config(
        pixart_transformer.config, num_layers=72, num_attention_heads=36, cross_attention_dim=2592,
    )
    module_size = bytes_to_giga_bytes(compute_module_sizes(bigger_transformer)[""])
    print(f"{module_size=} GB")
    pytorch_total_params = sum(p.numel() for p in bigger_transformer.parameters()) / 1e9
    print(f"{pytorch_total_params=} B")

    model = nn.Sequential(*[nn.Linear(8944, 8944) for _ in range(1000)])
    module_size = bytes_to_giga_bytes(compute_module_sizes(model)[""])
    print(f"{module_size=} GB")
    pytorch_total_params = sum(p.numel() for p in model.parameters()) / 1e9
    print(f"{pytorch_total_params=} B")

actual_bigger_transformer = Transformer2DModel.from_config(
    pixart_transformer.config, num_layers=72, num_attention_heads=36, cross_attention_dim=2592
)
actual_bigger_transformer.save_pretrained("/raid/.cache/actual_bigger_transformer", max_shard_size="10GB", push_to_hub=True)

As we can see from the Hub repo that its state dict is sharded. To perform with the model, all we have to do is this:

from diffusers import Transformer2DModel
import tempfile
import torch
import os

def get_inputs():
    sample = torch.randn(1, 4, 128, 128)
    timestep = torch.randint(0, 1000, size=(1, ))
    encoder_hidden_states = torch.randn(1, 120, 4096)

    resolution = torch.tensor([1024, 1024]).repeat(1, 1)
    aspect_ratio = torch.tensor([1.]).repeat(1, 1)
    added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
    return sample, timestep, encoder_hidden_states, added_cond_kwargs

with torch.no_grad():
    # max_memory = {0: "15GB"} # reasonable estimate for a consumer-gpu.
    with tempfile.TemporaryDirectory() as tmp_dir:
        new_model = Transformer2DModel.from_pretrained(
            "sayakpaul/actual_bigger_transformer",
            device_map="auto",
        )

        sample, timestep, encoder_hidden_states, added_cond_kwargs = get_inputs()
        out = new_model(
            hidden_states=sample,
            encoder_hidden_states=encoder_hidden_states,
            timestep=timestep, 
            added_cond_kwargs=added_cond_kwargs
        ).sample
        print(f"{out.shape=}, {out.device=}")

I haven't purposefully haven't added documentation because all of this will become useful once we use this in the context of a full-fledged pipeline execution (up next) :)

@sayakpaul sayakpaul requested review from yiyixuxu and SunMarc May 1, 2024 10:46
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member Author

@yiyixuxu @SunMarc a gentle ping here.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always delightful to deal with the from_pretrained code ;)

I don't really have any bigger comments, as this should hopefully work well since it's based on the transformers implementation. Only some smaller comments.

src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
src/diffusers/utils/hub_utils.py Outdated Show resolved Hide resolved
tests/models/test_modeling_common.py Show resolved Hide resolved
tests/models/test_modeling_common.py Show resolved Hide resolved
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your work @sayakpaul ! Left a suggestion (not a blocker, we can do it afterwards if needed) ! No major comments since @BenjaminBossan did a very thorough review already !

src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
@sayakpaul sayakpaul requested a review from Wauplin May 29, 2024 01:42
Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @sayakpaul for the integration and iterating over it! Current code looks good to me :) I'd rather have another pair of eyes reviewing it, given it's fairly easy to miss something when iterating/reviewing several times on the same code.

Thanks again!

src/diffusers/utils/hub_utils.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member Author

sayakpaul commented May 29, 2024

I'd rather have another pair of eyes reviewing it, given it's fairly easy to miss something when iterating/reviewing several times on the same code.

Yeah. @yiyixuxu would be the final approver here :)

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the PR!!
I left some comments and questions :)

@@ -349,7 +349,7 @@ def load_config(
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
subfolder = kwargs.pop("subfolder", None) or ""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we handle it where it fails then

we would only need to change one place, no?

src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
src/diffusers/models/modeling_utils.py Show resolved Hide resolved
src/diffusers/models/modeling_utils.py Outdated Show resolved Hide resolved
src/diffusers/utils/hub_utils.py Outdated Show resolved Hide resolved
@sayakpaul
Copy link
Member Author

sayakpaul commented Jun 3, 2024

@yiyixuxu do the recent changes work for you?

(I have run the tests)

@sayakpaul sayakpaul requested a review from yiyixuxu June 3, 2024 12:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants