@@ -135,6 +135,90 @@ def module_start_status_message(self, instance, inputs):
135135 assert status_messages [2 ].message == "Predict starting!"
136136
137137
138+ @pytest .mark .anyio
139+ async def test_concurrent_status_message_providers ():
140+ class MyProgram (dspy .Module ):
141+ def __init__ (self ):
142+ self .generate_question = dspy .Tool (lambda x : f"What color is the { x } ?" , name = "generate_question" )
143+ self .predict = dspy .Predict ("question->answer" )
144+
145+ def __call__ (self , x : str ):
146+ question = self .generate_question (x = x )
147+ return self .predict (question = question )
148+
149+ class MyStatusMessageProvider1 (StatusMessageProvider ):
150+ def tool_start_status_message (self , instance , inputs ):
151+ return "Provider1: Tool starting!"
152+
153+ def tool_end_status_message (self , outputs ):
154+ return "Provider1: Tool finished!"
155+
156+ def module_start_status_message (self , instance , inputs ):
157+ if isinstance (instance , dspy .Predict ):
158+ return "Provider1: Predict starting!"
159+
160+ class MyStatusMessageProvider2 (StatusMessageProvider ):
161+ def tool_start_status_message (self , instance , inputs ):
162+ return "Provider2: Tool starting!"
163+
164+ def tool_end_status_message (self , outputs ):
165+ return "Provider2: Tool finished!"
166+
167+ def module_start_status_message (self , instance , inputs ):
168+ if isinstance (instance , dspy .Predict ):
169+ return "Provider2: Predict starting!"
170+
171+ # Store the original callbacks to verify they're not modified
172+ original_callbacks = list (dspy .settings .callbacks )
173+
174+ lm = dspy .utils .DummyLM ([{"answer" : "red" }, {"answer" : "blue" }, {"answer" : "green" }, {"answer" : "yellow" }])
175+
176+ # Results storage for each thread
177+ results = {}
178+
179+ async def run_with_provider1 ():
180+ with dspy .context (lm = lm ):
181+ program = dspy .streamify (MyProgram (), status_message_provider = MyStatusMessageProvider1 ())
182+ output = program ("sky" )
183+
184+ status_messages = []
185+ async for value in output :
186+ if isinstance (value , StatusMessage ):
187+ status_messages .append (value .message )
188+
189+ results ["provider1" ] = status_messages
190+
191+ async def run_with_provider2 ():
192+ with dspy .context (lm = lm ):
193+ program = dspy .streamify (MyProgram (), status_message_provider = MyStatusMessageProvider2 ())
194+ output = program ("ocean" )
195+
196+ status_messages = []
197+ async for value in output :
198+ if isinstance (value , StatusMessage ):
199+ status_messages .append (value .message )
200+
201+ results ["provider2" ] = status_messages
202+
203+ # Run both tasks concurrently
204+ await asyncio .gather (run_with_provider1 (), run_with_provider2 ())
205+
206+ # Verify provider1 got its expected messages
207+ assert len (results ["provider1" ]) == 3
208+ assert results ["provider1" ][0 ] == "Provider1: Tool starting!"
209+ assert results ["provider1" ][1 ] == "Provider1: Tool finished!"
210+ assert results ["provider1" ][2 ] == "Provider1: Predict starting!"
211+
212+ # Verify provider2 got its expected messages
213+ assert len (results ["provider2" ]) == 3
214+ assert results ["provider2" ][0 ] == "Provider2: Tool starting!"
215+ assert results ["provider2" ][1 ] == "Provider2: Tool finished!"
216+ assert results ["provider2" ][2 ] == "Provider2: Predict starting!"
217+
218+ # Verify that the global callbacks were not modified
219+ assert dspy .settings .callbacks == original_callbacks
220+
221+
138222@pytest .mark .llm_call
139223@pytest .mark .anyio
140224async def test_stream_listener_chat_adapter (lm_for_test ):
0 commit comments