Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,7 @@ WORKDIR /app
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt

CMD ["gunicorn", "-c", "gunicorn.conf.py", "app:app"]
# Make logs directory exist
RUN mkdir -p logs

CMD ["python", "main.py"]
4 changes: 4 additions & 0 deletions agent/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""GH-style runner agent: polls Backend, dispatches jobs, sends results.

Distinct from Sandbox/runner/ which contains the Docker execution code.
"""
116 changes: 116 additions & 0 deletions agent/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""HTTP client for the Backend runner API."""
import requests

from . import config


class BackendClient:
"""Thin wrapper around requests, adding auth, base URL, error mapping."""

class AuthError(Exception):
"""Raised when backend rejects auth (401)."""

class TransientError(Exception):
"""Raised on 5xx or network errors — caller should retry."""

def __init__(self, base_url: str = None, rk_token: str = None):
self.base_url = base_url or config.BACKEND_URL
self.rk_token = rk_token

# ------- Public API -------

def register(self, name: str, registration_token: str) -> dict:
"""Register this runner. Returns the `data` payload from backend."""
rv = self._request(
"POST",
"/runners/register",
json_body={
"registration_token": registration_token,
"name": name
},
need_auth=False,
expected_statuses=(201, ),
)
return rv.json()["data"]

def heartbeat(self, runner_id: str) -> None:
"""Send a heartbeat. Raises AuthError on 401."""
self._request(
"POST",
f"/runners/{runner_id}/heartbeat",
expected_statuses=(204, ),
)

def next_job(self, runner_id: str) -> dict | None:
"""Poll for next job. Returns None if no job available (204)."""
rv = self._request(
"GET",
f"/runners/{runner_id}/next-job",
expected_statuses=(200, 204),
)
if rv.status_code == 204:
return None
return rv.json()["data"]

def complete_job(self, runner_id: str, job_id: str, tasks: list) -> str:
"""Send result. Returns 'ok' / 'reclaimed' / 'not_found'.

Raises TransientError on 5xx or network — caller should retry.
"""
rv = self._request(
"PUT",
f"/runners/{runner_id}/jobs/{job_id}/complete",
json_body={"tasks": tasks},
expected_statuses=(204, 409, 404),
)
return {204: "ok", 409: "reclaimed", 404: "not_found"}[rv.status_code]

def download_code(self, code_url: str, dest_path: str) -> None:
"""Download code zip from a presigned URL."""
try:
with requests.get(code_url,
stream=True,
timeout=config.HTTP_REQUEST_TIMEOUT_SEC) as r:
r.raise_for_status()
with open(dest_path, "wb") as f:
for chunk in r.iter_content(chunk_size=64 * 1024):
f.write(chunk)
except requests.RequestException as e:
raise self.TransientError(f"code download failed: {e}") from e

# ------- Internals -------

def _request(
self,
method: str,
path: str,
*,
json_body=None,
need_auth=True,
expected_statuses=(200, )) -> requests.Response:
headers = {}
if need_auth:
if not self.rk_token:
raise self.AuthError("rk_token not set")
headers["Authorization"] = f"Bearer {self.rk_token}"
if json_body is not None:
headers["Content-Type"] = "application/json"
try:
rv = requests.request(
method=method,
url=f"{self.base_url}{path}",
json=json_body,
headers=headers,
timeout=config.HTTP_REQUEST_TIMEOUT_SEC,
)
except requests.RequestException as e:
raise self.TransientError(f"network error: {e}") from e

if rv.status_code == 401:
raise self.AuthError(rv.text)
if rv.status_code >= 500:
raise self.TransientError(f"backend {rv.status_code}: {rv.text}")
if rv.status_code not in expected_statuses:
raise self.TransientError(
f"unexpected status {rv.status_code}: {rv.text}")
return rv
32 changes: 32 additions & 0 deletions agent/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Environment-driven configuration for the runner agent."""
import os
from pathlib import Path

# Backend URL — where to register, poll, send results
BACKEND_URL: str = os.getenv("BACKEND_URL", "http://web:8080")

# Shared registration secret (must match backend's RUNNER_REGISTRATION_TOKEN).
RUNNER_REGISTRATION_TOKEN: str = os.getenv(
"RUNNER_REGISTRATION_TOKEN",
"dev-only-registration-token-change-me",
)

# Display name shown in admin/listing. Defaults to hostname.
RUNNER_NAME: str = os.getenv("RUNNER_NAME", os.uname().nodename)

# Tunings (defaults match what backend returns from register; override here is rarely needed)
HEARTBEAT_INTERVAL_SEC: int = int(os.getenv("HEARTBEAT_INTERVAL_SEC", "15"))
POLL_INTERVAL_SEC: int = int(os.getenv("POLL_INTERVAL_SEC", "3"))

# Result delivery retry policy
RESULT_RETRY_MAX_ATTEMPTS: int = 5
RESULT_RETRY_INITIAL_BACKOFF_SEC: float = 1.0
RESULT_RETRY_MAX_BACKOFF_SEC: float = 16.0

# HTTP timeouts
HTTP_REQUEST_TIMEOUT_SEC: int = 10

# Where to download code zip to (per-job temp dir)
CODE_DOWNLOAD_DIR: Path = Path(
os.getenv("CODE_DOWNLOAD_DIR", "/tmp/runner-code"))
CODE_DOWNLOAD_DIR.mkdir(exist_ok=True)
39 changes: 39 additions & 0 deletions agent/heartbeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Heartbeat daemon thread: refreshes runner alive TTL on backend."""
import logging
import threading

from .client import BackendClient

log = logging.getLogger(__name__)


class HeartbeatThread(threading.Thread):
"""Periodically POSTs heartbeat. Tolerates transient errors silently."""

def __init__(
self,
client: BackendClient,
runner_id: str,
interval_sec: float,
shutdown_event: threading.Event,
):
super().__init__(daemon=True, name="heartbeat")
self.client = client
self.runner_id = runner_id
self.interval_sec = interval_sec
self.shutdown_event = shutdown_event

def run(self) -> None:
while not self.shutdown_event.is_set():
try:
self.client.heartbeat(runner_id=self.runner_id)
except BackendClient.TransientError as e:
log.warning(f"heartbeat failed (transient): {e}")
except BackendClient.AuthError as e:
# Auth fail means the backend forgot us (e.g., Redis loss).
# Caller will need to re-register; for now just log.
log.error(f"heartbeat auth failed: {e}")
except Exception as e: # defensive — never let thread die
log.exception(f"heartbeat unexpected error: {e}")
# Wait, but break early on shutdown
self.shutdown_event.wait(timeout=self.interval_sec)
97 changes: 97 additions & 0 deletions agent/poller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""Poller daemon thread: pulls jobs from backend, hands to dispatcher."""
import logging
import tempfile
import threading

from dispatcher.constant import Language
from dispatcher.testdata import (
ensure_testdata,
get_problem_meta,
get_problem_root,
)
from .client import BackendClient

log = logging.getLogger(__name__)


def prepare_submission_dir_for_job(dispatcher, job: dict,
client: BackendClient):
"""Download code + ensure testdata + extract into dispatcher's submission dir.

Reuses the existing dispatcher.prepare_submission_dir() — same testdata
fetching path as the old POST /submit handler.
"""
submission_id = job["submission_id"]
problem_id = job["problem_id"]
language = Language(job["language"])

ensure_testdata(problem_id)
meta = get_problem_meta(problem_id, language)

with tempfile.NamedTemporaryFile(suffix=".zip") as tmp:
client.download_code(job["code_url"], tmp.name)
with open(tmp.name, "rb") as src:
dispatcher.prepare_submission_dir(
root_dir=dispatcher.SUBMISSION_DIR,
submission_id=submission_id,
meta=meta,
source=src,
testdata=get_problem_root(problem_id),
)


class PollerThread(threading.Thread):
"""Polls backend for jobs and dispatches them to the internal dispatcher."""

def __init__(
self,
client: BackendClient,
runner_id: str,
dispatcher, # existing Dispatcher instance
poll_interval_sec: float,
shutdown_event: threading.Event,
):
super().__init__(daemon=True, name="poller")
self.client = client
self.runner_id = runner_id
self.dispatcher = dispatcher
self.poll_interval_sec = poll_interval_sec
self.shutdown_event = shutdown_event

def run(self) -> None:
while not self.shutdown_event.is_set():
if not self.dispatcher.has_capacity():
self.shutdown_event.wait(timeout=0.5)
continue
try:
job = self.client.next_job(runner_id=self.runner_id)
except BackendClient.TransientError as e:
log.warning(f"next_job failed: {e}")
self.shutdown_event.wait(timeout=self.poll_interval_sec)
continue
except BackendClient.AuthError as e:
log.error(f"next_job auth failed: {e}")
self.shutdown_event.wait(timeout=self.poll_interval_sec)
continue
except Exception as e:
log.warning(f"next_job unexpected error: {e}")
self.shutdown_event.wait(timeout=self.poll_interval_sec)
continue

if job is None:
self.shutdown_event.wait(timeout=self.poll_interval_sec)
continue

try:
prepare_submission_dir_for_job(self.dispatcher, job,
self.client)
self.dispatcher.handle(
submission_id=job["submission_id"],
job_id=job["job_id"],
)
log.info(f"dispatched submission={job['submission_id']} "
f"job={job['job_id']}")
except Exception as e:
log.exception(
f"failed to dispatch job {job.get('job_id')}: {e}")
# Don't retry — backend will reclaim after lease expiry.
33 changes: 33 additions & 0 deletions agent/registration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Self-registration on startup."""
from dataclasses import dataclass

from .client import BackendClient


@dataclass(frozen=True)
class RunnerCredentials:
runner_id: str
token: str
heartbeat_interval_sec: int
poll_interval_sec: int
max_concurrent_jobs: int


def register_runner(
client: BackendClient,
name: str,
registration_token: str,
) -> RunnerCredentials:
"""Call backend's register endpoint and return RunnerCredentials.

Raises BackendClient.AuthError or TransientError on failure.
"""
rv = client.register(name=name, registration_token=registration_token)
cfg = rv.get("config", {})
return RunnerCredentials(
runner_id=rv["runner_id"],
token=rv["token"],
heartbeat_interval_sec=cfg.get("heartbeat_interval_sec", 15),
poll_interval_sec=cfg.get("poll_interval_sec", 3),
max_concurrent_jobs=cfg.get("max_concurrent_jobs", 8),
)
Loading
Loading