Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions examples/puzzletron/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ In this example, we compress the [Llama-3.1-8B-Instruct](https://huggingface.co/

```bash
pip install -e .[hf,puzzletron]
pip install -r requirements.txt
```

- For this example we are using 2x NVIDIA H100 80GB HBM3 to show multi-GPU steps. You can use also use s single GPU.
Expand Down Expand Up @@ -231,6 +232,24 @@ vllm bench latency --model path/to/model --load-format safetensors --trust-remot
vllm bench throughput --model path/to/model --input-len 2000 --output-len 100 --load-format safetensors --trust-remote-code
```

## Knowledge Distillation

To recover degradation in the quality of the compressed model, we can use knowledge distillation. This allows transferring the capabilities of the original model to the pruned one. For this, we will use [NeMo framework](https://github.com/NVIDIA-NeMo/NeMo) with the [nemo:25.07](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo?version=25.07) container.

First, convert the HF model to NeMo format:

```bash
python -m nemo_export/convert_hf_to_nemo --input-ckpt-path path/to/HF-model --output-ckpt-path path/to/save/model-nemo
```

Now you can utilize all the training features available in NeMo, including distillation. Please refer to the [NeMo distillation documentation](https://docs.nvidia.com/nemo-framework/user-guide/latest/model-optimization/distillation/distillation.html).

[Optional] Once distillation is complete, you can convert the distilled model back to the HuggingFace format.

```bash
python -m nemo_export/convert_nemo_to_hf --input-ckpt-path path/to/nemo-model --output-ckpt-path path/to/save/model-HF
```

## Advanced Usage

Modify `llama-3_1-8B_pruneffn_memory.yaml` file for advanced compression scenarios.
98 changes: 98 additions & 0 deletions examples/puzzletron/nemo_export/convert_hf_to_nemo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would move it to modelopt/torch/puzzletron/...., the same for convert_nemo_to_hf

examples should not keep the logic - should be just examples,

similarly, example/puzzletron/main.py should go to modelopt/torch/puzzletron/...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are cmdline scripts which we should keep in examples folder just like all other modelopt examples. ModelOpt installation would be somewhere in /usr/local/... and we dont want users to run script from there

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And the more things we add to modelopt library will require more dependencies for modelopt. Keeping example scripts separate means we can keep modelopt dependencies leaner and move extra dependencies to examples/<example_name>/requirements.txt

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer we keep it in examples, the scripts rely on the nemo dependency too

# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from pathlib import Path
from typing import Any

from nemo.collections import llm

from modelopt.torch.puzzletron.export.MCore.llama_nemotron import (
PuzzletronLlamaNemotronModel,
PuzzletronNemotronModelConfig,
)


def convert_model(
hf_model_path_local: str, output_path_nemo_local: str, overwrite: bool = False
) -> Any:
"""Convert a Puzzletron HuggingFace model to NeMo format.

Args:
hf_model_path_local: Path to the input Puzzletron HuggingFace model directory
output_path_nemo_local: Path where the converted Puzzletron NeMo model will be saved
overwrite: Whether to overwrite existing output directory
"""

model = PuzzletronLlamaNemotronModel(config=PuzzletronNemotronModelConfig)
# NOTE: API call to import_ckpt is here: https://github.com/NVIDIA-NeMo/NeMo/blob/294ddff187f68c055d87ffe9400e65975b38693d/nemo/collections/llm/api.py#L888
print(
f"calling import_ckpt with model: {model}, "
f"source: {hf_model_path_local}, "
f"output_path: {output_path_nemo_local}, "
f"overwrite: {overwrite}"
)
nemo2_path = llm.import_ckpt(
model=model,
source="hf://" + hf_model_path_local,
output_path=Path(output_path_nemo_local),
overwrite=overwrite,
)

print(f"Model saved to {nemo2_path}")
return nemo2_path


def main() -> None:
parser = argparse.ArgumentParser(
description="Convert Puzzletron HuggingFace model to NeMo format"
)
parser.add_argument(
"--input-ckpt-path",
"-i",
type=str,
required=True,
help="Path to the input Puzzletron HuggingFace model directory",
)
parser.add_argument(
"--output-ckpt-path",
"-o",
type=str,
required=True,
help="Path where the converted Puzzletron NeMo model will be saved",
)
parser.add_argument(
"--overwrite",
action="store_true",
default=False,
help="Whether to overwrite existing output directory (default: False)",
)

args = parser.parse_args()

# Validate input path
if not os.path.exists(args.input_ckpt_path):
raise FileNotFoundError(f"Input model path does not exist: {args.input_ckpt_path}")

# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(args.output_ckpt_path), exist_ok=True)

print(f"Converting model from {args.input_ckpt_path} to {args.output_ckpt_path}")
convert_model(args.input_ckpt_path, args.output_ckpt_path, args.overwrite)


if __name__ == "__main__":
main()
96 changes: 96 additions & 0 deletions examples/puzzletron/nemo_export/convert_nemo_to_hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
from pathlib import Path
from typing import Any

from nemo.collections import llm

from modelopt.torch.puzzletron.tools.checkpoint_utils_hf import copy_deci_lm_hf_code


def convert_model(
nemo_model_path_local: str, output_path_hf_local: str, overwrite: bool = False
) -> Any:
"""Convert a NeMo model to HuggingFace format.

Args:
nemo_model_path_local: Path to the input NeMo model file (.nemo)
output_path_hf_local: Path where the converted HuggingFace model will be saved
overwrite: Whether to overwrite existing output directory
"""

# NOTE: API call to export_ckpt is here: https://github.com/NVIDIA-NeMo/NeMo/blob/main/nemo/collections/llm/api.py#L987
print(
f"calling export_ckpt with path: {nemo_model_path_local}, "
f"target: hf, output_path: {output_path_hf_local}, "
f"target_model_name: PuzzletronLlamaNemotronModel, "
f"overwrite: {overwrite}"
)

hf_path = llm.export_ckpt(
path=nemo_model_path_local,
target="hf",
output_path=Path(output_path_hf_local),
target_model_name="PuzzletronLlamaNemotronModel",
overwrite=overwrite,
)

copy_deci_lm_hf_code(hf_path)

print(f"Model saved to {hf_path}")
return hf_path


def main() -> None:
parser = argparse.ArgumentParser(description="Convert NeMo model to HuggingFace format")
parser.add_argument(
"--input-ckpt-path",
"-i",
type=str,
required=True,
help="Path to the input NeMo model checkpoint",
)
parser.add_argument(
"--output-ckpt-path",
"-o",
type=str,
required=True,
help="Path where the converted Puzzletron HuggingFace model will be saved",
)
parser.add_argument(
"--overwrite",
action="store_true",
default=False,
help="Whether to overwrite existing output directory (default: False)",
)

args = parser.parse_args()

# Validate input path
if not os.path.exists(args.input_ckpt_path):
raise FileNotFoundError(f"Input model path does not exist: {args.input_ckpt_path}")

# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(args.output_ckpt_path), exist_ok=True)
Comment on lines +88 to +89
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Edge case when output path has no directory component.

Same issue as in convert_hf_to_nemo.py — if output_ckpt_path has no directory component, os.path.dirname() returns an empty string.

Proposed fix
     # Create output directory if it doesn't exist
-    os.makedirs(os.path.dirname(args.output_ckpt_path), exist_ok=True)
+    output_dir = os.path.dirname(args.output_ckpt_path)
+    if output_dir:
+        os.makedirs(output_dir, exist_ok=True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Create output directory if it doesn't exist
os.makedirs(os.path.dirname(args.output_ckpt_path), exist_ok=True)
# Create output directory if it doesn't exist
output_dir = os.path.dirname(args.output_ckpt_path)
if output_dir:
os.makedirs(output_dir, exist_ok=True)
🤖 Prompt for AI Agents
In `@examples/puzzletron/nemo_export/convert_nemo_to_hf.py` around lines 79 - 80,
The os.makedirs call uses os.path.dirname(args.output_ckpt_path) which can be an
empty string when the path has no directory component; guard against that by
computing a directory variable (e.g., out_dir =
os.path.dirname(args.output_ckpt_path) or out_dir = os.path.dirname(...) or '.'
if that result is empty) and call os.makedirs(out_dir, exist_ok=True) so you
never pass an empty string to os.makedirs.


print(f"Converting model from {args.input_ckpt_path} to {args.output_ckpt_path}")
convert_model(args.input_ckpt_path, args.output_ckpt_path, args.overwrite)


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions examples/puzzletron/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
lm-eval==0.4.9
Loading