I want to use KTO with a quantized Mistral model but am getting pickle errors from the multiprocessing thread, probably since that changes the Embedding layers to be nn.Linear4bit instead of just nn.Linear.
File "~/HALOs/train.py", line 250, in main
mp.spawn(worker_main, nprocs=world_size, args=(world_size, config, tokenizer, train_iterator, eval_iterator, policy, reference_model), join=True)
File "~/conda/env/halos/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 241, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
File "~/conda/env/halos/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
process.start()
File "~/conda/env/halos/lib/python3.10/multiprocessing/process.py", line 121, in start
self._popen = self._Popen(self)
File "~/conda/env/halos/lib/python3.10/multiprocessing/context.py", line 288, in _Popen
return Popen(process_obj)
File "~/conda/env/halos/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 32, in __init__
super().__init__(process_obj)
File "~/conda/env/halos/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
self._launch(process_obj)
File "~/conda/env/halos/lib/python3.10/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "~/conda/env/halos/lib/python3.10/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
_pickle.PicklingError: Can't pickle <function Embedding.forward at 0x7f75fa914160>: it's not the same object as torch.nn.modules.sparse.Embedding.forward
Hi,
Firstly, thanks for the awesome work!
I want to use KTO with a quantized Mistral model but am getting pickle errors from the multiprocessing thread, probably since that changes the Embedding layers to be nn.Linear4bit instead of just nn.Linear.
I'm thinking a workaround would be to use
multiprocessinstead ofmultiprocessingto use dill instead of pickle, but haven't been successful with that yet... do you have any suggestions?