Skip to content

Commit b4d66a1

Browse files
committed
fix: improve reliability of multithreaded tests
1 parent 5445542 commit b4d66a1

File tree

1 file changed

+26
-16
lines changed

1 file changed

+26
-16
lines changed

time_based_storage/tests/test_thread_safe.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,22 @@ def test_wait_for_data(self):
7272
"""Test the wait_for_data functionality with multiple threads."""
7373
received_data = set() # Use a set to track unique values
7474
event = threading.Event()
75-
75+
data_count = 10
76+
7677
def consumer():
77-
while not event.is_set():
78-
if self.storage.wait_for_data(timeout=0.1):
78+
while not event.is_set() or len(received_data) < data_count:
79+
if self.storage.wait_for_data(timeout=0.2): # Increased timeout for reliability
7980
values = self.storage.get_all()
8081
received_data.update(values) # Add unique values to the set
82+
if len(received_data) >= data_count:
83+
break
8184

8285
def producer():
83-
for i in range(10):
86+
for i in range(data_count):
8487
self.storage.add(datetime.now(), i)
8588
time.sleep(0.1)
89+
# Wait a bit to ensure consumer can process the last item
90+
time.sleep(0.5)
8691
event.set()
8792

8893
consumer_thread = threading.Thread(target=consumer)
@@ -91,11 +96,11 @@ def producer():
9196
consumer_thread.start()
9297
producer_thread.start()
9398

94-
producer_thread.join()
95-
consumer_thread.join()
99+
producer_thread.join(timeout=5) # Add timeout to avoid hanging
100+
consumer_thread.join(timeout=5) # Add timeout to avoid hanging
96101

97-
self.assertEqual(len(received_data), 10)
98-
self.assertEqual(set(range(10)), received_data) # Verify we got all expected values
102+
self.assertEqual(len(received_data), data_count)
103+
self.assertEqual(set(range(data_count)), received_data) # Verify we got all expected values
99104

100105

101106
class TestThreadSafeTimeBasedStorageHeap(unittest.TestCase):
@@ -165,17 +170,22 @@ def test_wait_for_data(self):
165170
"""Test the wait_for_data functionality with multiple threads."""
166171
received_data = set() # Use a set to track unique values
167172
event = threading.Event()
168-
173+
data_count = 10
174+
169175
def consumer():
170-
while not event.is_set():
171-
if self.storage.wait_for_data(timeout=0.1):
176+
while not event.is_set() or len(received_data) < data_count:
177+
if self.storage.wait_for_data(timeout=0.2): # Increased timeout for reliability
172178
values = self.storage.get_all()
173179
received_data.update(values) # Add unique values to the set
180+
if len(received_data) >= data_count:
181+
break
174182

175183
def producer():
176-
for i in range(10):
184+
for i in range(data_count):
177185
self.storage.add(datetime.now(), i)
178186
time.sleep(0.1)
187+
# Wait a bit to ensure consumer can process the last item
188+
time.sleep(0.5)
179189
event.set()
180190

181191
consumer_thread = threading.Thread(target=consumer)
@@ -184,11 +194,11 @@ def producer():
184194
consumer_thread.start()
185195
producer_thread.start()
186196

187-
producer_thread.join()
188-
consumer_thread.join()
197+
producer_thread.join(timeout=5) # Add timeout to avoid hanging
198+
consumer_thread.join(timeout=5) # Add timeout to avoid hanging
189199

190-
self.assertEqual(len(received_data), 10)
191-
self.assertEqual(set(range(10)), received_data) # Verify we got all expected values
200+
self.assertEqual(len(received_data), data_count)
201+
self.assertEqual(set(range(data_count)), received_data) # Verify we got all expected values
192202

193203

194204
if __name__ == "__main__":

0 commit comments

Comments
 (0)