Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make dataloader stateful? #291

Closed
XinDongol opened this issue May 1, 2024 · 9 comments · Fixed by #279
Closed

Make dataloader stateful? #291

XinDongol opened this issue May 1, 2024 · 9 comments · Fixed by #279
Assignees
Labels
enhancement New feature or request

Comments

@XinDongol
Copy link

Resuming from checkpoint uses the same dataloader from begining currently. This may lead to issues for training.
We may need to resume dataloader from saved state to skip sampled data.

@tianyu-l
Copy link
Contributor

tianyu-l commented May 1, 2024

Thanks for creating this issue! In fact we recently started working on it, in #279.

@XinDongol
Copy link
Author

XinDongol commented May 1, 2024

Tested the branch

    File "torchtitan/train.py", line 255, in main
      checkpoint.load()
    File "torchtitan/torchtitan/checkpoint.py", line 217, in load
      dcp.load(
    File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/utils.py", line 427, in inner_func
      return func(*args, **kwargs)
    File "/usr/local/lib/python3.10/dist-packages/torch/distributed/checkpoint/state_dict_loader.py", line 174, in load
      elem.load_state_dict(statetful_sd[key])
    File "torchtitan/torchtitan/checkpoint.py", line 79, in load_state_dict
      self.dataloader.load_state_dict(state_dict[self.rank_id])
  KeyError: 1

It is weird but all ranks are receiving the same state_dict (rank 0's state_dict)

Not quite sure about what happen here but this might be the reason,

https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict_loader.load
WARNING
All tensors in state_dict must be allocated on their destination device prior to calling this function.
All non-tensor data is loaded using torch.load() and modified in place on state_dict.

@XinDongol XinDongol reopened this May 1, 2024
@tianyu-l
Copy link
Contributor

tianyu-l commented May 1, 2024

@XinDongol Thanks for trying out! The work is still in progress and not ready.
Yes, we are aware of the issue you mentioned and we are working on distributed checkpointing to fix it. We hope to resolve this problem soon.

@XinDongol
Copy link
Author

I was wondering whether you find the root cause of all ranks receiving the same state_dict of dataloader?
I guess that it is because the state_dict is not in DTensor? (not sure at all)

@tianyu-l
Copy link
Contributor

tianyu-l commented May 1, 2024

I was wondering whether you find the root cause of all ranks receiving the same state_dict of dataloader? I guess that it is because the state_dict is not in DTensor? (not sure at all)

Yes. Currently if it is not DTensor, only rank 0's value is saved. After the fix, we'd like to be able to save values across all ranks if the keys of state_dict are different per each rank.

@XinDongol
Copy link
Author

@tianyu-l @gokulavasan
Thanks for reply.
One more note I want to mention here is that the current implementation does not support num_worker>1.
If we set num_worker>1, different workers will load duplicated samples.
This is a common issue with IterableDataset as discussed here but can be solved with tricks.
When fixing the state_dict issue of dataloader, it would be great if you can take this into consideration.

@tianyu-l
Copy link
Contributor

tianyu-l commented May 1, 2024

@XinDongol
Thanks for the note! I believe we are aware of the issue (@gokulavasan to double check).

The reason we didn't prioritize supporting num_worker>1 is that the llama training is GPU bounded, so even if we load data using the main process, this part of data loading work can be overlapped by the remaining GPU work in the last iteration. Besides, the time spent on data loading is almost negligible compared with the time spent on training.

For these reasons, we think it's better not to introduce the additional complexity. However, things may change if we are going to support multi-model training, as the loading time of image / video could be much longer. Happy to hear your thoughts on it.

@XinDongol
Copy link
Author

XinDongol commented May 2, 2024

I agree. For very large model, it may be the case.

Torchtitan is currently doing on-the-fly tokenization.
I really like the idea of on-the-fly-tokenization which is great for SFT and makes changing tokenizer very easy.
I did a profiling and found that on-the-fly tokenization is 3x slower than pre-tokenization when num_worker=1. https://github.com/XinDongol/on-the-fly-tokenization-profiling

Increasing num_workers makes on-the-fly tokenization even faster because reading texts is more IO efficient than reading tokens.
I tried a 1B model and found that data loading time is about 10% of end-to-end time when num_workers=1 for torchtitan with on-the-fly tokenization.
Using num_workers=8 reduces it to 1%.
So I think supporting num_workers>1 could be still helpful.

@tianyu-l tianyu-l added the enhancement New feature or request label May 3, 2024
@tianyu-l
Copy link
Contributor

@XinDongol Appreciated your feedback a lot!

I tried a 1B model and found that data loading time is about 10% of end-to-end time when num_workers=1 for torchtitan with on-the-fly tokenization.

I wonder what log_freq you used for this experiment? Every log step would insert a CPU/GPU synchronize (for the loss) which would reduce the overlap opportunity. If log_freq = 1 for this experiment, I wonder how much time data loading would take if we set log_freq = 10 or something else.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants