Skip to content

Commit cb5faa1

Browse files
author
Emma Ai
committed
add test on thread safety
1 parent 4e90a70 commit cb5faa1

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

numexpr/tests/test_numexpr.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,7 @@ def run(self):
12011201
test.join()
12021202

12031203
def test_multithread(self):
1204+
12041205
import threading
12051206

12061207
# Running evaluate() from multiple threads shouldn't crash
@@ -1218,6 +1219,77 @@ def work(n):
12181219
for t in threads:
12191220
t.join()
12201221

1222+
def test_thread_safety(self):
1223+
"""
1224+
Expected output
1225+
1226+
When not safe (before the pr this test is commited)
1227+
AssertionError: Thread-0 failed: result does not match expected
1228+
1229+
When safe (after the pr this test is commited)
1230+
Should pass without failure
1231+
"""
1232+
import threading
1233+
import time
1234+
1235+
barrier = threading.Barrier(4)
1236+
1237+
# Function that each thread will run with different expressions
1238+
def thread_function(a_value, b_value, expression, expected_result, results, index):
1239+
validate(expression, local_dict={"a": a_value, "b": b_value})
1240+
# Wait for all threads to reach this point
1241+
# such that they all set _numexpr_last
1242+
barrier.wait()
1243+
1244+
# Simulate some work or a context switch delay
1245+
time.sleep(0.1)
1246+
1247+
result = re_evaluate(local_dict={"a": a_value, "b": b_value})
1248+
results[index] = np.array_equal(result, expected_result)
1249+
1250+
def test_thread_safety_with_numexpr():
1251+
num_threads = 4
1252+
array_size = 1000000
1253+
1254+
expressions = [
1255+
"a + b",
1256+
"a - b",
1257+
"a * b",
1258+
"a / b"
1259+
]
1260+
1261+
a_value = [np.full(array_size, i + 1) for i in range(num_threads)]
1262+
b_value = [np.full(array_size, (i + 1) * 2) for i in range(num_threads)]
1263+
1264+
expected_results = [
1265+
a_value[i] + b_value[i] if expr == "a + b" else
1266+
a_value[i] - b_value[i] if expr == "a - b" else
1267+
a_value[i] * b_value[i] if expr == "a * b" else
1268+
a_value[i] / b_value[i] if expr == "a / b" else None
1269+
for i, expr in enumerate(expressions)
1270+
]
1271+
1272+
results = [None] * num_threads
1273+
threads = []
1274+
1275+
# Create and start threads with different expressions
1276+
for i in range(num_threads):
1277+
thread = threading.Thread(
1278+
target=thread_function,
1279+
args=(a_value[i], b_value[i], expressions[i], expected_results[i], results, i)
1280+
)
1281+
threads.append(thread)
1282+
thread.start()
1283+
1284+
for thread in threads:
1285+
thread.join()
1286+
1287+
for i in range(num_threads):
1288+
if not results[i]:
1289+
self.fail(f"Thread-{i} failed: result does not match expected")
1290+
1291+
test_thread_safety_with_numexpr()
1292+
12211293

12221294
# The worker function for the subprocess (needs to be here because Windows
12231295
# has problems pickling nested functions with the multiprocess module :-/)

0 commit comments

Comments
 (0)