Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def save_fsdp_optimizer(
)
sd_options = _prepare_sd_options(fsdp_plugin)
# get the state dicts for model and optimize
(model_state_dict, optimizer_state_dict) = get_state_dict(
model_state_dict, optimizer_state_dict = get_state_dict(
model, optimizer, options=sd_options
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def _maybe_scatter(

# expect these products to be produced by an earlier
# all-to-all gather call
(send_counts, recv_counts, bins, sorted_expert_idxs, sorted_scattered_idxs) = (
send_counts, recv_counts, bins, sorted_expert_idxs, sorted_scattered_idxs = (
gather_products
)

Expand Down Expand Up @@ -421,7 +421,7 @@ def forward(self, hidden_states: torch.Tensor):
# compute the routing logits, weights, and expert assigments
# - router_logits: will be passed out of forward, used for computing
# routing loss.
(router_logits, routing_weights, selected_experts) = (
router_logits, routing_weights, selected_experts = (
self._compute_routing_weights(hidden_states)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _maybe_reshape_scattermoe_expert_weights(
num_experts: int,
intermediate_size: int,
):
(_is_w1, _is_w2, _is_w3) = [
_is_w1, _is_w2, _is_w3 = [
f"{x}.weight" in scatter_key for x in PARAM_NAME_WEIGHT_SCATTERMOE
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,8 @@ def get_gptq_peft_model(
model.model, model_id, adapter_name
)
except Exception as exc:
raise NotImplementedError(
f"{model.__class__.__name__} not support \
{peft_config.peft_type.value} peft type yet."
) from exc
raise NotImplementedError(f"{model.__class__.__name__} not support \
{peft_config.peft_type.value} peft type yet.") from exc

return peft_model

Expand Down
4 changes: 1 addition & 3 deletions plugins/accelerated-peft/tests/test_gptqmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,5 @@ def test_quantizing_pretrained_model_outputs_match(
target = torch.nn.functional.softmax(original_logits, dim=-1)
target = target.view(BS * SEQLEN, -1)
error = loss_fn(input, target)
assert error.lt(
LOSS_TOLERANCE
), "Model logits don't match between both libraries \
assert error.lt(LOSS_TOLERANCE), "Model logits don't match between both libraries \
after quantization"
1 change: 1 addition & 0 deletions plugins/framework/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"peft>=0.15.0",
"accelerate @ git+https://github.com/huggingface/accelerate.git@5998f8625b8dfde9253c241233ff13bc2c18635d",
"pandas",
"transformers>=4.55.0,<=4.55.4",
]

[tool.hatch.build.targets.wheel]
Expand Down
1 change: 1 addition & 0 deletions plugins/online-data-mixing/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"datasets==4.*",
"torchdata==0.11.0",
"sentence-transformers==5.*",
"transformers>=4.55.0,<=4.55.4",
]

[project.optional-dependencies]
Expand Down
24 changes: 13 additions & 11 deletions plugins/online-data-mixing/src/fms_acceleration_odm/odm/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(
self.id2cat = dict(enumerate(self.category_list))
self.cat2id = {c: i for i, c in enumerate(self.category_list)}
self.total_categories = len(self.category_list)
self.rank = os.environ.get("RANK", "0")

# If not starting weights given, then all arms (categories)
# are equally important. Weights based on the size of the datasets
Expand Down Expand Up @@ -174,7 +175,7 @@ def __init__(
self.output_dir = output_dir
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
self.log_file_path = os.path.join(self.output_dir, "odm.jsonl")
self.log_file_path = os.path.join(self.output_dir, f"odm_rank_{self.rank}.jsonl")
logger.info(
"Logs for online data mixing to be stored at {log_file_path}".format(
log_file_path=self.log_file_path
Expand All @@ -191,6 +192,7 @@ def __init__(
"rewards": [0] * self.total_categories,
"count": 0,
"action": "", # one of sample or update
"rank": self.rank,
}

# Local RNG so every process can deterministically sample identical streams.
Expand Down Expand Up @@ -274,6 +276,7 @@ def __next__(self):
"action": "sample",
}
)

return sample

def load_state_dict(self, state_dict):
Expand Down Expand Up @@ -548,13 +551,12 @@ def update_sampling_weights(self, model, accelerator, state):
count = accelerator.reduce(count, reduction="sum")

self._update_weights(count, rewards)
if accelerator and accelerator.is_main_process:
self.log_to_file(
{
"current_sampling_weights": self.sampling_weights.tolist(),
"current_sampling_ratio": self.sampling_ratio,
"rewards": rewards.tolist(),
"count": count.tolist(),
"action": "update",
}
)
self.log_to_file(
{
"current_sampling_weights": self.sampling_weights.tolist(),
"current_sampling_ratio": self.sampling_ratio,
"rewards": rewards.tolist(),
"count": count.tolist(),
"action": "update",
}
)