New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Supported features #571
Comments
Thank you for the comments! (1) Fused attention is on by default for training! We use "splash attention" which is a custom and faster version! (And we're working on accelerated inference attentions.) ttconnect is super cool, thanks for sending! |
Thanks for the answer. Looking forward to the DPO support. It would of course be fantastic if the HuggingFace datasets could natively be supported. I have never really been able to run large non-streaming datasets from HF on the TPUs (disk-size issues on the VMs), but we have been able to wrap the HF datasets in torch.split_dataset_by_node, to stream on multiple TPUs. Im not sure if I am able to implement something like this into MaxText though. Not really sure on what level it should be implemented. Any chance you support HF datasets in the future? But any way of preprocessing the data before it is split to the TPUs would be extremely useful for running experiments on dataset building. Thats both for sampling or filtering based on a field in the dataset. |
Yes support for HF datasets in MaxText is on the way |
Thank you for tagging me on this. Yes, supporting HuggingFace dataset is in our plan. We have some implementations and are undergoing some perf evaluations to understand it better. I will update here when we have it out. |
Mainly wanted to start with thanking you for making MaxText available. I have been using it for a few days, and the first impression is fantastic. Getting started was really easy, it seemed very stable, and the performance was fantastic. It seems to scale very nicely.
A few things that I have not been able to figure out yet, it might be because of lack of documentation, or simply because it is not implemented.
Are there any support for Flash attention, or any plans for implementing this? This has been a major area where GPUs have been ahead of TPUs. I have noticed that there now is at least an experimental implementation from the Jax-team: https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py.
Training directly from tfds seemed straight forward. However, I was a bit confused about how to implement more advanced data loader features, for instance probability sampling like explained here. This can be somewhat tricky to do efficiently on multiple tpus. What is the sensible approach here? Manually sampling into a tfds dataset does not seem very efficient. Are there external libraries here that are compatible with maxtext?
Are there plans for implementing DPO/RLHF?
I also shamelessly wanted to point you to my own repo: https://github.com/peregilk/ttconnect. It is a very simple bash script that ideally should be run on a VM in the same zone. It automatically opens up synchronised tmux windows to all the VMs in the pod, and allows you to type the same command into all the VMs. This makes it even easier to go from one tpu to pods.
The text was updated successfully, but these errors were encountered: