diff --git a/tui/index.ts b/tui/index.ts index 2ced3f0..5a943b7 100644 --- a/tui/index.ts +++ b/tui/index.ts @@ -841,9 +841,22 @@ async function checkJobAlive(job: Job): Promise { if (job.exec_mode !== "tmux" || job.tmux_sessions.length === 0) { return false; } + const results = await batchCheckJobsAlive([job]); + return results.get(job.id) ?? false; +} + +/** + * Batch check multiple jobs' tmux sessions in a single Python process. + * Returns a Map. + */ +async function batchCheckJobsAlive(jobs: Job[]): Promise> { + const results = new Map(); + const tmuxJobs = jobs.filter(j => j.exec_mode === "tmux" && j.tmux_sessions.length > 0); - const tmpFile = `/tmp/opensmi-check-${crypto.randomUUID()}.json`; - await Bun.write(tmpFile, JSON.stringify(job)); + if (tmuxJobs.length === 0) return results; + + const tmpFile = `/tmp/opensmi-batch-check-${crypto.randomUUID()}.json`; + await Bun.write(tmpFile, JSON.stringify(tmuxJobs)); const checkScript = ` import sys, json @@ -854,37 +867,39 @@ from opensmi.state import resolve_config_path import asyncio with open("${tmpFile}", "r") as f: - job_data = json.load(f) - -job = Job( - id=job_data["id"], - command=job_data["command"], - commands=job_data["commands"], - gpus=[tuple(g) for g in job_data["gpus"]], - requested_gpu_count=job_data["requested_gpu_count"], - dist_mode=job_data["dist_mode"], - exec_mode=job_data["exec_mode"], - tmux_sessions=job_data["tmux_sessions"], - status=job_data["status"], - submitted_at=job_data["submitted_at"], - started_at=job_data.get("started_at"), - finished_at=job_data.get("finished_at"), - exit_codes=job_data["exit_codes"], - error=job_data.get("error"), - user=job_data["user"], - restart_policy=job_data["restart_policy"], - retry_count=job_data["retry_count"], - max_retries=job_data["max_retries"], - tags=job_data["tags"], - queue_mode=job_data["queue_mode"], -) + jobs_data = json.load(f) cfg_path = resolve_config_path() cfg = load_config(cfg_path) async def main(): - alive = await check_job_alive(job, cfg) - print("true" if alive else "false") + results = {} + for jd in jobs_data: + job = Job( + id=jd["id"], + command=jd["command"], + commands=jd["commands"], + gpus=[tuple(g) for g in jd["gpus"]], + requested_gpu_count=jd["requested_gpu_count"], + dist_mode=jd["dist_mode"], + exec_mode=jd["exec_mode"], + tmux_sessions=jd["tmux_sessions"], + status=jd["status"], + submitted_at=jd["submitted_at"], + started_at=jd.get("started_at"), + finished_at=jd.get("finished_at"), + exit_codes=jd["exit_codes"], + error=jd.get("error"), + user=jd["user"], + restart_policy=jd["restart_policy"], + retry_count=jd["retry_count"], + max_retries=jd["max_retries"], + tags=jd["tags"], + queue_mode=jd["queue_mode"], + ) + alive = await check_job_alive(job, cfg) + results[jd["id"]] = alive + print(json.dumps(results)) asyncio.run(main()) `; @@ -905,13 +920,20 @@ asyncio.run(main()) } catch {} if (code !== 0) { - return false; + // Fallback: all unknown = not alive + for (const j of tmuxJobs) results.set(j.id, false); + return results; } - return stdout.trim() === "true"; + const parsed = JSON.parse(stdout.trim()); + for (const [id, alive] of Object.entries(parsed)) { + results.set(id, alive as boolean); + } + return results; } catch (e) { - console.error(`Failed to check job ${job.id} alive status:`, e); - return false; + console.error("Failed batch job alive check:", e); + for (const j of tmuxJobs) results.set(j.id, false); + return results; } } @@ -922,10 +944,12 @@ async function watchRunningJobs(): Promise { return; } - // Check each running job's tmux session health + // Batch check all running jobs in a single Python process + const aliveMap = await batchCheckJobsAlive(runningJobs); + for (const job of runningJobs) { try { - const alive = await checkJobAlive(job); + const alive = aliveMap.get(job.id) ?? false; if (!alive) { // Tmux session terminated - determine restart behavior based on policy @@ -1003,6 +1027,23 @@ else: async function executeJobRemote(job: Job): Promise { const tmuxSessions: string[] = []; + async function rollbackSessions() { + for (const session of tmuxSessions) { + try { + // Extract node from session name: opensmi-{id}-{node}[-gpu{n}] + const parts = session.replace(`opensmi-${job.id}-`, "").replace(/-gpu\d+$/, ""); + const node = parts || (job.gpus[0]?.[0]); + if (!node) continue; + await executeRemoteExec({ + node, + gpusCsv: "0", + mode: "direct", + command: `tmux kill-session -t ${session} 2>/dev/null || true`, + }); + } catch {} + } + } + if (job.dist_mode === "single") { const nodesByGpu = new Map(); @@ -1028,6 +1069,7 @@ async function executeJobRemote(job: Job): Promise { }); if (!payload.ok) { + await rollbackSessions(); throw new Error(`Failed to execute on ${node}: ${payload.rawStderr.trim()}`); } @@ -1058,6 +1100,7 @@ async function executeJobRemote(job: Job): Promise { }); if (!payload.ok) { + await rollbackSessions(); throw new Error(`Failed to execute on ${node}:GPU${gpu}: ${payload.rawStderr.trim()}`); }