@@ -72,17 +72,22 @@ def test_wait_for_data(self):
7272 """Test the wait_for_data functionality with multiple threads."""
7373 received_data = set () # Use a set to track unique values
7474 event = threading .Event ()
75-
75+ data_count = 10
76+
7677 def consumer ():
77- while not event .is_set ():
78- if self .storage .wait_for_data (timeout = 0.1 ):
78+ while not event .is_set () or len ( received_data ) < data_count :
79+ if self .storage .wait_for_data (timeout = 0.2 ): # Increased timeout for reliability
7980 values = self .storage .get_all ()
8081 received_data .update (values ) # Add unique values to the set
82+ if len (received_data ) >= data_count :
83+ break
8184
8285 def producer ():
83- for i in range (10 ):
86+ for i in range (data_count ):
8487 self .storage .add (datetime .now (), i )
8588 time .sleep (0.1 )
89+ # Wait a bit to ensure consumer can process the last item
90+ time .sleep (0.5 )
8691 event .set ()
8792
8893 consumer_thread = threading .Thread (target = consumer )
@@ -91,11 +96,11 @@ def producer():
9196 consumer_thread .start ()
9297 producer_thread .start ()
9398
94- producer_thread .join ()
95- consumer_thread .join ()
99+ producer_thread .join (timeout = 5 ) # Add timeout to avoid hanging
100+ consumer_thread .join (timeout = 5 ) # Add timeout to avoid hanging
96101
97- self .assertEqual (len (received_data ), 10 )
98- self .assertEqual (set (range (10 )), received_data ) # Verify we got all expected values
102+ self .assertEqual (len (received_data ), data_count )
103+ self .assertEqual (set (range (data_count )), received_data ) # Verify we got all expected values
99104
100105
101106class TestThreadSafeTimeBasedStorageHeap (unittest .TestCase ):
@@ -165,17 +170,22 @@ def test_wait_for_data(self):
165170 """Test the wait_for_data functionality with multiple threads."""
166171 received_data = set () # Use a set to track unique values
167172 event = threading .Event ()
168-
173+ data_count = 10
174+
169175 def consumer ():
170- while not event .is_set ():
171- if self .storage .wait_for_data (timeout = 0.1 ):
176+ while not event .is_set () or len ( received_data ) < data_count :
177+ if self .storage .wait_for_data (timeout = 0.2 ): # Increased timeout for reliability
172178 values = self .storage .get_all ()
173179 received_data .update (values ) # Add unique values to the set
180+ if len (received_data ) >= data_count :
181+ break
174182
175183 def producer ():
176- for i in range (10 ):
184+ for i in range (data_count ):
177185 self .storage .add (datetime .now (), i )
178186 time .sleep (0.1 )
187+ # Wait a bit to ensure consumer can process the last item
188+ time .sleep (0.5 )
179189 event .set ()
180190
181191 consumer_thread = threading .Thread (target = consumer )
@@ -184,11 +194,11 @@ def producer():
184194 consumer_thread .start ()
185195 producer_thread .start ()
186196
187- producer_thread .join ()
188- consumer_thread .join ()
197+ producer_thread .join (timeout = 5 ) # Add timeout to avoid hanging
198+ consumer_thread .join (timeout = 5 ) # Add timeout to avoid hanging
189199
190- self .assertEqual (len (received_data ), 10 )
191- self .assertEqual (set (range (10 )), received_data ) # Verify we got all expected values
200+ self .assertEqual (len (received_data ), data_count )
201+ self .assertEqual (set (range (data_count )), received_data ) # Verify we got all expected values
192202
193203
194204if __name__ == "__main__" :
0 commit comments