Skip to content

Commit a55fcf4

Browse files
authored
added prompt ack event and updated setPrompt (#13)
* added prompt ack event and updated setPrompt
1 parent a7697fc commit a55fcf4

9 files changed

Lines changed: 193 additions & 11 deletions

File tree

decart/lipsync/client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424

2525
class RealtimeLipsyncClient:
26-
2726
DECART_LIPSYNC_ENDPOINT = "/router/lipsync/ws"
2827
VIDEO_FPS = 25
2928

decart/realtime/client.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Callable, Optional
2+
import asyncio
23
import logging
34
import uuid
45
from aiortc import MediaStreamTrack
@@ -81,7 +82,23 @@ def _emit_error(self, error: DecartSDKError) -> None:
8182
async def set_prompt(self, prompt: str, enrich: bool = True) -> None:
8283
if not prompt or not prompt.strip():
8384
raise InvalidInputError("Prompt cannot be empty")
84-
await self._manager.send_message(PromptMessage(type="prompt", prompt=prompt))
85+
86+
event, result = self._manager.register_prompt_wait(prompt)
87+
88+
try:
89+
await self._manager.send_message(
90+
PromptMessage(type="prompt", prompt=prompt, enhance_prompt=enrich)
91+
)
92+
93+
try:
94+
await asyncio.wait_for(event.wait(), timeout=15.0)
95+
except asyncio.TimeoutError:
96+
raise DecartSDKError("Prompt acknowledgment timed out")
97+
98+
if not result["success"]:
99+
raise DecartSDKError(result["error"] or "Prompt failed")
100+
finally:
101+
self._manager.unregister_prompt_wait(prompt)
85102

86103
def is_connected(self) -> bool:
87104
return self._manager.is_connected()

decart/realtime/messages.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Literal, Union, Annotated
1+
from typing import Literal, Optional, Union, Annotated
22
from pydantic import BaseModel, Field, TypeAdapter
33

44
try:
@@ -42,9 +42,18 @@ class SessionIdMessage(BaseModel):
4242
server_ip: str
4343

4444

45+
class PromptAckMessage(BaseModel):
46+
"""Acknowledgment for prompt update from server."""
47+
48+
type: Literal["prompt_ack"]
49+
prompt: str
50+
success: bool
51+
error: Optional[str] = None
52+
53+
4554
# Discriminated union for incoming messages
4655
IncomingMessage = Annotated[
47-
Union[AnswerMessage, IceCandidateMessage, SessionIdMessage],
56+
Union[AnswerMessage, IceCandidateMessage, SessionIdMessage, PromptAckMessage],
4857
Field(discriminator="type"),
4958
]
5059

@@ -67,6 +76,7 @@ class PromptMessage(BaseModel):
6776

6877
type: Literal["prompt"]
6978
prompt: str
79+
enhance_prompt: bool = True
7080

7181

7282
# Outgoing message union (no discriminator needed - we know what we're sending)

decart/realtime/webrtc_connection.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
OfferMessage,
2222
IceCandidateMessage,
2323
IceCandidatePayload,
24+
PromptAckMessage,
2425
OutgoingMessage,
2526
)
2627
from .types import ConnectionState
@@ -36,7 +37,6 @@ def __init__(
3637
on_error: Optional[Callable[[Exception], None]] = None,
3738
customize_offer: Optional[Callable] = None,
3839
):
39-
4040
self._pc: Optional[RTCPeerConnection] = None
4141
self._ws: Optional[aiohttp.ClientWebSocketResponse] = None
4242
self._session: Optional[aiohttp.ClientSession] = None
@@ -47,6 +47,7 @@ def __init__(
4747
self._customize_offer = customize_offer
4848
self._ws_task: Optional[asyncio.Task] = None
4949
self._ice_candidates_queue: list[RTCIceCandidate] = []
50+
self._pending_prompts: dict[str, tuple[asyncio.Event, dict]] = {}
5051

5152
async def connect(
5253
self,
@@ -176,6 +177,8 @@ async def _handle_message(self, data: dict) -> None:
176177
await self._handle_ice_candidate(message.candidate)
177178
elif message.type == "session_id":
178179
logger.debug(f"Session ID: {message.session_id}")
180+
elif message.type == "prompt_ack":
181+
self._handle_prompt_ack(message)
179182

180183
async def _handle_answer(self, sdp: str) -> None:
181184
logger.debug("Received answer from server")
@@ -207,6 +210,23 @@ async def _handle_ice_candidate(self, candidate_data: IceCandidatePayload) -> No
207210
logger.debug("Queuing ICE candidate (no remote description yet)")
208211
self._ice_candidates_queue.append(candidate)
209212

213+
def _handle_prompt_ack(self, message: PromptAckMessage) -> None:
214+
logger.debug(f"Received prompt_ack for: {message.prompt}, success: {message.success}")
215+
if message.prompt in self._pending_prompts:
216+
event, result = self._pending_prompts[message.prompt]
217+
result["success"] = message.success
218+
result["error"] = message.error
219+
event.set()
220+
221+
def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]:
222+
event = asyncio.Event()
223+
result: dict = {"success": False, "error": None}
224+
self._pending_prompts[prompt] = (event, result)
225+
return event, result
226+
227+
def unregister_prompt_wait(self, prompt: str) -> None:
228+
self._pending_prompts.pop(prompt, None)
229+
210230
async def _send_message(self, message: OutgoingMessage) -> None:
211231
if not self._ws or self._ws.closed:
212232
raise RuntimeError("WebSocket not connected")

decart/realtime/webrtc_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
from typing import Optional, Callable
34
from dataclasses import dataclass
@@ -84,3 +85,9 @@ def is_connected(self) -> bool:
8485

8586
def get_connection_state(self) -> ConnectionState:
8687
return self._connection.state
88+
89+
def register_prompt_wait(self, prompt: str) -> tuple[asyncio.Event, dict]:
90+
return self._connection.register_prompt_wait(prompt)
91+
92+
def unregister_prompt_wait(self, prompt: str) -> None:
93+
self._connection.unregister_prompt_wait(prompt)

examples/lipsync_file.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ async def process_lipsync(video_path: str, audio_path: str, output_path: str):
7272
)
7373
for i in range(frame_count):
7474
try:
75-
7675
video_frame, audio_frame = await client.get_synced_output(timeout=1.0)
7776
bgr_frame = cv2.cvtColor(video_frame, cv2.COLOR_RGB2BGR)
7877
out.write(bgr_frame)

examples/realtime_synthetic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,11 @@ def on_error(error):
126126
await asyncio.sleep(5)
127127

128128
print("\n🎨 Changing style to 'Cyberpunk city'...")
129-
await realtime_client.set_prompt("Cyberpunk city")
129+
try:
130+
await realtime_client.set_prompt("Cyberpunk city")
131+
print("✓ Prompt set successfully")
132+
except Exception as e:
133+
print(f"⚠️ Failed to set prompt: {e}")
130134

131135
await asyncio.sleep(5)
132136

tests/test_realtime_unit.py

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,23 @@ def test_realtime_models_available():
4444
@pytest.mark.asyncio
4545
async def test_realtime_client_creation_with_mock():
4646
"""Test client creation with mocked WebRTC"""
47+
import asyncio
48+
4749
client = DecartClient(api_key="test-key")
4850

4951
with patch("decart.realtime.client.WebRTCManager") as mock_manager_class:
5052
mock_manager = AsyncMock()
5153
mock_manager.connect = AsyncMock(return_value=True)
5254
mock_manager.is_connected = MagicMock(return_value=True)
5355
mock_manager.get_connection_state = MagicMock(return_value="connected")
56+
mock_manager.send_message = AsyncMock()
57+
58+
prompt_event = asyncio.Event()
59+
prompt_result = {"success": True, "error": None}
60+
prompt_event.set()
61+
62+
mock_manager.register_prompt_wait = MagicMock(return_value=(prompt_event, prompt_result))
63+
mock_manager.unregister_prompt_wait = MagicMock()
5464
mock_manager_class.return_value = mock_manager
5565

5666
mock_track = MagicMock()
@@ -76,13 +86,24 @@ async def test_realtime_client_creation_with_mock():
7686

7787
@pytest.mark.asyncio
7888
async def test_realtime_set_prompt_with_mock():
79-
"""Test set_prompt with mocked WebRTC"""
89+
"""Test set_prompt with mocked WebRTC and prompt_ack"""
90+
import asyncio
91+
8092
client = DecartClient(api_key="test-key")
8193

8294
with patch("decart.realtime.client.WebRTCManager") as mock_manager_class:
8395
mock_manager = AsyncMock()
8496
mock_manager.connect = AsyncMock(return_value=True)
8597
mock_manager.send_message = AsyncMock()
98+
99+
prompt_event = asyncio.Event()
100+
prompt_result = {"success": True, "error": None}
101+
102+
def register_prompt_wait(prompt):
103+
return prompt_event, prompt_result
104+
105+
mock_manager.register_prompt_wait = MagicMock(side_effect=register_prompt_wait)
106+
mock_manager.unregister_prompt_wait = MagicMock()
86107
mock_manager_class.return_value = mock_manager
87108

88109
mock_track = MagicMock()
@@ -99,12 +120,19 @@ async def test_realtime_set_prompt_with_mock():
99120
),
100121
)
101122

123+
async def set_event():
124+
await asyncio.sleep(0.01)
125+
prompt_event.set()
126+
127+
asyncio.create_task(set_event())
102128
await realtime_client.set_prompt("New prompt")
103129

104-
mock_manager.send_message.assert_called_once()
130+
mock_manager.send_message.assert_called()
105131
call_args = mock_manager.send_message.call_args[0][0]
106132
assert call_args.type == "prompt"
107133
assert call_args.prompt == "New prompt"
134+
assert call_args.enhance_prompt is True
135+
mock_manager.unregister_prompt_wait.assert_called_with("New prompt")
108136

109137

110138
@pytest.mark.asyncio
@@ -152,3 +180,101 @@ def on_error(error):
152180
realtime_client._emit_error(test_error)
153181
assert len(errors) == 1
154182
assert errors[0].message == "Test error"
183+
184+
185+
@pytest.mark.asyncio
186+
async def test_realtime_set_prompt_timeout():
187+
"""Test set_prompt raises on timeout"""
188+
import asyncio
189+
190+
client = DecartClient(api_key="test-key")
191+
192+
with patch("decart.realtime.client.WebRTCManager") as mock_manager_class:
193+
mock_manager = AsyncMock()
194+
mock_manager.connect = AsyncMock(return_value=True)
195+
mock_manager.send_message = AsyncMock()
196+
197+
prompt_event = asyncio.Event()
198+
prompt_result = {"success": False, "error": None}
199+
200+
def register_prompt_wait(prompt):
201+
return prompt_event, prompt_result
202+
203+
mock_manager.register_prompt_wait = MagicMock(side_effect=register_prompt_wait)
204+
mock_manager.unregister_prompt_wait = MagicMock()
205+
mock_manager_class.return_value = mock_manager
206+
207+
mock_track = MagicMock()
208+
209+
from decart.realtime.types import RealtimeConnectOptions
210+
211+
realtime_client = await RealtimeClient.connect(
212+
base_url=client.base_url,
213+
api_key=client.api_key,
214+
local_track=mock_track,
215+
options=RealtimeConnectOptions(
216+
model=models.realtime("mirage"),
217+
on_remote_stream=lambda t: None,
218+
),
219+
)
220+
221+
from decart.errors import DecartSDKError
222+
223+
# Mock asyncio.wait_for to immediately raise TimeoutError
224+
with patch("asyncio.wait_for", side_effect=asyncio.TimeoutError):
225+
with pytest.raises(DecartSDKError) as exc_info:
226+
await realtime_client.set_prompt("New prompt")
227+
228+
assert "timed out" in str(exc_info.value)
229+
mock_manager.unregister_prompt_wait.assert_called_with("New prompt")
230+
231+
232+
@pytest.mark.asyncio
233+
async def test_realtime_set_prompt_server_error():
234+
"""Test set_prompt raises on server error"""
235+
import asyncio
236+
237+
client = DecartClient(api_key="test-key")
238+
239+
with patch("decart.realtime.client.WebRTCManager") as mock_manager_class:
240+
mock_manager = AsyncMock()
241+
mock_manager.connect = AsyncMock(return_value=True)
242+
mock_manager.send_message = AsyncMock()
243+
244+
prompt_event = asyncio.Event()
245+
prompt_result = {"success": False, "error": "Server rejected prompt"}
246+
247+
def register_prompt_wait(prompt):
248+
return prompt_event, prompt_result
249+
250+
mock_manager.register_prompt_wait = MagicMock(side_effect=register_prompt_wait)
251+
mock_manager.unregister_prompt_wait = MagicMock()
252+
mock_manager_class.return_value = mock_manager
253+
254+
mock_track = MagicMock()
255+
256+
from decart.realtime.types import RealtimeConnectOptions
257+
258+
realtime_client = await RealtimeClient.connect(
259+
base_url=client.base_url,
260+
api_key=client.api_key,
261+
local_track=mock_track,
262+
options=RealtimeConnectOptions(
263+
model=models.realtime("mirage"),
264+
on_remote_stream=lambda t: None,
265+
),
266+
)
267+
268+
async def set_event():
269+
await asyncio.sleep(0.01)
270+
prompt_event.set()
271+
272+
asyncio.create_task(set_event())
273+
274+
from decart.errors import DecartSDKError
275+
276+
with pytest.raises(DecartSDKError) as exc_info:
277+
await realtime_client.set_prompt("New prompt")
278+
279+
assert "Server rejected prompt" in str(exc_info.value)
280+
mock_manager.unregister_prompt_wait.assert_called_with("New prompt")

uv.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)