Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question on Model Init #312

Open
XinDongol opened this issue May 6, 2024 · 7 comments
Open

Question on Model Init #312

XinDongol opened this issue May 6, 2024 · 7 comments
Labels
question Further information is requested

Comments

@XinDongol
Copy link

I noticed that there are two parts of implementation that are related to model initialization.

Instancing the model with meta tensor

torchtitan/train.py

Lines 177 to 181 in f72a2a0

with torch.device("meta"):
logger.info(
f"Building {model_name} {job_config.model.flavor} with {model_config}"
)
model = model_cls.from_model_args(model_config)

Doing explicit model initalization

torchtitan/train.py

Lines 209 to 210 in f72a2a0

model.to_empty(device="cuda")
model.init_weights()

The issue is that if we do any weight initalization when instancing the module, it will ineffective becuase of the meta tensor.
As a result, we have to do all initalization explicitly in the model.init_weights().

My question is why we want to instance model with meta tensor?
If effencicy is not an issue, can we simply remove the with torch.device("meta"):

@awgu
Copy link
Contributor

awgu commented May 6, 2024

The advantage of meta-device init is that it is as fast as possible: the sharded parameters are directly initialized on GPU.

Any other flow requires something more, e.g. (1) initializing unsharded parameters on GPU and then sharding or (2) initializing unsharded parameters on CPU, copying to GPU, and then the sharding. For (1), you need to insert your sharding call to be inline to your module construction, or else you will use too much GPU memory (e.g. you must construct one transformer block on GPU, shard it, construct the next transformer block, shard it, etc.). For (2), initializing parameters on CPU is slow (if you run initialization kernels), and the largest model that you can support is bottlenecked by CPU RAM size.

In some sense, the current model.init_weights() meta-device approach is a compromise, where we require the user to define this method to initialize all model parameters/buffers but in turn, the initialization is as fast as possible. So to answer your question, if efficiency is not an issue (and CPU RAM size is not a bottleneck), then yes, you could remove the with torch.device("meta") (and instead do either (1) or (2).)

@XinDongol
Copy link
Author

XinDongol commented May 6, 2024

Thanks for clarfying. Really helpful!

If my understading is correct and CPU RAM is NOT a problem and if I want to do "(2)",

model = model_cls.from_model_args(model_config)
model.init_weights()
model = fully_shard(model, **fsdp_config)
model.to(device="cuda")
### training loop ###

Is this a correct way to do "(2)" ? @awgu

@awgu
Copy link
Contributor

awgu commented May 7, 2024

@XinDongol A few clarifications:

model = model_cls.from_model_args(model_config)
# (1) If the `model_cls.__init__` did not already call `init_weights()` or similar
model.init_weights()
# (2) Apply FSDP with multiple FSDP calls, e.g. on each transformer block
for module in model.modules():
    if isinstance(module, TransformerBlock):
        fully_shard(module, **fsdp_config)
fully_shard(model, **fsdp_config)  # always call on root
(3) Do not need to move to cuda explicitly

Regarding (1), the torchtitan Llama definition already calls init_weights() in the Transformer.__init__(), so there should be no need to call it again separately via model.init_weights() if we are doing CPU init.

self.init_weights()

Regarding (2), in case you were not already aware of the FSDP design, you should apply fully_shard to some submodules (generally transformer blocks for transformer architecture) in addition to the root module to achieve communication/computation overlap and to avoid peaking memory too much. Concretely, calling fully_shard(module) constructs one parameter group communicated together (e.g. all-gather parameters, reduce-scatter gradients) from module.parameters(), excluding those assigned to a nested fully_shard(submodule).

Regarding (3), each time you call fully_shard(module), the managed parameters/buffers will be moved to the mesh's corresponding device, and in our case, mesh.device_type == "cuda". This means we do not need to explicitly call model.to(device="cuda").

I was wondering if you could explain more why you want to do CPU init. You may notice the init time is quite long, especially for larger models.

@tianyu-l tianyu-l added the question Further information is requested label May 7, 2024
@XinDongol
Copy link
Author

XinDongol commented May 8, 2024

@awgu Thanks for your clarifications!

Initalization on meta tensor is painful. For some architectures (e.g., Mamba, etc), there are some parameters needing complicated pre-computed initaliztion. Assigining values to meta/DTensor tensors is chanlleging. As a result, I would like to initalize a model before distributing it.

I just noticed that there is a new added flag called create_seed_checkpoint. If my understanding is correct, it can be utilzed to avoid initalization on meta/DTensor tensors.

@awgu
Copy link
Contributor

awgu commented May 8, 2024

I would really appreciate some pointers to the complicated initialization to learn more about it.

And yes, I think that the seed checkpoint can be used to avoid the meta device init and instead try to init on CPU.

@XinDongol
Copy link
Author

XinDongol commented May 8, 2024

@awgu
Copy link
Contributor

awgu commented May 13, 2024

cc: @wanchaol @tianyu-l The above two pointers are good examples of real-model init methods that do not fit our current meta-device init flow. As far as I can tell, both would require some custom logic to use DTensor APIs to make it work (leading to some if <using distributed>: ... else: ...).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants