@@ -1201,6 +1201,7 @@ def run(self):
12011201 test .join ()
12021202
12031203 def test_multithread (self ):
1204+
12041205 import threading
12051206
12061207 # Running evaluate() from multiple threads shouldn't crash
@@ -1218,6 +1219,77 @@ def work(n):
12181219 for t in threads :
12191220 t .join ()
12201221
1222+ def test_thread_safety (self ):
1223+ """
1224+ Expected output
1225+
1226+ When not safe (before the pr this test is commited)
1227+ AssertionError: Thread-0 failed: result does not match expected
1228+
1229+ When safe (after the pr this test is commited)
1230+ Should pass without failure
1231+ """
1232+ import threading
1233+ import time
1234+
1235+ barrier = threading .Barrier (4 )
1236+
1237+ # Function that each thread will run with different expressions
1238+ def thread_function (a_value , b_value , expression , expected_result , results , index ):
1239+ validate (expression , local_dict = {"a" : a_value , "b" : b_value })
1240+ # Wait for all threads to reach this point
1241+ # such that they all set _numexpr_last
1242+ barrier .wait ()
1243+
1244+ # Simulate some work or a context switch delay
1245+ time .sleep (0.1 )
1246+
1247+ result = re_evaluate (local_dict = {"a" : a_value , "b" : b_value })
1248+ results [index ] = np .array_equal (result , expected_result )
1249+
1250+ def test_thread_safety_with_numexpr ():
1251+ num_threads = 4
1252+ array_size = 1000000
1253+
1254+ expressions = [
1255+ "a + b" ,
1256+ "a - b" ,
1257+ "a * b" ,
1258+ "a / b"
1259+ ]
1260+
1261+ a_value = [np .full (array_size , i + 1 ) for i in range (num_threads )]
1262+ b_value = [np .full (array_size , (i + 1 ) * 2 ) for i in range (num_threads )]
1263+
1264+ expected_results = [
1265+ a_value [i ] + b_value [i ] if expr == "a + b" else
1266+ a_value [i ] - b_value [i ] if expr == "a - b" else
1267+ a_value [i ] * b_value [i ] if expr == "a * b" else
1268+ a_value [i ] / b_value [i ] if expr == "a / b" else None
1269+ for i , expr in enumerate (expressions )
1270+ ]
1271+
1272+ results = [None ] * num_threads
1273+ threads = []
1274+
1275+ # Create and start threads with different expressions
1276+ for i in range (num_threads ):
1277+ thread = threading .Thread (
1278+ target = thread_function ,
1279+ args = (a_value [i ], b_value [i ], expressions [i ], expected_results [i ], results , i )
1280+ )
1281+ threads .append (thread )
1282+ thread .start ()
1283+
1284+ for thread in threads :
1285+ thread .join ()
1286+
1287+ for i in range (num_threads ):
1288+ if not results [i ]:
1289+ self .fail (f"Thread-{ i } failed: result does not match expected" )
1290+
1291+ test_thread_safety_with_numexpr ()
1292+
12211293
12221294# The worker function for the subprocess (needs to be here because Windows
12231295# has problems pickling nested functions with the multiprocess module :-/)
0 commit comments