diff --git a/docs/serverless/worker.md b/docs/serverless/worker.md index 5373ed82..9fdf8b72 100644 --- a/docs/serverless/worker.md +++ b/docs/serverless/worker.md @@ -58,6 +58,12 @@ For more complex operations where you are downloading files or making changes to return {"output": "Job completed successfully"} ``` +## Stopping Individual Jobs + +A worker can process more than one job concurrently. When a single request is cancelled, expires, or times out, the Runpod server signals the worker to stop just that request without affecting the worker's other in-progress jobs. The worker long-polls a dedicated stop channel so these signals arrive with low latency, and it cancels the task running the matching job, so a stopped job no longer consumes worker time. + +No handler changes are required to support this. Handlers that hold resources can perform cleanup by catching `asyncio.CancelledError` in async handlers. + ## See Also - [Worker Fitness Checks](./worker_fitness_checks.md) - Validate your worker environment at startup diff --git a/runpod/serverless/modules/rp_job.py b/runpod/serverless/modules/rp_job.py index 614c45e5..e85ad146 100644 --- a/runpod/serverless/modules/rp_job.py +++ b/runpod/serverless/modules/rp_job.py @@ -26,6 +26,62 @@ job_progress = JobsProgress() +def _job_stop_url() -> Optional[str]: + """ + Prepare the URL for the worker's dedicated stop channel. + + Derived from the job-take URL so it points at the same endpoint and worker. + Returns None when the job-take URL is not in the expected form. + """ + base_url = JOB_GET_URL.split("?")[0] + if "/job-take/" not in base_url: + return None + return base_url.replace("/job-take/", "/job-stop/") + + +async def get_stop_signals(session: ClientSession) -> List[str]: + """ + Long-poll the dedicated stop channel for request ids the worker should stop. + + The server holds the request open until a stop signal is available or the + poll times out, so cancellations and timeouts reach the worker without + waiting for the next heartbeat. + + Returns: + A list of request ids to stop. Empty when the poll returned no signals. + """ + stop_url = _job_stop_url() + if not stop_url: + return [] + + async with session.get(stop_url) as response: + if response.status == 204: + return [] + + if response.status == 429: + raise TooManyRequests( + response.request_info, + response.history, + status=response.status, + message=response.reason, + ) + + response.raise_for_status() + + if response.content_type != "application/json": + return [] + + try: + payload = await response.json() + except (aiohttp.ContentTypeError, ValueError): + return [] + + if not isinstance(payload, dict): + return [] + + return [job_id for job_id in payload.get("jobsToStop", []) if isinstance(job_id, str)] + + def _job_get_url(batch_size: int = 1): """ Prepare the URL for making a 'get' request to the serverless API (sls). diff --git a/runpod/serverless/modules/rp_scale.py b/runpod/serverless/modules/rp_scale.py index 8c27543e..9e03a0d3 100644 --- a/runpod/serverless/modules/rp_scale.py +++ b/runpod/serverless/modules/rp_scale.py @@ -10,7 +10,7 @@ from typing import Any, Dict, Set from ...http_client import AsyncClientSession, ClientSession, TooManyRequests -from .rp_job import get_job, handle_job +from .rp_job import get_job, get_stop_signals, handle_job from .rp_logger import RunPodLogger from .worker_state import JobsProgress, IS_LOCAL_TEST @@ -48,6 +48,12 @@ def __init__(self, config: Dict[str, Any]): self.config = config self.job_progress = JobsProgress() # Cache the singleton instance + # maps in-progress job ids to their running tasks so individual jobs + # can be stopped without killing the whole worker + self.jobs_tasks: Dict[str, asyncio.Task] = {} + + self.stop_signals_fetcher = get_stop_signals + self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency) self.concurrency_modifier = _default_concurrency_modifier @@ -71,6 +77,9 @@ def __init__(self, config: Dict[str, Any]): if jobs_handler := self.config.get("jobs_handler"): self.jobs_handler = jobs_handler + if stop_signals_fetcher := self.config.get("stop_signals_fetcher"): + self.stop_signals_fetcher = stop_signals_fetcher + async def set_scale(self): self.current_concurrency = self.concurrency_modifier(self.current_concurrency) @@ -128,8 +137,9 @@ async def run(self): # Create tasks for getting and running jobs. jobtake_task = asyncio.create_task(self.get_jobs(session)) jobrun_task = asyncio.create_task(self.run_jobs(session)) + jobstop_task = asyncio.create_task(self.monitor_stop_signals(session)) - tasks = [jobtake_task, jobrun_task] + tasks = [jobtake_task, jobrun_task, jobstop_task] # Concurrently run both tasks and wait for both to finish. await asyncio.gather(*tasks) @@ -226,9 +236,10 @@ async def run_jobs(self, session: ClientSession): # Fetch as many jobs as the concurrency allows while len(tasks) < self.current_concurrency and not self.jobs_queue.empty(): job = await self.jobs_queue.get() - # Create a new task for each job and add it to the task list + # Create a new task for each job and track it by job id task = asyncio.create_task(self.handle_job(session, job)) tasks.add(task) + self.jobs_tasks[job["id"]] = task # Wait for any job to finish if tasks: @@ -250,7 +261,51 @@ async def run_jobs(self, session: ClientSession): # Ensure all remaining tasks finish before stopping - await asyncio.gather(*tasks) + await asyncio.gather(*tasks, return_exceptions=True) + + async def monitor_stop_signals(self, session: ClientSession): + """ + Long-polls the dedicated stop channel and stops signalled jobs. + + Runs in an infinite loop while the worker is alive. The Runpod server + signals a request to be stopped (for example when it is cancelled or + times out) and this loop stops just that in-progress job, leaving the + worker's other jobs running. + """ + while self.is_alive(): + try: + job_ids = await self.stop_signals_fetcher(session) + for job_id in job_ids: + await self.stop_job(job_id) + except TooManyRequests: + await asyncio.sleep(5) # debounce + except asyncio.CancelledError: + raise + except Exception as error: + log.debug(f"JobScaler.monitor_stop_signals | Error: {error}.") + await asyncio.sleep(1) # don't spin on persistent errors + finally: + await asyncio.sleep(0) + + async def stop_job(self, job_id: str) -> bool: + """ + Stop a single in-progress job by cancelling its running task. + + Args: + job_id: The id of the job to stop. + + Returns: + True if a matching in-progress job was found and stopped, + False otherwise. + """ + task = self.jobs_tasks.get(job_id) + if task is None: + log.debug(f"JobScaler.stop_job | No in-progress job for {job_id}.") + return False + + log.info("Stopping job.", job_id) + task.cancel() + return True async def handle_job(self, session: ClientSession, job: dict): """ @@ -268,11 +323,16 @@ async def handle_job(self, session: ClientSession, job: dict): log.error(f"Error handling job: {err}", job["id"]) raise err + except asyncio.CancelledError: + log.info("Job stopped.", job["id"]) + raise + finally: # Inform Queue of a task completion self.jobs_queue.task_done() # Job is no longer in progress self.job_progress.remove(job) + self.jobs_tasks.pop(job["id"], None) log.debug("Finished Job", job["id"]) diff --git a/tests/test_serverless/test_modules/test_job.py b/tests/test_serverless/test_modules/test_job.py index 7abcefe3..6a1df3de 100644 --- a/tests/test_serverless/test_modules/test_job.py +++ b/tests/test_serverless/test_modules/test_job.py @@ -150,6 +150,76 @@ async def test_get_job_exception(self): self.assertEqual(str(context.exception), "Unexpected error") +class TestGetStopSignals(IsolatedAsyncioTestCase): + """Tests for the get_stop_signals function.""" + + STOP_TAKE_URL = "http://mock.url/v2/ep/job-take/pod?gpu=x" + + def test_job_stop_url_derived_from_job_take(self): + with patch("runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL): + assert rp_job._job_stop_url() == "http://mock.url/v2/ep/job-stop/pod" + + def test_job_stop_url_none_when_not_job_take(self): + with patch("runpod.serverless.modules.rp_job.JOB_GET_URL", "http://mock.url/other"): + assert rp_job._job_stop_url() is None + + async def test_get_stop_signals_200(self): + response = Mock(ClientResponse) + response.status = 200 + response.content_type = "application/json" + response.json = make_mocked_coro(return_value={"jobsToStop": ["a", "b", 5]}) + + with patch("aiohttp.ClientSession") as mock_session, patch( + "runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL + ): + mock_session.get.return_value.__aenter__.return_value = response + result = await rp_job.get_stop_signals(mock_session) + self.assertEqual(result, ["a", "b"]) + + async def test_get_stop_signals_204(self): + response = Mock(ClientResponse) + response.status = 204 + + with patch("aiohttp.ClientSession") as mock_session, patch( + "runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL + ): + mock_session.get.return_value.__aenter__.return_value = response + result = await rp_job.get_stop_signals(mock_session) + self.assertEqual(result, []) + + async def test_get_stop_signals_429(self): + response = Mock(ClientResponse) + response.status = 429 + + with patch("aiohttp.ClientSession") as mock_session, patch( + "runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL + ): + mock_session.get.return_value.__aenter__.return_value = response + with self.assertRaises(TooManyRequests): + await rp_job.get_stop_signals(mock_session) + + async def test_get_stop_signals_no_url(self): + with patch("aiohttp.ClientSession") as mock_session, patch( + "runpod.serverless.modules.rp_job.JOB_GET_URL", "http://mock.url/other" + ): + result = await rp_job.get_stop_signals(mock_session) + self.assertEqual(result, []) + mock_session.get.assert_not_called() + + async def test_get_stop_signals_non_dict_payload(self): + response = Mock(ClientResponse) + response.status = 200 + response.content_type = "application/json" + response.json = make_mocked_coro(return_value=["not", "a", "dict"]) + + with patch("aiohttp.ClientSession") as mock_session, patch( + "runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL + ): + mock_session.get.return_value.__aenter__.return_value = response + result = await rp_job.get_stop_signals(mock_session) + self.assertEqual(result, []) + + class TestRunJob(IsolatedAsyncioTestCase): """Tests the run_job function""" diff --git a/tests/test_serverless/test_rp_scale.py b/tests/test_serverless/test_rp_scale.py index c1a029b9..39069db4 100644 --- a/tests/test_serverless/test_rp_scale.py +++ b/tests/test_serverless/test_rp_scale.py @@ -238,6 +238,92 @@ async def handler(_session, _config, job): scaler.kill_worker() +@pytest.mark.asyncio +async def test_stop_job_cancels_inflight_task(job_scaler: PatchScaler): + scaler = job_scaler.scaler + job_started = asyncio.Event() + cancelled = [] + + async def handler(_session, _config, job): + job_started.set() + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + cancelled.append(job["id"]) + raise + + scaler.jobs_handler = handler + scaler.current_concurrency = 1 + scaler.jobs_queue = asyncio.Queue(maxsize=1) + run_task = asyncio.create_task(scaler.run_jobs(None)) + + await scaler.jobs_queue.put(generate_job("stop-me")) + await asyncio.wait_for(job_started.wait(), timeout=2) + + assert "stop-me" in scaler.jobs_tasks + assert await scaler.stop_job("stop-me") is True + + scaler.kill_worker() + await asyncio.wait_for(run_task, timeout=2) + + assert cancelled == ["stop-me"] + assert "stop-me" not in scaler.jobs_tasks + assert job_scaler.progress.count == 0 + + scaler.kill_worker() + + +@pytest.mark.asyncio +async def test_stop_job_unknown_id_returns_false(job_scaler: PatchScaler): + scaler = job_scaler.scaler + assert await scaler.stop_job("does-not-exist") is False + scaler.kill_worker() + + +@pytest.mark.asyncio +async def test_monitor_stop_signals_stops_jobs(job_scaler: PatchScaler): + scaler = job_scaler.scaler + stopped = [] + + async def fake_stop_job(job_id): + stopped.append(job_id) + return True + + async def fetcher(_session): + if not stopped: + return ["job-a", "job-b"] + return [] + + scaler.stop_job = fake_stop_job + scaler.stop_signals_fetcher = fetcher + + monitor_task = asyncio.create_task(scaler.monitor_stop_signals(AsyncMock())) + await asyncio.sleep(0.05) + scaler.kill_worker() + await asyncio.wait_for(monitor_task, timeout=2) + + assert sorted(stopped) == ["job-a", "job-b"] + + +@pytest.mark.asyncio +async def test_monitor_stop_signals_survives_errors(job_scaler: PatchScaler): + scaler = job_scaler.scaler + calls = {"value": 0} + + async def fetcher(_session): + calls["value"] += 1 + raise RuntimeError("boom") + + scaler.stop_signals_fetcher = fetcher + + monitor_task = asyncio.create_task(scaler.monitor_stop_signals(AsyncMock())) + await asyncio.sleep(0.05) + scaler.kill_worker() + await asyncio.wait_for(monitor_task, timeout=2) + + assert calls["value"] >= 1 + + @pytest.mark.asyncio async def test_get_jobs_feeds_workers_end_to_end(job_scaler: PatchScaler): scaler = job_scaler.scaler