From b5fb475429328dba030c7bf0fd536f5a70439a3d Mon Sep 17 00:00:00 2001 From: zeke <40004347+KAJdev@users.noreply.github.com> Date: Fri, 5 Jun 2026 14:43:36 -0700 Subject: [PATCH 1/2] feat: add per-job stop capability to serverless worker --- .gitignore | 1 + docs/serverless/worker.md | 6 ++ runpod/serverless/modules/rp_ping.py | 34 +++++- runpod/serverless/modules/rp_scale.py | 53 ++++++++- runpod/serverless/modules/worker_state.py | 101 ++++++++++++++++++ .../test_serverless/test_modules/test_ping.py | 70 ++++++++++++ .../test_modules/test_state.py | 58 +++++++++- tests/test_serverless/test_rp_scale.py | 65 +++++++++++ 8 files changed, 382 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index eb63accf..da5fdfab 100644 --- a/.gitignore +++ b/.gitignore @@ -139,6 +139,7 @@ dmypy.json .pyre/ runpod/_version.py .runpod_jobs.pkl +.runpod_jobs_to_stop.pkl *.lock benchmark_results/ diff --git a/docs/serverless/worker.md b/docs/serverless/worker.md index 5373ed82..2f9f33db 100644 --- a/docs/serverless/worker.md +++ b/docs/serverless/worker.md @@ -58,6 +58,12 @@ For more complex operations where you are downloading files or making changes to return {"output": "Job completed successfully"} ``` +## Stopping Individual Jobs + +A worker can process more than one job concurrently. When a single request expires or times out, the Runpod server signals the worker to stop just that request without affecting the worker's other in-progress jobs. The worker receives these signals on its heartbeat and cancels the task running the matching job, so a stopped job no longer consumes worker time. + +No handler changes are required to support this. Handlers that hold resources can perform cleanup by catching `asyncio.CancelledError` in async handlers. + ## See Also - [Worker Fitness Checks](./worker_fitness_checks.md) - Validate your worker environment at startup diff --git a/runpod/serverless/modules/rp_ping.py b/runpod/serverless/modules/rp_ping.py index 838030b2..c6a05535 100644 --- a/runpod/serverless/modules/rp_ping.py +++ b/runpod/serverless/modules/rp_ping.py @@ -11,7 +11,7 @@ from runpod.http_client import SyncClientSession from runpod.serverless.modules.rp_logger import RunPodLogger -from runpod.serverless.modules.worker_state import WORKER_ID, JobsProgress +from runpod.serverless.modules.worker_state import WORKER_ID, JobsProgress, JobsToStop from runpod.version import __version__ as runpod_version log = RunPodLogger() @@ -108,5 +108,37 @@ def _send_ping(self): log.debug( f"Heartbeat Sent | URL: {result.url} | Status: {result.status_code}" ) + + self._handle_stop_signals(result) except requests.RequestException as err: log.error(f"Ping Request Error: {err}, attempting to restart ping.") + + @staticmethod + def _handle_stop_signals(result): + """ + Records any per-job stop signals returned by the Runpod server. + + The server may include a `jobsToStop` list of request ids in the ping + response when a request expires or times out. Those ids are persisted + so the worker loop can stop the matching in-progress jobs. + """ + if result.status_code != 200 or not result.content: + return + + try: + payload = result.json() + except ValueError: + return + + if not isinstance(payload, dict): + return + + job_ids = payload.get("jobsToStop") or [] + if not job_ids: + return + + jobs_to_stop = JobsToStop() + for job_id in job_ids: + if isinstance(job_id, str): + jobs_to_stop.add(job_id) + log.debug(f"Heartbeat | Stop signal received for job {job_id}") diff --git a/runpod/serverless/modules/rp_scale.py b/runpod/serverless/modules/rp_scale.py index 8c27543e..f85d2936 100644 --- a/runpod/serverless/modules/rp_scale.py +++ b/runpod/serverless/modules/rp_scale.py @@ -12,7 +12,7 @@ from ...http_client import AsyncClientSession, ClientSession, TooManyRequests from .rp_job import get_job, handle_job from .rp_logger import RunPodLogger -from .worker_state import JobsProgress, IS_LOCAL_TEST +from .worker_state import JobsProgress, JobsToStop, IS_LOCAL_TEST log = RunPodLogger() @@ -47,6 +47,11 @@ def __init__(self, config: Dict[str, Any]): self.current_concurrency = 1 self.config = config self.job_progress = JobsProgress() # Cache the singleton instance + self.jobs_to_stop = JobsToStop() # Cache the singleton instance + + # maps in-progress job ids to their running tasks so individual jobs + # can be stopped without killing the whole worker + self.jobs_tasks: Dict[str, asyncio.Task] = {} self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency) @@ -128,8 +133,9 @@ async def run(self): # Create tasks for getting and running jobs. jobtake_task = asyncio.create_task(self.get_jobs(session)) jobrun_task = asyncio.create_task(self.run_jobs(session)) + jobstop_task = asyncio.create_task(self.stop_jobs()) - tasks = [jobtake_task, jobrun_task] + tasks = [jobtake_task, jobrun_task, jobstop_task] # Concurrently run both tasks and wait for both to finish. await asyncio.gather(*tasks) @@ -226,9 +232,10 @@ async def run_jobs(self, session: ClientSession): # Fetch as many jobs as the concurrency allows while len(tasks) < self.current_concurrency and not self.jobs_queue.empty(): job = await self.jobs_queue.get() - # Create a new task for each job and add it to the task list + # Create a new task for each job and track it by job id task = asyncio.create_task(self.handle_job(session, job)) tasks.add(task) + self.jobs_tasks[job["id"]] = task # Wait for any job to finish if tasks: @@ -250,7 +257,40 @@ async def run_jobs(self, session: ClientSession): # Ensure all remaining tasks finish before stopping - await asyncio.gather(*tasks) + await asyncio.gather(*tasks, return_exceptions=True) + + async def stop_jobs(self): + """ + Watches for stop signals and stops the matching in-progress jobs. + + Runs in an infinite loop while the worker is alive. Stop signals are + recorded by the heartbeat when the Runpod server flags a request to + be stopped (for example when it expires or times out). + """ + while self.is_alive(): + for job_id in self.jobs_to_stop.pop_all(): + await self.stop_job(job_id) + await asyncio.sleep(1) + + async def stop_job(self, job_id: str) -> bool: + """ + Stop a single in-progress job by cancelling its running task. + + Args: + job_id: The id of the job to stop. + + Returns: + True if a matching in-progress job was found and stopped, + False otherwise. + """ + task = self.jobs_tasks.get(job_id) + if task is None: + log.debug(f"JobScaler.stop_job | No in-progress job for {job_id}.") + return False + + log.info("Stopping job.", job_id) + task.cancel() + return True async def handle_job(self, session: ClientSession, job: dict): """ @@ -268,11 +308,16 @@ async def handle_job(self, session: ClientSession, job: dict): log.error(f"Error handling job: {err}", job["id"]) raise err + except asyncio.CancelledError: + log.info("Job stopped.", job["id"]) + raise + finally: # Inform Queue of a task completion self.jobs_queue.task_done() # Job is no longer in progress self.job_progress.remove(job) + self.jobs_tasks.pop(job["id"], None) log.debug("Finished Job", job["id"]) diff --git a/runpod/serverless/modules/worker_state.py b/runpod/serverless/modules/worker_state.py index b546ce02..ea1e5f4a 100644 --- a/runpod/serverless/modules/worker_state.py +++ b/runpod/serverless/modules/worker_state.py @@ -207,3 +207,104 @@ def get_job_count(self) -> int: Returns the number of jobs. """ return len(self) + + +# ---------------------------------------------------------------------------- # +# Stop Signals # +# ---------------------------------------------------------------------------- # +class JobsToStop(Set[str]): + """Track job ids that have been signalled to stop. + + The heartbeat process records stop signals received from the Runpod + server here, and the worker loop reads them to stop the matching + in-progress jobs. State is persisted so it can cross the process + boundary between the heartbeat and the worker loop. + """ + + _instance = None + _STATE_DIR = os.getcwd() + _STATE_FILE = os.path.join(_STATE_DIR, ".runpod_jobs_to_stop.pkl") + + def __new__(cls): + if JobsToStop._instance is None: + os.makedirs(cls._STATE_DIR, exist_ok=True) + JobsToStop._instance = set.__new__(cls) + set.__init__(JobsToStop._instance) + JobsToStop._instance._load_state() + return JobsToStop._instance + + def __init__(self): + # singleton, never clear data on re-init + pass + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}>: {sorted(self)}" + + def _load_state(self): + """Load stop signals from pickle file with file locking.""" + try: + if ( + os.path.exists(self._STATE_FILE) + and os.path.getsize(self._STATE_FILE) > 0 + ): + with FileLock(self._STATE_FILE + ".lock"): + with open(self._STATE_FILE, "rb") as f: + try: + loaded = pickle.load(f) + super().clear() + for job_id in loaded: + set.add(self, job_id) + except (EOFError, pickle.UnpicklingError): + log.debug( + "JobsToStop: Failed to load state file, starting empty" + ) + except FileNotFoundError: + log.debug("JobsToStop: No state file found, starting empty") + + def _save_state(self): + """Save stop signals to pickle file with atomic write and file locking.""" + try: + with FileLock(self._STATE_FILE + ".lock"): + with tempfile.NamedTemporaryFile( + dir=self._STATE_DIR, delete=False, mode="wb" + ) as temp_f: + pickle.dump(set(self), temp_f) + os.replace(temp_f.name, self._STATE_FILE) + except Exception as e: + log.error(f"Failed to save stop signal state: {e}") + + def clear(self) -> None: + super().clear() + self._save_state() + + def add(self, element: Any): + """Records a stop signal for the given job id.""" + if isinstance(element, Job): + element = element.id + + if not isinstance(element, str): + raise TypeError("Only job id strings can be added to JobsToStop.") + + result = super().add(element) + self._save_state() + return result + + def remove(self, element: Any): + """Removes a stop signal for the given job id.""" + if isinstance(element, Job): + element = element.id + + if not isinstance(element, str): + raise TypeError("Only job id strings can be removed from JobsToStop.") + + result = super().discard(element) + self._save_state() + return result + + def pop_all(self) -> Set[str]: + """Returns all pending stop signals and clears them.""" + self._load_state() + pending = set(self) + if pending: + self.clear() + return pending diff --git a/tests/test_serverless/test_modules/test_ping.py b/tests/test_serverless/test_modules/test_ping.py index c40870f8..a9eddd28 100644 --- a/tests/test_serverless/test_modules/test_ping.py +++ b/tests/test_serverless/test_modules/test_ping.py @@ -274,6 +274,76 @@ def test_send_ping_request_exception(self, mock_env, mock_worker_id, mock_sessio "Ping Request Error: Connection error, attempting to restart ping." ) + def test_send_ping_records_stop_signals(self, mock_env, mock_worker_id, mock_session, mock_jobs): + """Stop signals in the ping response are recorded for the worker loop""" + heartbeat = Heartbeat() + + mock_response = MagicMock() + mock_response.url = "https://test.com/ping/test_worker_123" + mock_response.status_code = 200 + mock_response.content = b'{"jobsToStop": ["job1", "job2"]}' + mock_response.json.return_value = {"jobsToStop": ["job1", "job2"]} + mock_session.get.return_value = mock_response + + with patch("runpod.serverless.modules.rp_ping.JobsToStop") as mock_stop: + with patch("runpod.serverless.modules.rp_ping.runpod_version", "1.0.0"): + heartbeat._send_ping() + + instance = mock_stop.return_value + instance.add.assert_any_call("job1") + instance.add.assert_any_call("job2") + + def test_send_ping_no_stop_signals(self, mock_env, mock_worker_id, mock_session, mock_jobs): + """Empty or missing stop signals do not touch the stop store""" + heartbeat = Heartbeat() + + mock_response = MagicMock() + mock_response.url = "https://test.com/ping/test_worker_123" + mock_response.status_code = 200 + mock_response.content = b"{}" + mock_response.json.return_value = {} + mock_session.get.return_value = mock_response + + with patch("runpod.serverless.modules.rp_ping.JobsToStop") as mock_stop: + with patch("runpod.serverless.modules.rp_ping.runpod_version", "1.0.0"): + heartbeat._send_ping() + + mock_stop.return_value.add.assert_not_called() + + def test_send_ping_ignores_invalid_json(self, mock_env, mock_worker_id, mock_session, mock_jobs): + """A non-JSON ping body is ignored without error""" + heartbeat = Heartbeat() + + mock_response = MagicMock() + mock_response.url = "https://test.com/ping/test_worker_123" + mock_response.status_code = 200 + mock_response.content = b"not json" + mock_response.json.side_effect = ValueError("no json") + mock_session.get.return_value = mock_response + + with patch("runpod.serverless.modules.rp_ping.JobsToStop") as mock_stop: + with patch("runpod.serverless.modules.rp_ping.runpod_version", "1.0.0"): + heartbeat._send_ping() + + mock_stop.return_value.add.assert_not_called() + + def test_send_ping_ignores_non_dict_payload(self, mock_env, mock_worker_id, mock_session, mock_jobs): + """A JSON list payload does not trigger stop handling""" + heartbeat = Heartbeat() + + mock_response = MagicMock() + mock_response.url = "https://test.com/ping/test_worker_123" + mock_response.status_code = 200 + mock_response.content = b"[]" + mock_response.json.return_value = [] + mock_session.get.return_value = mock_response + + with patch("runpod.serverless.modules.rp_ping.JobsToStop") as mock_stop: + with patch("runpod.serverless.modules.rp_ping.runpod_version", "1.0.0"): + heartbeat._send_ping() + + mock_stop.return_value.add.assert_not_called() + def test_custom_pool_connections(self, mock_env, mock_worker_id, mock_session): """Test initialization with custom pool connections and retries""" heartbeat = Heartbeat(pool_connections=20, retries=5) diff --git a/tests/test_serverless/test_modules/test_state.py b/tests/test_serverless/test_modules/test_state.py index 94772bde..86799a0a 100644 --- a/tests/test_serverless/test_modules/test_state.py +++ b/tests/test_serverless/test_modules/test_state.py @@ -6,6 +6,7 @@ from runpod.serverless.modules.worker_state import ( Job, JobsProgress, + JobsToStop, IS_LOCAL_TEST, WORKER_ID, ) @@ -222,4 +223,59 @@ async def test_file_persistence_after_clear(self): # Verify that no jobs remain assert jobs2.get_job_count() == 0, "Jobs should be cleared in persistent state" - assert jobs2.get_job_list() is None, "Job list should be None after clear" \ No newline at end of file + assert jobs2.get_job_list() is None, "Job list should be None after clear" + + +class TestJobsToStop(unittest.IsolatedAsyncioTestCase): + """Tests for JobsToStop class""" + + async def asyncSetUp(self): + self.stops = JobsToStop() + self.stops.clear() + + def test_singleton(self): + stops2 = JobsToStop() + self.assertEqual(self.stops, stops2) + + async def test_add_and_pop_all(self): + assert not len(self.stops) + + self.stops.add("job-1") + self.stops.add("job-2") + assert len(self.stops) == 2 + + pending = self.stops.pop_all() + assert pending == {"job-1", "job-2"} + assert not len(self.stops) + + async def test_add_job_object(self): + self.stops.add(Job(id="job-3")) + assert self.stops.pop_all() == {"job-3"} + + async def test_add_rejects_non_string(self): + with self.assertRaises(TypeError): + self.stops.add(123) + + async def test_remove(self): + self.stops.add("job-4") + self.stops.remove("job-4") + assert not len(self.stops) + + async def test_remove_rejects_non_string(self): + with self.assertRaises(TypeError): + self.stops.remove(123) + + async def test_repr(self): + self.stops.add("job-5") + assert "job-5" in repr(self.stops) + + async def test_pop_all_empty(self): + assert self.stops.pop_all() == set() + + async def test_state_persistence(self): + self.stops.add("persist-1") + + JobsToStop._instance = None + stops2 = JobsToStop() + + assert stops2.pop_all() == {"persist-1"} \ No newline at end of file diff --git a/tests/test_serverless/test_rp_scale.py b/tests/test_serverless/test_rp_scale.py index c1a029b9..306a8a78 100644 --- a/tests/test_serverless/test_rp_scale.py +++ b/tests/test_serverless/test_rp_scale.py @@ -238,6 +238,71 @@ async def handler(_session, _config, job): scaler.kill_worker() +@pytest.mark.asyncio +async def test_stop_job_cancels_inflight_task(job_scaler: PatchScaler): + scaler = job_scaler.scaler + job_started = asyncio.Event() + cancelled = [] + + async def handler(_session, _config, job): + job_started.set() + try: + await asyncio.sleep(10) + except asyncio.CancelledError: + cancelled.append(job["id"]) + raise + + scaler.jobs_handler = handler + scaler.current_concurrency = 1 + scaler.jobs_queue = asyncio.Queue(maxsize=1) + run_task = asyncio.create_task(scaler.run_jobs(None)) + + await scaler.jobs_queue.put(generate_job("stop-me")) + await asyncio.wait_for(job_started.wait(), timeout=2) + + assert "stop-me" in scaler.jobs_tasks + assert await scaler.stop_job("stop-me") is True + + scaler.kill_worker() + await asyncio.wait_for(run_task, timeout=2) + + assert cancelled == ["stop-me"] + assert "stop-me" not in scaler.jobs_tasks + assert job_scaler.progress.count == 0 + + scaler.kill_worker() + + +@pytest.mark.asyncio +async def test_stop_job_unknown_id_returns_false(job_scaler: PatchScaler): + scaler = job_scaler.scaler + assert await scaler.stop_job("does-not-exist") is False + scaler.kill_worker() + + +@pytest.mark.asyncio +async def test_stop_jobs_consumes_stop_signals(job_scaler: PatchScaler): + scaler = job_scaler.scaler + stopped = [] + + async def fake_stop_job(job_id): + stopped.append(job_id) + return True + + scaler.stop_job = fake_stop_job + scaler.jobs_to_stop.clear() + scaler.jobs_to_stop.add("job-a") + scaler.jobs_to_stop.add("job-b") + + stop_task = asyncio.create_task(scaler.stop_jobs()) + await asyncio.sleep(0.05) + scaler.kill_worker() + await asyncio.wait_for(stop_task, timeout=2) + + assert sorted(stopped) == ["job-a", "job-b"] + assert scaler.jobs_to_stop.pop_all() == set() + + @pytest.mark.asyncio async def test_get_jobs_feeds_workers_end_to_end(job_scaler: PatchScaler): scaler = job_scaler.scaler From e769ac115fa4800da7843c5e8ade39d9f1fa408c Mon Sep 17 00:00:00 2001 From: zeke <40004347+KAJdev@users.noreply.github.com> Date: Fri, 5 Jun 2026 14:58:10 -0700 Subject: [PATCH 2/2] refactor: deliver job stop signals over a dedicated low-latency channel --- .gitignore | 1 - docs/serverless/worker.md | 2 +- runpod/serverless/modules/rp_job.py | 56 ++++++++++ runpod/serverless/modules/rp_ping.py | 34 +----- runpod/serverless/modules/rp_scale.py | 39 ++++--- runpod/serverless/modules/worker_state.py | 101 ------------------ .../test_serverless/test_modules/test_job.py | 70 ++++++++++++ .../test_serverless/test_modules/test_ping.py | 70 ------------ .../test_modules/test_state.py | 58 +--------- tests/test_serverless/test_rp_scale.py | 35 ++++-- 10 files changed, 184 insertions(+), 282 deletions(-) diff --git a/.gitignore b/.gitignore index da5fdfab..eb63accf 100644 --- a/.gitignore +++ b/.gitignore @@ -139,7 +139,6 @@ dmypy.json .pyre/ runpod/_version.py .runpod_jobs.pkl -.runpod_jobs_to_stop.pkl *.lock benchmark_results/ diff --git a/docs/serverless/worker.md b/docs/serverless/worker.md index 2f9f33db..9fdf8b72 100644 --- a/docs/serverless/worker.md +++ b/docs/serverless/worker.md @@ -60,7 +60,7 @@ For more complex operations where you are downloading files or making changes to ## Stopping Individual Jobs -A worker can process more than one job concurrently. When a single request expires or times out, the Runpod server signals the worker to stop just that request without affecting the worker's other in-progress jobs. The worker receives these signals on its heartbeat and cancels the task running the matching job, so a stopped job no longer consumes worker time. +A worker can process more than one job concurrently. When a single request is cancelled, expires, or times out, the Runpod server signals the worker to stop just that request without affecting the worker's other in-progress jobs. The worker long-polls a dedicated stop channel so these signals arrive with low latency, and it cancels the task running the matching job, so a stopped job no longer consumes worker time. No handler changes are required to support this. Handlers that hold resources can perform cleanup by catching `asyncio.CancelledError` in async handlers. diff --git a/runpod/serverless/modules/rp_job.py b/runpod/serverless/modules/rp_job.py index 614c45e5..e85ad146 100644 --- a/runpod/serverless/modules/rp_job.py +++ b/runpod/serverless/modules/rp_job.py @@ -26,6 +26,62 @@ job_progress = JobsProgress() +def _job_stop_url() -> Optional[str]: + """ + Prepare the URL for the worker's dedicated stop channel. + + Derived from the job-take URL so it points at the same endpoint and worker. + Returns None when the job-take URL is not in the expected form. + """ + base_url = JOB_GET_URL.split("?")[0] + if "/job-take/" not in base_url: + return None + return base_url.replace("/job-take/", "/job-stop/") + + +async def get_stop_signals(session: ClientSession) -> List[str]: + """ + Long-poll the dedicated stop channel for request ids the worker should stop. + + The server holds the request open until a stop signal is available or the + poll times out, so cancellations and timeouts reach the worker without + waiting for the next heartbeat. + + Returns: + A list of request ids to stop. Empty when the poll returned no signals. + """ + stop_url = _job_stop_url() + if not stop_url: + return [] + + async with session.get(stop_url) as response: + if response.status == 204: + return [] + + if response.status == 429: + raise TooManyRequests( + response.request_info, + response.history, + status=response.status, + message=response.reason, + ) + + response.raise_for_status() + + if response.content_type != "application/json": + return [] + + try: + payload = await response.json() + except (aiohttp.ContentTypeError, ValueError): + return [] + + if not isinstance(payload, dict): + return [] + + return [job_id for job_id in payload.get("jobsToStop", []) if isinstance(job_id, str)] + + def _job_get_url(batch_size: int = 1): """ Prepare the URL for making a 'get' request to the serverless API (sls). diff --git a/runpod/serverless/modules/rp_ping.py b/runpod/serverless/modules/rp_ping.py index c6a05535..838030b2 100644 --- a/runpod/serverless/modules/rp_ping.py +++ b/runpod/serverless/modules/rp_ping.py @@ -11,7 +11,7 @@ from runpod.http_client import SyncClientSession from runpod.serverless.modules.rp_logger import RunPodLogger -from runpod.serverless.modules.worker_state import WORKER_ID, JobsProgress, JobsToStop +from runpod.serverless.modules.worker_state import WORKER_ID, JobsProgress from runpod.version import __version__ as runpod_version log = RunPodLogger() @@ -108,37 +108,5 @@ def _send_ping(self): log.debug( f"Heartbeat Sent | URL: {result.url} | Status: {result.status_code}" ) - - self._handle_stop_signals(result) except requests.RequestException as err: log.error(f"Ping Request Error: {err}, attempting to restart ping.") - - @staticmethod - def _handle_stop_signals(result): - """ - Records any per-job stop signals returned by the Runpod server. - - The server may include a `jobsToStop` list of request ids in the ping - response when a request expires or times out. Those ids are persisted - so the worker loop can stop the matching in-progress jobs. - """ - if result.status_code != 200 or not result.content: - return - - try: - payload = result.json() - except ValueError: - return - - if not isinstance(payload, dict): - return - - job_ids = payload.get("jobsToStop") or [] - if not job_ids: - return - - jobs_to_stop = JobsToStop() - for job_id in job_ids: - if isinstance(job_id, str): - jobs_to_stop.add(job_id) - log.debug(f"Heartbeat | Stop signal received for job {job_id}") diff --git a/runpod/serverless/modules/rp_scale.py b/runpod/serverless/modules/rp_scale.py index f85d2936..9e03a0d3 100644 --- a/runpod/serverless/modules/rp_scale.py +++ b/runpod/serverless/modules/rp_scale.py @@ -10,9 +10,9 @@ from typing import Any, Dict, Set from ...http_client import AsyncClientSession, ClientSession, TooManyRequests -from .rp_job import get_job, handle_job +from .rp_job import get_job, get_stop_signals, handle_job from .rp_logger import RunPodLogger -from .worker_state import JobsProgress, JobsToStop, IS_LOCAL_TEST +from .worker_state import JobsProgress, IS_LOCAL_TEST log = RunPodLogger() @@ -47,12 +47,13 @@ def __init__(self, config: Dict[str, Any]): self.current_concurrency = 1 self.config = config self.job_progress = JobsProgress() # Cache the singleton instance - self.jobs_to_stop = JobsToStop() # Cache the singleton instance # maps in-progress job ids to their running tasks so individual jobs # can be stopped without killing the whole worker self.jobs_tasks: Dict[str, asyncio.Task] = {} + self.stop_signals_fetcher = get_stop_signals + self.jobs_queue = asyncio.Queue(maxsize=self.current_concurrency) self.concurrency_modifier = _default_concurrency_modifier @@ -76,6 +77,9 @@ def __init__(self, config: Dict[str, Any]): if jobs_handler := self.config.get("jobs_handler"): self.jobs_handler = jobs_handler + if stop_signals_fetcher := self.config.get("stop_signals_fetcher"): + self.stop_signals_fetcher = stop_signals_fetcher + async def set_scale(self): self.current_concurrency = self.concurrency_modifier(self.current_concurrency) @@ -133,7 +137,7 @@ async def run(self): # Create tasks for getting and running jobs. jobtake_task = asyncio.create_task(self.get_jobs(session)) jobrun_task = asyncio.create_task(self.run_jobs(session)) - jobstop_task = asyncio.create_task(self.stop_jobs()) + jobstop_task = asyncio.create_task(self.monitor_stop_signals(session)) tasks = [jobtake_task, jobrun_task, jobstop_task] @@ -259,18 +263,29 @@ async def run_jobs(self, session: ClientSession): # Ensure all remaining tasks finish before stopping await asyncio.gather(*tasks, return_exceptions=True) - async def stop_jobs(self): + async def monitor_stop_signals(self, session: ClientSession): """ - Watches for stop signals and stops the matching in-progress jobs. + Long-polls the dedicated stop channel and stops signalled jobs. - Runs in an infinite loop while the worker is alive. Stop signals are - recorded by the heartbeat when the Runpod server flags a request to - be stopped (for example when it expires or times out). + Runs in an infinite loop while the worker is alive. The Runpod server + signals a request to be stopped (for example when it is cancelled or + times out) and this loop stops just that in-progress job, leaving the + worker's other jobs running. """ while self.is_alive(): - for job_id in self.jobs_to_stop.pop_all(): - await self.stop_job(job_id) - await asyncio.sleep(1) + try: + job_ids = await self.stop_signals_fetcher(session) + for job_id in job_ids: + await self.stop_job(job_id) + except TooManyRequests: + await asyncio.sleep(5) # debounce + except asyncio.CancelledError: + raise + except Exception as error: + log.debug(f"JobScaler.monitor_stop_signals | Error: {error}.") + await asyncio.sleep(1) # don't spin on persistent errors + finally: + await asyncio.sleep(0) async def stop_job(self, job_id: str) -> bool: """ diff --git a/runpod/serverless/modules/worker_state.py b/runpod/serverless/modules/worker_state.py index ea1e5f4a..b546ce02 100644 --- a/runpod/serverless/modules/worker_state.py +++ b/runpod/serverless/modules/worker_state.py @@ -207,104 +207,3 @@ def get_job_count(self) -> int: Returns the number of jobs. """ return len(self) - - -# ---------------------------------------------------------------------------- # -# Stop Signals # -# ---------------------------------------------------------------------------- # -class JobsToStop(Set[str]): - """Track job ids that have been signalled to stop. - - The heartbeat process records stop signals received from the Runpod - server here, and the worker loop reads them to stop the matching - in-progress jobs. State is persisted so it can cross the process - boundary between the heartbeat and the worker loop. - """ - - _instance = None - _STATE_DIR = os.getcwd() - _STATE_FILE = os.path.join(_STATE_DIR, ".runpod_jobs_to_stop.pkl") - - def __new__(cls): - if JobsToStop._instance is None: - os.makedirs(cls._STATE_DIR, exist_ok=True) - JobsToStop._instance = set.__new__(cls) - set.__init__(JobsToStop._instance) - JobsToStop._instance._load_state() - return JobsToStop._instance - - def __init__(self): - # singleton, never clear data on re-init - pass - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}>: {sorted(self)}" - - def _load_state(self): - """Load stop signals from pickle file with file locking.""" - try: - if ( - os.path.exists(self._STATE_FILE) - and os.path.getsize(self._STATE_FILE) > 0 - ): - with FileLock(self._STATE_FILE + ".lock"): - with open(self._STATE_FILE, "rb") as f: - try: - loaded = pickle.load(f) - super().clear() - for job_id in loaded: - set.add(self, job_id) - except (EOFError, pickle.UnpicklingError): - log.debug( - "JobsToStop: Failed to load state file, starting empty" - ) - except FileNotFoundError: - log.debug("JobsToStop: No state file found, starting empty") - - def _save_state(self): - """Save stop signals to pickle file with atomic write and file locking.""" - try: - with FileLock(self._STATE_FILE + ".lock"): - with tempfile.NamedTemporaryFile( - dir=self._STATE_DIR, delete=False, mode="wb" - ) as temp_f: - pickle.dump(set(self), temp_f) - os.replace(temp_f.name, self._STATE_FILE) - except Exception as e: - log.error(f"Failed to save stop signal state: {e}") - - def clear(self) -> None: - super().clear() - self._save_state() - - def add(self, element: Any): - """Records a stop signal for the given job id.""" - if isinstance(element, Job): - element = element.id - - if not isinstance(element, str): - raise TypeError("Only job id strings can be added to JobsToStop.") - - result = super().add(element) - self._save_state() - return result - - def remove(self, element: Any): - """Removes a stop signal for the given job id.""" - if isinstance(element, Job): - element = element.id - - if not isinstance(element, str): - raise TypeError("Only job id strings can be removed from JobsToStop.") - - result = super().discard(element) - self._save_state() - return result - - def pop_all(self) -> Set[str]: - """Returns all pending stop signals and clears them.""" - self._load_state() - pending = set(self) - if pending: - self.clear() - return pending diff --git a/tests/test_serverless/test_modules/test_job.py b/tests/test_serverless/test_modules/test_job.py index 7abcefe3..6a1df3de 100644 --- a/tests/test_serverless/test_modules/test_job.py +++ b/tests/test_serverless/test_modules/test_job.py @@ -150,6 +150,76 @@ async def test_get_job_exception(self): self.assertEqual(str(context.exception), "Unexpected error") +class TestGetStopSignals(IsolatedAsyncioTestCase): + """Tests for the get_stop_signals function.""" + + STOP_TAKE_URL = "http://mock.url/v2/ep/job-take/pod?gpu=x" + + def test_job_stop_url_derived_from_job_take(self): + with patch("runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL): + assert rp_job._job_stop_url() == "http://mock.url/v2/ep/job-stop/pod" + + def test_job_stop_url_none_when_not_job_take(self): + with patch("runpod.serverless.modules.rp_job.JOB_GET_URL", "http://mock.url/other"): + assert rp_job._job_stop_url() is None + + async def test_get_stop_signals_200(self): + response = Mock(ClientResponse) + response.status = 200 + response.content_type = "application/json" + response.json = make_mocked_coro(return_value={"jobsToStop": ["a", "b", 5]}) + + with patch("aiohttp.ClientSession") as mock_session, patch( + "runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL + ): + mock_session.get.return_value.__aenter__.return_value = response + result = await rp_job.get_stop_signals(mock_session) + self.assertEqual(result, ["a", "b"]) + + async def test_get_stop_signals_204(self): + response = Mock(ClientResponse) + response.status = 204 + + with patch("aiohttp.ClientSession") as mock_session, patch( + "runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL + ): + mock_session.get.return_value.__aenter__.return_value = response + result = await rp_job.get_stop_signals(mock_session) + self.assertEqual(result, []) + + async def test_get_stop_signals_429(self): + response = Mock(ClientResponse) + response.status = 429 + + with patch("aiohttp.ClientSession") as mock_session, patch( + "runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL + ): + mock_session.get.return_value.__aenter__.return_value = response + with self.assertRaises(TooManyRequests): + await rp_job.get_stop_signals(mock_session) + + async def test_get_stop_signals_no_url(self): + with patch("aiohttp.ClientSession") as mock_session, patch( + "runpod.serverless.modules.rp_job.JOB_GET_URL", "http://mock.url/other" + ): + result = await rp_job.get_stop_signals(mock_session) + self.assertEqual(result, []) + mock_session.get.assert_not_called() + + async def test_get_stop_signals_non_dict_payload(self): + response = Mock(ClientResponse) + response.status = 200 + response.content_type = "application/json" + response.json = make_mocked_coro(return_value=["not", "a", "dict"]) + + with patch("aiohttp.ClientSession") as mock_session, patch( + "runpod.serverless.modules.rp_job.JOB_GET_URL", self.STOP_TAKE_URL + ): + mock_session.get.return_value.__aenter__.return_value = response + result = await rp_job.get_stop_signals(mock_session) + self.assertEqual(result, []) + + class TestRunJob(IsolatedAsyncioTestCase): """Tests the run_job function""" diff --git a/tests/test_serverless/test_modules/test_ping.py b/tests/test_serverless/test_modules/test_ping.py index a9eddd28..c40870f8 100644 --- a/tests/test_serverless/test_modules/test_ping.py +++ b/tests/test_serverless/test_modules/test_ping.py @@ -274,76 +274,6 @@ def test_send_ping_request_exception(self, mock_env, mock_worker_id, mock_sessio "Ping Request Error: Connection error, attempting to restart ping." ) - def test_send_ping_records_stop_signals(self, mock_env, mock_worker_id, mock_session, mock_jobs): - """Stop signals in the ping response are recorded for the worker loop""" - heartbeat = Heartbeat() - - mock_response = MagicMock() - mock_response.url = "https://test.com/ping/test_worker_123" - mock_response.status_code = 200 - mock_response.content = b'{"jobsToStop": ["job1", "job2"]}' - mock_response.json.return_value = {"jobsToStop": ["job1", "job2"]} - mock_session.get.return_value = mock_response - - with patch("runpod.serverless.modules.rp_ping.JobsToStop") as mock_stop: - with patch("runpod.serverless.modules.rp_ping.runpod_version", "1.0.0"): - heartbeat._send_ping() - - instance = mock_stop.return_value - instance.add.assert_any_call("job1") - instance.add.assert_any_call("job2") - - def test_send_ping_no_stop_signals(self, mock_env, mock_worker_id, mock_session, mock_jobs): - """Empty or missing stop signals do not touch the stop store""" - heartbeat = Heartbeat() - - mock_response = MagicMock() - mock_response.url = "https://test.com/ping/test_worker_123" - mock_response.status_code = 200 - mock_response.content = b"{}" - mock_response.json.return_value = {} - mock_session.get.return_value = mock_response - - with patch("runpod.serverless.modules.rp_ping.JobsToStop") as mock_stop: - with patch("runpod.serverless.modules.rp_ping.runpod_version", "1.0.0"): - heartbeat._send_ping() - - mock_stop.return_value.add.assert_not_called() - - def test_send_ping_ignores_invalid_json(self, mock_env, mock_worker_id, mock_session, mock_jobs): - """A non-JSON ping body is ignored without error""" - heartbeat = Heartbeat() - - mock_response = MagicMock() - mock_response.url = "https://test.com/ping/test_worker_123" - mock_response.status_code = 200 - mock_response.content = b"not json" - mock_response.json.side_effect = ValueError("no json") - mock_session.get.return_value = mock_response - - with patch("runpod.serverless.modules.rp_ping.JobsToStop") as mock_stop: - with patch("runpod.serverless.modules.rp_ping.runpod_version", "1.0.0"): - heartbeat._send_ping() - - mock_stop.return_value.add.assert_not_called() - - def test_send_ping_ignores_non_dict_payload(self, mock_env, mock_worker_id, mock_session, mock_jobs): - """A JSON list payload does not trigger stop handling""" - heartbeat = Heartbeat() - - mock_response = MagicMock() - mock_response.url = "https://test.com/ping/test_worker_123" - mock_response.status_code = 200 - mock_response.content = b"[]" - mock_response.json.return_value = [] - mock_session.get.return_value = mock_response - - with patch("runpod.serverless.modules.rp_ping.JobsToStop") as mock_stop: - with patch("runpod.serverless.modules.rp_ping.runpod_version", "1.0.0"): - heartbeat._send_ping() - - mock_stop.return_value.add.assert_not_called() - def test_custom_pool_connections(self, mock_env, mock_worker_id, mock_session): """Test initialization with custom pool connections and retries""" heartbeat = Heartbeat(pool_connections=20, retries=5) diff --git a/tests/test_serverless/test_modules/test_state.py b/tests/test_serverless/test_modules/test_state.py index 86799a0a..94772bde 100644 --- a/tests/test_serverless/test_modules/test_state.py +++ b/tests/test_serverless/test_modules/test_state.py @@ -6,7 +6,6 @@ from runpod.serverless.modules.worker_state import ( Job, JobsProgress, - JobsToStop, IS_LOCAL_TEST, WORKER_ID, ) @@ -223,59 +222,4 @@ async def test_file_persistence_after_clear(self): # Verify that no jobs remain assert jobs2.get_job_count() == 0, "Jobs should be cleared in persistent state" - assert jobs2.get_job_list() is None, "Job list should be None after clear" - - -class TestJobsToStop(unittest.IsolatedAsyncioTestCase): - """Tests for JobsToStop class""" - - async def asyncSetUp(self): - self.stops = JobsToStop() - self.stops.clear() - - def test_singleton(self): - stops2 = JobsToStop() - self.assertEqual(self.stops, stops2) - - async def test_add_and_pop_all(self): - assert not len(self.stops) - - self.stops.add("job-1") - self.stops.add("job-2") - assert len(self.stops) == 2 - - pending = self.stops.pop_all() - assert pending == {"job-1", "job-2"} - assert not len(self.stops) - - async def test_add_job_object(self): - self.stops.add(Job(id="job-3")) - assert self.stops.pop_all() == {"job-3"} - - async def test_add_rejects_non_string(self): - with self.assertRaises(TypeError): - self.stops.add(123) - - async def test_remove(self): - self.stops.add("job-4") - self.stops.remove("job-4") - assert not len(self.stops) - - async def test_remove_rejects_non_string(self): - with self.assertRaises(TypeError): - self.stops.remove(123) - - async def test_repr(self): - self.stops.add("job-5") - assert "job-5" in repr(self.stops) - - async def test_pop_all_empty(self): - assert self.stops.pop_all() == set() - - async def test_state_persistence(self): - self.stops.add("persist-1") - - JobsToStop._instance = None - stops2 = JobsToStop() - - assert stops2.pop_all() == {"persist-1"} \ No newline at end of file + assert jobs2.get_job_list() is None, "Job list should be None after clear" \ No newline at end of file diff --git a/tests/test_serverless/test_rp_scale.py b/tests/test_serverless/test_rp_scale.py index 306a8a78..39069db4 100644 --- a/tests/test_serverless/test_rp_scale.py +++ b/tests/test_serverless/test_rp_scale.py @@ -281,7 +281,7 @@ async def test_stop_job_unknown_id_returns_false(job_scaler: PatchScaler): @pytest.mark.asyncio -async def test_stop_jobs_consumes_stop_signals(job_scaler: PatchScaler): +async def test_monitor_stop_signals_stops_jobs(job_scaler: PatchScaler): scaler = job_scaler.scaler stopped = [] @@ -289,18 +289,39 @@ async def fake_stop_job(job_id): stopped.append(job_id) return True + async def fetcher(_session): + if not stopped: + return ["job-a", "job-b"] + return [] + scaler.stop_job = fake_stop_job - scaler.jobs_to_stop.clear() - scaler.jobs_to_stop.add("job-a") - scaler.jobs_to_stop.add("job-b") + scaler.stop_signals_fetcher = fetcher - stop_task = asyncio.create_task(scaler.stop_jobs()) + monitor_task = asyncio.create_task(scaler.monitor_stop_signals(AsyncMock())) await asyncio.sleep(0.05) scaler.kill_worker() - await asyncio.wait_for(stop_task, timeout=2) + await asyncio.wait_for(monitor_task, timeout=2) assert sorted(stopped) == ["job-a", "job-b"] - assert scaler.jobs_to_stop.pop_all() == set() + + +@pytest.mark.asyncio +async def test_monitor_stop_signals_survives_errors(job_scaler: PatchScaler): + scaler = job_scaler.scaler + calls = {"value": 0} + + async def fetcher(_session): + calls["value"] += 1 + raise RuntimeError("boom") + + scaler.stop_signals_fetcher = fetcher + + monitor_task = asyncio.create_task(scaler.monitor_stop_signals(AsyncMock())) + await asyncio.sleep(0.05) + scaler.kill_worker() + await asyncio.wait_for(monitor_task, timeout=2) + + assert calls["value"] >= 1 @pytest.mark.asyncio