From 2852c0a032c3b2af97d63ec21aa7ccda63f7f51e Mon Sep 17 00:00:00 2001 From: zhansu Date: Fri, 12 Jul 2024 07:33:48 -0400 Subject: [PATCH 01/37] add dpo training for arrow routing --- mttl/models/expert_model.py | 94 +++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index aa4117303..bd12d815d 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -27,6 +27,8 @@ from mttl.models.utils import EfficientCheckpointModule, prepare_model_for_kbit_training from mttl.utils import logger +import torch.nn.functional as F + torch.set_float32_matmul_precision("high") @@ -589,6 +591,98 @@ def set_routing_infos(self, batch, generate=False): self.model.info_container["routing_infos"] = RoutingInfo.from_batch(batch) +def calculate_DPO_loss( + original_prefered_logprob, + original_disprefered_logprob, + ref_prefered_logprob, + ref_disprefered_logprob, + beta=0.5, +): + """ + Calculate the DPO loss. + original_prefered_logprob: the logprob of the prefered expert in the original model + original_disprefered_logprob: the logprob of the disprefered expert in the original model + ref_prefered_logprob: the logprob of the prefered expert in the reference model + ref_disprefered_logprob: the logprob of the disprefered expert in the reference model + """ + + original_prefered_relative_logprob = ( + original_prefered_logprob - ref_prefered_logprob + ) + disprefered_relative_logprob = ( + original_disprefered_logprob - ref_disprefered_logprob + ) + + reward_accuracies = ( + (original_prefered_relative_logprob > disprefered_relative_logprob) + .float() + .mean(dim=-1) + ) + reward_margins = ( + original_prefered_relative_logprob - disprefered_relative_logprob + ).mean(dim=-1) + + loss = -F.logsigmoid( + beta * (original_prefered_relative_logprob - disprefered_relative_logprob) + ).mean(dim=-1) + return loss, reward_accuracies, reward_margins + + +def get_log_prob(logits, labels): + log_probs = F.log_softmax(logits, dim=-1) + return torch.gather(log_probs, -1, labels.unsqueeze(-1)).squeeze(-1).mean(-1) + + +class MultiExpertModelDPO(ExpertModel): + + def __init__(self, **kwargs): + self.multi_model = MultiExpertModel(**kwargs) + self.ref_multi_model = MultiExpertModel(**kwargs) + + def training_step(self, batch, _): + + prompt_prefered_ids = batch["prompt_prefered_ids"] + prompt_disprefered_ids = batch["prompt_disprefered_ids"] + + prompt_prefered_mask = batch["prompt_prefered_mask"] + prompt_disprefered_mask = batch["prompt_disprefered_mask"] + + # original model + model_prefered_log_prob = get_log_prob( + self.multi_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits + ) + + model_disprefered_log_prob = get_log_prob( + self.multi_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits + ) + + # reference model + ref_prefered_log_prob = get_log_prob( + self.ref_multi_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits + ) + + ref_disprefered_log_prob = get_log_prob( + self.ref_multi_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits + ) + + loss = calculate_DPO_loss( + model_prefered_log_prob, + model_disprefered_log_prob, + ref_prefered_log_prob, + ref_disprefered_log_prob, + beta=0.1, + ) + return loss + + class MoEModel(MultiExpertModel): def __init__(self, expert_library: ExpertLibrary = None, **kwargs): kwargs["top_k"] = kwargs["moe_top_k"] From fb47f71006db8c9a028779b81bf8c966560bfa46 Mon Sep 17 00:00:00 2001 From: zhansu Date: Fri, 12 Jul 2024 17:23:20 -0400 Subject: [PATCH 02/37] add dpo train with preference data --- mttl/config.py | 3 + mttl/models/expert_model.py | 128 ++++++++++++++++++++++++++++++++---- 2 files changed, 118 insertions(+), 13 deletions(-) diff --git a/mttl/config.py b/mttl/config.py index 04256472e..f8d79be97 100644 --- a/mttl/config.py +++ b/mttl/config.py @@ -346,3 +346,6 @@ def _set_defaults(self): self.soft_prompt_learn_kv: bool = False self.prompt_placement: str = "prefix" self.add_routing_token: bool = False + + # rl training + self.rl_training = "dpo" diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index bd12d815d..a461ce5a3 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -625,7 +625,8 @@ def calculate_DPO_loss( loss = -F.logsigmoid( beta * (original_prefered_relative_logprob - disprefered_relative_logprob) ).mean(dim=-1) - return loss, reward_accuracies, reward_margins + + return loss def get_log_prob(logits, labels): @@ -633,11 +634,16 @@ def get_log_prob(logits, labels): return torch.gather(log_probs, -1, labels.unsqueeze(-1)).squeeze(-1).mean(-1) -class MultiExpertModelDPO(ExpertModel): +class ExpertModelDPO(EfficientCheckpointModule): - def __init__(self, **kwargs): - self.multi_model = MultiExpertModel(**kwargs) - self.ref_multi_model = MultiExpertModel(**kwargs) + def __init__(self, expert_model, ref_expert_model, **kwargs): + super().__init__(**kwargs) + self.expert_model = expert_model + self.ref_expert_model = ref_expert_model + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + return optimizer def training_step(self, batch, _): @@ -649,28 +655,124 @@ def training_step(self, batch, _): # original model model_prefered_log_prob = get_log_prob( - self.multi_model.model.forward( + self.expert_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits, + labels=prompt_prefered_ids, + ) + + model_disprefered_log_prob = get_log_prob( + self.expert_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits, + labels=prompt_disprefered_ids, + ) + + # reference model + ref_prefered_log_prob = get_log_prob( + self.ref_expert_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits, + labels=prompt_prefered_ids, + ) + + ref_disprefered_log_prob = get_log_prob( + self.ref_expert_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits, + labels=prompt_disprefered_ids, + ) + + loss = calculate_DPO_loss( + model_prefered_log_prob, + model_disprefered_log_prob, + ref_prefered_log_prob, + ref_disprefered_log_prob, + beta=0.1, + ) + return loss + + def validation_step(self, batch, _): + prompt_prefered_ids = batch["prompt_prefered_ids"] + prompt_disprefered_ids = batch["prompt_disprefered_ids"] + + prompt_prefered_mask = batch["prompt_prefered_mask"] + prompt_disprefered_mask = batch["prompt_disprefered_mask"] + + # original model + model_prefered_log_prob = get_log_prob( + self.expert_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits, + labels=prompt_prefered_ids, + ) + + model_disprefered_log_prob = get_log_prob( + self.expert_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits, + labels=prompt_disprefered_ids, + ) + + # reference model + ref_prefered_log_prob = get_log_prob( + self.ref_expert_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits, + labels=prompt_prefered_ids, + ) + + ref_disprefered_log_prob = get_log_prob( + self.ref_expert_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits, + labels=prompt_disprefered_ids, + ) + + loss = calculate_DPO_loss( + model_prefered_log_prob, + model_disprefered_log_prob, + ref_prefered_log_prob, + ref_disprefered_log_prob, + beta=0.1, + ) + return loss + + def test_step(self, batch, _): + prompt_prefered_ids = batch["prompt_prefered_ids"] + prompt_disprefered_ids = batch["prompt_disprefered_ids"] + + prompt_prefered_mask = batch["prompt_prefered_mask"] + prompt_disprefered_mask = batch["prompt_disprefered_mask"] + + # original model + model_prefered_log_prob = get_log_prob( + self.expert_model.model.forward( prompt_prefered_ids, attention_mask=prompt_prefered_mask - ).logits + ).logits, + labels=prompt_prefered_ids, ) model_disprefered_log_prob = get_log_prob( - self.multi_model.model.forward( + self.expert_model.model.forward( prompt_disprefered_ids, attention_mask=prompt_disprefered_mask - ).logits + ).logits, + labels=prompt_disprefered_ids, ) # reference model ref_prefered_log_prob = get_log_prob( - self.ref_multi_model.model.forward( + self.ref_expert_model.model.forward( prompt_prefered_ids, attention_mask=prompt_prefered_mask - ).logits + ).logits, + labels=prompt_prefered_ids, ) ref_disprefered_log_prob = get_log_prob( - self.ref_multi_model.model.forward( + self.ref_expert_model.model.forward( prompt_disprefered_ids, attention_mask=prompt_disprefered_mask - ).logits + ).logits, + labels=prompt_disprefered_ids, ) loss = calculate_DPO_loss( From 665bfc4c297eafb4e0271cd320769cfc05742038 Mon Sep 17 00:00:00 2001 From: zhansu Date: Sat, 13 Jul 2024 02:09:10 -0400 Subject: [PATCH 03/37] add preference data module --- mttl/datamodule/preference_data_module.py | 94 ++++++++++ projects/modular_llm/train_dpo.py | 215 ++++++++++++++++++++++ 2 files changed, 309 insertions(+) create mode 100644 mttl/datamodule/preference_data_module.py create mode 100644 projects/modular_llm/train_dpo.py diff --git a/mttl/datamodule/preference_data_module.py b/mttl/datamodule/preference_data_module.py new file mode 100644 index 000000000..677771a8e --- /dev/null +++ b/mttl/datamodule/preference_data_module.py @@ -0,0 +1,94 @@ +from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule +from dataclasses import dataclass +from mttl.models.library.expert_library import DatasetLibrary +import torch + + +@dataclass +class DataCollatorForDPO(DefaultCollator): + def __call__(self, batch): + prompts = ["Instruct: " + item["prompt"] + "\n" for item in batch] + chosen_responses = ["Output: " + item["chosen"] for item in batch] + rejected_responses = ["Output: " + item["rejected"] for item in batch] + + prompt_ids = self.tokenizer.batch_encode_plus( + prompts, + padding=True, + return_tensors="pt", + max_length=self.max_input_length, + truncation=True, + )["input_ids"] + + prefered_ids = self.tokenizer.batch_encode_plus( + chosen_responses, + padding=True, + return_tensors="pt", + max_length=self.max_input_length, + truncation=True, + )["input_ids"] + + disprefered_ids = self.tokenizer.batch_encode_plus( + rejected_responses, + padding=True, + return_tensors="pt", + max_length=self.max_input_length, + truncation=True, + )["input_ids"] + + prompt_prefered_ids = torch.cat([prompt_ids, prefered_ids], dim=-1) + prompt_disprefered_ids = torch.cat([prompt_ids, disprefered_ids], dim=-1) + + prompt_prefered_mask = torch.cat( + [torch.ones_like(prompt_ids), torch.zeros_like(prefered_ids)], dim=-1 + ) + prompt_disprefered_mask = torch.cat( + [torch.ones_like(prompt_ids), torch.zeros_like(disprefered_ids)], dim=-1 + ) + + return { + "prompt_prefered_ids": prompt_prefered_ids, + "prompt_disprefered_ids": prompt_disprefered_ids, + "prompt_prefered_mask": prompt_prefered_mask, + "prompt_disprefered_mask": prompt_disprefered_mask, + } + + +@dataclass +class Preferencemodule(DefaultDataModule): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setup_dataset(self): + train_dataset = DatasetLibrary.pull_dataset_with_retry( + "jondurbin/truthy-dpo-v0.1" + )["train"] + + self.train_dataset, self.dev_dataset = self.create_train_valid_split( + train_dataset, 0.1 + ) + self.test_dataset = self.dev_dataset + + self.print_infos() + + @property + def collate_fn(self): + return DataCollatorForDPO( + tokenizer=self.tokenizer, + padding="longest", + max_input_length=self.config.max_input_length, + max_output_length=self.config.max_output_length, + return_tensors="pt", + model_family=self.config.model_family, + for_generation=self.for_generation, + ) + + +if __name__ == "__main__": + config = DatasetConfig(model="microsoft/phi-2") + datamodule = Preferencemodule(config) + train_dataloader = datamodule.train_dataloader() + val_dataloder = datamodule.val_dataloader() + for batch in val_dataloder: + print(batch) + breakpoint() diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py new file mode 100644 index 000000000..820f85c65 --- /dev/null +++ b/projects/modular_llm/train_dpo.py @@ -0,0 +1,215 @@ +import os +import shutil +import sys +from tempfile import TemporaryDirectory + +import torch +from pytorch_lightning import Trainer, seed_everything + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) + +from mttl.callbacks import LiveCheckpointCallback, NanoMMLUCallback, RougeCallback + +# from mttl.datamodule.base import get_datamodule +from mttl.models.expert_config import ExpertConfig +from mttl.models.expert_model import ExpertModel, MoEModel, ExpertModelDPO +from mttl.models.library.expert import Expert, load_expert +from mttl.models.library.expert_library import ExpertLibrary, LocalExpertLibrary +from mttl.models.monitors import get_monitors +from mttl.utils import ( + generate_random_string, + get_pl_loggers, + logger, + rank_zero_only_and_wait, + remote_login, + setup_logging, +) +from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule +from mttl.datamodule.preference_data_module import Preferencemodule +from projects.modular_llm.src.transfer_matrix import TransferMatrixConfig +from projects.modular_llm.src.transfer_matrix import run_eval as produce_transfer_matrix + + +def create_transfer_matrix(args, checkpoint): + ######################## + # create transfer matrix + config = TransferMatrixConfig() + for k, v in vars(args).items(): + if k in vars(config): + setattr(config, k, v) + config.eval_base = False + config.eval_metric = "rougeL" + + expert: Expert = load_expert(checkpoint) + expert.expert_info.expert_name = str(args.finetune_task_name) + expert.expert_info.expert_task_name = str(args.finetune_task_name) + temp_dir = TemporaryDirectory() + destination = temp_dir.name + LocalExpertLibrary.from_expert_dict({"checkpoint": expert}, destination=destination) + config.library_id = destination + config.finetune_task_name = ( + args.finetune_task_name.split(",") + if not isinstance(args.finetune_task_name, list) + else args.finetune_task_name + ) + if len(config.finetune_task_name) < 50: + produce_transfer_matrix(config, debug=False) + ######################## + temp_dir.cleanup() + + +def run_multitask(args: ExpertConfig): + seed_everything(args.seed, workers=True) + + # get directory of the current file + setup_logging(args.output_dir) + logger.info("Args: {}".format(args.to_json())) + + remote_login(args.remote_token) + expert_library = None + if args.library_id: + + @rank_zero_only_and_wait(before=False, after=True) + def create_library(args): + expert_library = ExpertLibrary.get_expert_library( + repo_id=args.library_id, + create=True, + destination_id=args.destination_library_id, + ) + return expert_library + + expert_library = create_library(args) + + loggers = get_pl_loggers(args) + # select dataloader + if args.model_modifier == "poly": + args.init_from_scratch = True + model_class = MoEModel + else: + model_class = ExpertModel + + config = DatasetConfig(model=args.model) + dm = Preferencemodule(config) + + # dm = get_datamodule(args) + # args.n_tasks = len(dm._task_names) + # args.task_names = dm._task_names + + model = model_class(**vars(args), tokenizer=dm.tokenizer) + + if args.rl_training == "dpo": + ref_model = model_class(**vars(args), tokenizer=dm.tokenizer) + module = ExpertModelDPO(model, ref_model) + + # get metric monitors for models + callbacks = get_monitors(args) + if "mbpp" in args.dataset: + monitor = "downstream/mbpp" + mode = "max" + else: + monitor = "val/loss" + mode = "min" + + checkpoint_callback = LiveCheckpointCallback( + dirpath=args.output_dir, + monitor=monitor, + save_last=True, + mode=mode, + save_each_epoch=args.save_each_epoch, + ) + callbacks.append(checkpoint_callback) + + val_check_interval = args.eval_every + if val_check_interval == -1 or val_check_interval is None: + val_check_interval = None + else: + val_check_interval = args.gradient_accumulation_steps * args.eval_every + if val_check_interval > len(dm.train_dataloader()): + val_check_interval = len(dm.train_dataloader()) + elif val_check_interval > args.total_steps and args.total_steps != -1: + val_check_interval = args.total_steps + + trainer = Trainer( + # devices=-1, + # accelerator="gpu", + logger=loggers, + num_sanity_val_steps=0, + default_root_dir=args.output_dir, + max_epochs=args.num_train_epochs, + max_steps=args.total_steps + 1 if args.total_steps != -1 else -1, + gradient_clip_val=args.max_grad_norm, + strategy=args.compute_strategy if args.compute_strategy else "auto", + callbacks=callbacks, + enable_checkpointing=False, + log_every_n_steps=args.gradient_accumulation_steps, + accumulate_grad_batches=args.gradient_accumulation_steps, + precision=( + int(args.precision) if args.precision in ["16", "32"] else args.precision + ), + val_check_interval=val_check_interval, + ) + + # initial validation only for a bunch of datasets... ? + if args.eval_before_training: + # validating before training fails with deepspeed + trainer.validate(module, dm) + + if args.do_train: + trainer.fit(module, dm) + + torch.cuda.empty_cache() + + # reload best model before pushing! + checkpoint = ( + checkpoint_callback.best_model_path or checkpoint_callback.last_model_path + ) + if args.compute_strategy == "deepspeed": + from deepspeed.utils.zero_to_fp32 import ( + convert_zero_checkpoint_to_fp32_state_dict, + ) + + new_path = checkpoint.replace(".ckpt", "_fp32.ckpt") + + @rank_zero_only_and_wait(before=True, after=True) + def convert_ckpt(path, new_path): + convert_zero_checkpoint_to_fp32_state_dict(path, new_path) + + convert_ckpt(checkpoint, new_path) + checkpoint = torch.load(new_path) + else: + checkpoint = torch.load(checkpoint)["state_dict"] + + module.load_state_dict(checkpoint) + trainer.test(module, dm) + + @rank_zero_only_and_wait(before=False, after=True) + def upload_library(expert_library, module): + if expert_library is not None: + # refresh expert library: so we dont overwrite the readme if the remote has changed. + expert_library.refresh_from_remote() + + if isinstance(module, MoEModel): + with expert_library.batched_commit(): + for expert_name in module.experts_names: + expert = module.get_expert_instance(expert_name) + expert_library.add_expert(expert, expert_name) + elif isinstance(module, ExpertModel): + expert = module.as_expert() + expert_name = ( + args.expert_name + or args.finetune_task_name + or generate_random_string() + ) + expert_library.add_expert(expert, expert_name) + else: + raise ValueError("Model class not recognized") + + upload_library(expert_library, module) + + if args.create_transfer_matrix: + create_transfer_matrix(args, checkpoint) + + +if __name__ == "__main__": + args = ExpertConfig.parse() + run_multitask(args) From 330813c435c4c1774c6e58ec931f358347d47a11 Mon Sep 17 00:00:00 2001 From: zhansu Date: Mon, 15 Jul 2024 05:14:32 -0400 Subject: [PATCH 04/37] add log for training and validation --- mttl/models/expert_model.py | 57 ++++++++++++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 4 deletions(-) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index a461ce5a3..5611f58fb 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -626,7 +626,7 @@ def calculate_DPO_loss( beta * (original_prefered_relative_logprob - disprefered_relative_logprob) ).mean(dim=-1) - return loss + return loss, reward_accuracies, reward_margins def get_log_prob(logits, labels): @@ -683,13 +683,30 @@ def training_step(self, batch, _): labels=prompt_disprefered_ids, ) - loss = calculate_DPO_loss( + loss, reward_accuracies, reward_margins = calculate_DPO_loss( model_prefered_log_prob, model_disprefered_log_prob, ref_prefered_log_prob, ref_disprefered_log_prob, beta=0.1, ) + self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + + self.log( + "train/reward_accuracies", + reward_accuracies, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + self.log( + "train/reward_margins", + reward_margins, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + return loss def validation_step(self, batch, _): @@ -729,13 +746,30 @@ def validation_step(self, batch, _): labels=prompt_disprefered_ids, ) - loss = calculate_DPO_loss( + loss, reward_accuracies, reward_margins = calculate_DPO_loss( model_prefered_log_prob, model_disprefered_log_prob, ref_prefered_log_prob, ref_disprefered_log_prob, beta=0.1, ) + + self.log("val/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log( + "val/reward_accuracies", + reward_accuracies, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + self.log( + "val/reward_margins", + reward_margins, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + return loss def test_step(self, batch, _): @@ -775,13 +809,28 @@ def test_step(self, batch, _): labels=prompt_disprefered_ids, ) - loss = calculate_DPO_loss( + loss, reward_accuracies, reward_margins = calculate_DPO_loss( model_prefered_log_prob, model_disprefered_log_prob, ref_prefered_log_prob, ref_disprefered_log_prob, beta=0.1, ) + self.log("test/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log( + "test/reward_accuracies", + reward_accuracies, + on_step=True, + on_epoch=True, + prog_bar=True, + ) + self.log( + "test/reward_margins", + reward_margins, + on_step=True, + on_epoch=True, + prog_bar=True, + ) return loss From ec73c5a79e80c168f66e25050d1a2755b41cee3f Mon Sep 17 00:00:00 2001 From: zhansu Date: Mon, 15 Jul 2024 12:56:57 -0400 Subject: [PATCH 05/37] wip --- mttl/models/expert_model.py | 29 +++++++++++++++++----- projects/modular_llm/train_dpo.py | 41 ++++++++++++++++++++++--------- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 5611f58fb..b8ae239e2 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -11,7 +11,11 @@ from mttl.models.containers import add_expert_to_transformer from mttl.models.containers.expert_containers import ExpertContainer -from mttl.models.containers.selectors import Selector, SelectorConfig +from mttl.models.containers.selectors import ( + Selector, + SelectorConfig, + ArrowSelectorConfig, +) from mttl.models.expert_config import ExpertConfig from mttl.models.library.expert import Expert, ExpertInfo from mttl.models.library.expert_library import ExpertLibrary @@ -640,9 +644,23 @@ def __init__(self, expert_model, ref_expert_model, **kwargs): super().__init__(**kwargs) self.expert_model = expert_model self.ref_expert_model = ref_expert_model + self.trainable_param_names = kwargs.get("trainable_param_names", None) def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) + params = [] + # for param_name, param in self.named_parameters(): + # param.requires_grad = False + # if self.trainable_param_names and re.fullmatch( + # self.trainable_param_names, param_name + # ): + # param.requires_grad = True + # params.append(param) + + # logger.info(f"Setting {param_name} to trainable.") + optimizer = torch.optim.Adam( + filter(lambda p: p.requires_grad, self.parameters()), lr=1e-3 + ) + return optimizer def training_step(self, batch, _): @@ -865,15 +883,14 @@ def __init__(self, expert_library: ExpertLibrary = None, **kwargs): self.hparams.library_id ) for i, expert in enumerate(sorted(list(expert_library.keys()))): - self.add_expert_instance(expert_library[expert], expert_name=f"e{i}") - + self.add_expert_instance(expert_library[expert], expert_name=expert) self.moe_num_experts = i + 1 if isinstance( - self.selector_config, (ArrowConfig, HiddenStateComputerConfig) + self.selector_config, (ArrowSelectorConfig, HiddenStateComputerConfig) ): from projects.modular_llm.eval_library import patch_prototypes - patch_prototypes(self, expert_library, self.selector_config) + patch_prototypes(self, expert_library, self.hparams) def training_step(self, batch, _): loss = super().training_step(batch, _) diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index 820f85c65..dc5bbea05 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -2,17 +2,22 @@ import shutil import sys from tempfile import TemporaryDirectory - +import copy import torch from pytorch_lightning import Trainer, seed_everything sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from mttl.callbacks import LiveCheckpointCallback, NanoMMLUCallback, RougeCallback +from mttl.callbacks import LiveCheckpointCallback # from mttl.datamodule.base import get_datamodule from mttl.models.expert_config import ExpertConfig -from mttl.models.expert_model import ExpertModel, MoEModel, ExpertModelDPO +from mttl.models.expert_model import ( + ExpertModel, + MultiExpertModel, + MoEModel, + ExpertModelDPO, +) from mttl.models.library.expert import Expert, load_expert from mttl.models.library.expert_library import ExpertLibrary, LocalExpertLibrary from mttl.models.monitors import get_monitors @@ -24,10 +29,11 @@ remote_login, setup_logging, ) -from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule +from mttl.datamodule.base import DatasetConfig from mttl.datamodule.preference_data_module import Preferencemodule from projects.modular_llm.src.transfer_matrix import TransferMatrixConfig from projects.modular_llm.src.transfer_matrix import run_eval as produce_transfer_matrix +from projects.modular_llm.eval_library import patch_prototypes def create_transfer_matrix(args, checkpoint): @@ -83,23 +89,34 @@ def create_library(args): loggers = get_pl_loggers(args) # select dataloader if args.model_modifier == "poly": - args.init_from_scratch = True model_class = MoEModel else: model_class = ExpertModel - config = DatasetConfig(model=args.model) dm = Preferencemodule(config) # dm = get_datamodule(args) # args.n_tasks = len(dm._task_names) # args.task_names = dm._task_names - - model = model_class(**vars(args), tokenizer=dm.tokenizer) + # if args.router_selector == "arrow_router": + args.trainable_param_names = None + ref_model = model_class( + **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library + ) if args.rl_training == "dpo": - ref_model = model_class(**vars(args), tokenizer=dm.tokenizer) - module = ExpertModelDPO(model, ref_model) + args.trainable_param_names = ".*prototypes.*" + model = model_class( + **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library + ) + # if args.library_id: + # model.add_experts_from_library(expert_library) + # patch_prototypes(model, expert_library, args) + + # # ref_model = copy.deepcopy(model) + # ref_model.add_experts_from_library(expert_library) + # patch_prototypes(ref_model, expert_library, args) + module = ExpertModelDPO(model, ref_model, **vars(args)) # get metric monitors for models callbacks = get_monitors(args) @@ -130,8 +147,8 @@ def create_library(args): val_check_interval = args.total_steps trainer = Trainer( - # devices=-1, - # accelerator="gpu", + devices=-1, + accelerator="gpu", logger=loggers, num_sanity_val_steps=0, default_root_dir=args.output_dir, From 93b5526a20b3f0706258d1ae5df6e2ac69627630 Mon Sep 17 00:00:00 2001 From: zhansu Date: Mon, 15 Jul 2024 13:02:26 -0400 Subject: [PATCH 06/37] fix isort --- mttl/datamodule/preference_data_module.py | 6 ++++-- mttl/models/expert_model.py | 5 ++--- projects/modular_llm/train_dpo.py | 13 +++++++------ 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/mttl/datamodule/preference_data_module.py b/mttl/datamodule/preference_data_module.py index 677771a8e..445ad39ac 100644 --- a/mttl/datamodule/preference_data_module.py +++ b/mttl/datamodule/preference_data_module.py @@ -1,8 +1,10 @@ -from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule from dataclasses import dataclass -from mttl.models.library.expert_library import DatasetLibrary + import torch +from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule +from mttl.models.library.expert_library import DatasetLibrary + @dataclass class DataCollatorForDPO(DefaultCollator): diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index b8ae239e2..0fba66597 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -6,15 +6,16 @@ from typing import Dict, List import torch +import torch.nn.functional as F from torch.optim.optimizer import Optimizer from transformers import PreTrainedModel from mttl.models.containers import add_expert_to_transformer from mttl.models.containers.expert_containers import ExpertContainer from mttl.models.containers.selectors import ( + ArrowSelectorConfig, Selector, SelectorConfig, - ArrowSelectorConfig, ) from mttl.models.expert_config import ExpertConfig from mttl.models.library.expert import Expert, ExpertInfo @@ -31,8 +32,6 @@ from mttl.models.utils import EfficientCheckpointModule, prepare_model_for_kbit_training from mttl.utils import logger -import torch.nn.functional as F - torch.set_float32_matmul_precision("high") diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index dc5bbea05..1438f550d 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -1,22 +1,25 @@ +import copy import os import shutil import sys from tempfile import TemporaryDirectory -import copy + import torch from pytorch_lightning import Trainer, seed_everything sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from mttl.callbacks import LiveCheckpointCallback +from mttl.datamodule.base import DatasetConfig +from mttl.datamodule.preference_data_module import Preferencemodule # from mttl.datamodule.base import get_datamodule from mttl.models.expert_config import ExpertConfig from mttl.models.expert_model import ( ExpertModel, - MultiExpertModel, - MoEModel, ExpertModelDPO, + MoEModel, + MultiExpertModel, ) from mttl.models.library.expert import Expert, load_expert from mttl.models.library.expert_library import ExpertLibrary, LocalExpertLibrary @@ -29,11 +32,9 @@ remote_login, setup_logging, ) -from mttl.datamodule.base import DatasetConfig -from mttl.datamodule.preference_data_module import Preferencemodule +from projects.modular_llm.eval_library import patch_prototypes from projects.modular_llm.src.transfer_matrix import TransferMatrixConfig from projects.modular_llm.src.transfer_matrix import run_eval as produce_transfer_matrix -from projects.modular_llm.eval_library import patch_prototypes def create_transfer_matrix(args, checkpoint): From 1f97122e6e73eba9aa821f199cd26a8a1db87a78 Mon Sep 17 00:00:00 2001 From: zhansu Date: Tue, 16 Jul 2024 08:37:41 -0400 Subject: [PATCH 07/37] fix optimization with efficient checkpoint module --- mttl/models/expert_model.py | 19 ++----------------- 1 file changed, 2 insertions(+), 17 deletions(-) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 0fba66597..823102630 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -644,23 +644,8 @@ def __init__(self, expert_model, ref_expert_model, **kwargs): self.expert_model = expert_model self.ref_expert_model = ref_expert_model self.trainable_param_names = kwargs.get("trainable_param_names", None) - - def configure_optimizers(self): - params = [] - # for param_name, param in self.named_parameters(): - # param.requires_grad = False - # if self.trainable_param_names and re.fullmatch( - # self.trainable_param_names, param_name - # ): - # param.requires_grad = True - # params.append(param) - - # logger.info(f"Setting {param_name} to trainable.") - optimizer = torch.optim.Adam( - filter(lambda p: p.requires_grad, self.parameters()), lr=1e-3 - ) - - return optimizer + # log hyperparameters + self.save_hyperparameters(kwargs) def training_step(self, batch, _): From b4fa78b716ead22487182d9a2a10f8cb8ade03ca Mon Sep 17 00:00:00 2001 From: zhansu Date: Tue, 16 Jul 2024 10:23:49 -0400 Subject: [PATCH 08/37] wip --- projects/modular_llm/train_dpo.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index 1438f550d..ff3383a76 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -99,14 +99,13 @@ def create_library(args): # dm = get_datamodule(args) # args.n_tasks = len(dm._task_names) # args.task_names = dm._task_names - # if args.router_selector == "arrow_router": - args.trainable_param_names = None + # args.trainable_param_names = ".*prototypes.*" ref_model = model_class( **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library ) if args.rl_training == "dpo": - args.trainable_param_names = ".*prototypes.*" + # args.trainable_param_names = ".*prototypes.*" model = model_class( **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library ) @@ -117,7 +116,9 @@ def create_library(args): # # ref_model = copy.deepcopy(model) # ref_model.add_experts_from_library(expert_library) # patch_prototypes(ref_model, expert_library, args) - module = ExpertModelDPO(model, ref_model, **vars(args)) + module = ExpertModelDPO( + **vars(args), expert_model=model, ref_expert_model=ref_model + ) # get metric monitors for models callbacks = get_monitors(args) From 6b20a61da91905b5c0579541f710821ce75e1000 Mon Sep 17 00:00:00 2001 From: zhansu Date: Wed, 17 Jul 2024 07:05:49 -0400 Subject: [PATCH 09/37] fix dpo training --- mttl/models/expert_model.py | 16 ++++++++-------- projects/modular_llm/train_dpo.py | 7 +++---- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 823102630..cad291e80 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -639,9 +639,9 @@ def get_log_prob(logits, labels): class ExpertModelDPO(EfficientCheckpointModule): - def __init__(self, expert_model, ref_expert_model, **kwargs): + def __init__(self, preference_model, ref_expert_model, **kwargs): super().__init__(**kwargs) - self.expert_model = expert_model + self.preference_model = preference_model self.ref_expert_model = ref_expert_model self.trainable_param_names = kwargs.get("trainable_param_names", None) # log hyperparameters @@ -657,14 +657,14 @@ def training_step(self, batch, _): # original model model_prefered_log_prob = get_log_prob( - self.expert_model.model.forward( + self.preference_model.model.forward( prompt_prefered_ids, attention_mask=prompt_prefered_mask ).logits, labels=prompt_prefered_ids, ) model_disprefered_log_prob = get_log_prob( - self.expert_model.model.forward( + self.preference_model.model.forward( prompt_disprefered_ids, attention_mask=prompt_disprefered_mask ).logits, labels=prompt_disprefered_ids, @@ -720,14 +720,14 @@ def validation_step(self, batch, _): # original model model_prefered_log_prob = get_log_prob( - self.expert_model.model.forward( + self.preference_model.model.forward( prompt_prefered_ids, attention_mask=prompt_prefered_mask ).logits, labels=prompt_prefered_ids, ) model_disprefered_log_prob = get_log_prob( - self.expert_model.model.forward( + self.preference_model.model.forward( prompt_disprefered_ids, attention_mask=prompt_disprefered_mask ).logits, labels=prompt_disprefered_ids, @@ -783,14 +783,14 @@ def test_step(self, batch, _): # original model model_prefered_log_prob = get_log_prob( - self.expert_model.model.forward( + self.preference_model.model.forward( prompt_prefered_ids, attention_mask=prompt_prefered_mask ).logits, labels=prompt_prefered_ids, ) model_disprefered_log_prob = get_log_prob( - self.expert_model.model.forward( + self.preference_model.model.forward( prompt_disprefered_ids, attention_mask=prompt_disprefered_mask ).logits, labels=prompt_disprefered_ids, diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index ff3383a76..fcff951b8 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -99,13 +99,12 @@ def create_library(args): # dm = get_datamodule(args) # args.n_tasks = len(dm._task_names) # args.task_names = dm._task_names - # args.trainable_param_names = ".*prototypes.*" ref_model = model_class( **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library ) if args.rl_training == "dpo": - # args.trainable_param_names = ".*prototypes.*" + args.trainable_param_names = "^(?=.*preference_model)(?=.*prototypes).*" model = model_class( **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library ) @@ -117,7 +116,7 @@ def create_library(args): # ref_model.add_experts_from_library(expert_library) # patch_prototypes(ref_model, expert_library, args) module = ExpertModelDPO( - **vars(args), expert_model=model, ref_expert_model=ref_model + **vars(args), preference_model=model, ref_expert_model=ref_model ) # get metric monitors for models @@ -223,7 +222,7 @@ def upload_library(expert_library, module): else: raise ValueError("Model class not recognized") - upload_library(expert_library, module) + # upload_library(expert_library, module) if args.create_transfer_matrix: create_transfer_matrix(args, checkpoint) From 8c61ba3427cdf298892891a8f27dfd60813a9416 Mon Sep 17 00:00:00 2001 From: zhansu Date: Wed, 17 Jul 2024 08:14:07 -0400 Subject: [PATCH 10/37] add simpo training --- mttl/models/expert_model.py | 114 ++++++++++++++++++++++++++++++ projects/modular_llm/train_dpo.py | 17 +++-- 2 files changed, 126 insertions(+), 5 deletions(-) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index cad291e80..92521ab1a 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -637,6 +637,120 @@ def get_log_prob(logits, labels): return torch.gather(log_probs, -1, labels.unsqueeze(-1)).squeeze(-1).mean(-1) +class ExpertModelSimPO(EfficientCheckpointModule): + def __init__(self, preference_model, **kwargs): + super().__init__(**kwargs) + self.preference_model = preference_model + self.trainable_param_names = kwargs.get("trainable_param_names", None) + self.beta = kwargs.get("beta", 0.5) + self.loss_type = kwargs.get("loss_type", "sigmoid") + self.label_smoothing = kwargs.get("label_smoothing", 0.1) + # log hyperparameters + self.save_hyperparameters(kwargs) + + def simpo_loss( + self, original_prefered_logprob, original_disprefered_logprob, gamma_beta_ratio + ): + """ + Compute the SIMPO loss. + + ref: https://github.com/princeton-nlp/SimPO/blob/main/scripts/simpo_trainer.py + + args: original_prefered_logps: log probabiliteis of the prefered expert in the original model + original_disprefered_logps: log probabiliteis of the disprefered expert in the original model + """ + + pi_logratios = original_prefered_logprob - original_disprefered_logprob + logits = pi_logratios - gamma_beta_ratio + + if self.loss_type == "sigmoid": + losses = ( + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing + ) + elif self.loss_type == "hinge": + losses = torch.relu(1 - self.beta * logits) + else: + raise ValueError( + f"Loss type {self.loss_type} not supported. Choose from ['sigmoid', 'hinge']" + ) + + chosen_rewards = ( + self.beta * original_prefered_logprob.to(self.accelerator.device).detach() + ) + + reject_rewards = ( + -self.beta + * original_disprefered_logprob.to(self.accelerator.device).detach() + ) + + return losses, chosen_rewards, reject_rewards + + def training_step(self, batch, _): + prompt_prefered_ids = batch["prompt_prefered_ids"] + prompt_disprefered_ids = batch["prompt_disprefered_ids"] + + prompt_prefered_mask = batch["prompt_prefered_mask"] + prompt_disprefered_mask = batch["prompt_disprefered_mask"] + + # original model + model_prefered_log_prob = get_log_prob( + self.preference_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits, + labels=prompt_prefered_ids, + ) + + model_disprefered_log_prob = get_log_prob( + self.preference_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits, + labels=prompt_disprefered_ids, + ) + + loss, chosen_rewards, rejected_rewards = self.simpo_loss( + model_prefered_log_prob, model_disprefered_log_prob, gamma_beta_ratio=0.1 + ) + self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log("train/chosen_rewards", chosen_rewards, on_step=True, on_epoch=True) + self.log( + "train/rejected_rewards", rejected_rewards, on_step=True, on_epoch=True + ) + + return loss + + def validation_step(self, batch, _): + prompt_prefered_ids = batch["prompt_prefered_ids"] + prompt_disprefered_ids = batch["prompt_disprefered_ids"] + + prompt_prefered_mask = batch["prompt_prefered_mask"] + prompt_disprefered_mask = batch["prompt_disprefered_mask"] + + # original model + model_prefered_log_prob = get_log_prob( + self.preference_model.model.forward( + prompt_prefered_ids, attention_mask=prompt_prefered_mask + ).logits, + labels=prompt_prefered_ids, + ) + + model_disprefered_log_prob = get_log_prob( + self.preference_model.model.forward( + prompt_disprefered_ids, attention_mask=prompt_disprefered_mask + ).logits, + labels=prompt_disprefered_ids, + ) + + loss, chosen_rewards, rejected_rewards = self.simpo_loss( + model_prefered_log_prob, model_disprefered_log_prob, gamma_beta_ratio=0.1 + ) + self.log("val/loss", loss, on_step=True, on_epoch=True, prog_bar=True) + self.log("val/chosen_rewards", chosen_rewards, on_step=True, on_epoch=True) + self.log("val/rejected_rewards", rejected_rewards, on_step=True, on_epoch=True) + + return loss + + class ExpertModelDPO(EfficientCheckpointModule): def __init__(self, preference_model, ref_expert_model, **kwargs): diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index fcff951b8..545a5b004 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -19,7 +19,7 @@ ExpertModel, ExpertModelDPO, MoEModel, - MultiExpertModel, + ExpertModelSimPO, ) from mttl.models.library.expert import Expert, load_expert from mttl.models.library.expert_library import ExpertLibrary, LocalExpertLibrary @@ -99,15 +99,15 @@ def create_library(args): # dm = get_datamodule(args) # args.n_tasks = len(dm._task_names) # args.task_names = dm._task_names - ref_model = model_class( + model = model_class( **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library ) - if args.rl_training == "dpo": args.trainable_param_names = "^(?=.*preference_model)(?=.*prototypes).*" - model = model_class( + ref_model = model_class( **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library ) + # if args.library_id: # model.add_experts_from_library(expert_library) # patch_prototypes(model, expert_library, args) @@ -118,7 +118,14 @@ def create_library(args): module = ExpertModelDPO( **vars(args), preference_model=model, ref_expert_model=ref_model ) - + elif args.rl_training == "simpo": + args.trainable_param_names = "^(?=.*preference_model)(?=.*prototypes).*" + model = model_class( + **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library + ) + module = ExpertModelSimPO(**vars(args), preference_model=model) + else: + module = model # get metric monitors for models callbacks = get_monitors(args) if "mbpp" in args.dataset: From 0f0be96d3e235eb20bcefe59e94ea7b9c915b31c Mon Sep 17 00:00:00 2001 From: wzm289 Date: Wed, 17 Jul 2024 14:58:18 +0200 Subject: [PATCH 11/37] fix --- mttl/models/expert_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 92521ab1a..912c82f00 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -676,12 +676,12 @@ def simpo_loss( ) chosen_rewards = ( - self.beta * original_prefered_logprob.to(self.accelerator.device).detach() + self.beta * original_prefered_logprob.detach() ) reject_rewards = ( -self.beta - * original_disprefered_logprob.to(self.accelerator.device).detach() + * original_disprefered_logprob.detach() ) return losses, chosen_rewards, reject_rewards From 3de4297e112fa3564b9d5e42d2efe5335d4b4e4d Mon Sep 17 00:00:00 2001 From: zhansu Date: Wed, 17 Jul 2024 09:09:19 -0400 Subject: [PATCH 12/37] isort --- mttl/models/expert_model.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 912c82f00..ba4d993f1 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -711,10 +711,10 @@ def training_step(self, batch, _): loss, chosen_rewards, rejected_rewards = self.simpo_loss( model_prefered_log_prob, model_disprefered_log_prob, gamma_beta_ratio=0.1 ) - self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) - self.log("train/chosen_rewards", chosen_rewards, on_step=True, on_epoch=True) + self.log("train/loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) + self.log("train/chosen_rewards", chosen_rewards.mean(), on_step=True, on_epoch=True) self.log( - "train/rejected_rewards", rejected_rewards, on_step=True, on_epoch=True + "train/rejected_rewards", rejected_rewards.mean(), on_step=True, on_epoch=True ) return loss @@ -744,9 +744,9 @@ def validation_step(self, batch, _): loss, chosen_rewards, rejected_rewards = self.simpo_loss( model_prefered_log_prob, model_disprefered_log_prob, gamma_beta_ratio=0.1 ) - self.log("val/loss", loss, on_step=True, on_epoch=True, prog_bar=True) - self.log("val/chosen_rewards", chosen_rewards, on_step=True, on_epoch=True) - self.log("val/rejected_rewards", rejected_rewards, on_step=True, on_epoch=True) + self.log("val/loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) + self.log("val/chosen_rewards", chosen_rewards.mean(), on_step=True, on_epoch=True) + self.log("val/rejected_rewards", rejected_rewards.mean(), on_step=True, on_epoch=True) return loss From d3b77a1d7b0672b8f0e698d22087bebd6fc6583d Mon Sep 17 00:00:00 2001 From: wzm289 Date: Wed, 17 Jul 2024 15:54:36 +0200 Subject: [PATCH 13/37] wip --- mttl/models/expert_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index ba4d993f1..3f7ac4a1b 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -717,7 +717,7 @@ def training_step(self, batch, _): "train/rejected_rewards", rejected_rewards.mean(), on_step=True, on_epoch=True ) - return loss + return loss.mean() def validation_step(self, batch, _): prompt_prefered_ids = batch["prompt_prefered_ids"] @@ -748,7 +748,7 @@ def validation_step(self, batch, _): self.log("val/chosen_rewards", chosen_rewards.mean(), on_step=True, on_epoch=True) self.log("val/rejected_rewards", rejected_rewards.mean(), on_step=True, on_epoch=True) - return loss + return loss.mean() class ExpertModelDPO(EfficientCheckpointModule): From 6601b677aa8ae456e7d6697b30cea007cc50014a Mon Sep 17 00:00:00 2001 From: wzm289 Date: Thu, 25 Jul 2024 00:05:45 +0200 Subject: [PATCH 14/37] get cluster dataset by embedding --- projects/modular_llm/get_clusters.py | 167 +++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 projects/modular_llm/get_clusters.py diff --git a/projects/modular_llm/get_clusters.py b/projects/modular_llm/get_clusters.py new file mode 100644 index 000000000..3b1537f07 --- /dev/null +++ b/projects/modular_llm/get_clusters.py @@ -0,0 +1,167 @@ +from datasets import load_dataset +from torch.utils.data import Dataset, DataLoader, Subset +from mttl.models.ranker.adapter_ranker import AdapterRankerHelper +from tqdm import tqdm +from nomic import atlas +import numpy as np +from sklearn.cluster import KMeans +import argparse +import logging +from mttl.datamodule.mt_seq_to_seq_module import FlanModule, FlanConfig +from sentence_transformers import SentenceTransformer +from datasets import concatenate_datasets +import os +import huggingface_hub + +huggingface_token = os.environ.get("HF_TOKEN") +huggingface_hub.login(token=huggingface_token) + + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +parser = argparse.ArgumentParser() +parser.add_argument("--batch_size", type=int, default=64) + +parser.add_argument("--subsample", type=float, default=0.2) + +parser.add_argument("--num_clusters", type=int, default=256) + +parser.add_argument("--dataset", type=str, default="orca") + +parser.add_argument("--output_file", type=str, default="orca_cluster.json") + +parser.add_argument("--encoding", type=str, default="classifier") + +parser.add_argument( + "--new_dataname", type=str, default="zhan1993/flan-10k-flat-cluster-embedding" +) + +parser.add_argument( + "--model", type=str, default="sentence-transformers/sentence-t5-xxl" +) +args = parser.parse_args() +np.random.seed(42) + + +def get_text_encode(text, model): + + if args.encoding == "classifier": + return model.get_text_encode(text).cpu().detach().numpy() + elif args.encoding == "embedding": + return model.encode(text) + + +if args.encoding == "classifier": + model = AdapterRankerHelper.get_ranker_instance( + ranker_model="classifier", + ranker_path="zhan1993/classifier_ranker_t5_v3", + ) +elif args.encoding == "embedding": + model = SentenceTransformer(args.model) + +# load the dataset + + +def get_orca_dataset(): + + dataset = load_dataset("Open-Orca/OpenOrca") + + # create the subsample of the dataset. + dataset_size = len(dataset["train"]) + indices = list(range(dataset_size)) + # random indices + np.random.shuffle(indices) + split = int(np.floor(args.subsample * dataset_size)) + subset_indices = indices[:split] + subset_dataset = Subset(dataset["train"], subset_indices) + + train_dataloader = DataLoader( + subset_dataset, batch_size=args.batch_size, num_workers=8 + ) + all_dataloader = DataLoader( + dataset["train"], batch_size=args.batch_size, num_workers=8 + ) + + return train_dataloader, all_dataloader, dataset["train"] + + +def get_flan_dataset(): + + flan = FlanModule( + FlanConfig( + model="EleutherAI/gpt-neo-125m", + model_family="gpt", + train_batch_size=4, + predict_batch_size=4, + dataset="sordonia/flan-10k-flat", + remove_phi_eval_tasks=True, + ) + ) + + dataset = flan.train_dataset + # create the subsample of the dataset. + dataset_size = len(dataset) + indices = list(range(dataset_size)) + # random indices + np.random.shuffle(indices) + split = int(np.floor(args.subsample * dataset_size)) + subset_indices = indices[:split] + subset_dataset = Subset(dataset, subset_indices) + + train_dataloader = DataLoader( + subset_dataset, batch_size=args.batch_size, num_workers=8 + ) + all_dataloader = flan.train_dataloader() + + all_dataset = concatenate_datasets( + [flan.train_dataset, flan.dev_dataset, flan.test_dataset] + ) + + return train_dataloader, all_dataloader, all_dataset + + +if __name__ == "__main__": + if args.dataset == "orca": + train_dataloader, all_dataloader, all_dataset = get_orca_dataset() + elif args.dataset == "flan": + train_dataloader, all_dataloader, all_dataset = get_flan_dataset() + + embedding_list = [] + + for i, batch in tqdm( + enumerate(train_dataloader), total=len(train_dataloader), desc="dataset" + ): + if "source" in batch: + embedding = get_text_encode(batch["source"], model) + else: + embedding = get_text_encode(batch["question"], model) + embedding_list.append(embedding) + + all_embedding = np.concatenate(embedding_list, axis=0).reshape(-1, 768) + logger.info(f"all_embedding shape: {all_embedding.shape}") + kmeans = KMeans( + n_clusters=args.num_clusters, + init="k-means++", + n_init=10, + random_state=42, + ).fit(all_embedding) + + # map the new item with kmeans cluster + + def add_cluster_id(example): + if "source" in example: + embedding = get_text_encode(example["source"], model) + else: + embedding = get_text_encode(example["question"], model) + embedding = embedding.reshape(1, -1) + example["cluster_id"] = str(kmeans.predict(embedding)[0]) + return example + + # all_dataset = all_dataset.select(list(range(100))) + + dataset = all_dataset.map(add_cluster_id) + + # Push the merged dataset back to Hugging Face Hub + dataset.push_to_hub(args.new_dataname) From d69f03b0e419f05afe31f54f3da23f95552a20c8 Mon Sep 17 00:00:00 2001 From: zhansu Date: Tue, 30 Jul 2024 04:03:42 -0400 Subject: [PATCH 15/37] add args for dpo training --- mttl/config.py | 1 + mttl/models/expert_model.py | 44 +++++++++++++++++++------------ projects/modular_llm/train_dpo.py | 8 ------ 3 files changed, 28 insertions(+), 25 deletions(-) diff --git a/mttl/config.py b/mttl/config.py index f8d79be97..79bc8751c 100644 --- a/mttl/config.py +++ b/mttl/config.py @@ -349,3 +349,4 @@ def _set_defaults(self): # rl training self.rl_training = "dpo" + self.beta = 0.5 diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 3f7ac4a1b..a8ddc5a8d 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -599,7 +599,7 @@ def calculate_DPO_loss( original_disprefered_logprob, ref_prefered_logprob, ref_disprefered_logprob, - beta=0.5, + beta=2.0, ): """ Calculate the DPO loss. @@ -645,6 +645,7 @@ def __init__(self, preference_model, **kwargs): self.beta = kwargs.get("beta", 0.5) self.loss_type = kwargs.get("loss_type", "sigmoid") self.label_smoothing = kwargs.get("label_smoothing", 0.1) + self.gamma_beta_ratio = kwargs.get("gamma_beta_ratio", 0.5) # log hyperparameters self.save_hyperparameters(kwargs) @@ -675,14 +676,9 @@ def simpo_loss( f"Loss type {self.loss_type} not supported. Choose from ['sigmoid', 'hinge']" ) - chosen_rewards = ( - self.beta * original_prefered_logprob.detach() - ) + chosen_rewards = self.beta * original_prefered_logprob.detach() - reject_rewards = ( - -self.beta - * original_disprefered_logprob.detach() - ) + reject_rewards = self.beta * original_disprefered_logprob.detach() return losses, chosen_rewards, reject_rewards @@ -709,12 +705,19 @@ def training_step(self, batch, _): ) loss, chosen_rewards, rejected_rewards = self.simpo_loss( - model_prefered_log_prob, model_disprefered_log_prob, gamma_beta_ratio=0.1 + model_prefered_log_prob, + model_disprefered_log_prob, + gamma_beta_ratio=self.gamma_beta_ratio, ) self.log("train/loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) - self.log("train/chosen_rewards", chosen_rewards.mean(), on_step=True, on_epoch=True) self.log( - "train/rejected_rewards", rejected_rewards.mean(), on_step=True, on_epoch=True + "train/chosen_rewards", chosen_rewards.mean(), on_step=True, on_epoch=True + ) + self.log( + "train/rejected_rewards", + rejected_rewards.mean(), + on_step=True, + on_epoch=True, ) return loss.mean() @@ -742,11 +745,17 @@ def validation_step(self, batch, _): ) loss, chosen_rewards, rejected_rewards = self.simpo_loss( - model_prefered_log_prob, model_disprefered_log_prob, gamma_beta_ratio=0.1 + model_prefered_log_prob, + model_disprefered_log_prob, + gamma_beta_ratio=self.gamma_beta_ratio, ) self.log("val/loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) - self.log("val/chosen_rewards", chosen_rewards.mean(), on_step=True, on_epoch=True) - self.log("val/rejected_rewards", rejected_rewards.mean(), on_step=True, on_epoch=True) + self.log( + "val/chosen_rewards", chosen_rewards.mean(), on_step=True, on_epoch=True + ) + self.log( + "val/rejected_rewards", rejected_rewards.mean(), on_step=True, on_epoch=True + ) return loss.mean() @@ -758,6 +767,7 @@ def __init__(self, preference_model, ref_expert_model, **kwargs): self.preference_model = preference_model self.ref_expert_model = ref_expert_model self.trainable_param_names = kwargs.get("trainable_param_names", None) + self.beta = kwargs.get("beta", 2.0) # log hyperparameters self.save_hyperparameters(kwargs) @@ -804,7 +814,7 @@ def training_step(self, batch, _): model_disprefered_log_prob, ref_prefered_log_prob, ref_disprefered_log_prob, - beta=0.1, + beta=self.beta, ) self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) @@ -867,7 +877,7 @@ def validation_step(self, batch, _): model_disprefered_log_prob, ref_prefered_log_prob, ref_disprefered_log_prob, - beta=0.1, + beta=self.beta, ) self.log("val/loss", loss, on_step=True, on_epoch=True, prog_bar=True) @@ -930,7 +940,7 @@ def test_step(self, batch, _): model_disprefered_log_prob, ref_prefered_log_prob, ref_disprefered_log_prob, - beta=0.1, + beta=self.beta, ) self.log("test/loss", loss, on_step=True, on_epoch=True, prog_bar=True) self.log( diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index 545a5b004..f6e1ea421 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -107,14 +107,6 @@ def create_library(args): ref_model = model_class( **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library ) - - # if args.library_id: - # model.add_experts_from_library(expert_library) - # patch_prototypes(model, expert_library, args) - - # # ref_model = copy.deepcopy(model) - # ref_model.add_experts_from_library(expert_library) - # patch_prototypes(ref_model, expert_library, args) module = ExpertModelDPO( **vars(args), preference_model=model, ref_expert_model=ref_model ) From 4a0d3d6e80c298ee1723460ef8792357f42a876c Mon Sep 17 00:00:00 2001 From: zhansu Date: Wed, 31 Jul 2024 10:26:54 -0400 Subject: [PATCH 16/37] add log --- mttl/models/expert_model.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index a8ddc5a8d..6ddc47fc9 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -709,10 +709,19 @@ def training_step(self, batch, _): model_disprefered_log_prob, gamma_beta_ratio=self.gamma_beta_ratio, ) + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + self.log( + "train/accuracies", + reward_accuracies.mean().cpu(), + on_step=True, + on_epoch=True, + ) self.log("train/loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) self.log( "train/chosen_rewards", chosen_rewards.mean(), on_step=True, on_epoch=True ) + self.log( "train/rejected_rewards", rejected_rewards.mean(), @@ -749,6 +758,15 @@ def validation_step(self, batch, _): model_disprefered_log_prob, gamma_beta_ratio=self.gamma_beta_ratio, ) + + reward_accuracies = (chosen_rewards > rejected_rewards).float() + self.log( + "val/reward_accuracies", + reward_accuracies.mean(), + on_step=True, + on_epoch=True, + prog_bar=True, + ) self.log("val/loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) self.log( "val/chosen_rewards", chosen_rewards.mean(), on_step=True, on_epoch=True From b73de95c45548bee858a924114a8e8dfbf00afe9 Mon Sep 17 00:00:00 2001 From: zhansu Date: Tue, 6 Aug 2024 22:46:48 -0400 Subject: [PATCH 17/37] wip --- mttl/datamodule/preference_data_module.py | 2 + projects/modular_llm/get_clusters.py | 108 +++++++++++++--------- 2 files changed, 66 insertions(+), 44 deletions(-) diff --git a/mttl/datamodule/preference_data_module.py b/mttl/datamodule/preference_data_module.py index 445ad39ac..fb7aa9071 100644 --- a/mttl/datamodule/preference_data_module.py +++ b/mttl/datamodule/preference_data_module.py @@ -47,6 +47,8 @@ def __call__(self, batch): [torch.ones_like(prompt_ids), torch.zeros_like(disprefered_ids)], dim=-1 ) + breakpoint() + return { "prompt_prefered_ids": prompt_prefered_ids, "prompt_disprefered_ids": prompt_disprefered_ids, diff --git a/projects/modular_llm/get_clusters.py b/projects/modular_llm/get_clusters.py index 3b1537f07..efd8cf32f 100644 --- a/projects/modular_llm/get_clusters.py +++ b/projects/modular_llm/get_clusters.py @@ -61,72 +61,92 @@ def get_text_encode(text, model): elif args.encoding == "embedding": model = SentenceTransformer(args.model) -# load the dataset +def get_dataset(args): + dataset = DatasetLibrary.pull_dataset(args.dataset, split="train") + # create the subsample of the dataset if cutoff is set. + if args.cutoff > 0: + dataset = dataset.shuffle(seed=args.seed) + dataset = dataset.select(range(args.cutoff)) -def get_orca_dataset(): - - dataset = load_dataset("Open-Orca/OpenOrca") - - # create the subsample of the dataset. - dataset_size = len(dataset["train"]) + dataset_size = len(dataset) indices = list(range(dataset_size)) - # random indices np.random.shuffle(indices) split = int(np.floor(args.subsample * dataset_size)) subset_indices = indices[:split] - subset_dataset = Subset(dataset["train"], subset_indices) + subset_dataset = dataset.select(subset_indices) train_dataloader = DataLoader( - subset_dataset, batch_size=args.batch_size, num_workers=8 + subset_dataset, batch_size=args.batch_size, num_workers=args.num_workers ) all_dataloader = DataLoader( - dataset["train"], batch_size=args.batch_size, num_workers=8 + dataset, batch_size=args.batch_size, num_workers=args.num_workers ) - return train_dataloader, all_dataloader, dataset["train"] + return train_dataloader, all_dataloader, dataset, subset_dataset +# def get_orca_dataset(): -def get_flan_dataset(): +# dataset = load_dataset("Open-Orca/OpenOrca") - flan = FlanModule( - FlanConfig( - model="EleutherAI/gpt-neo-125m", - model_family="gpt", - train_batch_size=4, - predict_batch_size=4, - dataset="sordonia/flan-10k-flat", - remove_phi_eval_tasks=True, - ) - ) +# # create the subsample of the dataset. +# dataset_size = len(dataset["train"]) +# indices = list(range(dataset_size)) +# # random indices +# np.random.shuffle(indices) +# split = int(np.floor(args.subsample * dataset_size)) +# subset_indices = indices[:split] +# subset_dataset = Subset(dataset["train"], subset_indices) - dataset = flan.train_dataset - # create the subsample of the dataset. - dataset_size = len(dataset) - indices = list(range(dataset_size)) - # random indices - np.random.shuffle(indices) - split = int(np.floor(args.subsample * dataset_size)) - subset_indices = indices[:split] - subset_dataset = Subset(dataset, subset_indices) +# train_dataloader = DataLoader( +# subset_dataset, batch_size=args.batch_size, num_workers=8 +# ) +# all_dataloader = DataLoader( +# dataset["train"], batch_size=args.batch_size, num_workers=8 +# ) - train_dataloader = DataLoader( - subset_dataset, batch_size=args.batch_size, num_workers=8 - ) - all_dataloader = flan.train_dataloader() +# return train_dataloader, all_dataloader, dataset["train"] - all_dataset = concatenate_datasets( - [flan.train_dataset, flan.dev_dataset, flan.test_dataset] - ) - return train_dataloader, all_dataloader, all_dataset +# def get_flan_dataset(): + +# flan = FlanModule( +# FlanConfig( +# model="EleutherAI/gpt-neo-125m", +# model_family="gpt", +# train_batch_size=4, +# predict_batch_size=4, +# dataset="sordonia/flan-10k-flat", +# remove_phi_eval_tasks=True, +# ) +# ) + +# dataset = flan.train_dataset +# # create the subsample of the dataset. +# dataset_size = len(dataset) +# indices = list(range(dataset_size)) +# # random indices +# np.random.shuffle(indices) +# split = int(np.floor(args.subsample * dataset_size)) +# subset_indices = indices[:split] +# subset_dataset = Subset(dataset, subset_indices) + +# train_dataloader = DataLoader( +# subset_dataset, batch_size=args.batch_size, num_workers=8 +# ) +# all_dataloader = flan.train_dataloader() + +# all_dataset = concatenate_datasets( +# [flan.train_dataset, flan.dev_dataset, flan.test_dataset] +# ) + +# return train_dataloader, all_dataloader, all_dataset if __name__ == "__main__": - if args.dataset == "orca": - train_dataloader, all_dataloader, all_dataset = get_orca_dataset() - elif args.dataset == "flan": - train_dataloader, all_dataloader, all_dataset = get_flan_dataset() + + train_dataloader, all_dataloader, all_dataset, subset_dataset = get_dataset(args) + breakpoint() embedding_list = [] From 7b9f417b4ee36c8cb9c3c988acad667ae41b135b Mon Sep 17 00:00:00 2001 From: zhansu Date: Wed, 7 Aug 2024 02:23:05 -0400 Subject: [PATCH 18/37] add length normalization --- mttl/datamodule/preference_data_module.py | 26 +++-- mttl/models/expert_model.py | 115 +++++++++++----------- 2 files changed, 77 insertions(+), 64 deletions(-) diff --git a/mttl/datamodule/preference_data_module.py b/mttl/datamodule/preference_data_module.py index fb7aa9071..04149fa6a 100644 --- a/mttl/datamodule/preference_data_module.py +++ b/mttl/datamodule/preference_data_module.py @@ -21,21 +21,23 @@ def __call__(self, batch): truncation=True, )["input_ids"] - prefered_ids = self.tokenizer.batch_encode_plus( + prefered_tokenize = self.tokenizer.batch_encode_plus( chosen_responses, padding=True, return_tensors="pt", max_length=self.max_input_length, truncation=True, - )["input_ids"] + ) + prefered_ids = prefered_tokenize["input_ids"] - disprefered_ids = self.tokenizer.batch_encode_plus( + disprefered_tokenize = self.tokenizer.batch_encode_plus( rejected_responses, padding=True, return_tensors="pt", max_length=self.max_input_length, truncation=True, - )["input_ids"] + ) + disprefered_ids = disprefered_tokenize["input_ids"] prompt_prefered_ids = torch.cat([prompt_ids, prefered_ids], dim=-1) prompt_disprefered_ids = torch.cat([prompt_ids, disprefered_ids], dim=-1) @@ -43,17 +45,21 @@ def __call__(self, batch): prompt_prefered_mask = torch.cat( [torch.ones_like(prompt_ids), torch.zeros_like(prefered_ids)], dim=-1 ) + # compute the each length of the prefered + prefered_y_len = prefered_tokenize["attention_mask"].sum(dim=1) + disprefered_y_len = disprefered_tokenize["attention_mask"].sum(dim=1) + prompt_disprefered_mask = torch.cat( [torch.ones_like(prompt_ids), torch.zeros_like(disprefered_ids)], dim=-1 ) - breakpoint() - return { "prompt_prefered_ids": prompt_prefered_ids, "prompt_disprefered_ids": prompt_disprefered_ids, "prompt_prefered_mask": prompt_prefered_mask, "prompt_disprefered_mask": prompt_disprefered_mask, + "prefered_y_len": prefered_y_len, + "disprefered_y_len": disprefered_y_len, } @@ -94,5 +100,11 @@ def collate_fn(self): train_dataloader = datamodule.train_dataloader() val_dataloder = datamodule.val_dataloader() for batch in val_dataloder: - print(batch) + prompt_prefered_mask = batch["prompt_prefered_mask"] + prompt_disprefered_mask = batch["prompt_disprefered_mask"] + + # get the length of the response + prefered_y_len = batch["prefered_y_len"] + disprefered_y_len = batch["disprefered_y_len"] + print(prefered_y_len, disprefered_y_len) breakpoint() diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 6ddc47fc9..2516af623 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -650,7 +650,12 @@ def __init__(self, preference_model, **kwargs): self.save_hyperparameters(kwargs) def simpo_loss( - self, original_prefered_logprob, original_disprefered_logprob, gamma_beta_ratio + self, + original_prefered_logprob, + original_disprefered_logprob, + gamma_beta_ratio, + prefered_y_len, + disprefered_y_len, ): """ Compute the SIMPO loss. @@ -676,20 +681,27 @@ def simpo_loss( f"Loss type {self.loss_type} not supported. Choose from ['sigmoid', 'hinge']" ) - chosen_rewards = self.beta * original_prefered_logprob.detach() + # normalize the log probabilities with the length of the response + chosen_rewards = self.beta * original_prefered_logprob.detach() / prefered_y_len - reject_rewards = self.beta * original_disprefered_logprob.detach() + reject_rewards = ( + self.beta * original_disprefered_logprob.detach() / disprefered_y_len + ) return losses, chosen_rewards, reject_rewards - def training_step(self, batch, _): + def forward(self, batch): prompt_prefered_ids = batch["prompt_prefered_ids"] prompt_disprefered_ids = batch["prompt_disprefered_ids"] prompt_prefered_mask = batch["prompt_prefered_mask"] prompt_disprefered_mask = batch["prompt_disprefered_mask"] - # original model + # get the length of the response + prefered_y_len = batch["prefered_y_len"] + disprefered_y_len = batch["disprefered_y_len"] + + # get the log probabilities of the prefered and disprefered experts model_prefered_log_prob = get_log_prob( self.preference_model.model.forward( prompt_prefered_ids, attention_mask=prompt_prefered_mask @@ -708,72 +720,61 @@ def training_step(self, batch, _): model_prefered_log_prob, model_disprefered_log_prob, gamma_beta_ratio=self.gamma_beta_ratio, + prefered_y_len=prefered_y_len, + disprefered_y_len=disprefered_y_len, ) + return loss.mean(), chosen_rewards, rejected_rewards + + def training_step(self, batch, _): + + loss, chosen_rewards, rejected_rewards = self.forward(batch) reward_accuracies = (chosen_rewards > rejected_rewards).float() - self.log( - "train/accuracies", - reward_accuracies.mean().cpu(), - on_step=True, - on_epoch=True, - ) - self.log("train/loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) - self.log( - "train/chosen_rewards", chosen_rewards.mean(), on_step=True, on_epoch=True - ) - self.log( - "train/rejected_rewards", - rejected_rewards.mean(), - on_step=True, - on_epoch=True, - ) + metrices = { + "loss": loss.mean(), + "reward_accuracies": reward_accuracies.mean().cpu(), + "chosen_rewards": chosen_rewards.mean(), + "rejected_rewards": rejected_rewards.mean(), + "reward_margins": (chosen_rewards - rejected_rewards).mean().cpu(), + } + + for key, value in metrices.items(): + self.log(f"train/{key}", value, on_step=True, on_epoch=True, prog_bar=True) return loss.mean() def validation_step(self, batch, _): - prompt_prefered_ids = batch["prompt_prefered_ids"] - prompt_disprefered_ids = batch["prompt_disprefered_ids"] - prompt_prefered_mask = batch["prompt_prefered_mask"] - prompt_disprefered_mask = batch["prompt_disprefered_mask"] + loss, chosen_rewards, rejected_rewards = self.forward(batch) - # original model - model_prefered_log_prob = get_log_prob( - self.preference_model.model.forward( - prompt_prefered_ids, attention_mask=prompt_prefered_mask - ).logits, - labels=prompt_prefered_ids, - ) + reward_accuracies = (chosen_rewards > rejected_rewards).float() - model_disprefered_log_prob = get_log_prob( - self.preference_model.model.forward( - prompt_disprefered_ids, attention_mask=prompt_disprefered_mask - ).logits, - labels=prompt_disprefered_ids, - ) + metrices = { + "loss": loss.mean(), + "reward_accuracies": reward_accuracies.mean().cpu(), + "chosen_rewards": chosen_rewards.mean(), + "rejected_rewards": rejected_rewards.mean(), + "reward_margins": (chosen_rewards - rejected_rewards).mean().cpu(), + } + for key, value in metrices.items(): + self.log(f"val/{key}", value, on_step=True, on_epoch=True, prog_bar=True) - loss, chosen_rewards, rejected_rewards = self.simpo_loss( - model_prefered_log_prob, - model_disprefered_log_prob, - gamma_beta_ratio=self.gamma_beta_ratio, - ) + return loss.mean() + + def test_step(self, batch, _): + loss, chosen_rewards, rejected_rewards = self.forward(batch) reward_accuracies = (chosen_rewards > rejected_rewards).float() - self.log( - "val/reward_accuracies", - reward_accuracies.mean(), - on_step=True, - on_epoch=True, - prog_bar=True, - ) - self.log("val/loss", loss.mean(), on_step=True, on_epoch=True, prog_bar=True) - self.log( - "val/chosen_rewards", chosen_rewards.mean(), on_step=True, on_epoch=True - ) - self.log( - "val/rejected_rewards", rejected_rewards.mean(), on_step=True, on_epoch=True - ) + metrices = { + "loss": loss.mean(), + "reward_accuracies": reward_accuracies.mean().cpu(), + "chosen_rewards": chosen_rewards.mean(), + "rejected_rewards": rejected_rewards.mean(), + "reward_margins": (chosen_rewards - rejected_rewards).mean().cpu(), + } + for key, value in metrices.items(): + self.log(f"test/{key}", value, on_step=True, on_epoch=True, prog_bar=True) return loss.mean() From 0566d96d9e4389df41dc8e520d7f6dc5fbbb9582 Mon Sep 17 00:00:00 2001 From: zhansu Date: Wed, 7 Aug 2024 09:20:37 -0400 Subject: [PATCH 19/37] fix --- mttl/models/expert_model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 2516af623..8e46447ee 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -666,6 +666,10 @@ def simpo_loss( original_disprefered_logps: log probabiliteis of the disprefered expert in the original model """ + # normalize the log probabilities with the length of the response + original_prefered_logprob = original_prefered_logprob / prefered_y_len + original_disprefered_logprob = original_disprefered_logprob / disprefered_y_len + pi_logratios = original_prefered_logprob - original_disprefered_logprob logits = pi_logratios - gamma_beta_ratio @@ -682,11 +686,9 @@ def simpo_loss( ) # normalize the log probabilities with the length of the response - chosen_rewards = self.beta * original_prefered_logprob.detach() / prefered_y_len + chosen_rewards = self.beta * original_prefered_logprob.detach() - reject_rewards = ( - self.beta * original_disprefered_logprob.detach() / disprefered_y_len - ) + reject_rewards = self.beta * original_disprefered_logprob.detach() return losses, chosen_rewards, reject_rewards From 1dba64d39ed8d77ec1d13088368738c736c28dbb Mon Sep 17 00:00:00 2001 From: wzm289 Date: Wed, 21 Aug 2024 04:24:38 +0200 Subject: [PATCH 20/37] make the reference model eval mode --- projects/modular_llm/train_dpo.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index f6e1ea421..5a5d0f741 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -107,14 +107,13 @@ def create_library(args): ref_model = model_class( **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library ) + # eval mode + ref_model.eval() module = ExpertModelDPO( **vars(args), preference_model=model, ref_expert_model=ref_model ) elif args.rl_training == "simpo": args.trainable_param_names = "^(?=.*preference_model)(?=.*prototypes).*" - model = model_class( - **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library - ) module = ExpertModelSimPO(**vars(args), preference_model=model) else: module = model From 38049bb72cca43d72ce972c21cc9bca7c49d90e5 Mon Sep 17 00:00:00 2001 From: zhansu Date: Wed, 21 Aug 2024 03:56:26 -0400 Subject: [PATCH 21/37] fix train_batch_size --- projects/modular_llm/train_dpo.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index 5a5d0f741..e53ddd905 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -93,7 +93,13 @@ def create_library(args): model_class = MoEModel else: model_class = ExpertModel - config = DatasetConfig(model=args.model) + config = DatasetConfig( + model=args.model, + train_batch_size=args.train_batch_size, + predict_batch_size=args.predict_batch_size, + max_input_length=args.max_input_length, + max_output_length=args.max_output_length, + ) dm = Preferencemodule(config) # dm = get_datamodule(args) From 75a80d99446b80ecaf9fc90655a3bfa15d816b36 Mon Sep 17 00:00:00 2001 From: zhansu Date: Wed, 21 Aug 2024 07:01:29 -0400 Subject: [PATCH 22/37] wip --- mttl/models/expert_model.py | 117 +++++++++++++++++++++++++++++++----- 1 file changed, 103 insertions(+), 14 deletions(-) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 8e46447ee..c6f87473e 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -781,6 +781,53 @@ def test_step(self, batch, _): return loss.mean() +from functools import wraps + + +def gpu_memory_usage_decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # 清除GPU缓存,确保测量准确 + torch.cuda.empty_cache() + + # 获取函数执行前的GPU内存使用情况 + memory_allocated_before = torch.cuda.memory_allocated(device) + memory_reserved_before = torch.cuda.memory_reserved(device) + print( + f"[Before] Memory allocated: {memory_allocated_before / (1024 ** 2):.2f} MB" + ) + print( + f"[Before] Memory reserved: {memory_reserved_before / (1024 ** 2):.2f} MB" + ) + + # 执行目标函数 + result = func(*args, **kwargs) + + # 获取函数执行后的GPU内存使用情况 + memory_allocated_after = torch.cuda.memory_allocated(device) + memory_reserved_after = torch.cuda.memory_reserved(device) + print( + f"[After] Memory allocated: {memory_allocated_after / (1024 ** 2):.2f} MB" + ) + print(f"[After] Memory reserved: {memory_reserved_after / (1024 ** 2):.2f} MB") + + # 计算内存使用的差值 + memory_allocated_diff = memory_allocated_after - memory_allocated_before + memory_reserved_diff = memory_reserved_after - memory_reserved_before + print( + f"Memory allocated difference: {memory_allocated_diff / (1024 ** 2):.2f} MB" + ) + print( + f"Memory reserved difference: {memory_reserved_diff / (1024 ** 2):.2f} MB" + ) + + return result + + return wrapper + + class ExpertModelDPO(EfficientCheckpointModule): def __init__(self, preference_model, ref_expert_model, **kwargs): @@ -800,6 +847,14 @@ def training_step(self, batch, _): prompt_prefered_mask = batch["prompt_prefered_mask"] prompt_disprefered_mask = batch["prompt_disprefered_mask"] + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + # Measure GPU memory before forward pass + memory_allocated_before = torch.cuda.memory_allocated(device) + memory_reserved_before = torch.cuda.memory_reserved(device) + + # logits = self.preference_model.model.forward(prompt_prefered_ids).logits + + # loss = torch.mean(logits) # original model model_prefered_log_prob = get_log_prob( self.preference_model.model.forward( @@ -830,6 +885,39 @@ def training_step(self, batch, _): labels=prompt_disprefered_ids, ) + # Measure GPU memory after forward pass + memory_allocated_after = torch.cuda.memory_allocated(device) + memory_reserved_after = torch.cuda.memory_reserved(device) + + # Calculate the difference in memory usage + memory_allocated_diff = memory_allocated_after - memory_allocated_before + memory_reserved_diff = memory_reserved_after - memory_reserved_before + + print( + f"Memory allocated before forward pass: {memory_allocated_before / (1024 ** 2):.2f} MB" + ) + print( + f"Memory allocated after forward pass: {memory_allocated_after / (1024 ** 2):.2f} MB" + ) + print( + f"Memory allocated difference: {memory_allocated_diff / (1024 ** 2):.2f} MB" + ) + + print( + f"Memory reserved before forward pass: {memory_reserved_before / (1024 ** 2):.2f} MB" + ) + print( + f"Memory reserved after forward pass: {memory_reserved_after / (1024 ** 2):.2f} MB" + ) + print( + f"Memory reserved difference: {memory_reserved_diff / (1024 ** 2):.2f} MB" + ) + + loss = -F.logsigmoid( + self.beta * (model_prefered_log_prob - model_disprefered_log_prob) + - (ref_prefered_log_prob - ref_disprefered_log_prob) + ).mean() + loss, reward_accuracies, reward_margins = calculate_DPO_loss( model_prefered_log_prob, model_disprefered_log_prob, @@ -839,20 +927,21 @@ def training_step(self, batch, _): ) self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True) - self.log( - "train/reward_accuracies", - reward_accuracies, - on_step=True, - on_epoch=True, - prog_bar=True, - ) - self.log( - "train/reward_margins", - reward_margins, - on_step=True, - on_epoch=True, - prog_bar=True, - ) + # self.log( + # "train/reward_accuracies", + # reward_accuracies, + # on_step=True, + # on_epoch=True, + # prog_bar=True, + # ) + # self.log( + # "train/reward_margins", + # reward_margins, + # on_step=True, + # on_epoch=True, + # prog_bar=True, + # ) + # clear the gpu memory return loss From 3904f11fe3d28fe2ed94816f3856e5bed937eab4 Mon Sep 17 00:00:00 2001 From: zhansu Date: Wed, 28 Aug 2024 03:35:59 -0400 Subject: [PATCH 23/37] add ultrafeedback data --- mttl/datamodule/ultrafeedback_data_module.py | 148 +++++++++++++++++++ 1 file changed, 148 insertions(+) create mode 100644 mttl/datamodule/ultrafeedback_data_module.py diff --git a/mttl/datamodule/ultrafeedback_data_module.py b/mttl/datamodule/ultrafeedback_data_module.py new file mode 100644 index 000000000..bf1a26288 --- /dev/null +++ b/mttl/datamodule/ultrafeedback_data_module.py @@ -0,0 +1,148 @@ +from dataclasses import dataclass +from typing import Any +import torch + +from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule +from mttl.models.library.expert_library import DatasetLibrary + + +def is_openai_format(messages: Any) -> bool: + """ + Check if the input messages are in OpenAI format. + Args: + messages (`Any`): + Messages to check. + Returns: + `bool`: Whether the messages are in OpenAI format. + """ + if isinstance(messages, list) and all( + isinstance(message, dict) for message in messages + ): + return all("role" in message and "content" in message for message in messages) + return False + + +@dataclass +class UltrafeedbackCollator(DefaultCollator): + def __call__(self, batch): + + # For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue + # We therefore need to extract the N-1 turns to form the prompt + prompts = [] + chosen_responses = [] + rejected_responses = [] + for example in batch: + if "prompt" in example and is_openai_format(example["prompt"]): + prompt_messages = example["prompt"] + chosen_messages = example["chosen"] + rejected_messages = example["rejected"] + else: + prompt_messages = example["chosen"][:-1] + # Now we extract the final turn to define chosen/rejected responses + chosen_messages = example["chosen"][-1:] + rejected_messages = example["rejected"][-1:] + prompts.append( + self.tokenizer.apply_chat_template(prompt_messages, tokenize=False) + ) + chosen_responses.append( + self.tokenizer.apply_chat_template(chosen_messages, tokenize=False) + ) + rejected_responses.append( + self.tokenizer.apply_chat_template(rejected_messages, tokenize=False) + ) + + prompt_ids = self.tokenizer.batch_encode_plus( + prompts, + padding=True, + return_tensors="pt", + max_length=self.max_input_length, + truncation=True, + )["input_ids"] + + prefered_tokenize = self.tokenizer.batch_encode_plus( + chosen_responses, + padding=True, + return_tensors="pt", + max_length=self.max_input_length, + truncation=True, + ) + prefered_ids = prefered_tokenize["input_ids"] + + disprefered_tokenize = self.tokenizer.batch_encode_plus( + rejected_responses, + padding=True, + return_tensors="pt", + max_length=self.max_input_length, + truncation=True, + ) + disprefered_ids = disprefered_tokenize["input_ids"] + + prompt_prefered_ids = torch.cat([prompt_ids, prefered_ids], dim=-1) + prompt_disprefered_ids = torch.cat([prompt_ids, disprefered_ids], dim=-1) + + prompt_prefered_mask = torch.cat( + [torch.ones_like(prompt_ids), torch.zeros_like(prefered_ids)], dim=-1 + ) + # compute the each length of the prefered + prefered_y_len = prefered_tokenize["attention_mask"].sum(dim=1) + disprefered_y_len = disprefered_tokenize["attention_mask"].sum(dim=1) + + prompt_disprefered_mask = torch.cat( + [torch.ones_like(prompt_ids), torch.zeros_like(disprefered_ids)], dim=-1 + ) + + return { + "prompt_prefered_ids": prompt_prefered_ids, + "prompt_disprefered_ids": prompt_disprefered_ids, + "prompt_prefered_mask": prompt_prefered_mask, + "prompt_disprefered_mask": prompt_disprefered_mask, + "prefered_y_len": prefered_y_len, + "disprefered_y_len": disprefered_y_len, + } + + +@dataclass +class UltrafeedbackDPOmodule(DefaultDataModule): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setup_dataset(self): + dataset = DatasetLibrary.pull_dataset_with_retry( + "princeton-nlp/gemma2-ultrafeedback-armorm" + ) + + # format the ultrafeedback dataset to chatbot format + self.train_dataset = dataset["train"] + self.test_dataset = dataset["test"] + self.dev_dataset = self.test_dataset + + self.print_infos() + + @property + def collate_fn(self): + return UltrafeedbackCollator( + tokenizer=self.tokenizer, + padding="longest", + max_input_length=self.config.max_input_length, + max_output_length=self.config.max_output_length, + return_tensors="pt", + model_family=self.config.model_family, + for_generation=self.for_generation, + ) + + +if __name__ == "__main__": + config = DatasetConfig(model="microsoft/Phi-3-mini-4k-instruct") + datamodule = UltrafeedbackDPOmodule(config) + train_dataloader = datamodule.train_dataloader() + val_dataloder = datamodule.val_dataloader() + for batch in val_dataloder: + prompt_prefered_mask = batch["prompt_prefered_mask"] + prompt_disprefered_mask = batch["prompt_disprefered_mask"] + + # get the length of the response + prefered_y_len = batch["prefered_y_len"] + disprefered_y_len = batch["disprefered_y_len"] + print(prefered_y_len, disprefered_y_len) + breakpoint() From d1696619dc259971270bc559c9064eabb791e802 Mon Sep 17 00:00:00 2001 From: zhansu Date: Wed, 28 Aug 2024 03:40:31 -0400 Subject: [PATCH 24/37] add urtra data train --- projects/modular_llm/train_dpo.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index e53ddd905..40f395cd0 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -7,11 +7,11 @@ import torch from pytorch_lightning import Trainer, seed_everything -sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from mttl.callbacks import LiveCheckpointCallback from mttl.datamodule.base import DatasetConfig from mttl.datamodule.preference_data_module import Preferencemodule +from mttl.datamodule.ultrafeedback_data_module import UltrafeedbackDPOmodule # from mttl.datamodule.base import get_datamodule from mttl.models.expert_config import ExpertConfig @@ -100,7 +100,11 @@ def create_library(args): max_input_length=args.max_input_length, max_output_length=args.max_output_length, ) - dm = Preferencemodule(config) + + if "ultrafeedback" in args.dataset: + dm = UltrafeedbackDPOmodule(config) + else: + dm = Preferencemodule(config) # dm = get_datamodule(args) # args.n_tasks = len(dm._task_names) From fe40732aa8c4ff2c50b2e01fcf97f861fb2c057a Mon Sep 17 00:00:00 2001 From: zhansu Date: Thu, 5 Sep 2024 11:26:19 -0400 Subject: [PATCH 25/37] add ultralfeedback sft --- mttl/datamodule/ultrafeedback_data_module.py | 67 +++++++++++++++++--- 1 file changed, 59 insertions(+), 8 deletions(-) diff --git a/mttl/datamodule/ultrafeedback_data_module.py b/mttl/datamodule/ultrafeedback_data_module.py index bf1a26288..0b18d72bb 100644 --- a/mttl/datamodule/ultrafeedback_data_module.py +++ b/mttl/datamodule/ultrafeedback_data_module.py @@ -23,7 +23,7 @@ def is_openai_format(messages: Any) -> bool: @dataclass -class UltrafeedbackCollator(DefaultCollator): +class UltrafeedbackDPOCollator(DefaultCollator): def __call__(self, batch): # For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue @@ -121,7 +121,58 @@ def setup_dataset(self): @property def collate_fn(self): - return UltrafeedbackCollator( + return UltrafeedbackDPOCollator( + tokenizer=self.tokenizer, + padding="longest", + max_input_length=self.config.max_input_length, + max_output_length=self.config.max_output_length, + return_tensors="pt", + model_family=self.config.model_family, + for_generation=self.for_generation, + ) + + +@dataclass +class UltrafeedbackSFTCollator(DefaultCollator): + def __call__(self, batch): + + # For SFT, the inputs are triples of (prompt, message), where `chosen` and `rejected` are the final turn of a dialogue + # We therefore need to extract the N-1 turns to form the prompt + prompts = [] + messages = [] + for example in batch: + prompt_messages = example["prompt"] + chosen_messages = example["messages"] + prompts.append(prompt_messages) + messages.append( + self.tokenizer.apply_chat_template(chosen_messages, tokenize=False) + ) + + return { + "sources_texts": prompts, + "labels_texts": messages, + } + + +@dataclass +class UltrafeedbackSFTmodule(DefaultDataModule): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setup_dataset(self): + dataset = DatasetLibrary.pull_dataset_with_retry("HuggingFaceH4/ultrachat_200k") + + # format the ultrafeedback dataset to chatbot format + self.train_dataset = dataset["train_sft"] + self.test_dataset = dataset["test_sft"] + self.dev_dataset = self.test_dataset + + self.print_infos() + + @property + def collate_fn(self): + return UltrafeedbackSFTCollator( tokenizer=self.tokenizer, padding="longest", max_input_length=self.config.max_input_length, @@ -134,15 +185,15 @@ def collate_fn(self): if __name__ == "__main__": config = DatasetConfig(model="microsoft/Phi-3-mini-4k-instruct") - datamodule = UltrafeedbackDPOmodule(config) + datamodule = UltrafeedbackSFTmodule(config) train_dataloader = datamodule.train_dataloader() val_dataloder = datamodule.val_dataloader() for batch in val_dataloder: - prompt_prefered_mask = batch["prompt_prefered_mask"] - prompt_disprefered_mask = batch["prompt_disprefered_mask"] + # prompt_prefered_mask = batch["prompt_prefered_mask"] + # prompt_disprefered_mask = batch["prompt_disprefered_mask"] # get the length of the response - prefered_y_len = batch["prefered_y_len"] - disprefered_y_len = batch["disprefered_y_len"] - print(prefered_y_len, disprefered_y_len) + # prefered_y_len = batch["prefered_y_len"] + # disprefered_y_len = batch["disprefered_y_len"] + print(batch) breakpoint() From 49ee7341ea19b69dad156ac764fa73fdd038b154 Mon Sep 17 00:00:00 2001 From: zhansu Date: Tue, 10 Sep 2024 07:03:29 -0400 Subject: [PATCH 26/37] training module with ultral sft --- mttl/datamodule/base.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mttl/datamodule/base.py b/mttl/datamodule/base.py index 3ded6181f..e088f80e8 100644 --- a/mttl/datamodule/base.py +++ b/mttl/datamodule/base.py @@ -655,6 +655,10 @@ def get_datamodule(args, for_generation=False, dataset_override=None): WinograndeMultiChoiceDataModule, ) + from mttl.datamodule.ultrafeedback_data_module import ( + UltrafeedbackSFTmodule, + ) + # refactor all the common arguments below into a dict common kwargs dataset = args.dataset if not dataset_override else dataset_override @@ -737,6 +741,11 @@ def get_datamodule(args, for_generation=False, dataset_override=None): augment_few_shot=args.augment_few_shot, ) dm = FlatMultiTaskModule(config, for_generation=for_generation) + elif "ultrachat" in dataset: + config = DatasetConfig( + **common_kwargs, + ) + dm = UltrafeedbackSFTmodule(config, for_generation=for_generation) elif "mmlu" in dataset: config = MMLUDataConfig( **common_kwargs, From 8f151884c047370c1e572ee324e6d294310da872 Mon Sep 17 00:00:00 2001 From: zhansu Date: Sat, 16 Nov 2024 01:45:09 -0500 Subject: [PATCH 27/37] fix save --- mttl/evaluators/loglike_evaluator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mttl/evaluators/loglike_evaluator.py b/mttl/evaluators/loglike_evaluator.py index ee7b33adc..668081080 100644 --- a/mttl/evaluators/loglike_evaluator.py +++ b/mttl/evaluators/loglike_evaluator.py @@ -70,7 +70,11 @@ def evaluate( ) loss_per_option = loss_per_option.cpu() - if loss_per_option.dtype in [torch.bfloat16, torch.float16]: + if loss_per_option.dtype in [ + torch.bfloat16, + torch.float16, + torch.float32, + ]: loss_per_option = loss_per_option.float().numpy() loss_per_example = [ @@ -105,6 +109,5 @@ def evaluate( "predictions": all_predictions, "accuracy": float(np.mean(all_accuracies)) if all_accuracies else None, } - self.save_metrics(metrics, output_path) return metrics["accuracy"] From bf45a61e3be17615c55fdfbb723be313662d6c36 Mon Sep 17 00:00:00 2001 From: zhansu Date: Sat, 16 Nov 2024 01:45:22 -0500 Subject: [PATCH 28/37] fix arrow --- mttl/models/containers/selectors/arrow_selector.py | 2 -- mttl/models/library/expert_library.py | 3 ++- projects/modular_llm/eval_library.py | 1 - 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mttl/models/containers/selectors/arrow_selector.py b/mttl/models/containers/selectors/arrow_selector.py index fbfdb3877..3121f1586 100644 --- a/mttl/models/containers/selectors/arrow_selector.py +++ b/mttl/models/containers/selectors/arrow_selector.py @@ -13,7 +13,6 @@ def compute_arrow_embeddings( ab_only=True, tie_params=None, tie_op="concat", - add_base_proto=False, recompute_prototypes=False, ) -> str: from mttl.models.library.library_transforms import ( @@ -26,7 +25,6 @@ def compute_arrow_embeddings( ab_only=ab_only, tie_params=tie_params or "default", tie_op=tie_op, - add_base_proto=add_base_proto, ) ArrowTransform(cfg).transform( library, diff --git a/mttl/models/library/expert_library.py b/mttl/models/library/expert_library.py index 8065dccdc..96c88f544 100644 --- a/mttl/models/library/expert_library.py +++ b/mttl/models/library/expert_library.py @@ -15,7 +15,8 @@ CommitOperationCopy, CommitOperationDelete, ) -from huggingface_hub.errors import RepositoryNotFoundError + +# from huggingface_hub.errors import RepositoryNotFoundError from mttl.logging import logger from mttl.models.library.backend_engine import ( diff --git a/projects/modular_llm/eval_library.py b/projects/modular_llm/eval_library.py index 8a387cb89..af8b1eea9 100644 --- a/projects/modular_llm/eval_library.py +++ b/projects/modular_llm/eval_library.py @@ -117,7 +117,6 @@ def fetch_prototypes(args: EvaluationConfig, library: ExpertLibrary) -> str: ab_only=args.ab_only, tie_params=args.tie_params, tie_op=args.tie_op, - add_base_proto=args.add_base_proto, recompute_prototypes=args.recompute_prototypes, ) elif args.merge_or_route == "hidden": From 2f081900c572be7dd685899e7ecd33fbb7f54940 Mon Sep 17 00:00:00 2001 From: zhansu Date: Sun, 17 Nov 2024 12:05:23 -0500 Subject: [PATCH 29/37] clean --- projects/modular_llm/train_dpo.py | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index 40f395cd0..9eb0d9266 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -37,34 +37,6 @@ from projects.modular_llm.src.transfer_matrix import run_eval as produce_transfer_matrix -def create_transfer_matrix(args, checkpoint): - ######################## - # create transfer matrix - config = TransferMatrixConfig() - for k, v in vars(args).items(): - if k in vars(config): - setattr(config, k, v) - config.eval_base = False - config.eval_metric = "rougeL" - - expert: Expert = load_expert(checkpoint) - expert.expert_info.expert_name = str(args.finetune_task_name) - expert.expert_info.expert_task_name = str(args.finetune_task_name) - temp_dir = TemporaryDirectory() - destination = temp_dir.name - LocalExpertLibrary.from_expert_dict({"checkpoint": expert}, destination=destination) - config.library_id = destination - config.finetune_task_name = ( - args.finetune_task_name.split(",") - if not isinstance(args.finetune_task_name, list) - else args.finetune_task_name - ) - if len(config.finetune_task_name) < 50: - produce_transfer_matrix(config, debug=False) - ######################## - temp_dir.cleanup() - - def run_multitask(args: ExpertConfig): seed_everything(args.seed, workers=True) @@ -232,9 +204,6 @@ def upload_library(expert_library, module): # upload_library(expert_library, module) - if args.create_transfer_matrix: - create_transfer_matrix(args, checkpoint) - if __name__ == "__main__": args = ExpertConfig.parse() From b33316b6ec7d7bebf4675dae553234e08835c1b3 Mon Sep 17 00:00:00 2001 From: zhansu Date: Sun, 17 Nov 2024 22:31:40 -0500 Subject: [PATCH 30/37] merge the code to the current version --- mttl/arguments.py | 3 +++ mttl/datamodule/preference_data_module.py | 6 +++--- mttl/datamodule/ultrafeedback_data_module.py | 8 ++++---- mttl/models/expert_model.py | 7 ++++--- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/mttl/arguments.py b/mttl/arguments.py index cc10e5788..29cd24231 100644 --- a/mttl/arguments.py +++ b/mttl/arguments.py @@ -410,6 +410,9 @@ class TrainingArgs(DataArgs): profile: bool = False # if 'True' will profile the model training + # dpo + rl_training: str = "dpo" + @property def dataset_config(self): if self.dataset_type is not None: diff --git a/mttl/datamodule/preference_data_module.py b/mttl/datamodule/preference_data_module.py index 04149fa6a..630a0b637 100644 --- a/mttl/datamodule/preference_data_module.py +++ b/mttl/datamodule/preference_data_module.py @@ -2,8 +2,8 @@ import torch -from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule -from mttl.models.library.expert_library import DatasetLibrary +from mttl.datamodule.base import DatasetConfig, DefaultCollator, DataModule +from mttl.models.library.dataset_library import DatasetLibrary @dataclass @@ -64,7 +64,7 @@ def __call__(self, batch): @dataclass -class Preferencemodule(DefaultDataModule): +class Preferencemodule(DataModule): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/mttl/datamodule/ultrafeedback_data_module.py b/mttl/datamodule/ultrafeedback_data_module.py index 0b18d72bb..df3ad6e57 100644 --- a/mttl/datamodule/ultrafeedback_data_module.py +++ b/mttl/datamodule/ultrafeedback_data_module.py @@ -2,8 +2,8 @@ from typing import Any import torch -from mttl.datamodule.base import DatasetConfig, DefaultCollator, DefaultDataModule -from mttl.models.library.expert_library import DatasetLibrary +from mttl.datamodule.base import DatasetConfig, DefaultCollator, DataModule +from mttl.models.library.dataset_library import DatasetLibrary def is_openai_format(messages: Any) -> bool: @@ -102,7 +102,7 @@ def __call__(self, batch): @dataclass -class UltrafeedbackDPOmodule(DefaultDataModule): +class UltrafeedbackDPOmodule(DataModule): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -155,7 +155,7 @@ def __call__(self, batch): @dataclass -class UltrafeedbackSFTmodule(DefaultDataModule): +class UltrafeedbackSFTmodule(DataModule): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/mttl/models/expert_model.py b/mttl/models/expert_model.py index 37b220b05..fc08c8f1c 100644 --- a/mttl/models/expert_model.py +++ b/mttl/models/expert_model.py @@ -32,6 +32,8 @@ Modifier, ) from mttl.models.modifiers.modify_model import modify_transformer +from mttl.models.lightning.base_module import LightningEfficientCheckpoint +import torch.nn.functional as F @contextlib.contextmanager @@ -64,7 +66,6 @@ def __init__( **loading_kwargs, ): super().__init__(config, model_object=model_object, **loading_kwargs) - if config.modifier_config is not None: modify_transformer(self.model, config.modifier_config) @@ -591,7 +592,7 @@ def get_log_prob(logits, labels): return torch.gather(log_probs, -1, labels.unsqueeze(-1)).squeeze(-1).mean(-1) -class ExpertModelSimPO(EfficientCheckpointModule): +class ExpertModelSimPO(LightningEfficientCheckpoint): def __init__(self, preference_model, **kwargs): super().__init__(**kwargs) self.preference_model = preference_model @@ -782,7 +783,7 @@ def wrapper(*args, **kwargs): return wrapper -class ExpertModelDPO(EfficientCheckpointModule): +class ExpertModelDPO(LightningEfficientCheckpoint): def __init__(self, preference_model, ref_expert_model, **kwargs): super().__init__(**kwargs) From 96bb1c4c97dcb64d1cecb07a540ea3f99bcee81a Mon Sep 17 00:00:00 2001 From: zhansu Date: Sun, 17 Nov 2024 22:31:52 -0500 Subject: [PATCH 31/37] clean the code --- projects/modular_llm/train_dpo.py | 42 ++++++++++++++++++------------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index 9eb0d9266..06ffd9527 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -8,13 +8,14 @@ from pytorch_lightning import Trainer, seed_everything -from mttl.callbacks import LiveCheckpointCallback +from mttl.models.lightning.callbacks import LiveCheckpointCallback from mttl.datamodule.base import DatasetConfig from mttl.datamodule.preference_data_module import Preferencemodule from mttl.datamodule.ultrafeedback_data_module import UltrafeedbackDPOmodule # from mttl.datamodule.base import get_datamodule -from mttl.models.expert_config import ExpertConfig +from mttl.arguments import ExpertConfig, MultiExpertConfig +from mttl.models.expert_model import ExpertModelConfig from mttl.models.expert_model import ( ExpertModel, ExpertModelDPO, @@ -24,17 +25,9 @@ from mttl.models.library.expert import Expert, load_expert from mttl.models.library.expert_library import ExpertLibrary, LocalExpertLibrary from mttl.models.monitors import get_monitors -from mttl.utils import ( - generate_random_string, - get_pl_loggers, - logger, - rank_zero_only_and_wait, - remote_login, - setup_logging, -) -from projects.modular_llm.eval_library import patch_prototypes -from projects.modular_llm.src.transfer_matrix import TransferMatrixConfig -from projects.modular_llm.src.transfer_matrix import run_eval as produce_transfer_matrix +from mttl.models.lightning.loggers import get_pl_loggers +from mttl.logging import logger, setup_logging +from mttl.utils import generate_random_string, rank_zero_only_and_wait, remote_login def run_multitask(args: ExpertConfig): @@ -81,13 +74,26 @@ def create_library(args): # dm = get_datamodule(args) # args.n_tasks = len(dm._task_names) # args.task_names = dm._task_names + + args.tokenizer = dm.tokenizer + model_config = ExpertModelConfig( + base_model=args.model, + task_name=args.finetune_task_name, + expert_name=args.expert_name, + modifier_config=args.modifier_config, + ) + model = model_class( - **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library + config=model_config, + expert_library=expert_library, + **vars(args), ) if args.rl_training == "dpo": - args.trainable_param_names = "^(?=.*preference_model)(?=.*prototypes).*" + # args.trainable_param_names = "^(?=.*preference_model)(?=.*prototypes).*" ref_model = model_class( - **vars(args), tokenizer=dm.tokenizer, expert_library=expert_library + config=model_config, + expert_library=expert_library, + **vars(args), ) # eval mode ref_model.eval() @@ -95,7 +101,7 @@ def create_library(args): **vars(args), preference_model=model, ref_expert_model=ref_model ) elif args.rl_training == "simpo": - args.trainable_param_names = "^(?=.*preference_model)(?=.*prototypes).*" + # args.trainable_param_names = "^(?=.*preference_model)(?=.*prototypes).*" module = ExpertModelSimPO(**vars(args), preference_model=model) else: module = model @@ -206,5 +212,5 @@ def upload_library(expert_library, module): if __name__ == "__main__": - args = ExpertConfig.parse() + args = MultiExpertConfig.parse() ## in case we only train the routing run_multitask(args) From 54da2226daff77eac75bbc39fdd1ef3fa0598198 Mon Sep 17 00:00:00 2001 From: zhansu Date: Mon, 18 Nov 2024 02:22:26 -0500 Subject: [PATCH 32/37] clean the code --- projects/modular_llm/train_dpo.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/projects/modular_llm/train_dpo.py b/projects/modular_llm/train_dpo.py index 06ffd9527..9fd4928e1 100644 --- a/projects/modular_llm/train_dpo.py +++ b/projects/modular_llm/train_dpo.py @@ -71,15 +71,8 @@ def create_library(args): else: dm = Preferencemodule(config) - # dm = get_datamodule(args) - # args.n_tasks = len(dm._task_names) - # args.task_names = dm._task_names - - args.tokenizer = dm.tokenizer model_config = ExpertModelConfig( base_model=args.model, - task_name=args.finetune_task_name, - expert_name=args.expert_name, modifier_config=args.modifier_config, ) From 147878e7464ba7a704060b394f1e6c4ea8bd709f Mon Sep 17 00:00:00 2001 From: zhansu Date: Mon, 18 Nov 2024 02:23:11 -0500 Subject: [PATCH 33/37] add orca datamodule --- mttl/datamodule/orca_data_module.py | 86 +++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 mttl/datamodule/orca_data_module.py diff --git a/mttl/datamodule/orca_data_module.py b/mttl/datamodule/orca_data_module.py new file mode 100644 index 000000000..7b2c6a616 --- /dev/null +++ b/mttl/datamodule/orca_data_module.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass + +from mttl.datamodule.base import DatasetConfig, DefaultCollator, DataModule +from mttl.models.library.dataset_library import DatasetLibrary +import ast + + +@dataclass +class OrcaDataModuleCollator(DefaultCollator): + def __call__(self, batch): + source_texts = [] + target_texts = [] + for item in batch: + if type(item["messages"]) == str: + item["messages"] = ast.literal_eval(item["messages"]) + source_texts.append("You are a helpful assistant.") + target_texts.append( + self.tokenizer.apply_chat_template(item["messages"], tokenize=False) + ) + + return { + "source_texts": source_texts, + "target_texts": target_texts, + } + + +@dataclass +class OrcaDataModule(DataModule): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setup_dataset(self): + train_dataset = DatasetLibrary.pull_dataset_with_retry( + "microsoft/orca-agentinstruct-1M-v1" + )[self.config.finetune_task_name] + + self.train_dataset, self.dev_dataset = self.create_train_valid_split( + train_dataset, 0.1 + ) + self.test_dataset = self.dev_dataset + + self.print_infos() + + @property + def collate_fn(self): + return OrcaDataModuleCollator( + tokenizer=self.tokenizer, + padding="longest", + max_input_length=self.config.max_input_length, + max_output_length=self.config.max_output_length, + return_tensors="pt", + model_family=self.config.model_family, + for_generation=self.for_generation, + ) + + +if __name__ == "__main__": + config = DatasetConfig( + model="microsoft/Phi-3-mini-4k-instruct", finetune_task_name="creative_content" + ) + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct") + message = [ + {"role": "system", "content": ""}, + { + "role": "user", + "content": '\n# Adapted CMake code snippet for finding a library, its include directories, and checking for a specific version\n# Additionally, handle different library names based on the platform and check for multiple components of the library\n\n```cmake\n# Define a function to check the version of the library\nfunction(check_mylib_version)\n # Assume MYLIB_VERSION_STRING is set by some mechanism after finding the library\n if (MYLIB_VERSION_STRING VERSION_LESS "2.0")\n message(WARNING "MyLib version is too old, at least version 2.0 is required.")\n set(MYLIB_VERSION_OK FALSE PARENT_SCOPE)\n else()\n set(MYLIB_VERSION_OK TRUE PARENT_SCOPE)\n endif()\nendfunction()\n\n# Potential names of the library based on the platform\nif (WIN32)\n set(MYLIB_POSSIBLE_NAMES mylib_win)\nelseif (UNIX)\n set(MYLIB_POSSIBLE_NAMES mylib_unix mylib_unix_alt)\nelse()\n set(MYLIB_POSSIBLE_NAMES mylib)\nendif()\n\n# Find the include directory and the library\nfind_path(MYLIB_INCLUDE_DIRS NAMES mylib.h)\nfind_library(MYLIB_LIBRARIES NAMES ${MYLIB_POSSIBLE_NAMES})\n\n# Handle multiple components of the library\nfind_library(MYLIB_EXTRA_LIBRARIES NAMES mylib_extra)\n\ninclude(FindPackageHandleStandardArgs)\nfind_package_handle_standard_args(MyLib DEFAULT_MSG\n MYLIB_LIBRARIES MYLIB_INCLUDE_DIRS MYLIB_EXTRA_LIBRARIES)\n\n# After finding the library, check its version\ncheck_mylib_version()\n\nif (MYLIB_INCLUDE_DIRS AND MYLIB_LIBRARIES AND MYLIB_EXTRA_LIBRARIES AND MYLIB_VERSION_OK)\n set(MYLIB_FOUND TRUE)\nendif()\n\nmark_as_advanced(MYLIB_INCLUDE_DIRS MYLIB_LIBRARIES MYLIB_EXTRA_LIBRARIES)\n\n# Output should include:\n# - Comments explaining the code execution for the scenario where mylib.h is not found,\n# but the mylib library and its extra component are found in non-standard locations.\n# - Indicate whether `MYLIB_FOUND` would be set to `TRUE` or not.\n# - Explain the version checking logic and its impact on `MYLIB_FOUND`.\n# - Discuss the handling of different library names based on the platform.\n# - Explain the importance of finding all components of the library.\n```\n\nInput:\n# Scenario where mylib.h is not found, but the mylib library and its extra component are found in non-standard locations, and the version found is 2.1.\n', + }, + { + "role": "assistant", + "content": "Output:\n\nIn the given scenario, the `find_path` command is used to search for the `mylib.h` header file. Since `mylib.h` is not found, `MYLIB_INCLUDE_DIRS` will not be set, and this will be a critical failure for the `find_package_handle_standard_args` function, which expects `MYLIB_INCLUDE_DIRS` to be set to a valid path.\n\nThe `find_library` command is used to locate the `mylib` library and its extra component. The library names are determined based on the platform. For Windows, it looks for `mylib_win`, for Unix-like systems, it looks for `mylib_unix` or `mylib_unix_alt`, and for other platforms, it defaults to `mylib`. Since the libraries are found in non-standard locations, `MYLIB_LIBRARIES` and `MYLIB_EXTRA_LIBRARIES` will be set to those locations.\n\nThe `find_package_handle_standard_args` function is called to handle the results of the find commands. It checks if the required variables (`MYLIB_LIBRARIES`, `MYLIB_INCLUDE_DIRS`, `MYLIB_EXTRA_LIBRARIES`) are set. If any of these are not found, it will print a standard error message, and `MYLIB_FOUND` will not be set to `TRUE`.\n\nAfter the find commands, the `check_mylib_version` function is called to check the version of the library. Since the version found is 2.1, which is not less than 2.0, `MYLIB_VERSION_OK` will be set to `TRUE`.\n\nHowever, despite the library and its extra component being found and the version being acceptable, `MYLIB_FOUND` will not be set to `TRUE` because `MYLIB_INCLUDE_DIRS` is not set (the header file was not found). The `MYLIB_FOUND` variable is only set to `TRUE` if all components (include directories, libraries, extra libraries) are found and the version is okay.\n\nThe version checking logic is important because it ensures that the library version meets the minimum requirements for compatibility with the project. If the version is too old, it may lack necessary features or bug fixes, and the `MYLIB_VERSION_OK` variable will be set to `FALSE`, which would prevent `MYLIB_FOUND` from being set to `TRUE`.\n\nHandling different library names based on the platform is crucial because libraries can have different naming conventions on different operating systems. This flexibility allows the CMake script to find the correct library files regardless of the platform it is running on.\n\nFinding all components of the library is important because a library typically consists of both binary files (the library itself) and header files (for including in the source code). If any component is missing, the library cannot be fully utilized, and the build process may fail. Therefore, it is essential to ensure that all parts of the library are present before proceeding with the build.", + }, + ] + + dataset = DatasetLibrary.pull_dataset_with_retry( + "microsoft/orca-agentinstruct-1M-v1" + ) + # output = tokenizer.apply_chat_template(message, tokenize=False) + datamodule = OrcaDataModule(config) + train_dataloader = datamodule.train_dataloader() + val_dataloder = datamodule.val_dataloader() + for batch in val_dataloder: + print(batch) + breakpoint() From 23b3b3677e2157e4c2c823623cbd4ebf59b2e556 Mon Sep 17 00:00:00 2001 From: zhansu Date: Mon, 18 Nov 2024 03:31:14 -0500 Subject: [PATCH 34/37] add orca sft training --- mttl/datamodule/base.py | 7 +++++++ mttl/datamodule/orca_data_module.py | 22 ++++++++++++++-------- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/mttl/datamodule/base.py b/mttl/datamodule/base.py index a4526f152..4747c752f 100644 --- a/mttl/datamodule/base.py +++ b/mttl/datamodule/base.py @@ -972,6 +972,8 @@ def get_datamodule(args, for_generation=False, dataset_override=None): UltrafeedbackSFTmodule, ) + from mttl.datamodule.orca_data_module import OrcaDataModule + # if we have a DataArgs object, we can directly create the datamodule if isinstance(args, DataArgs) and args.dataset_type is not None: dataset_config = args.dataset_config @@ -1072,6 +1074,11 @@ def get_datamodule(args, for_generation=False, dataset_override=None): **common_kwargs, ) dm = UltrafeedbackSFTmodule(config, for_generation=for_generation) + elif "orca" in dataset: + config = DatasetConfig( + **common_kwargs, + ) + dm = OrcaDataModule(config, for_generation=for_generation) elif "mmlu" in dataset: config = MMLUDataConfig( **common_kwargs, diff --git a/mttl/datamodule/orca_data_module.py b/mttl/datamodule/orca_data_module.py index 7b2c6a616..97fea486b 100644 --- a/mttl/datamodule/orca_data_module.py +++ b/mttl/datamodule/orca_data_module.py @@ -8,20 +8,26 @@ @dataclass class OrcaDataModuleCollator(DefaultCollator): def __call__(self, batch): - source_texts = [] - target_texts = [] + sources = [] + labels = [] for item in batch: if type(item["messages"]) == str: item["messages"] = ast.literal_eval(item["messages"]) - source_texts.append("You are a helpful assistant.") - target_texts.append( + + sources.append("You are a helpful assistant.") + labels.append( self.tokenizer.apply_chat_template(item["messages"], tokenize=False) ) + output_batch = ( + self.prepare_inputs_for_gpt_family(sources, labels) + if self.model_family == "gpt" + else self.prepare_inputs_for_seq2seq_family(sources, labels) + ) + + output_batch["sources_texts"] = sources + output_batch["labels_texts"] = labels - return { - "source_texts": source_texts, - "target_texts": target_texts, - } + return output_batch @dataclass From 19fb72c57a5087702238d7ea031ab3daed5501c5 Mon Sep 17 00:00:00 2001 From: zhansu Date: Tue, 19 Nov 2024 08:30:16 -0500 Subject: [PATCH 35/37] update orca --- mttl/datamodule/orca_data_module.py | 85 ++++++++++++++--------------- 1 file changed, 41 insertions(+), 44 deletions(-) diff --git a/mttl/datamodule/orca_data_module.py b/mttl/datamodule/orca_data_module.py index 97fea486b..599bf919f 100644 --- a/mttl/datamodule/orca_data_module.py +++ b/mttl/datamodule/orca_data_module.py @@ -4,30 +4,29 @@ from mttl.models.library.dataset_library import DatasetLibrary import ast - -@dataclass -class OrcaDataModuleCollator(DefaultCollator): - def __call__(self, batch): - sources = [] - labels = [] - for item in batch: - if type(item["messages"]) == str: - item["messages"] = ast.literal_eval(item["messages"]) - - sources.append("You are a helpful assistant.") - labels.append( - self.tokenizer.apply_chat_template(item["messages"], tokenize=False) - ) - output_batch = ( - self.prepare_inputs_for_gpt_family(sources, labels) - if self.model_family == "gpt" - else self.prepare_inputs_for_seq2seq_family(sources, labels) - ) - - output_batch["sources_texts"] = sources - output_batch["labels_texts"] = labels - - return output_batch +# @dataclass +# class OrcaDataModuleCollator(DefaultCollator): +# def __call__(self, batch): +# sources = [] +# labels = [] +# for item in batch: +# if type(item["messages"]) == str: +# item["messages"] = ast.literal_eval(item["messages"]) + +# sources.append("You are a helpful assistant.") +# labels.append( +# self.tokenizer.apply_chat_template(item["messages"], tokenize=False) +# ) +# output_batch = ( +# self.prepare_inputs_for_gpt_family(sources, labels) +# if self.model_family == "gpt" +# else self.prepare_inputs_for_seq2seq_family(sources, labels) +# ) + +# output_batch["sources_texts"] = sources +# output_batch["labels_texts"] = labels + +# return output_batch @dataclass @@ -38,8 +37,10 @@ def __init__(self, *args, **kwargs): def setup_dataset(self): train_dataset = DatasetLibrary.pull_dataset_with_retry( - "microsoft/orca-agentinstruct-1M-v1" - )[self.config.finetune_task_name] + "zhan1993/orca_sqs_dataset" + )["train"] + + train_dataset = train_dataset.filter(lambda example: example["task_name"] == self.config.finetune_task_name) self.train_dataset, self.dev_dataset = self.create_train_valid_split( train_dataset, 0.1 @@ -50,7 +51,7 @@ def setup_dataset(self): @property def collate_fn(self): - return OrcaDataModuleCollator( + return DefaultCollator( tokenizer=self.tokenizer, padding="longest", max_input_length=self.config.max_input_length, @@ -66,24 +67,20 @@ def collate_fn(self): model="microsoft/Phi-3-mini-4k-instruct", finetune_task_name="creative_content" ) from transformers import AutoTokenizer - + #<|assistant|> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct") - message = [ - {"role": "system", "content": ""}, - { - "role": "user", - "content": '\n# Adapted CMake code snippet for finding a library, its include directories, and checking for a specific version\n# Additionally, handle different library names based on the platform and check for multiple components of the library\n\n```cmake\n# Define a function to check the version of the library\nfunction(check_mylib_version)\n # Assume MYLIB_VERSION_STRING is set by some mechanism after finding the library\n if (MYLIB_VERSION_STRING VERSION_LESS "2.0")\n message(WARNING "MyLib version is too old, at least version 2.0 is required.")\n set(MYLIB_VERSION_OK FALSE PARENT_SCOPE)\n else()\n set(MYLIB_VERSION_OK TRUE PARENT_SCOPE)\n endif()\nendfunction()\n\n# Potential names of the library based on the platform\nif (WIN32)\n set(MYLIB_POSSIBLE_NAMES mylib_win)\nelseif (UNIX)\n set(MYLIB_POSSIBLE_NAMES mylib_unix mylib_unix_alt)\nelse()\n set(MYLIB_POSSIBLE_NAMES mylib)\nendif()\n\n# Find the include directory and the library\nfind_path(MYLIB_INCLUDE_DIRS NAMES mylib.h)\nfind_library(MYLIB_LIBRARIES NAMES ${MYLIB_POSSIBLE_NAMES})\n\n# Handle multiple components of the library\nfind_library(MYLIB_EXTRA_LIBRARIES NAMES mylib_extra)\n\ninclude(FindPackageHandleStandardArgs)\nfind_package_handle_standard_args(MyLib DEFAULT_MSG\n MYLIB_LIBRARIES MYLIB_INCLUDE_DIRS MYLIB_EXTRA_LIBRARIES)\n\n# After finding the library, check its version\ncheck_mylib_version()\n\nif (MYLIB_INCLUDE_DIRS AND MYLIB_LIBRARIES AND MYLIB_EXTRA_LIBRARIES AND MYLIB_VERSION_OK)\n set(MYLIB_FOUND TRUE)\nendif()\n\nmark_as_advanced(MYLIB_INCLUDE_DIRS MYLIB_LIBRARIES MYLIB_EXTRA_LIBRARIES)\n\n# Output should include:\n# - Comments explaining the code execution for the scenario where mylib.h is not found,\n# but the mylib library and its extra component are found in non-standard locations.\n# - Indicate whether `MYLIB_FOUND` would be set to `TRUE` or not.\n# - Explain the version checking logic and its impact on `MYLIB_FOUND`.\n# - Discuss the handling of different library names based on the platform.\n# - Explain the importance of finding all components of the library.\n```\n\nInput:\n# Scenario where mylib.h is not found, but the mylib library and its extra component are found in non-standard locations, and the version found is 2.1.\n', - }, - { - "role": "assistant", - "content": "Output:\n\nIn the given scenario, the `find_path` command is used to search for the `mylib.h` header file. Since `mylib.h` is not found, `MYLIB_INCLUDE_DIRS` will not be set, and this will be a critical failure for the `find_package_handle_standard_args` function, which expects `MYLIB_INCLUDE_DIRS` to be set to a valid path.\n\nThe `find_library` command is used to locate the `mylib` library and its extra component. The library names are determined based on the platform. For Windows, it looks for `mylib_win`, for Unix-like systems, it looks for `mylib_unix` or `mylib_unix_alt`, and for other platforms, it defaults to `mylib`. Since the libraries are found in non-standard locations, `MYLIB_LIBRARIES` and `MYLIB_EXTRA_LIBRARIES` will be set to those locations.\n\nThe `find_package_handle_standard_args` function is called to handle the results of the find commands. It checks if the required variables (`MYLIB_LIBRARIES`, `MYLIB_INCLUDE_DIRS`, `MYLIB_EXTRA_LIBRARIES`) are set. If any of these are not found, it will print a standard error message, and `MYLIB_FOUND` will not be set to `TRUE`.\n\nAfter the find commands, the `check_mylib_version` function is called to check the version of the library. Since the version found is 2.1, which is not less than 2.0, `MYLIB_VERSION_OK` will be set to `TRUE`.\n\nHowever, despite the library and its extra component being found and the version being acceptable, `MYLIB_FOUND` will not be set to `TRUE` because `MYLIB_INCLUDE_DIRS` is not set (the header file was not found). The `MYLIB_FOUND` variable is only set to `TRUE` if all components (include directories, libraries, extra libraries) are found and the version is okay.\n\nThe version checking logic is important because it ensures that the library version meets the minimum requirements for compatibility with the project. If the version is too old, it may lack necessary features or bug fixes, and the `MYLIB_VERSION_OK` variable will be set to `FALSE`, which would prevent `MYLIB_FOUND` from being set to `TRUE`.\n\nHandling different library names based on the platform is crucial because libraries can have different naming conventions on different operating systems. This flexibility allows the CMake script to find the correct library files regardless of the platform it is running on.\n\nFinding all components of the library is important because a library typically consists of both binary files (the library itself) and header files (for including in the source code). If any component is missing, the library cannot be fully utilized, and the build process may fail. Therefore, it is essential to ensure that all parts of the library are present before proceeding with the build.", - }, - ] - - dataset = DatasetLibrary.pull_dataset_with_retry( - "microsoft/orca-agentinstruct-1M-v1" - ) - # output = tokenizer.apply_chat_template(message, tokenize=False) + message = [ { "content": "Create an outline for a presentation that includes five engaging and informative slides to test the audience's understanding of the concepts discussed in your talk, such as p-value functions, compatibility intervals, and the misconceptions about p-values. Each slide should present a multiple-choice question related to these topics. Prepare talking points that explain the correct answers and why they are correct, ensuring the explanations are clear and suitable for an educational presentation.", "role": "user" }, { "content": "Title: Understanding P-Values and Statistical Inference\n\nSlide 1: Introduction to P-Values\n- Multiple-Choice Question: What does a p-value indicate in hypothesis testing?\n A) The probability that the null hypothesis is true\n B) The probability of observing the data, or something more extreme, if the null hypothesis is true\n C) The probability that the alternative hypothesis is true\n D) The likelihood of making a Type I error\n\nTalking Points:\n- The correct answer is B. The p-value is the probability of observing the data, or something more extreme, if the null hypothesis is true. It is not a direct statement about the truth of the null or alternative hypothesis (A and C are incorrect). While it is related to the risk of a Type I error, it is not the likelihood of making this error (D is incorrect), but rather a measure of evidence against the null hypothesis given the data.\n\nSlide 2: Understanding Compatibility Intervals\n- Multiple-Choice Question: What does a 95% compatibility interval (CI) for a parameter estimate indicate?\n A) 95% of the data falls within this interval\n B) There is a 95% chance that the interval contains the true parameter value\n C) If we were to repeat the experiment many times, 95% of the calculated intervals would contain the true parameter value\n D) The parameter has a 95% probability of being within this interval\n\nTalking Points:\n- The correct answer is C. A 95% compatibility interval means that if we were to repeat the experiment many times, we would expect 95% of the calculated intervals to contain the true parameter value. It does not mean that 95% of the data falls within this interval (A is incorrect), nor does it imply a probability statement about the parameter itself (B and D are incorrect), as frequentist intervals do not assign probabilities to fixed parameters.\n\nSlide 3: Misconceptions About P-Values\n- Multiple-Choice Question: Which of the following is a common misconception about p-values?\n A) A low p-value confirms the research hypothesis\n B) A high p-value proves the null hypothesis is true\n C) P-values can indicate the size of an effect\n D) All of the above\n\nTalking Points:\n- The correct answer is D. All of the options listed are common misconceptions about p-values. A low p-value does not confirm the research hypothesis (A is incorrect); it simply suggests that the data are unlikely under the null hypothesis. A high p-value does not prove the null hypothesis is true (B is incorrect); it suggests that the data are not sufficiently unusual under the null hypothesis. P-values do not provide information about the size or importance of an effect (C is incorrect); they only indicate how incompatible the data are with the null hypothesis.\n\nSlide 4: P-Value Functions\n- Multiple-Choice Question: What information does a p-value function provide that a single p-value does not?\n A) The probability that the null hypothesis is true for different parameter values\n B) The p-values for a range of different alternative hypotheses\n C) The likelihood of the data under various parameter values\n D) The confidence level of the compatibility interval\n\nTalking Points:\n- The correct answer is B. A p-value function provides the p-values for a range of different alternative hypotheses, not just a single point alternative. It does not give the probability that the null hypothesis is true for different parameter values (A is incorrect), nor does it directly provide the likelihood of the data (C is incorrect). The confidence level of the compatibility interval is a separate concept (D is incorrect).\n\nSlide 5: Interpreting P-Values and Compatibility Intervals Together\n- Multiple-Choice Question: How should p-values and compatibility intervals be used together to interpret statistical results?\n A) To determine the probability that the null hypothesis is true\n B) To assess the evidence against the null hypothesis and the range of parameter values that are compatible with the data\n C) To calculate the effect size and its significance\n D) To prove the research hypothesis\n\nTalking Points:\n- The correct answer is B. P-values and compatibility intervals should be used together to assess the evidence against the null hypothesis (provided by the p-value) and to understand the range of parameter values that are compatible with the data (provided by the compatibility interval). They do not determine the probability that the null hypothesis is true (A is incorrect), nor do they prove the research hypothesis (D is incorrect). While they can help contextualize the effect size, they do not calculate it (C is incorrect).", "role": "assistant" } ] + output = tokenizer.apply_chat_template(message, tokenize=False) + + # dataset = DatasetLibrary.pull_dataset_with_retry( + # "zhan1993/orca_sqs_dataset" + # ) + + # dataset = dataset.filter(lambda example: example["task_name"] == config.finetune_task_name) + # splited_dataset = dataset.map(get_source_target,remove_columns=["messages","split"]) + # breakpoint() + + datamodule = OrcaDataModule(config) train_dataloader = datamodule.train_dataloader() val_dataloder = datamodule.val_dataloader() From 9d9ab5b316081be7316c6daa1f18bb282350ae5e Mon Sep 17 00:00:00 2001 From: zhansu Date: Tue, 19 Nov 2024 09:03:24 -0500 Subject: [PATCH 36/37] chose finetune task --- mttl/datamodule/orca_data_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mttl/datamodule/orca_data_module.py b/mttl/datamodule/orca_data_module.py index 599bf919f..66ad54179 100644 --- a/mttl/datamodule/orca_data_module.py +++ b/mttl/datamodule/orca_data_module.py @@ -39,8 +39,8 @@ def setup_dataset(self): train_dataset = DatasetLibrary.pull_dataset_with_retry( "zhan1993/orca_sqs_dataset" )["train"] - - train_dataset = train_dataset.filter(lambda example: example["task_name"] == self.config.finetune_task_name) + if self.config.finetune_task_name: + train_dataset = train_dataset.filter(lambda example: example["task_name"] == self.config.finetune_task_name) self.train_dataset, self.dev_dataset = self.create_train_valid_split( train_dataset, 0.1 From 718b138396e501e77f3f867a7fe91fad7c272665 Mon Sep 17 00:00:00 2001 From: zhansu Date: Tue, 19 Nov 2024 09:12:37 -0500 Subject: [PATCH 37/37] format --- mttl/datamodule/orca_data_module.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/mttl/datamodule/orca_data_module.py b/mttl/datamodule/orca_data_module.py index 66ad54179..aa6248034 100644 --- a/mttl/datamodule/orca_data_module.py +++ b/mttl/datamodule/orca_data_module.py @@ -40,8 +40,9 @@ def setup_dataset(self): "zhan1993/orca_sqs_dataset" )["train"] if self.config.finetune_task_name: - train_dataset = train_dataset.filter(lambda example: example["task_name"] == self.config.finetune_task_name) - + train_dataset = train_dataset.filter( + lambda example: example["task_name"] == self.config.finetune_task_name + ) self.train_dataset, self.dev_dataset = self.create_train_valid_split( train_dataset, 0.1 ) @@ -67,9 +68,19 @@ def collate_fn(self): model="microsoft/Phi-3-mini-4k-instruct", finetune_task_name="creative_content" ) from transformers import AutoTokenizer - #<|assistant|> + + # <|assistant|> tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct") - message = [ { "content": "Create an outline for a presentation that includes five engaging and informative slides to test the audience's understanding of the concepts discussed in your talk, such as p-value functions, compatibility intervals, and the misconceptions about p-values. Each slide should present a multiple-choice question related to these topics. Prepare talking points that explain the correct answers and why they are correct, ensuring the explanations are clear and suitable for an educational presentation.", "role": "user" }, { "content": "Title: Understanding P-Values and Statistical Inference\n\nSlide 1: Introduction to P-Values\n- Multiple-Choice Question: What does a p-value indicate in hypothesis testing?\n A) The probability that the null hypothesis is true\n B) The probability of observing the data, or something more extreme, if the null hypothesis is true\n C) The probability that the alternative hypothesis is true\n D) The likelihood of making a Type I error\n\nTalking Points:\n- The correct answer is B. The p-value is the probability of observing the data, or something more extreme, if the null hypothesis is true. It is not a direct statement about the truth of the null or alternative hypothesis (A and C are incorrect). While it is related to the risk of a Type I error, it is not the likelihood of making this error (D is incorrect), but rather a measure of evidence against the null hypothesis given the data.\n\nSlide 2: Understanding Compatibility Intervals\n- Multiple-Choice Question: What does a 95% compatibility interval (CI) for a parameter estimate indicate?\n A) 95% of the data falls within this interval\n B) There is a 95% chance that the interval contains the true parameter value\n C) If we were to repeat the experiment many times, 95% of the calculated intervals would contain the true parameter value\n D) The parameter has a 95% probability of being within this interval\n\nTalking Points:\n- The correct answer is C. A 95% compatibility interval means that if we were to repeat the experiment many times, we would expect 95% of the calculated intervals to contain the true parameter value. It does not mean that 95% of the data falls within this interval (A is incorrect), nor does it imply a probability statement about the parameter itself (B and D are incorrect), as frequentist intervals do not assign probabilities to fixed parameters.\n\nSlide 3: Misconceptions About P-Values\n- Multiple-Choice Question: Which of the following is a common misconception about p-values?\n A) A low p-value confirms the research hypothesis\n B) A high p-value proves the null hypothesis is true\n C) P-values can indicate the size of an effect\n D) All of the above\n\nTalking Points:\n- The correct answer is D. All of the options listed are common misconceptions about p-values. A low p-value does not confirm the research hypothesis (A is incorrect); it simply suggests that the data are unlikely under the null hypothesis. A high p-value does not prove the null hypothesis is true (B is incorrect); it suggests that the data are not sufficiently unusual under the null hypothesis. P-values do not provide information about the size or importance of an effect (C is incorrect); they only indicate how incompatible the data are with the null hypothesis.\n\nSlide 4: P-Value Functions\n- Multiple-Choice Question: What information does a p-value function provide that a single p-value does not?\n A) The probability that the null hypothesis is true for different parameter values\n B) The p-values for a range of different alternative hypotheses\n C) The likelihood of the data under various parameter values\n D) The confidence level of the compatibility interval\n\nTalking Points:\n- The correct answer is B. A p-value function provides the p-values for a range of different alternative hypotheses, not just a single point alternative. It does not give the probability that the null hypothesis is true for different parameter values (A is incorrect), nor does it directly provide the likelihood of the data (C is incorrect). The confidence level of the compatibility interval is a separate concept (D is incorrect).\n\nSlide 5: Interpreting P-Values and Compatibility Intervals Together\n- Multiple-Choice Question: How should p-values and compatibility intervals be used together to interpret statistical results?\n A) To determine the probability that the null hypothesis is true\n B) To assess the evidence against the null hypothesis and the range of parameter values that are compatible with the data\n C) To calculate the effect size and its significance\n D) To prove the research hypothesis\n\nTalking Points:\n- The correct answer is B. P-values and compatibility intervals should be used together to assess the evidence against the null hypothesis (provided by the p-value) and to understand the range of parameter values that are compatible with the data (provided by the compatibility interval). They do not determine the probability that the null hypothesis is true (A is incorrect), nor do they prove the research hypothesis (D is incorrect). While they can help contextualize the effect size, they do not calculate it (C is incorrect).", "role": "assistant" } ] + message = [ + { + "content": "Create an outline for a presentation that includes five engaging and informative slides to test the audience's understanding of the concepts discussed in your talk, such as p-value functions, compatibility intervals, and the misconceptions about p-values. Each slide should present a multiple-choice question related to these topics. Prepare talking points that explain the correct answers and why they are correct, ensuring the explanations are clear and suitable for an educational presentation.", + "role": "user", + }, + { + "content": "Title: Understanding P-Values and Statistical Inference\n\nSlide 1: Introduction to P-Values\n- Multiple-Choice Question: What does a p-value indicate in hypothesis testing?\n A) The probability that the null hypothesis is true\n B) The probability of observing the data, or something more extreme, if the null hypothesis is true\n C) The probability that the alternative hypothesis is true\n D) The likelihood of making a Type I error\n\nTalking Points:\n- The correct answer is B. The p-value is the probability of observing the data, or something more extreme, if the null hypothesis is true. It is not a direct statement about the truth of the null or alternative hypothesis (A and C are incorrect). While it is related to the risk of a Type I error, it is not the likelihood of making this error (D is incorrect), but rather a measure of evidence against the null hypothesis given the data.\n\nSlide 2: Understanding Compatibility Intervals\n- Multiple-Choice Question: What does a 95% compatibility interval (CI) for a parameter estimate indicate?\n A) 95% of the data falls within this interval\n B) There is a 95% chance that the interval contains the true parameter value\n C) If we were to repeat the experiment many times, 95% of the calculated intervals would contain the true parameter value\n D) The parameter has a 95% probability of being within this interval\n\nTalking Points:\n- The correct answer is C. A 95% compatibility interval means that if we were to repeat the experiment many times, we would expect 95% of the calculated intervals to contain the true parameter value. It does not mean that 95% of the data falls within this interval (A is incorrect), nor does it imply a probability statement about the parameter itself (B and D are incorrect), as frequentist intervals do not assign probabilities to fixed parameters.\n\nSlide 3: Misconceptions About P-Values\n- Multiple-Choice Question: Which of the following is a common misconception about p-values?\n A) A low p-value confirms the research hypothesis\n B) A high p-value proves the null hypothesis is true\n C) P-values can indicate the size of an effect\n D) All of the above\n\nTalking Points:\n- The correct answer is D. All of the options listed are common misconceptions about p-values. A low p-value does not confirm the research hypothesis (A is incorrect); it simply suggests that the data are unlikely under the null hypothesis. A high p-value does not prove the null hypothesis is true (B is incorrect); it suggests that the data are not sufficiently unusual under the null hypothesis. P-values do not provide information about the size or importance of an effect (C is incorrect); they only indicate how incompatible the data are with the null hypothesis.\n\nSlide 4: P-Value Functions\n- Multiple-Choice Question: What information does a p-value function provide that a single p-value does not?\n A) The probability that the null hypothesis is true for different parameter values\n B) The p-values for a range of different alternative hypotheses\n C) The likelihood of the data under various parameter values\n D) The confidence level of the compatibility interval\n\nTalking Points:\n- The correct answer is B. A p-value function provides the p-values for a range of different alternative hypotheses, not just a single point alternative. It does not give the probability that the null hypothesis is true for different parameter values (A is incorrect), nor does it directly provide the likelihood of the data (C is incorrect). The confidence level of the compatibility interval is a separate concept (D is incorrect).\n\nSlide 5: Interpreting P-Values and Compatibility Intervals Together\n- Multiple-Choice Question: How should p-values and compatibility intervals be used together to interpret statistical results?\n A) To determine the probability that the null hypothesis is true\n B) To assess the evidence against the null hypothesis and the range of parameter values that are compatible with the data\n C) To calculate the effect size and its significance\n D) To prove the research hypothesis\n\nTalking Points:\n- The correct answer is B. P-values and compatibility intervals should be used together to assess the evidence against the null hypothesis (provided by the p-value) and to understand the range of parameter values that are compatible with the data (provided by the compatibility interval). They do not determine the probability that the null hypothesis is true (A is incorrect), nor do they prove the research hypothesis (D is incorrect). While they can help contextualize the effect size, they do not calculate it (C is incorrect).", + "role": "assistant", + }, + ] output = tokenizer.apply_chat_template(message, tokenize=False) # dataset = DatasetLibrary.pull_dataset_with_retry( @@ -80,7 +91,6 @@ def collate_fn(self): # splited_dataset = dataset.map(get_source_target,remove_columns=["messages","split"]) # breakpoint() - datamodule = OrcaDataModule(config) train_dataloader = datamodule.train_dataloader() val_dataloder = datamodule.val_dataloader()