Skip to content

Commit 753af76

Browse files
committed
refactor: some cleanup, tests and deduplication
1 parent e0c4ceb commit 753af76

11 files changed

Lines changed: 238 additions & 107 deletions

File tree

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,4 @@ __pycache__/
1212
*.pyc
1313
.pytest_cache/
1414
.ruff_cache/
15+
/data

app/backends/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,13 +270,12 @@ async def check_connectivity(self) -> bool:
270270
"""Check whether the backend is reachable. Returns True if healthy."""
271271
...
272272

273-
def resolve_job_logs(
273+
def resolve_job_logs( # noqa: B027
274274
self,
275275
jobs: list[SnkmtJobResponse],
276276
workflow_files: list[WorkflowFileInfo] | None,
277277
) -> None:
278278
"""Set job.log for each job based on backend-specific log paths."""
279-
return
280279

281280
@abstractmethod
282281
async def cleanup(

app/backends/slurm_ssh.py

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,29 @@
4848
MONITOR_DEAD_SENTINEL = "DEAD"
4949

5050

51+
def _build_status_check_cmd(work_dir_quoted: str) -> str:
52+
"""Build shell command to check job exit status via .exitcode / PID probe."""
53+
return (
54+
f"test -f {work_dir_quoted}/.exitcode && cat {work_dir_quoted}/.exitcode "
55+
f"|| (kill -0 $(cat {work_dir_quoted}/.pid 2>/dev/null) 2>/dev/null "
56+
f"&& echo {MONITOR_RUNNING_SENTINEL} "
57+
f"|| echo {MONITOR_DEAD_SENTINEL})"
58+
)
59+
60+
61+
def _parse_job_status(status: str) -> int | None:
62+
"""Parse status check output → exit code, -1 for dead, or None for running."""
63+
if status == MONITOR_RUNNING_SENTINEL:
64+
return None
65+
if status == MONITOR_DEAD_SENTINEL:
66+
return -1
67+
try:
68+
return int(status)
69+
except ValueError:
70+
logger.warning("Unparsable job status value: %r, treating as running", status)
71+
return None
72+
73+
5174
def _build_rsync_filter(cache_dirs: list[str]) -> str:
5275
"""Build rsync --include/--exclude args from cache_dirs patterns.
5376
@@ -84,8 +107,8 @@ def _add(arg: str) -> None:
84107
class SlurmSSHBackend(ComputeBackend):
85108
"""
86109
Compute backend that connects to an HPC head node via SSH,
87-
clones a workflow, runs Snakemake via pixi in a detached process,
88-
and monitors via polling.
110+
fetches a workflow into a bare repo and creates a worktree,
111+
runs Snakemake via pixi in a detached process, and monitors via polling.
89112
"""
90113

91114
def __init__(self, config: SlurmSSHConfig) -> None:
@@ -314,16 +337,11 @@ async def monitor(
314337

315338
while True:
316339
try:
317-
# Read new log bytes, print a marker, then print the exit
318-
# code (or "RUNNING" if still alive)
319340
wd = shlex.quote(work_dir)
320341
cmd = (
321342
f"tail -c +{offset + 1} {wd}/.stdout.log 2>/dev/null; "
322343
f"echo '{MONITOR_LOG_MARKER}'; "
323-
f"test -f {wd}/.exitcode && cat {wd}/.exitcode "
324-
f"|| (kill -0 $(cat {wd}/.pid 2>/dev/null) 2>/dev/null "
325-
f"&& echo {MONITOR_RUNNING_SENTINEL} "
326-
f"|| echo {MONITOR_DEAD_SENTINEL})"
344+
f"{_build_status_check_cmd(wd)}"
327345
)
328346
result = await self._run_ssh(cmd, check=False)
329347
stdout = result.stdout or ""
@@ -336,32 +354,20 @@ async def monitor(
336354
consecutive_errors = 0
337355

338356
if new_log_data:
339-
# Advance offset so next poll only reads new bytes
340357
offset += len(new_log_data.encode("utf-8", errors="replace"))
341358
for line in new_log_data.splitlines():
342359
log_callback(line)
343360

344-
# "RUNNING" = still going
345-
# "DEAD" = process killed without writing .exitcode
346-
# number = exit code
347-
if status_part == MONITOR_DEAD_SENTINEL:
361+
exit_code = _parse_job_status(status_part)
362+
if exit_code == -1:
348363
logger.warning(
349364
"Job %s: wrapper process died without writing "
350365
".exitcode (likely OOM or SIGKILL)",
351366
job_id,
352367
)
353368
return -1
354-
if status_part != MONITOR_RUNNING_SENTINEL:
355-
try:
356-
return int(status_part)
357-
except ValueError:
358-
logger.warning(
359-
"Unexpected status value: %r, treating as running",
360-
status_part,
361-
)
362-
# Sleep for poll_interval before retrying
363-
await asyncio.sleep(self._config.poll_interval)
364-
continue
369+
if exit_code is not None:
370+
return exit_code
365371

366372
except (TimeoutError, OSError, asyncssh.Error) as exc:
367373
consecutive_errors += 1
@@ -374,22 +380,8 @@ async def monitor(
374380
async def check_job_status(self, job_id: str, work_dir: str) -> int | None:
375381
"""Check if a job process has finished without blocking."""
376382
wd = shlex.quote(work_dir)
377-
cmd = (
378-
f"test -f {wd}/.exitcode && cat {wd}/.exitcode "
379-
f"|| (kill -0 $(cat {wd}/.pid 2>/dev/null) 2>/dev/null "
380-
f"&& echo {MONITOR_RUNNING_SENTINEL} "
381-
f"|| echo {MONITOR_DEAD_SENTINEL})"
382-
)
383-
result = await self._run_ssh(cmd, check=False)
384-
status = (result.stdout or "").strip()
385-
if status == MONITOR_RUNNING_SENTINEL:
386-
return None
387-
if status == MONITOR_DEAD_SENTINEL:
388-
return -1
389-
try:
390-
return int(status)
391-
except ValueError:
392-
return None
383+
result = await self._run_ssh(_build_status_check_cmd(wd), check=False)
384+
return _parse_job_status((result.stdout or "").strip())
393385

394386
async def check_connectivity(self) -> bool:
395387
"""Check SSH connectivity and scratch filesystem health."""

app/routes/jobs.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,22 @@ def _build_outputs_response(
7676
return JobOutputsResponse(files=files)
7777

7878

79-
_RESPONSE_FIELDS = tuple(JobResponse.model_fields.keys())
80-
81-
8279
def _build_job_response(record: JobRecord) -> JobResponse:
83-
"""Build a JobResponse from a JobRecord using _RESPONSE_FIELDS."""
84-
return JobResponse(**{f: getattr(record, f) for f in _RESPONSE_FIELDS})
80+
"""Build a JobResponse from a JobRecord with explicit field mapping."""
81+
return JobResponse(
82+
job_id=record.job_id,
83+
status=record.status,
84+
workflow=record.workflow,
85+
configfile=record.configfile,
86+
git_ref=record.git_ref,
87+
git_sha=record.git_sha,
88+
exit_code=record.exit_code,
89+
created_at=record.created_at,
90+
started_at=record.started_at,
91+
completed_at=record.completed_at,
92+
total_job_count=record.total_job_count,
93+
jobs_finished=record.jobs_finished,
94+
)
8595

8696

8797
@router.post("/jobs", response_model=JobResponse, status_code=201)

app/routes/snkmt.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import sqlite3
1212
from collections.abc import Callable
1313
from pathlib import Path as FilePath
14-
from typing import Annotated, Any, TypeVar
14+
from typing import Annotated, Any
1515

1616
from fastapi import APIRouter, Depends, HTTPException, Path
1717

@@ -33,7 +33,6 @@
3333

3434
router = APIRouter()
3535

36-
_T = TypeVar("_T")
3736

3837
# Job statuses where snkmt data is not yet available (before workflow execution starts).
3938
_PRE_EXECUTION_STATUSES = frozenset({JobStatus.PENDING, JobStatus.SETUP})
@@ -69,7 +68,7 @@ def _require_snkmt(
6968
return snkmt_db, record.work_dir, record.workflow_files
7069

7170

72-
async def _run_snkmt_query[T](job_id: str, query: Callable[[], _T]) -> _T:
71+
async def _run_snkmt_query[T](job_id: str, query: Callable[[], T]) -> T:
7372
"""Run query in a thread and map sqlite3.DatabaseError to HTTP 502."""
7473
try:
7574
return await asyncio.to_thread(query)

app/snkmt.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,14 @@ class ErrorRow(TypedDict):
8585
rule_name: str | None
8686

8787

88-
def _rename_end_time_column(row: dict[str, Any]) -> dict[str, Any]:
89-
"""Rename DB column end_time → completed_at to match our naming."""
90-
row = dict(row)
91-
row["completed_at"] = row.pop("end_time", None)
92-
return row
88+
def _rename_end_time_column(row: sqlite3.Row) -> dict[str, Any]:
89+
"""Rename DB column ``end_time`` → ``completed_at`` to match our TypedDicts.
90+
91+
Callers cast the result to the appropriate TypedDict (WorkflowRow or JobRow).
92+
"""
93+
out = dict(row)
94+
out["completed_at"] = out.pop("end_time", None)
95+
return out
9396

9497

9598
def safe_json_loads(value: str | None) -> dict[str, Any] | None:
@@ -135,7 +138,7 @@ def get_workflow(conn: sqlite3.Connection, workflow_id: str) -> WorkflowRow | No
135138
).fetchone()
136139
if row is None:
137140
return None
138-
return cast("WorkflowRow", _rename_end_time_column(dict(row)))
141+
return cast("WorkflowRow", _rename_end_time_column(row))
139142

140143

141144
def get_rules(conn: sqlite3.Connection, workflow_id: str) -> list[RuleRow]:
@@ -154,7 +157,7 @@ def get_jobs(conn: sqlite3.Connection, workflow_id: str) -> list[JobRow]:
154157
"WHERE j.workflow_id = ?",
155158
(workflow_id,),
156159
).fetchall()
157-
return [cast("JobRow", _rename_end_time_column(dict(r))) for r in rows]
160+
return [cast("JobRow", _rename_end_time_column(r)) for r in rows]
158161

159162

160163
def get_job_files_by_snakemake_id(
@@ -254,7 +257,7 @@ def get_jobs_by_rule(
254257
"WHERE j.workflow_id = ? AND r.name = ?", # noqa: S608 — parameterized query
255258
(workflow_id, rule_name),
256259
).fetchall()
257-
return [cast("JobRow", _rename_end_time_column(dict(r))) for r in rows]
260+
return [cast("JobRow", _rename_end_time_column(r)) for r in rows]
258261

259262

260263
def get_rule_by_name(

app/store.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -142,24 +142,8 @@ def restore_from_disk(self) -> None:
142142
kwargs = {k: data[k] for k in _PERSIST_FIELDS}
143143
record = JobRecord(**kwargs)
144144
self._jobs[record.job_id] = record
145-
except json.JSONDecodeError as exc:
146-
logger.warning(
147-
"Skipping corrupt job file (malformed JSON): %s: %s", json_path, exc
148-
)
149-
except KeyError as exc:
150-
logger.warning(
151-
"Skipping corrupt job file (missing field %s): %s", exc, json_path
152-
)
153-
except ValueError as exc:
154-
logger.warning(
155-
"Skipping corrupt job file (invalid field value): %s: %s",
156-
json_path,
157-
exc,
158-
)
159-
except TypeError as exc:
160-
logger.warning(
161-
"Skipping corrupt job file (type error): %s: %s", json_path, exc
162-
)
145+
except (json.JSONDecodeError, KeyError, ValueError, TypeError) as exc:
146+
logger.warning("Skipping corrupt job file: %s: %s", json_path, exc)
163147

164148
def mark_setup(
165149
self, job_id: str, work_dir: str, git_ref: str | None, git_sha: str | None

0 commit comments

Comments
 (0)