-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question on Model Init #312
Comments
The advantage of meta-device init is that it is as fast as possible: the sharded parameters are directly initialized on GPU. Any other flow requires something more, e.g. (1) initializing unsharded parameters on GPU and then sharding or (2) initializing unsharded parameters on CPU, copying to GPU, and then the sharding. For (1), you need to insert your sharding call to be inline to your module construction, or else you will use too much GPU memory (e.g. you must construct one transformer block on GPU, shard it, construct the next transformer block, shard it, etc.). For (2), initializing parameters on CPU is slow (if you run initialization kernels), and the largest model that you can support is bottlenecked by CPU RAM size. In some sense, the current |
Thanks for clarfying. Really helpful! If my understading is correct and CPU RAM is NOT a problem and if I want to do "(2)", model = model_cls.from_model_args(model_config)
model.init_weights()
model = fully_shard(model, **fsdp_config)
model.to(device="cuda")
### training loop ### Is this a correct way to do "(2)" ? @awgu |
@XinDongol A few clarifications:
Regarding (1), the torchtitan Llama definition already calls torchtitan/torchtitan/models/llama/model.py Line 373 in f72a2a0
Regarding (2), in case you were not already aware of the FSDP design, you should apply Regarding (3), each time you call I was wondering if you could explain more why you want to do CPU init. You may notice the init time is quite long, especially for larger models. |
@awgu Thanks for your clarifications! Initalization on meta tensor is painful. For some architectures (e.g., Mamba, etc), there are some parameters needing complicated pre-computed initaliztion. Assigining values to meta/DTensor tensors is chanlleging. As a result, I would like to initalize a model before distributing it. I just noticed that there is a new added flag called |
I would really appreciate some pointers to the complicated initialization to learn more about it. And yes, I think that the seed checkpoint can be used to avoid the meta device init and instead try to init on CPU. |
Here are some examples. Initalizing them on DTensor seems chanllenging. |
I noticed that there are two parts of implementation that are related to model initialization.
Instancing the model with meta tensor
torchtitan/train.py
Lines 177 to 181 in f72a2a0
Doing explicit model initalization
torchtitan/train.py
Lines 209 to 210 in f72a2a0
The issue is that if we do any weight initalization when instancing the module, it will ineffective becuase of the
meta tensor
.As a result, we have to do all initalization explicitly in the
model.init_weights()
.My question is why we want to instance model with
meta tensor
?If effencicy is not an issue, can we simply remove the
with torch.device("meta"):
The text was updated successfully, but these errors were encountered: