Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 77 additions & 34 deletions tui/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -841,9 +841,22 @@ async function checkJobAlive(job: Job): Promise<boolean> {
if (job.exec_mode !== "tmux" || job.tmux_sessions.length === 0) {
return false;
}
const results = await batchCheckJobsAlive([job]);
return results.get(job.id) ?? false;
}

/**
* Batch check multiple jobs' tmux sessions in a single Python process.
* Returns a Map<jobId, isAlive>.
*/
async function batchCheckJobsAlive(jobs: Job[]): Promise<Map<string, boolean>> {
const results = new Map<string, boolean>();
const tmuxJobs = jobs.filter(j => j.exec_mode === "tmux" && j.tmux_sessions.length > 0);

const tmpFile = `/tmp/opensmi-check-${crypto.randomUUID()}.json`;
await Bun.write(tmpFile, JSON.stringify(job));
if (tmuxJobs.length === 0) return results;

const tmpFile = `/tmp/opensmi-batch-check-${crypto.randomUUID()}.json`;
await Bun.write(tmpFile, JSON.stringify(tmuxJobs));

const checkScript = `
import sys, json
Expand All @@ -854,37 +867,39 @@ from opensmi.state import resolve_config_path
import asyncio

with open("${tmpFile}", "r") as f:
job_data = json.load(f)

job = Job(
id=job_data["id"],
command=job_data["command"],
commands=job_data["commands"],
gpus=[tuple(g) for g in job_data["gpus"]],
requested_gpu_count=job_data["requested_gpu_count"],
dist_mode=job_data["dist_mode"],
exec_mode=job_data["exec_mode"],
tmux_sessions=job_data["tmux_sessions"],
status=job_data["status"],
submitted_at=job_data["submitted_at"],
started_at=job_data.get("started_at"),
finished_at=job_data.get("finished_at"),
exit_codes=job_data["exit_codes"],
error=job_data.get("error"),
user=job_data["user"],
restart_policy=job_data["restart_policy"],
retry_count=job_data["retry_count"],
max_retries=job_data["max_retries"],
tags=job_data["tags"],
queue_mode=job_data["queue_mode"],
)
jobs_data = json.load(f)

cfg_path = resolve_config_path()
cfg = load_config(cfg_path)

async def main():
alive = await check_job_alive(job, cfg)
print("true" if alive else "false")
results = {}
for jd in jobs_data:
job = Job(
id=jd["id"],
command=jd["command"],
commands=jd["commands"],
gpus=[tuple(g) for g in jd["gpus"]],
requested_gpu_count=jd["requested_gpu_count"],
dist_mode=jd["dist_mode"],
exec_mode=jd["exec_mode"],
tmux_sessions=jd["tmux_sessions"],
status=jd["status"],
submitted_at=jd["submitted_at"],
started_at=jd.get("started_at"),
finished_at=jd.get("finished_at"),
exit_codes=jd["exit_codes"],
error=jd.get("error"),
user=jd["user"],
restart_policy=jd["restart_policy"],
retry_count=jd["retry_count"],
max_retries=jd["max_retries"],
tags=jd["tags"],
queue_mode=jd["queue_mode"],
)
alive = await check_job_alive(job, cfg)
results[jd["id"]] = alive
print(json.dumps(results))

asyncio.run(main())
`;
Expand All @@ -905,13 +920,20 @@ asyncio.run(main())
} catch {}

if (code !== 0) {
return false;
// Fallback: all unknown = not alive
for (const j of tmuxJobs) results.set(j.id, false);
return results;
}

return stdout.trim() === "true";
const parsed = JSON.parse(stdout.trim());
for (const [id, alive] of Object.entries(parsed)) {
results.set(id, alive as boolean);
}
return results;
} catch (e) {
console.error(`Failed to check job ${job.id} alive status:`, e);
return false;
console.error("Failed batch job alive check:", e);
for (const j of tmuxJobs) results.set(j.id, false);
return results;
}
}

Expand All @@ -922,10 +944,12 @@ async function watchRunningJobs(): Promise<void> {
return;
}

// Check each running job's tmux session health
// Batch check all running jobs in a single Python process
const aliveMap = await batchCheckJobsAlive(runningJobs);

for (const job of runningJobs) {
try {
const alive = await checkJobAlive(job);
const alive = aliveMap.get(job.id) ?? false;

if (!alive) {
// Tmux session terminated - determine restart behavior based on policy
Expand Down Expand Up @@ -1003,6 +1027,23 @@ else:
async function executeJobRemote(job: Job): Promise<void> {
const tmuxSessions: string[] = [];

async function rollbackSessions() {
for (const session of tmuxSessions) {
try {
// Extract node from session name: opensmi-{id}-{node}[-gpu{n}]
const parts = session.replace(`opensmi-${job.id}-`, "").replace(/-gpu\d+$/, "");
const node = parts || (job.gpus[0]?.[0]);
if (!node) continue;
await executeRemoteExec({
node,
gpusCsv: "0",
mode: "direct",
command: `tmux kill-session -t ${session} 2>/dev/null || true`,
});
} catch {}
}
}

if (job.dist_mode === "single") {
const nodesByGpu = new Map<string, number[]>();

Expand All @@ -1028,6 +1069,7 @@ async function executeJobRemote(job: Job): Promise<void> {
});

if (!payload.ok) {
await rollbackSessions();
throw new Error(`Failed to execute on ${node}: ${payload.rawStderr.trim()}`);
}

Expand Down Expand Up @@ -1058,6 +1100,7 @@ async function executeJobRemote(job: Job): Promise<void> {
});

if (!payload.ok) {
await rollbackSessions();
throw new Error(`Failed to execute on ${node}:GPU${gpu}: ${payload.rawStderr.trim()}`);
}

Expand Down
Loading