Skip to content

Commit 7fffca1

Browse files
FIX: streamify was appending StatusStreamingCallback directly to the shared settings.callbacks list (#9073)
* Tests for streamify status message settings leak * fix: streamify should deepcopy settings before appending a callback * fix: shallow copy instead of deepcopy * fix: New test for concurrent status message provider in 2 threads, with original callback verification --------- Co-authored-by: Emile Riberdy <emileriberdy@gmail.com>
1 parent 31c5062 commit 7fffca1

File tree

2 files changed

+85
-1
lines changed

2 files changed

+85
-1
lines changed

dspy/streaming/streamify.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ async def use_streaming():
161161
elif not iscoroutinefunction(program):
162162
program = asyncify(program)
163163

164-
callbacks = settings.callbacks
164+
callbacks = list(settings.callbacks)
165165
status_streaming_callback = StatusStreamingCallback(status_message_provider)
166166
if not any(isinstance(c, StatusStreamingCallback) for c in callbacks):
167167
callbacks.append(status_streaming_callback)

tests/streaming/test_streaming.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,90 @@ def module_start_status_message(self, instance, inputs):
135135
assert status_messages[2].message == "Predict starting!"
136136

137137

138+
@pytest.mark.anyio
139+
async def test_concurrent_status_message_providers():
140+
class MyProgram(dspy.Module):
141+
def __init__(self):
142+
self.generate_question = dspy.Tool(lambda x: f"What color is the {x}?", name="generate_question")
143+
self.predict = dspy.Predict("question->answer")
144+
145+
def __call__(self, x: str):
146+
question = self.generate_question(x=x)
147+
return self.predict(question=question)
148+
149+
class MyStatusMessageProvider1(StatusMessageProvider):
150+
def tool_start_status_message(self, instance, inputs):
151+
return "Provider1: Tool starting!"
152+
153+
def tool_end_status_message(self, outputs):
154+
return "Provider1: Tool finished!"
155+
156+
def module_start_status_message(self, instance, inputs):
157+
if isinstance(instance, dspy.Predict):
158+
return "Provider1: Predict starting!"
159+
160+
class MyStatusMessageProvider2(StatusMessageProvider):
161+
def tool_start_status_message(self, instance, inputs):
162+
return "Provider2: Tool starting!"
163+
164+
def tool_end_status_message(self, outputs):
165+
return "Provider2: Tool finished!"
166+
167+
def module_start_status_message(self, instance, inputs):
168+
if isinstance(instance, dspy.Predict):
169+
return "Provider2: Predict starting!"
170+
171+
# Store the original callbacks to verify they're not modified
172+
original_callbacks = list(dspy.settings.callbacks)
173+
174+
lm = dspy.utils.DummyLM([{"answer": "red"}, {"answer": "blue"}, {"answer": "green"}, {"answer": "yellow"}])
175+
176+
# Results storage for each thread
177+
results = {}
178+
179+
async def run_with_provider1():
180+
with dspy.context(lm=lm):
181+
program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider1())
182+
output = program("sky")
183+
184+
status_messages = []
185+
async for value in output:
186+
if isinstance(value, StatusMessage):
187+
status_messages.append(value.message)
188+
189+
results["provider1"] = status_messages
190+
191+
async def run_with_provider2():
192+
with dspy.context(lm=lm):
193+
program = dspy.streamify(MyProgram(), status_message_provider=MyStatusMessageProvider2())
194+
output = program("ocean")
195+
196+
status_messages = []
197+
async for value in output:
198+
if isinstance(value, StatusMessage):
199+
status_messages.append(value.message)
200+
201+
results["provider2"] = status_messages
202+
203+
# Run both tasks concurrently
204+
await asyncio.gather(run_with_provider1(), run_with_provider2())
205+
206+
# Verify provider1 got its expected messages
207+
assert len(results["provider1"]) == 3
208+
assert results["provider1"][0] == "Provider1: Tool starting!"
209+
assert results["provider1"][1] == "Provider1: Tool finished!"
210+
assert results["provider1"][2] == "Provider1: Predict starting!"
211+
212+
# Verify provider2 got its expected messages
213+
assert len(results["provider2"]) == 3
214+
assert results["provider2"][0] == "Provider2: Tool starting!"
215+
assert results["provider2"][1] == "Provider2: Tool finished!"
216+
assert results["provider2"][2] == "Provider2: Predict starting!"
217+
218+
# Verify that the global callbacks were not modified
219+
assert dspy.settings.callbacks == original_callbacks
220+
221+
138222
@pytest.mark.llm_call
139223
@pytest.mark.anyio
140224
async def test_stream_listener_chat_adapter(lm_for_test):

0 commit comments

Comments
 (0)