Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/raft fine tuning #874

Open
wants to merge 57 commits into
base: main
Choose a base branch
from

Conversation

efenocchi
Copy link

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

What are the changes made in this PR?
This PR introduces the ability to load a dataset from Activeloop by calling load_deep_lake_dataset instead of load_dataset (see torchtune/torchtune/datasets/_utils.py). Additionally, it enables fine-tuning of available models using the RAFT technique (see torchtune/recipes/configs/llama3/8B_lora_single_device_deep_lake_raft.yaml).

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
    • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

I want to report that there is an error in the main repository when running the integration tests in the tests/recipes/test_eleuther_eval.py file.
integration_test_error_1
integration_test_error_2

Copy link

pytorch-bot bot commented Apr 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/874

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 26, 2024
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Appreciate you adding the unit tests as well. My main questions are around the addition of the new dependency and its handling on imports, need for a dataloader class, and the new configs.

Re the failing test, we do not currently have lm_eval in our optional dependencies due to some other transitive dependencies, so if you follow the suggestion in the console about pip install "lm_eval=0.4.*" I suspect you will be able to get that test to pass.

@@ -47,6 +47,7 @@ dev = [
"pytest-integration",
"tensorboard",
"wandb",
"deeplake"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does deeplake have any upstream dependencies? Wanna make sure we're aware of what we're pulling in here

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I'm not mistaken deeplake requires : aioboto3, boto3, click, humbug, libdeeplake, lz4, nest-asyncio, numpy, pathos, pillow, pydantic, pyjwt, tqdm (tested with pip show deeplake, tell me if something is wrong)

@@ -0,0 +1,94 @@
# Config for single device LoRA finetuning in lora_finetune_single_device.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need to create separate configs for a new dataset. In general we would recommend either CLI overrides or tune cp + modify the config locally rather than adding an entirely new file (we already have a bunch of configs as it is, trying to not profilerate them any more than strictly necessary)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it would be helpful to have an example because it's not just another data set but it's another technique to fine tune with. I thought this example would make it easier to use this technique

@@ -525,4 +525,4 @@ def recipe_main(cfg: DictConfig) -> None:


if __name__ == "__main__":
sys.exit(recipe_main())
sys.exit(recipe_main())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know you checked that you ran pre-commit hooks on the summary, but please just double-check. The missing newline at the end of this file makes me a bit suspicious. You can re-run on all files by following these instructions

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked again, added a white line at end of file, after all tests passed again

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import deeplake
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will break if someone is importing from datasets without having installed dev dependencies

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we handle this part?

# """


class DeepLakeDataloader(Dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry noob question, but why is just generic PyTorch dataloader not sufficient here? I don't see any custom sampling logic or anything like that, seems like the rest of this should be handled on the dataset side already?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To get the dataset values ​​we need to create a custom dataset otherwise we would not be able to access the data without throwing any error

@@ -87,6 +87,43 @@ def format(
return prompt


class RAFTInstructTemplate(InstructTemplate):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the difference between this and the AlpacaInstructTemplate with no input above? It looks pretty much the same at first glance

Copy link
Author

@efenocchi efenocchi May 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, if we assume we have no input it's the same thing

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just refined the prompt to closely align with the structure outlined in the referenced paper

from torchtune.modules.tokenizers import Tokenizer


class InstructDatasetDeepLakeRAFT(Dataset):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

High-level question on this class: is the only difference between this and InstructDataset the usage of load_deep_lake_dataset instead of load_dataset? If so, I wonder if this is something we should consider parametrizing rather than having to write an entirely new dataset class. cc @RdoubleA for any thoughts here

Copy link
Author

@efenocchi efenocchi May 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think it can be parameterized to use load_deep_lake_datset, can you give me some advice on how we can proceed here?
Instead in the class method _prepare_sample() we access a different input column name "cot_answer".

Comment on lines +11 to +14
activeloop_dataset = "hub://manufe/raft_format_dataset_biomedical" # Replace with your ActiveLoop dataset


def raft_dataset(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this dataset exactly? Seems like raft_dataset is still the generic dataset type. But for dataset builders like these we typically point to a canonical dataset on the Hub (e.g. Alpaca). Is raft_format_dataset_biomedical the canonical dataset here? If so, is this something that's widely-used, or intended more as a demo?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RAFT being a fairly new technique there is no real reference dataset, this is a dataset that we are using on a real project and can be used as a demo

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here too perhaps we can replace the dataset with an input taken in the yaml file or passed as a parameter in the cli

@efenocchi
Copy link
Author

Hi @ebsmothers,
thank you very much for the suggestions and for the quick response you gave me, sorry it took me so long to reply.. if there is anything I can do to get this PR accepted I will be happy to do it.
Have a good weekend.
Emanuele

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @efenocchi thanks for making the updates and apologies for the delay on getting back to you. Given that

(a) we try to keep our core dependencies pretty minimal in torchtune, and
(b) integrating a dev dependency tightly into our core code (in this case torchtune/datasets) is something we try to avoid,

I think we may want to take a slightly different approach on this PR. We have another ongoing integration with ClearML Logger and I wonder if we can follow a similar path here. Namely, you can push the change to a fork, and we can highlight it as a community integration in our README. Then as we get more usage and requests to integrate we can look at merging into main. Basically the process outlined here.

Let me know if this type of integration makes sense to you. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants