Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility with quantized embeddings #18

Open
kawlil opened this issue May 8, 2024 · 2 comments
Open

Compatibility with quantized embeddings #18

kawlil opened this issue May 8, 2024 · 2 comments

Comments

@kawlil
Copy link

kawlil commented May 8, 2024

Hi,

Firstly, thanks for the awesome work!

I want to use KTO with a quantized Mistral model but am getting pickle errors from the multiprocessing thread, probably since that changes the Embedding layers to be nn.Linear4bit instead of just nn.Linear.

File "~/HALOs/train.py", line 250, in main
    mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, tokenizer, train_iterator, eval_iterator, policy, reference_model), join=True)
  File "~/conda/env/halos/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 241, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
  File "~/conda/env/halos/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
    process.start()
  File "~/conda/env/halos/lib/python3.10/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
    return Popen(process_obj)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "~/conda/env/halos/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <function Embedding.forward at 0x7f75fa914160>: it's not the same object as torch.nn.modules.sparse.Embedding.forward

I'm thinking a workaround would be to use multiprocess instead of multiprocessing to use dill instead of pickle, but haven't been successful with that yet... do you have any suggestions?

@kawine
Copy link
Collaborator

kawine commented May 8, 2024

Is the model you're using on Huggingface? (asking so i can reproduce the issue)

@kawlil
Copy link
Author

kawlil commented May 8, 2024

Yes, mistralai/Mistral-7B-Instruct-v0.1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants