Skip to content

Commit 009c4f0

Browse files
committed
fix: typings
1 parent f5c30c2 commit 009c4f0

File tree

2 files changed

+29
-19
lines changed

2 files changed

+29
-19
lines changed

src/inworld_sdk/http_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ async def request(
5151
path: str,
5252
data: Optional[dict] = None,
5353
stream: bool = False,
54-
) -> Union[dict | ResponseWrapper]:
54+
) -> Union[dict, ResponseWrapper]:
5555
requestData = (
5656
json.dumps(data) if method != "get" and data and len(data.keys()) > 0 else None
5757
)

src/inworld_sdk/tts.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import base64
22
import io
33
import json
4-
from typing import AsyncGenerator, Optional
4+
from typing import Any, AsyncGenerator, cast, Dict, List, Optional
55

66
from .http_client import HttpClient
7+
from .http_client import ResponseWrapper
78
from .typings.tts import AudioConfig
89
from .typings.tts import TTSLanguageCodes
910
from .typings.tts import TTSVoices
@@ -29,7 +30,7 @@ def __init__(
2930
self.__voice = voice or "Emma"
3031

3132
@property
32-
def audioConfig(self) -> AudioConfig:
33+
def audioConfig(self) -> Optional[AudioConfig]:
3334
"""Get default audio config"""
3435
return self.__audioConfig
3536

@@ -49,7 +50,7 @@ def languageCode(self, languageCode: TTSLanguageCodes):
4950
self.__languageCode = languageCode
5051

5152
@property
52-
def modelId(self) -> str:
53+
def modelId(self) -> Optional[str]:
5354
"""Get default model ID"""
5455
return self.__modelId
5556

@@ -75,7 +76,7 @@ async def synthesizeSpeech(
7576
languageCode: Optional[TTSLanguageCodes] = None,
7677
modelId: Optional[str] = None,
7778
audioConfig: Optional[AudioConfig] = None,
78-
) -> dict:
79+
) -> Dict[str, Any]:
7980
"""Synthesize speech"""
8081
data = {
8182
"input": {"text": input},
@@ -91,11 +92,12 @@ async def synthesizeSpeech(
9192
if modelId or self.__modelId:
9293
data["modelId"] = modelId or self.__modelId
9394

94-
return await self.__client.request(
95+
response = await self.__client.request(
9596
"post",
9697
"/tts/v1alpha/text:synthesize-sync",
9798
data=data,
9899
)
100+
return cast(Dict[str, Any], response)
99101

100102
async def synthesizeSpeechAsWav(
101103
self,
@@ -117,7 +119,10 @@ async def synthesizeSpeechAsWav(
117119
audioConfig=audioConfig,
118120
)
119121

120-
decoded_audio = base64.b64decode(response.get("audioContent"))
122+
audio_content = response.get("audioContent")
123+
if not audio_content:
124+
raise ValueError("No audio content in response")
125+
decoded_audio = base64.b64decode(audio_content)
121126

122127
return io.BytesIO(decoded_audio)
123128

@@ -128,7 +133,7 @@ async def synthesizeSpeechStream(
128133
languageCode: Optional[TTSLanguageCodes] = None,
129134
modelId: Optional[str] = None,
130135
audioConfig: Optional[AudioConfig] = None,
131-
) -> AsyncGenerator[dict, None]:
136+
) -> AsyncGenerator[Dict[str, Any], None]:
132137
"""Synthesize speech as a stream"""
133138
data = {
134139
"input": {"text": input},
@@ -144,13 +149,16 @@ async def synthesizeSpeechStream(
144149
if modelId or self.__modelId:
145150
data["modelId"] = modelId or self.__modelId
146151

147-
response = None
152+
response: Optional[ResponseWrapper] = None
148153
try:
149-
response = await self.__client.request(
150-
"post",
151-
"/tts/v1alpha/text:synthesize",
152-
data=data,
153-
stream=True,
154+
response = cast(
155+
ResponseWrapper,
156+
await self.__client.request(
157+
"post",
158+
"/tts/v1alpha/text:synthesize",
159+
data=data,
160+
stream=True,
161+
),
154162
)
155163

156164
async for chunk in response.content:
@@ -184,8 +192,9 @@ async def synthesizeSpeechStreamAsWav(
184192
languageCode=languageCode,
185193
audioConfig=audioConfig,
186194
):
187-
if chunk and chunk.get("audioContent") is not None:
188-
decoded_audio = base64.b64decode(chunk.get("audioContent"))
195+
audio_content = chunk.get("audioContent")
196+
if audio_content is not None:
197+
decoded_audio = base64.b64decode(audio_content)
189198
yield io.BytesIO(decoded_audio)
190199
except Exception:
191200
raise
@@ -194,13 +203,14 @@ async def voices(
194203
self,
195204
languageCode: Optional[TTSLanguageCodes] = None,
196205
modelId: Optional[str] = None,
197-
) -> list[VoiceResponse]:
206+
) -> List[VoiceResponse]:
198207
"""Get voices"""
199-
data = {}
208+
data: Dict[str, Any] = {}
200209
if languageCode:
201210
data["languageCode"] = languageCode
202211
if modelId:
203212
data["modelId"] = modelId
204213

205214
response = await self.__client.request("get", "/tts/v1alpha/voices", data=data)
206-
return response.get("voices")
215+
voices = response.get("voices", [])
216+
return cast(List[VoiceResponse], voices)

0 commit comments

Comments
 (0)