Skip to content

Commit 13963a4

Browse files
dtsongclaude
andauthored
fix: remove CPython internals from PriorityThreadPoolExecutor (#19)
* fix: remove CPython internals from PriorityThreadPoolExecutor (#4) Replace the _WorkItem/_work_queue subclass hack with a dispatcher-wrapper pattern: a PriorityQueue feeds work in priority order to a standard ThreadPoolExecutor via a daemon thread. This eliminates all CPython internal dependencies and fixes compatibility with Python 3.14+. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: harden PriorityThreadPoolExecutor error handling - Guard _chain_future against cancelled dest and internal exceptions - Add try/except in dispatcher to propagate errors to proxy futures - Reject submit() after shutdown() with RuntimeError - Use float('inf') sentinel priority to never preempt queued work - Add 30s timeout to dispatcher join to prevent deadlock on crash - Add tests for all new error paths Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: replace fragile dir() check with explicit proxy init in dispatcher Initialize proxy = None before try block and check `is not None` instead of using `"proxy" in dir()` which doesn't reliably reflect local variables and retains stale references across loop iterations. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 770e01a commit 13963a4

2 files changed

Lines changed: 272 additions & 26 deletions

File tree

data_diff/thread_utils.py

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,81 @@
11
import itertools
2+
import threading
23
from collections import deque
34
from collections.abc import Callable, Iterable, Iterator
4-
from concurrent.futures import ThreadPoolExecutor
5-
from concurrent.futures.thread import _WorkItem
5+
from concurrent.futures import Future, ThreadPoolExecutor
66
from queue import PriorityQueue
77
from time import sleep
88
from typing import Any
99

1010
import attrs
1111

12-
13-
class AutoPriorityQueue(PriorityQueue):
14-
"""Overrides PriorityQueue to automatically get the priority from _WorkItem.kwargs
15-
16-
We also assign a unique id for each item, to avoid making comparisons on _WorkItem.
17-
As a side effect, items with the same priority are returned FIFO.
18-
"""
19-
20-
_counter = itertools.count().__next__
21-
22-
def put(self, item: _WorkItem | None, block=True, timeout=None) -> None:
23-
priority = item.kwargs.pop("priority") if item is not None else 0
24-
super().put((-priority, self._counter(), item), block, timeout)
25-
26-
def get(self, block=True, timeout=None) -> _WorkItem | None:
27-
_p, _c, work_item = super().get(block, timeout)
28-
return work_item
12+
_SENTINEL = object()
13+
14+
15+
def _chain_future(source: Future, dest: Future) -> None:
16+
"""Propagate the outcome (result, exception, or cancellation) from source to dest."""
17+
if dest.cancelled():
18+
return
19+
try:
20+
if source.cancelled():
21+
dest.cancel()
22+
elif exc := source.exception():
23+
dest.set_exception(exc)
24+
else:
25+
dest.set_result(source.result())
26+
except Exception as exc:
27+
try:
28+
dest.set_exception(exc)
29+
except Exception:
30+
pass
2931

3032

31-
class PriorityThreadPoolExecutor(ThreadPoolExecutor):
32-
"""Overrides ThreadPoolExecutor to use AutoPriorityQueue
33+
class PriorityThreadPoolExecutor:
34+
"""Thread pool that executes tasks in priority order.
3335
34-
XXX WARNING: Might break in future versions of Python
36+
Uses a dispatcher thread to pull work from a PriorityQueue and
37+
submit it to a standard ThreadPoolExecutor. No CPython internals.
3538
"""
3639

37-
def __init__(self, *args) -> None:
38-
super().__init__(*args)
39-
self._work_queue = AutoPriorityQueue()
40+
def __init__(self, max_workers: int | None = None) -> None:
41+
self._inner = ThreadPoolExecutor(max_workers=max_workers)
42+
self._queue: PriorityQueue = PriorityQueue()
43+
self._counter = itertools.count().__next__
44+
self._shutdown = False
45+
self._dispatcher = threading.Thread(target=self._dispatch, daemon=True)
46+
self._dispatcher.start()
47+
48+
def _dispatch(self) -> None:
49+
while True:
50+
proxy = None
51+
try:
52+
_priority, _count, item = self._queue.get()
53+
if item is _SENTINEL:
54+
break
55+
fn, args, kwargs, proxy = item
56+
inner_future = self._inner.submit(fn, *args, **kwargs)
57+
inner_future.add_done_callback(lambda f, p=proxy: _chain_future(f, p))
58+
except Exception as exc:
59+
if proxy is not None and not proxy.done():
60+
try:
61+
proxy.set_exception(exc)
62+
except Exception:
63+
pass
64+
65+
def submit(self, fn, /, *args, priority: int = 0, **kwargs) -> Future:
66+
if self._shutdown:
67+
raise RuntimeError("cannot submit after shutdown")
68+
proxy = Future()
69+
self._queue.put((-priority, self._counter(), (fn, args, kwargs, proxy)))
70+
return proxy
71+
72+
def shutdown(self, wait: bool = True) -> None:
73+
self._shutdown = True
74+
self._queue.put((float("inf"), self._counter(), _SENTINEL))
75+
self._dispatcher.join(timeout=30)
76+
if self._dispatcher.is_alive():
77+
raise RuntimeError("PriorityThreadPoolExecutor dispatcher did not shut down within 30s")
78+
self._inner.shutdown(wait=wait)
4079

4180

4281
@attrs.define(frozen=False, init=False)
@@ -47,7 +86,7 @@ class ThreadedYielder(Iterable):
4786
Priority for the iterator can be provided via the keyword argument 'priority'. (higher runs first)
4887
"""
4988

50-
_pool: ThreadPoolExecutor
89+
_pool: PriorityThreadPoolExecutor
5190
_futures: deque
5291
_yield: deque = attrs.field(alias="_yield") # Python keyword!
5392
_exception: None = None

tests/test_thread_utils.py

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
import threading
2+
from concurrent.futures import Future
3+
4+
import pytest
5+
6+
from data_diff.thread_utils import (
7+
PriorityThreadPoolExecutor,
8+
ThreadedYielder,
9+
_chain_future,
10+
)
11+
12+
13+
class TestPriorityThreadPoolExecutor:
14+
def test_priority_ordering(self):
15+
"""Higher-priority tasks execute before lower-priority ones."""
16+
gate = threading.Event()
17+
results = []
18+
19+
pool = PriorityThreadPoolExecutor(max_workers=1)
20+
21+
# Block the single worker so tasks queue up
22+
pool.submit(lambda: gate.wait(), priority=0)
23+
24+
# Submit tasks with different priorities while worker is blocked
25+
for p in [1, 3, 2]:
26+
pool.submit(lambda p=p: results.append(p), priority=p)
27+
28+
# Release the gate — queued tasks run in priority order
29+
gate.set()
30+
pool.shutdown(wait=True)
31+
32+
assert results == [3, 2, 1]
33+
34+
def test_fifo_within_same_priority(self):
35+
"""Equal-priority tasks run in submission order (FIFO)."""
36+
gate = threading.Event()
37+
results = []
38+
39+
pool = PriorityThreadPoolExecutor(max_workers=1)
40+
pool.submit(lambda: gate.wait(), priority=0)
41+
42+
for i in range(5):
43+
pool.submit(lambda i=i: results.append(i), priority=1)
44+
45+
gate.set()
46+
pool.shutdown(wait=True)
47+
48+
assert results == [0, 1, 2, 3, 4]
49+
50+
def test_submit_returns_future_with_result(self):
51+
"""submit() returns a Future that resolves to the function's return value."""
52+
pool = PriorityThreadPoolExecutor(max_workers=2)
53+
future = pool.submit(lambda: 42)
54+
assert future.result(timeout=5) == 42
55+
pool.shutdown()
56+
57+
def test_submit_returns_future_with_exception(self):
58+
"""Exceptions in submitted functions propagate through the Future."""
59+
pool = PriorityThreadPoolExecutor(max_workers=2)
60+
future = pool.submit(lambda: 1 / 0)
61+
with pytest.raises(ZeroDivisionError):
62+
future.result(timeout=5)
63+
pool.shutdown()
64+
65+
def test_concurrent_submit(self):
66+
"""Submitting from multiple threads is safe."""
67+
pool = PriorityThreadPoolExecutor(max_workers=4)
68+
results = []
69+
lock = threading.Lock()
70+
71+
def task(n):
72+
with lock:
73+
results.append(n)
74+
75+
threads = []
76+
for i in range(20):
77+
t = threading.Thread(target=lambda i=i: pool.submit(task, i, priority=0))
78+
threads.append(t)
79+
t.start()
80+
81+
for t in threads:
82+
t.join()
83+
84+
pool.shutdown(wait=True)
85+
assert sorted(results) == list(range(20))
86+
87+
def test_shutdown_with_pending_work(self):
88+
"""Shutdown completes all pending work before returning."""
89+
results = []
90+
pool = PriorityThreadPoolExecutor(max_workers=1)
91+
92+
for i in range(10):
93+
pool.submit(lambda i=i: results.append(i), priority=0)
94+
95+
pool.shutdown(wait=True)
96+
assert sorted(results) == list(range(10))
97+
98+
def test_no_cpython_internals_imported(self):
99+
"""Verify _WorkItem is not imported."""
100+
import data_diff.thread_utils as mod
101+
102+
assert not hasattr(mod, "_WorkItem")
103+
104+
def test_submit_forwards_args_and_kwargs(self):
105+
"""submit() correctly forwards positional and keyword arguments."""
106+
pool = PriorityThreadPoolExecutor(max_workers=1)
107+
future = pool.submit(lambda a, b, c=None: (a, b, c), 1, 2, c=3)
108+
assert future.result(timeout=5) == (1, 2, 3)
109+
pool.shutdown()
110+
111+
def test_submit_after_shutdown_raises(self):
112+
"""submit() raises RuntimeError after shutdown() is called."""
113+
pool = PriorityThreadPoolExecutor(max_workers=1)
114+
pool.shutdown()
115+
with pytest.raises(RuntimeError, match="cannot submit after shutdown"):
116+
pool.submit(lambda: None)
117+
118+
def test_shutdown_drains_high_priority_work(self):
119+
"""Sentinel does not preempt queued higher-priority work."""
120+
gate = threading.Event()
121+
results = []
122+
123+
pool = PriorityThreadPoolExecutor(max_workers=1)
124+
pool.submit(lambda: gate.wait(), priority=0)
125+
126+
for i in range(5):
127+
pool.submit(lambda i=i: results.append(i), priority=10)
128+
129+
gate.set()
130+
pool.shutdown(wait=True)
131+
assert sorted(results) == list(range(5))
132+
133+
134+
class TestChainFuture:
135+
def test_propagates_result(self):
136+
"""Chains result from source to dest."""
137+
source = Future()
138+
dest = Future()
139+
source.set_result(42)
140+
_chain_future(source, dest)
141+
assert dest.result() == 42
142+
143+
def test_propagates_exception(self):
144+
"""Chains exception from source to dest."""
145+
source = Future()
146+
dest = Future()
147+
source.set_exception(ValueError("oops"))
148+
_chain_future(source, dest)
149+
with pytest.raises(ValueError, match="oops"):
150+
dest.result()
151+
152+
def test_skips_cancelled_dest(self):
153+
"""Does not raise if dest was already cancelled."""
154+
source = Future()
155+
dest = Future()
156+
dest.cancel()
157+
source.set_result(42)
158+
_chain_future(source, dest) # should not raise
159+
160+
161+
class TestThreadedYielder:
162+
def test_basic_yield(self):
163+
"""ThreadedYielder collects results from submitted functions."""
164+
ty = ThreadedYielder(max_workers=2)
165+
ty.submit(lambda: [1, 2, 3])
166+
ty.submit(lambda: [4, 5, 6])
167+
168+
result = list(ty)
169+
assert sorted(result) == [1, 2, 3, 4, 5, 6]
170+
171+
def test_priority_behavior(self):
172+
"""Higher-priority iterators get scheduled first."""
173+
gate = threading.Event()
174+
ty = ThreadedYielder(max_workers=1)
175+
176+
# Block the worker
177+
def wait_gate():
178+
gate.wait()
179+
return []
180+
181+
ty.submit(wait_gate, priority=0)
182+
183+
# Queue tasks with different priorities
184+
ty.submit(lambda: ["low"], priority=1)
185+
ty.submit(lambda: ["high"], priority=3)
186+
ty.submit(lambda: ["mid"], priority=2)
187+
188+
gate.set()
189+
result = list(ty)
190+
# High-priority tasks should execute first
191+
assert result == ["high", "mid", "low"]
192+
193+
def test_yield_list_mode(self):
194+
"""yield_list=True appends entire results rather than extending."""
195+
ty = ThreadedYielder(max_workers=1, yield_list=True)
196+
ty.submit(lambda: [1, 2, 3])
197+
198+
result = list(ty)
199+
assert result == [[1, 2, 3]]
200+
201+
def test_exception_propagation(self):
202+
"""Exceptions in submitted functions propagate through iteration."""
203+
ty = ThreadedYielder(max_workers=1)
204+
ty.submit(lambda: (_ for _ in ()).throw(ValueError("boom")))
205+
206+
with pytest.raises(ValueError, match="boom"):
207+
list(ty)

0 commit comments

Comments
 (0)