-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy paththread_pool.py
More file actions
160 lines (126 loc) · 5.83 KB
/
thread_pool.py
File metadata and controls
160 lines (126 loc) · 5.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
thread_pool.py - Custom Fixed-Size Thread Pool
Provides a pool of worker threads that pull tasks off a shared queue. This
avoids the overhead of spawning a new OS thread per request (expensive for
short-lived connections) while keeping the implementation simple enough to
reason about.
Concurrency model
-----------------
Producer (main thread) → queue.Queue → N worker threads
Each worker loops forever waiting on Queue.get(). When it receives None (the
"poison pill" sentinel), it knows the pool is shutting down and exits cleanly.
We use blocking get() without a timeout deliberately: workers should sleep
(not spin-wait) when idle, which is exactly what a blocking get achieves at
zero CPU cost.
"""
import logging
import queue
import threading
from typing import Callable, Optional
logger = logging.getLogger(__name__)
# A task is a callable that takes no arguments; the socket is captured in a
# closure created by the server when a connection is accepted.
Task = Callable[[], None]
# The poison-pill sentinel value that signals workers to exit.
_STOP_SENTINEL: None = None
class ThreadPool:
"""
A fixed-size pool of daemon worker threads backed by a Queue.
Args:
num_workers: Number of threads to create. A good default is 10 for
a development server; production would tune this to the
expected concurrency level.
"""
def __init__(self, num_workers: int = 10) -> None:
if num_workers < 1:
raise ValueError(f"num_workers must be >= 1, got {num_workers}")
self._num_workers: int = num_workers
# unbounded queue — in production you would cap this to avoid
# unlimited memory growth under extreme load.
self._task_queue: queue.Queue = queue.Queue()
self._workers: list[threading.Thread] = []
self._started: bool = False
# Tracks how many workers are actively executing a task right now.
# Protected by _workers_lock so reads from any thread are consistent.
self._active_workers: int = 0
self._workers_lock: threading.Lock = threading.Lock()
# ── Lifecycle ─────────────────────────────────────────────────────────
def start(self) -> None:
"""Spawn all worker threads and begin accepting tasks."""
if self._started:
logger.warning("ThreadPool.start() called on an already-running pool")
return
for i in range(self._num_workers):
thread = threading.Thread(
target=self._worker_loop,
name=f"worker-{i}",
# Daemon threads are automatically killed when the main thread
# exits, which acts as a safety net if shutdown() is skipped.
daemon=True,
)
thread.start()
self._workers.append(thread)
self._started = True
logger.info("ThreadPool started with %d workers", self._num_workers)
def submit(self, task: Task) -> None:
"""
Enqueue a task for execution by the next available worker.
This call returns immediately; the task runs asynchronously.
"""
self._task_queue.put(task)
@property
def active_workers(self) -> int:
"""Number of worker threads currently executing a task (not idle)."""
with self._workers_lock:
return self._active_workers
@property
def queue_depth(self) -> int:
"""
Number of tasks waiting in the queue (not yet picked up by a worker).
A persistently non-zero value means the pool is undersized for the
current load — the first thing to check when diagnosing latency spikes.
"""
return self._task_queue.qsize()
def shutdown(self, wait: bool = True) -> None:
"""
Signal all workers to exit by enqueuing one poison pill per thread.
Args:
wait: If True, block until every worker thread has terminated.
Pass False during tests or when you need a fast exit.
"""
logger.info("ThreadPool shutting down (wait=%s)…", wait)
# One None per worker guarantees each thread receives exactly one stop
# signal even if some threads are idle and some are processing tasks.
for _ in self._workers:
self._task_queue.put(_STOP_SENTINEL)
if wait:
for thread in self._workers:
thread.join()
logger.info("ThreadPool shutdown complete")
# ── Internal ──────────────────────────────────────────────────────────
def _worker_loop(self) -> None:
"""
Infinite loop run by each worker thread.
Pulls tasks off the queue one at a time. Exceptions inside a task are
caught here so a single badly-behaved handler cannot kill the thread
and permanently reduce pool capacity.
"""
thread_name = threading.current_thread().name
logger.debug("%s started", thread_name)
while True:
task: Optional[Task] = self._task_queue.get()
if task is _STOP_SENTINEL:
logger.debug("%s received stop sentinel, exiting", thread_name)
self._task_queue.task_done()
break
try:
with self._workers_lock:
self._active_workers += 1
task()
except Exception:
logger.exception("%s: unhandled exception in task", thread_name)
finally:
with self._workers_lock:
self._active_workers -= 1
self._task_queue.task_done()
logger.debug("%s exited", thread_name)