Skip to content

Commit 5373343

Browse files
committed
Make PyUnstable_Code_SetExtra/GetExtra thread-safe
1 parent 0bbdb4e commit 5373343

File tree

3 files changed

+197
-28
lines changed

3 files changed

+197
-28
lines changed

Lib/test/test_free_threading/test_code.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,39 @@
11
import unittest
22

3+
try:
4+
import ctypes
5+
except ImportError:
6+
ctypes = None
7+
38
from threading import Thread
49
from unittest import TestCase
510

611
from test.support import threading_helper
12+
from test.support.threading_helper import run_concurrently
13+
14+
if ctypes is not None:
15+
capi = ctypes.pythonapi
16+
17+
freefunc = ctypes.CFUNCTYPE(None, ctypes.c_voidp)
18+
19+
RequestCodeExtraIndex = capi.PyUnstable_Eval_RequestCodeExtraIndex
20+
RequestCodeExtraIndex.argtypes = (freefunc,)
21+
RequestCodeExtraIndex.restype = ctypes.c_ssize_t
22+
23+
SetExtra = capi.PyUnstable_Code_SetExtra
24+
SetExtra.argtypes = (ctypes.py_object, ctypes.c_ssize_t, ctypes.c_voidp)
25+
SetExtra.restype = ctypes.c_int
26+
27+
GetExtra = capi.PyUnstable_Code_GetExtra
28+
GetExtra.argtypes = (
29+
ctypes.py_object,
30+
ctypes.c_ssize_t,
31+
ctypes.POINTER(ctypes.c_voidp),
32+
)
33+
GetExtra.restype = ctypes.c_int
34+
35+
NTHREADS = 20
36+
737

838
@threading_helper.requires_working_threading()
939
class TestCode(TestCase):
@@ -25,6 +55,78 @@ def run_in_thread():
2555
for thread in threads:
2656
thread.join()
2757

58+
@unittest.skipUnless(ctypes, "ctypes is required")
59+
def test_request_code_extra_index_concurrent(self):
60+
"""Test concurrent calls to RequestCodeExtraIndex"""
61+
results = []
62+
63+
def worker():
64+
idx = RequestCodeExtraIndex(freefunc(0))
65+
self.assertGreaterEqual(idx, 0)
66+
results.append(idx)
67+
68+
run_concurrently(worker_func=worker, nthreads=NTHREADS)
69+
70+
# Every thread must get a unique index.
71+
self.assertEqual(len(results), NTHREADS)
72+
self.assertEqual(len(set(results)), NTHREADS)
73+
74+
@unittest.skipUnless(ctypes, "ctypes is required")
75+
def test_code_extra_all_ops_concurrent(self):
76+
"""Test concurrent RequestCodeExtraIndex + SetExtra + GetExtra"""
77+
LOOP = 100
78+
79+
def f():
80+
pass
81+
82+
code = f.__code__
83+
84+
def worker():
85+
idx = RequestCodeExtraIndex(freefunc(0))
86+
self.assertGreaterEqual(idx, 0)
87+
88+
for i in range(LOOP):
89+
SetExtra(code, idx, ctypes.c_voidp(i + 1))
90+
91+
for _ in range(LOOP):
92+
extra = ctypes.c_voidp()
93+
GetExtra(code, idx, extra)
94+
# The slot was set by this thread, so the value must
95+
# be the last one written.
96+
self.assertEqual(extra.value, LOOP)
97+
98+
run_concurrently(worker_func=worker, nthreads=NTHREADS)
99+
100+
@unittest.skipUnless(ctypes, "ctypes is required")
101+
def test_code_extra_set_get_concurrent(self):
102+
"""Test concurrent SetExtra + GetExtra on a shared index"""
103+
LOOP = 100
104+
105+
def f():
106+
pass
107+
108+
code = f.__code__
109+
110+
idx = RequestCodeExtraIndex(freefunc(0))
111+
self.assertGreaterEqual(idx, 0)
112+
113+
def worker():
114+
for i in range(LOOP):
115+
SetExtra(code, idx, ctypes.c_voidp(i + 1))
116+
117+
for _ in range(LOOP):
118+
extra = ctypes.c_voidp()
119+
GetExtra(code, idx, extra)
120+
# Value is set by any writer thread.
121+
self.assertTrue(1 <= extra.value <= LOOP)
122+
123+
run_concurrently(worker_func=worker, nthreads=NTHREADS)
124+
125+
# Every thread's last write is LOOP, so the final value must be LOOP.
126+
extra = ctypes.c_voidp()
127+
GetExtra(code, idx, extra)
128+
self.assertEqual(extra.value, LOOP)
129+
28130

29131
if __name__ == "__main__":
30132
unittest.main()

Objects/codeobject.c

Lines changed: 77 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,6 +1575,12 @@ typedef struct {
15751575
} _PyCodeObjectExtra;
15761576

15771577

1578+
static inline size_t
1579+
code_extra_size(Py_ssize_t n)
1580+
{
1581+
return sizeof(_PyCodeObjectExtra) + (n - 1) * sizeof(void *);
1582+
}
1583+
15781584
int
15791585
PyUnstable_Code_GetExtra(PyObject *code, Py_ssize_t index, void **extra)
15801586
{
@@ -1583,15 +1589,19 @@ PyUnstable_Code_GetExtra(PyObject *code, Py_ssize_t index, void **extra)
15831589
return -1;
15841590
}
15851591

1586-
PyCodeObject *o = (PyCodeObject*) code;
1587-
_PyCodeObjectExtra *co_extra = (_PyCodeObjectExtra*) o->co_extra;
1592+
PyCodeObject *o = (PyCodeObject *)code;
1593+
*extra = NULL;
15881594

1589-
if (co_extra == NULL || index < 0 || co_extra->ce_size <= index) {
1590-
*extra = NULL;
1595+
if (index < 0) {
15911596
return 0;
15921597
}
15931598

1594-
*extra = co_extra->ce_extras[index];
1599+
// Lock-free read; pairs with release store in SetExtra.
1600+
_PyCodeObjectExtra *co_extra = FT_ATOMIC_LOAD_PTR_ACQUIRE(o->co_extra);
1601+
if (co_extra != NULL && index < co_extra->ce_size) {
1602+
*extra = co_extra->ce_extras[index];
1603+
}
1604+
15951605
return 0;
15961606
}
15971607

@@ -1601,40 +1611,81 @@ PyUnstable_Code_SetExtra(PyObject *code, Py_ssize_t index, void *extra)
16011611
{
16021612
PyInterpreterState *interp = _PyInterpreterState_GET();
16031613

1604-
if (!PyCode_Check(code) || index < 0 ||
1605-
index >= interp->co_extra_user_count) {
1614+
// co_extra_user_count increases monotonically and is published with a
1615+
// release store, so once an index is valid it remains valid.
1616+
Py_ssize_t user_count = FT_ATOMIC_LOAD_SSIZE_ACQUIRE(
1617+
interp->co_extra_user_count);
1618+
1619+
if (!PyCode_Check(code) || index < 0 || index >= user_count) {
16061620
PyErr_BadInternalCall();
16071621
return -1;
16081622
}
16091623

1610-
PyCodeObject *o = (PyCodeObject*) code;
1611-
_PyCodeObjectExtra *co_extra = (_PyCodeObjectExtra *) o->co_extra;
1624+
PyCodeObject *o = (PyCodeObject *) code;
1625+
int res = 0;
16121626

1613-
if (co_extra == NULL || co_extra->ce_size <= index) {
1614-
Py_ssize_t i = (co_extra == NULL ? 0 : co_extra->ce_size);
1615-
co_extra = PyMem_Realloc(
1616-
co_extra,
1617-
sizeof(_PyCodeObjectExtra) +
1618-
(interp->co_extra_user_count-1) * sizeof(void*));
1619-
if (co_extra == NULL) {
1620-
return -1;
1621-
}
1622-
for (; i < interp->co_extra_user_count; i++) {
1623-
co_extra->ce_extras[i] = NULL;
1624-
}
1625-
co_extra->ce_size = interp->co_extra_user_count;
1626-
o->co_extra = co_extra;
1627+
Py_BEGIN_CRITICAL_SECTION(o);
1628+
1629+
_PyCodeObjectExtra *old_extra = (_PyCodeObjectExtra *) o->co_extra;
1630+
Py_ssize_t old_size = (old_extra == NULL) ? 0 : old_extra->ce_size;
1631+
1632+
// user_count > index is checked above.
1633+
Py_ssize_t new_size = old_size > index ? old_size : user_count;
1634+
assert(new_size > 0 && new_size > index);
1635+
1636+
// Free-threaded builds require copy-on-write to avoid mutating
1637+
// co_extra while lock-free readers in GetExtra may be traversing it.
1638+
// GIL builds could realloc in place, but SetExtra is called rarely
1639+
// and co_extra is small, so use the same path for simplicity.
1640+
_PyCodeObjectExtra *co_extra = PyMem_Malloc(code_extra_size(new_size));
1641+
if (co_extra == NULL) {
1642+
PyErr_NoMemory();
1643+
res = -1;
1644+
goto done;
16271645
}
16281646

1629-
if (co_extra->ce_extras[index] != NULL) {
1647+
co_extra->ce_size = new_size;
1648+
1649+
// Copy existing extras from the old buffer.
1650+
if (old_size > 0) {
1651+
memcpy(co_extra->ce_extras, old_extra->ce_extras,
1652+
old_size * sizeof(void *));
1653+
}
1654+
1655+
// NULL-initialize new slots.
1656+
for (Py_ssize_t i = old_size; i < new_size; i++) {
1657+
co_extra->ce_extras[i] = NULL;
1658+
}
1659+
1660+
if (old_extra != NULL && index < old_size &&
1661+
old_extra->ce_extras[index] != NULL)
1662+
{
1663+
// Free the old extra value if a free function was registered.
1664+
// We assume the caller ensures no other thread is concurrently
1665+
// using the old value.
16301666
freefunc free = interp->co_extra_freefuncs[index];
16311667
if (free != NULL) {
1632-
free(co_extra->ce_extras[index]);
1668+
free(old_extra->ce_extras[index]);
16331669
}
16341670
}
16351671

16361672
co_extra->ce_extras[index] = extra;
1637-
return 0;
1673+
1674+
// Publish pointer and slot writes to lock-free readers.
1675+
FT_ATOMIC_STORE_PTR_RELEASE(o->co_extra, co_extra);
1676+
1677+
if (old_extra != NULL) {
1678+
#ifdef Py_GIL_DISABLED
1679+
// Defer container free for lock-free readers.
1680+
_PyMem_FreeDelayed(old_extra, code_extra_size(old_size));
1681+
#else
1682+
PyMem_Free(old_extra);
1683+
#endif
1684+
}
1685+
1686+
done:;
1687+
Py_END_CRITICAL_SECTION();
1688+
return res;
16381689
}
16391690

16401691

Python/ceval.c

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3493,11 +3493,27 @@ PyUnstable_Eval_RequestCodeExtraIndex(freefunc free)
34933493
PyInterpreterState *interp = _PyInterpreterState_GET();
34943494
Py_ssize_t new_index;
34953495

3496-
if (interp->co_extra_user_count == MAX_CO_EXTRA_USERS - 1) {
3496+
#ifdef Py_GIL_DISABLED
3497+
struct _py_code_state *state = &interp->code_state;
3498+
FT_MUTEX_LOCK(&state->mutex);
3499+
#endif
3500+
3501+
if (interp->co_extra_user_count >= MAX_CO_EXTRA_USERS - 1) {
3502+
#ifdef Py_GIL_DISABLED
3503+
FT_MUTEX_UNLOCK(&state->mutex);
3504+
#endif
34973505
return -1;
34983506
}
3499-
new_index = interp->co_extra_user_count++;
3507+
3508+
new_index = interp->co_extra_user_count;
35003509
interp->co_extra_freefuncs[new_index] = free;
3510+
3511+
// Publish freefuncs[new_index] before making the index visible.
3512+
FT_ATOMIC_STORE_SSIZE_RELEASE(interp->co_extra_user_count, new_index + 1);
3513+
3514+
#ifdef Py_GIL_DISABLED
3515+
FT_MUTEX_UNLOCK(&state->mutex);
3516+
#endif
35013517
return new_index;
35023518
}
35033519

0 commit comments

Comments
 (0)