Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/serverless/worker.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ For more complex operations where you are downloading files or making changes to
return {"output": "Job completed successfully"}
```

## Stopping Individual Jobs

A worker can process more than one job concurrently. When a single request is cancelled, expires, or times out, the Runpod server signals the worker to stop just that request without affecting the worker's other in-progress jobs. The worker long-polls a dedicated stop channel so these signals arrive with low latency, and it cancels the task running the matching job, so a stopped job no longer consumes worker time.

No handler changes are required to support this. Handlers that hold resources can perform cleanup by catching `asyncio.CancelledError` in async handlers.

## See Also

- [Worker Fitness Checks](./worker_fitness_checks.md) - Validate your worker environment at startup
Expand Down
56 changes: 56 additions & 0 deletions runpod/serverless/modules/rp_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,62 @@
job_progress = JobsProgress()


def _job_stop_url() -> Optional[str]:
"""
Prepare the URL for the worker's dedicated stop channel.

Derived from the job-take URL so it points at the same endpoint and worker.
Returns None when the job-take URL is not in the expected form.
"""
base_url = JOB_GET_URL.split("?")[0]
if "/job-take/" not in base_url:
return None
return base_url.replace("/job-take/", "/job-stop/")


async def get_stop_signals(session: ClientSession) -> List[str]:
"""
Long-poll the dedicated stop channel for request ids the worker should stop.

The server holds the request open until a stop signal is available or the
poll times out, so cancellations and timeouts reach the worker without
waiting for the next heartbeat.

Returns:
A list of request ids to stop. Empty when the poll returned no signals.
"""
stop_url = _job_stop_url()
if not stop_url:
return []

async with session.get(stop_url) as response:
if response.status == 204:
return []

if response.status == 429:
raise TooManyRequests(
response.request_info,
response.history,
status=response.status,
message=response.reason,
)

response.raise_for_status()

if response.content_type != "application/json":
return []

try:
payload = await response.json()
except (aiohttp.ContentTypeError, ValueError):
return []

if not isinstance(payload, dict):
return []

return [job_id for job_id in payload.get("jobsToStop", []) if isinstance(job_id, str)]


def _job_get_url(batch_size: int = 1):
"""
Prepare the URL for making a 'get' request to the serverless API (sls).
Expand Down
68 changes: 64 additions & 4 deletions runpod/serverless/modules/rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, Dict, Set

from ...http_client import AsyncClientSession, ClientSession, TooManyRequests
from .rp_job import get_job, handle_job
from .rp_job import get_job, get_stop_signals, handle_job
from .rp_logger import RunPodLogger
from .worker_state import JobsProgress, IS_LOCAL_TEST

Expand Down Expand Up @@ -48,6 +48,12 @@ def __init__(self, config: Dict[str, Any]):
self.config = config
self.job_progress = JobsProgress() # Cache the singleton instance

# maps in-progress job ids to their running tasks so individual jobs
# can be stopped without killing the whole worker
self.jobs_tasks: Dict[str, asyncio.Task] = {}

self.stop_signals_fetcher = get_stop_signals

self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency)

self.concurrency_modifier = _default_concurrency_modifier
Expand All @@ -71,6 +77,9 @@ def __init__(self, config: Dict[str, Any]):
if jobs_handler := self.config.get("jobs_handler"):
self.jobs_handler = jobs_handler

if stop_signals_fetcher := self.config.get("stop_signals_fetcher"):
self.stop_signals_fetcher = stop_signals_fetcher

async def set_scale(self):
self.current_concurrency = self.concurrency_modifier(self.current_concurrency)

Expand Down Expand Up @@ -128,8 +137,9 @@ async def run(self):
# Create tasks for getting and running jobs.
jobtake_task = asyncio.create_task(self.get_jobs(session))
jobrun_task = asyncio.create_task(self.run_jobs(session))
jobstop_task = asyncio.create_task(self.monitor_stop_signals(session))

tasks = [jobtake_task, jobrun_task]
tasks = [jobtake_task, jobrun_task, jobstop_task]

# Concurrently run both tasks and wait for both to finish.
await asyncio.gather(*tasks)
Expand Down Expand Up @@ -226,9 +236,10 @@ async def run_jobs(self, session: ClientSession):
# Fetch as many jobs as the concurrency allows
while len(tasks) < self.current_concurrency and not self.jobs_queue.empty():
job = await self.jobs_queue.get()
# Create a new task for each job and add it to the task list
# Create a new task for each job and track it by job id
task = asyncio.create_task(self.handle_job(session, job))
tasks.add(task)
self.jobs_tasks[job["id"]] = task

# Wait for any job to finish
if tasks:
Expand All @@ -250,7 +261,51 @@ async def run_jobs(self, session: ClientSession):


# Ensure all remaining tasks finish before stopping
await asyncio.gather(*tasks)
await asyncio.gather(*tasks, return_exceptions=True)

async def monitor_stop_signals(self, session: ClientSession):
"""
Long-polls the dedicated stop channel and stops signalled jobs.

Runs in an infinite loop while the worker is alive. The Runpod server
signals a request to be stopped (for example when it is cancelled or
times out) and this loop stops just that in-progress job, leaving the
worker's other jobs running.
"""
while self.is_alive():
try:
job_ids = await self.stop_signals_fetcher(session)
for job_id in job_ids:
await self.stop_job(job_id)
except TooManyRequests:
await asyncio.sleep(5) # debounce
except asyncio.CancelledError:
raise
except Exception as error:
log.debug(f"JobScaler.monitor_stop_signals | Error: {error}.")
await asyncio.sleep(1) # don't spin on persistent errors
finally:
await asyncio.sleep(0)

async def stop_job(self, job_id: str) -> bool:
"""
Stop a single in-progress job by cancelling its running task.

Args:
job_id: The id of the job to stop.

Returns:
True if a matching in-progress job was found and stopped,
False otherwise.
"""
task = self.jobs_tasks.get(job_id)
if task is None:
log.debug(f"JobScaler.stop_job | No in-progress job for {job_id}.")
return False

log.info("Stopping job.", job_id)
task.cancel()
return True

async def handle_job(self, session: ClientSession, job: dict):
"""
Expand All @@ -268,11 +323,16 @@ async def handle_job(self, session: ClientSession, job: dict):
log.error(f"Error handling job: {err}", job["id"])
raise err

except asyncio.CancelledError:
log.info("Job stopped.", job["id"])
raise

finally:
# Inform Queue of a task completion
self.jobs_queue.task_done()

# Job is no longer in progress
self.job_progress.remove(job)
self.jobs_tasks.pop(job["id"], None)

log.debug("Finished Job", job["id"])
70 changes: 70 additions & 0 deletions tests/test_serverless/test_modules/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,76 @@ async def test_get_job_exception(self):
self.assertEqual(str(context.exception), "Unexpected error")


class TestGetStopSignals(IsolatedAsyncioTestCase):
"""Tests for the get_stop_signals function."""

STOP_TAKE_URL = "http://mock.url/v2/ep/job-take/pod?gpu=x"

def test_job_stop_url_derived_from_job_take(self):
with patch("runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL):
assert rp_job._job_stop_url() == "http://mock.url/v2/ep/job-stop/pod"

def test_job_stop_url_none_when_not_job_take(self):
with patch("runpod.serverless.modules.rp_job.JOB_GET_URL", "http://mock.url/other"):
assert rp_job._job_stop_url() is None

async def test_get_stop_signals_200(self):
response = Mock(ClientResponse)
response.status = 200
response.content_type = "application/json"
response.json = make_mocked_coro(return_value={"jobsToStop": ["a", "b", 5]})

with patch("aiohttp.ClientSession") as mock_session, patch(
"runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL
):
mock_session.get.return_value.__aenter__.return_value = response
result = await rp_job.get_stop_signals(mock_session)
self.assertEqual(result, ["a", "b"])

async def test_get_stop_signals_204(self):
response = Mock(ClientResponse)
response.status = 204

with patch("aiohttp.ClientSession") as mock_session, patch(
"runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL
):
mock_session.get.return_value.__aenter__.return_value = response
result = await rp_job.get_stop_signals(mock_session)
self.assertEqual(result, [])

async def test_get_stop_signals_429(self):
response = Mock(ClientResponse)
response.status = 429

with patch("aiohttp.ClientSession") as mock_session, patch(
"runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL
):
mock_session.get.return_value.__aenter__.return_value = response
with self.assertRaises(TooManyRequests):
await rp_job.get_stop_signals(mock_session)

async def test_get_stop_signals_no_url(self):
with patch("aiohttp.ClientSession") as mock_session, patch(
"runpod.serverless.modules.rp_job.JOB_GET_URL", "http://mock.url/other"
):
result = await rp_job.get_stop_signals(mock_session)
self.assertEqual(result, [])
mock_session.get.assert_not_called()

async def test_get_stop_signals_non_dict_payload(self):
response = Mock(ClientResponse)
response.status = 200
response.content_type = "application/json"
response.json = make_mocked_coro(return_value=["not", "a", "dict"])

with patch("aiohttp.ClientSession") as mock_session, patch(
"runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL
):
mock_session.get.return_value.__aenter__.return_value = response
result = await rp_job.get_stop_signals(mock_session)
self.assertEqual(result, [])


class TestRunJob(IsolatedAsyncioTestCase):
"""Tests the run_job function"""

Expand Down
86 changes: 86 additions & 0 deletions tests/test_serverless/test_rp_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,92 @@ async def handler(_session, _config, job):

scaler.kill_worker()

@pytest.mark.asyncio
async def test_stop_job_cancels_inflight_task(job_scaler: PatchScaler):
scaler = job_scaler.scaler
job_started = asyncio.Event()
cancelled = []

async def handler(_session, _config, job):
job_started.set()
try:
await asyncio.sleep(10)
except asyncio.CancelledError:
cancelled.append(job["id"])
raise

scaler.jobs_handler = handler
scaler.current_concurrency = 1
scaler.jobs_queue = asyncio.Queue(maxsize=1)
run_task = asyncio.create_task(scaler.run_jobs(None))

await scaler.jobs_queue.put(generate_job("stop-me"))
await asyncio.wait_for(job_started.wait(), timeout=2)

assert "stop-me" in scaler.jobs_tasks
assert await scaler.stop_job("stop-me") is True

scaler.kill_worker()
await asyncio.wait_for(run_task, timeout=2)

assert cancelled == ["stop-me"]
assert "stop-me" not in scaler.jobs_tasks
assert job_scaler.progress.count == 0

scaler.kill_worker()


@pytest.mark.asyncio
async def test_stop_job_unknown_id_returns_false(job_scaler: PatchScaler):
scaler = job_scaler.scaler
assert await scaler.stop_job("does-not-exist") is False
scaler.kill_worker()


@pytest.mark.asyncio
async def test_monitor_stop_signals_stops_jobs(job_scaler: PatchScaler):
scaler = job_scaler.scaler
stopped = []

async def fake_stop_job(job_id):
stopped.append(job_id)
return True

async def fetcher(_session):
if not stopped:
return ["job-a", "job-b"]
return []

scaler.stop_job = fake_stop_job
scaler.stop_signals_fetcher = fetcher

monitor_task = asyncio.create_task(scaler.monitor_stop_signals(AsyncMock()))
await asyncio.sleep(0.05)
scaler.kill_worker()
await asyncio.wait_for(monitor_task, timeout=2)

assert sorted(stopped) == ["job-a", "job-b"]


@pytest.mark.asyncio
async def test_monitor_stop_signals_survives_errors(job_scaler: PatchScaler):
scaler = job_scaler.scaler
calls = {"value": 0}

async def fetcher(_session):
calls["value"] += 1
raise RuntimeError("boom")

scaler.stop_signals_fetcher = fetcher

monitor_task = asyncio.create_task(scaler.monitor_stop_signals(AsyncMock()))
await asyncio.sleep(0.05)
scaler.kill_worker()
await asyncio.wait_for(monitor_task, timeout=2)

assert calls["value"] >= 1


@pytest.mark.asyncio
async def test_get_jobs_feeds_workers_end_to_end(job_scaler: PatchScaler):
scaler = job_scaler.scaler
Expand Down