Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions .github/workflows/pr-title.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: PR Title Convention

on:
pull_request:
types: [opened, edited, synchronize, reopened]
branches: [main]

permissions:
pull-requests: read

jobs:
check-title:
name: Validate PR title
runs-on: ubuntu-22.04
timeout-minutes: 1
steps:
- name: Check conventional commit format
env:
PR_TITLE: ${{ github.event.pull_request.title }}
run: |
# Allowed conventional commit types
TYPES="feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert"

# Pattern: type(optional-scope): description
# OR: type!: description (breaking change)
PATTERN="^($TYPES)(\(.+\))?\!?: .+"

if echo "$PR_TITLE" | grep -qP "$PATTERN"; then
echo "PR title is valid: $PR_TITLE"
else
echo "::error::PR title does not follow Conventional Commits."
echo ""
echo "Got: $PR_TITLE"
echo ""
echo "Expected: <type>[optional scope]: <description>"
echo ""
echo "Allowed types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert"
echo "Read more: https://www.conventionalcommits.org/en/v1.0.0/"
echo ""
echo "Examples:"
echo " feat: add new optimization algorithm"
echo " fix: resolve memory leak in model loading"
echo " ci(pruna): pin transformers version"
echo ""
exit 1
fi
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ repos:
hooks:
- id: ty
name: type checking using ty
entry: uvx ty check .
entry: uvx ty check src/pruna
language: system
types: [python]
pass_filenames: false
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ invalid-return-type = "ignore" # mypy is more permissive with return types
invalid-parameter-default = "ignore" # mypy is more permissive with parameter defaults
no-matching-overload = "ignore" # mypy is more permissive with overloads
unresolved-reference = "ignore" # mypy is more permissive with references
possibly-unbound-import = "ignore"
possibly-missing-import = "ignore"
possibly-missing-attribute = "ignore"
missing-argument = "ignore"
unused-type-ignore-comment = "ignore"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I fully agree that we should ignore the comments for now, and I'll go through the code in a future PR to remove those ignore statements one by one because this frankly isn't checking anything...


[tool.coverage.run]
source = ["src/pruna"]
Expand Down Expand Up @@ -181,7 +183,7 @@ dev = [
"pytest-rerunfailures",
"coverage",
"docutils",
"ty==0.0.1a21",
"ty==0.0.18",
"types-PyYAML",
"logbar",
"pytest-xdist>=3.8.0",
Expand Down
4 changes: 2 additions & 2 deletions src/pruna/algorithms/c_translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def __call__(
x_tensor = x["input_ids"]
else:
x_tensor = x
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))]
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] # type: ignore[not-subscriptable]
return self.generator.generate_batch(token_list, min_length=min_length, max_length=max_length, *args, **kwargs) # type: ignore[operator]


Expand Down Expand Up @@ -468,7 +468,7 @@ def __call__(
x_tensor = x["input_ids"]
else:
x_tensor = x
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))]
token_list = [self.tokenizer.convert_ids_to_tokens(x_tensor[i]) for i in range(len(x_tensor))] # type: ignore[not-subscriptable]
return self.translator.translate_batch( # type: ignore[operator]
token_list,
min_decoding_length=min_decoding_length,
Expand Down
9 changes: 3 additions & 6 deletions src/pruna/algorithms/sage_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
from pruna.algorithms.base.tags import AlgorithmTag as tags
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.target_modules import TARGET_MODULES_TYPE, TargetModules, map_targeted_nn_roots
from pruna.config.target_modules import TargetModules, map_targeted_nn_roots
from pruna.engine.save import SAVE_FUNCTIONS


Expand Down Expand Up @@ -91,10 +91,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
target_modules = smash_config["target_modules"]

if target_modules is None:
target_modules = self.get_model_dependent_hyperparameter_defaults(
model,
smash_config
)
target_modules = self.get_model_dependent_hyperparameter_defaults(model, smash_config)

def apply_sage_attn(
root_name: str | None,
Expand Down Expand Up @@ -154,7 +151,7 @@ def get_model_dependent_hyperparameter_defaults(
self,
model: Any,
smash_config: SmashConfigPrefixWrapper,
) -> TARGET_MODULES_TYPE:
) -> dict[str, Any]:
"""
Provide default `target_modules` targeting all transformer modules.

Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/torch_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
else:
modules_to_quantize = {torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.Linear}

quantized_model = torch.quantization.quantize_dynamic(
quantized_model = torch.quantization.quantize_dynamic( # type: ignore[deprecated]
model,
modules_to_quantize,
dtype=getattr(torch, smash_config["weight_bits"]),
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/config/smash_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def load_from_json(self, path: str | Path) -> None:
setattr(self, name, config_dict.pop(name))

# Keep only values that still exist in the space, drop stale keys
supported_hparam_names = {hp.name for hp in SMASH_SPACE.get_hyperparameters()}
supported_hparam_names = {hp.name for hp in list(SMASH_SPACE.values())}
saved_values = {k: v for k, v in config_dict.items() if k in supported_hparam_names}

# Seed with the defaults, then overlay the saved values
Expand Down