Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: CI

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
lint-and-format:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "latest"

- name: Set up Python 3.12
run: uv python install 3.12

- name: Check code formatting with Ruff
run: |
uvx ruff format --check --diff .

- name: Lint with Ruff
run: |
uvx ruff check .

- name: Type check with mypy
run: |
uvx mypy . --ignore-missing-imports || true
continue-on-error: true

yaml-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v3
with:
version: "latest"

- name: Set up Python 3.12
run: uv python install 3.12

- name: Lint YAML files
run: |
uvx yamllint . || true
continue-on-error: true

27 changes: 27 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- id: check-merge-conflict
- id: check-toml
- id: check-json
- id: debug-statements
- id: mixed-line-ending

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.13
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
additional_dependencies: [types-all]
args: [--ignore-missing-imports]
42 changes: 23 additions & 19 deletions oellm/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Standard library imports
import yaml
import logging
import os
import re
Expand All @@ -12,6 +11,7 @@
from typing import Iterable

import pandas as pd
import yaml
from huggingface_hub import hf_hub_download, snapshot_download
from jsonargparse import auto_cli
from rich.console import Console
Expand All @@ -22,23 +22,23 @@ def ensure_singularity_image(image_name: str) -> None:
# TODO: switch to OELLM dataset repo once it is created
hf_repo = os.environ.get("HF_SIF_REPO", "timurcarstensen/testing")
image_path = Path(os.getenv("EVAL_BASE_DIR")) / image_name

try:
hf_hub_download(
repo_id=hf_repo,
filename=image_name,
repo_type="dataset",
local_dir=os.getenv("EVAL_BASE_DIR"),
)
logging.info("Successfully downloaded latest Singularity image from HuggingFace")
logging.info(
"Successfully downloaded latest Singularity image from HuggingFace"
)
except Exception as e:
logging.warning(
"Failed to fetch latest container image from HuggingFace: %s", str(e)
)
if image_path.exists():
logging.info(
"Using existing Singularity image at %s", image_path
)
logging.info("Using existing Singularity image at %s", image_path)
else:
raise RuntimeError(
f"No container image found at {image_path} and failed to download from HuggingFace. "
Expand Down Expand Up @@ -79,7 +79,7 @@ def _load_cluster_env() -> None:
"""
Loads the correct cluster environment variables from `clusters.yaml` based on the hostname.
"""
with open(Path(__file__).parent / "clusters.yaml", "r") as f:
with open(Path(__file__).parent / "clusters.yaml") as f:
clusters = yaml.safe_load(f)
hostname = socket.gethostname()

Expand Down Expand Up @@ -115,7 +115,7 @@ def _load_cluster_env() -> None:
if len(e.args) > 1:
raise ValueError(
f"Env. variable substitution for {k} failed. Missing keys: {', '.join(e.args)}"
)
) from e

missing_key: str = e.args[0]
os.environ[k] = str(v).format(
Expand Down Expand Up @@ -181,10 +181,9 @@ def _process_model_paths(models: Iterable[str]) -> dict[str, list[Path | str]]:
)

if "," in model:
model_kwargs = {
k: v
for k, v in [kv.split("=") for kv in model.split(",") if "=" in kv]
}
model_kwargs = dict(
[kv.split("=") for kv in model.split(",") if "=" in kv]
)

# The first element before the comma is the repository ID on the 🤗 Hub
repo_id = model.split(",")[0]
Expand Down Expand Up @@ -359,16 +358,20 @@ def schedule_evals(
expanded_rows.append(new_row)
df = pd.DataFrame(expanded_rows)
else:
logging.info("Skipping model path processing and validation (--skip-checks enabled)")
logging.info(
"Skipping model path processing and validation (--skip-checks enabled)"
)

elif models and tasks and n_shot is not None:
if not skip_checks:
model_path_map = _process_model_paths(models.split(","))
model_paths = [p for paths in model_path_map.values() for p in paths]
else:
logging.info("Skipping model path processing and validation (--skip-checks enabled)")
logging.info(
"Skipping model path processing and validation (--skip-checks enabled)"
)
model_paths = models.split(",")

tasks_list = tasks.split(",")

# cross product of model_paths and tasks into a dataframe
Expand Down Expand Up @@ -425,7 +428,7 @@ def schedule_evals(

logging.debug(f"Saved evaluation dataframe to temporary CSV: {csv_path}")

with open(Path(__file__).parent / "template.sbatch", "r") as f:
with open(Path(__file__).parent / "template.sbatch") as f:
sbatch_template = f.read()

# replace the placeholders in the template with the actual values
Expand Down Expand Up @@ -468,17 +471,18 @@ def schedule_evals(
)
logging.info("Job submitted successfully.")
logging.info(result.stdout)

# Provide helpful information about job monitoring and file locations
logging.info(f"📁 Evaluation directory: {evals_dir}")
logging.info(f"📄 SLURM script: {sbatch_script_path}")
logging.info(f"📋 Job configuration: {csv_path}")
logging.info(f"📜 SLURM logs will be stored in: {slurm_logs_dir}")
logging.info(f"📊 Results will be stored in: {evals_dir / 'results'}")

# Extract job ID from sbatch output for monitoring commands
import re
job_id_match = re.search(r'Submitted batch job (\d+)', result.stdout)

job_id_match = re.search(r"Submitted batch job (\d+)", result.stdout)
if job_id_match:
job_id = job_id_match.group(1)
logging.info(f"🔍 Monitor job status: squeue -j {job_id}")
Expand Down
32 changes: 31 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,34 @@ torchvision = [
[[tool.uv.index]]
name = "pytorch-cpu"
url = "https://download.pytorch.org/whl/cpu"
explicit = true
explicit = true

[tool.ruff]
line-length = 88
target-version = "py38"

[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
]
ignore = [
"E501", # line too long
"B008", # do not perform function calls in argument defaults
"C901", # too complex
"W191", # indentation contains tabs
]

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]

[tool.ruff.format]
quote-style = "double"
indent-style = "space"
skip-magic-trailing-comma = false
line-ending = "auto"
Loading