From 07161da160e39164ec35b435f42cd4cae41d665c Mon Sep 17 00:00:00 2001 From: romit Date: Fri, 30 Jan 2026 05:33:37 +0000 Subject: [PATCH 1/4] Added per rank log file for ODM Signed-off-by: romit --- .../src/fms_acceleration_odm/odm/dataset.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py index c4d12177..6257c964 100644 --- a/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py +++ b/plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py @@ -142,6 +142,7 @@ def __init__( self.id2cat = dict(enumerate(self.category_list)) self.cat2id = {c: i for i, c in enumerate(self.category_list)} self.total_categories = len(self.category_list) + self.rank = os.environ.get("RANK", "0") # If not starting weights given, then all arms (categories) # are equally important. Weights based on the size of the datasets @@ -174,7 +175,7 @@ def __init__( self.output_dir = output_dir if not os.path.exists(self.output_dir): os.makedirs(self.output_dir) - self.log_file_path = os.path.join(self.output_dir, "odm.jsonl") + self.log_file_path = os.path.join(self.output_dir, f"odm_rank_{self.rank}.jsonl") logger.info( "Logs for online data mixing to be stored at {log_file_path}".format( log_file_path=self.log_file_path @@ -191,6 +192,7 @@ def __init__( "rewards": [0] * self.total_categories, "count": 0, "action": "", # one of sample or update + "rank": self.rank, } # Local RNG so every process can deterministically sample identical streams. @@ -274,6 +276,7 @@ def __next__(self): "action": "sample", } ) + return sample def load_state_dict(self, state_dict): @@ -548,13 +551,12 @@ def update_sampling_weights(self, model, accelerator, state): count = accelerator.reduce(count, reduction="sum") self._update_weights(count, rewards) - if accelerator and accelerator.is_main_process: - self.log_to_file( - { - "current_sampling_weights": self.sampling_weights.tolist(), - "current_sampling_ratio": self.sampling_ratio, - "rewards": rewards.tolist(), - "count": count.tolist(), - "action": "update", - } - ) + self.log_to_file( + { + "current_sampling_weights": self.sampling_weights.tolist(), + "current_sampling_ratio": self.sampling_ratio, + "rewards": rewards.tolist(), + "count": count.tolist(), + "action": "update", + } + ) From 3065efa710d009ff0b11d6cb7507ce1c4627e914 Mon Sep 17 00:00:00 2001 From: romit Date: Fri, 30 Jan 2026 06:42:02 +0000 Subject: [PATCH 2/4] Pinned transformers version Signed-off-by: romit --- plugins/online-data-mixing/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/online-data-mixing/pyproject.toml b/plugins/online-data-mixing/pyproject.toml index ba60ca7c..43b30250 100644 --- a/plugins/online-data-mixing/pyproject.toml +++ b/plugins/online-data-mixing/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "datasets==4.*", "torchdata==0.11.0", "sentence-transformers==5.*", + "transformers>=4.55.0,<=4.55.4", ] [project.optional-dependencies] From 334a0b7f1b7175bf7e4e3188cedaa3980f7120c2 Mon Sep 17 00:00:00 2001 From: romit Date: Fri, 30 Jan 2026 07:00:26 +0000 Subject: [PATCH 3/4] Pinned transformers in framework package --- plugins/framework/pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/plugins/framework/pyproject.toml b/plugins/framework/pyproject.toml index 0c43df11..8a4d681a 100644 --- a/plugins/framework/pyproject.toml +++ b/plugins/framework/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "peft>=0.15.0", "accelerate @ git+https://github.com/huggingface/accelerate.git@5998f8625b8dfde9253c241233ff13bc2c18635d", "pandas", + "transformers>=4.55.0,<=4.55.4", ] [tool.hatch.build.targets.wheel] From 604fc555d239d84dfcafa136e651b35ee7b12a32 Mon Sep 17 00:00:00 2001 From: romit Date: Fri, 30 Jan 2026 08:30:58 +0000 Subject: [PATCH 4/4] Fixed CI/CD for other packages Signed-off-by: romit --- .../src/fms_acceleration_moe/utils/checkpoint_utils.py | 2 +- .../src/fms_acceleration_moe/utils/scattermoe.py | 4 ++-- .../src/fms_acceleration_moe/utils/scattermoe_state_dict.py | 2 +- .../src/fms_acceleration_peft/gptqmodel/utils/peft.py | 6 ++---- plugins/accelerated-peft/tests/test_gptqmodel.py | 4 +--- 5 files changed, 7 insertions(+), 11 deletions(-) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py index 03986806..bcef70ae 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/checkpoint_utils.py @@ -113,7 +113,7 @@ def save_fsdp_optimizer( ) sd_options = _prepare_sd_options(fsdp_plugin) # get the state dicts for model and optimize - (model_state_dict, optimizer_state_dict) = get_state_dict( + model_state_dict, optimizer_state_dict = get_state_dict( model, optimizer, options=sd_options ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py index 44125acd..b1ecc4a2 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe.py @@ -389,7 +389,7 @@ def _maybe_scatter( # expect these products to be produced by an earlier # all-to-all gather call - (send_counts, recv_counts, bins, sorted_expert_idxs, sorted_scattered_idxs) = ( + send_counts, recv_counts, bins, sorted_expert_idxs, sorted_scattered_idxs = ( gather_products ) @@ -421,7 +421,7 @@ def forward(self, hidden_states: torch.Tensor): # compute the routing logits, weights, and expert assigments # - router_logits: will be passed out of forward, used for computing # routing loss. - (router_logits, routing_weights, selected_experts) = ( + router_logits, routing_weights, selected_experts = ( self._compute_routing_weights(hidden_states) ) diff --git a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py index e13f6ba5..c3f0c432 100644 --- a/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py +++ b/plugins/accelerated-moe/src/fms_acceleration_moe/utils/scattermoe_state_dict.py @@ -188,7 +188,7 @@ def _maybe_reshape_scattermoe_expert_weights( num_experts: int, intermediate_size: int, ): - (_is_w1, _is_w2, _is_w3) = [ + _is_w1, _is_w2, _is_w3 = [ f"{x}.weight" in scatter_key for x in PARAM_NAME_WEIGHT_SCATTERMOE ] diff --git a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/peft.py b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/peft.py index c73a1d8d..a9df0e0c 100644 --- a/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/peft.py +++ b/plugins/accelerated-peft/src/fms_acceleration_peft/gptqmodel/utils/peft.py @@ -163,10 +163,8 @@ def get_gptq_peft_model( model.model, model_id, adapter_name ) except Exception as exc: - raise NotImplementedError( - f"{model.__class__.__name__} not support \ - {peft_config.peft_type.value} peft type yet." - ) from exc + raise NotImplementedError(f"{model.__class__.__name__} not support \ + {peft_config.peft_type.value} peft type yet.") from exc return peft_model diff --git a/plugins/accelerated-peft/tests/test_gptqmodel.py b/plugins/accelerated-peft/tests/test_gptqmodel.py index e20db946..a8f3b98e 100644 --- a/plugins/accelerated-peft/tests/test_gptqmodel.py +++ b/plugins/accelerated-peft/tests/test_gptqmodel.py @@ -297,7 +297,5 @@ def test_quantizing_pretrained_model_outputs_match( target = torch.nn.functional.softmax(original_logits, dim=-1) target = target.view(BS * SEQLEN, -1) error = loss_fn(input, target) - assert error.lt( - LOSS_TOLERANCE - ), "Model logits don't match between both libraries \ + assert error.lt(LOSS_TOLERANCE), "Model logits don't match between both libraries \ after quantization"