Skip to content

Commit 1449c55

Browse files
committed
Update app to use Direct CSM by default
1 parent c2f9875 commit 1449c55

4 files changed

Lines changed: 81 additions & 74 deletions

File tree

app/api/voice_generator.py

Lines changed: 15 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ def load_model(self, device: str = None) -> bool:
130130
return True
131131
except DirectCSMError as e:
132132
logger.warning(f"Failed to load Direct CSM implementation: {e}")
133+
if not config.DIRECT_CSM_FALLBACK_TO_STANDARD:
134+
logger.error("Direct CSM fallback is disabled, failing")
135+
raise RuntimeError(f"Failed to load Direct CSM and fallback is disabled: {e}")
133136
logger.info("Falling back to standard CSM model")
134137
self.direct_csm = None
135138

@@ -161,6 +164,9 @@ def load_model(self, device: str = None) -> bool:
161164
return True
162165
except DirectCSMError as e:
163166
logger.warning(f"Failed to load Direct CSM implementation on CPU: {e}")
167+
if not config.DIRECT_CSM_FALLBACK_TO_STANDARD:
168+
logger.error("Direct CSM fallback is disabled, failing")
169+
raise RuntimeError(f"Failed to load Direct CSM on CPU and fallback is disabled: {e}")
164170
logger.info("Falling back to standard CSM model on CPU")
165171
self.direct_csm = None
166172

@@ -252,6 +258,7 @@ def generate(
252258
if self.direct_csm is not None:
253259
try:
254260
# Generate speech using direct CSM
261+
logger.info("Using Direct CSM for voice generation")
255262
audio, sample_rate = self.direct_csm.generate_speech(
256263
text=text,
257264
speaker_id=speaker_id,
@@ -277,12 +284,16 @@ def generate(
277284
return output_path, url
278285

279286
except DirectCSMError as e:
280-
logger.warning(f"Direct CSM failed: {e}, falling back to standard CSM model")
281-
# Fall back to standard CSM model
287+
logger.warning(f"Direct CSM failed: {e}")
288+
if not config.DIRECT_CSM_FALLBACK_TO_STANDARD:
289+
logger.error("Direct CSM fallback is disabled, failing")
290+
return None, None
291+
logger.info("Falling back to standard CSM model")
282292

283293
# If direct CSM failed or is not available, use the standard CSM model
284294
if self.model is not None:
285295
# Generate speech
296+
logger.info("Using standard CSM model for voice generation")
286297
audio, sample_rate = self.model.generate_speech(
287298
text=text,
288299
speaker_id=speaker_id,
@@ -309,77 +320,9 @@ def generate(
309320
else:
310321
logger.error("No model available for voice generation")
311322
return None, None
312-
323+
313324
except Exception as e:
314-
logger.error(f"Error generating voice: {e}")
315-
316-
# Try again with CPU if we were using CUDA and it failed
317-
if device == "cuda" or device == "auto":
318-
logger.info("Attempting to fall back to CPU after error")
319-
try:
320-
# Try direct CSM on CPU if available
321-
if self.direct_csm is not None:
322-
try:
323-
# Generate speech using direct CSM on CPU
324-
audio, sample_rate = self.direct_csm.generate_speech(
325-
text=text,
326-
speaker_id=speaker_id,
327-
temperature=temperature,
328-
top_k=top_k,
329-
device="cpu"
330-
)
331-
332-
# Save the audio
333-
self.direct_csm.save_audio(audio, sample_rate, output_path)
334-
335-
# Check if file was created
336-
if not os.path.exists(output_path):
337-
logger.error(f"Output file not created: {output_path}")
338-
raise DirectCSMError("Output file not created")
339-
340-
# Create URL for accessing the file
341-
relative_path = os.path.relpath(output_path, self.output_dir)
342-
url = f"/voices/{relative_path}"
343-
344-
logger.info(f"Voice generated successfully with Direct CSM on CPU: {output_path}")
345-
return output_path, url
346-
347-
except DirectCSMError as cpu_e:
348-
logger.warning(f"Direct CSM on CPU failed: {cpu_e}, falling back to standard CSM model on CPU")
349-
350-
# If direct CSM failed or is not available, use the standard CSM model on CPU
351-
if self.model is not None:
352-
# Generate speech on CPU
353-
audio, sample_rate = self.model.generate_speech(
354-
text=text,
355-
speaker_id=speaker_id,
356-
temperature=temperature,
357-
top_k=top_k,
358-
device="cpu"
359-
)
360-
361-
# Save the audio
362-
self.model.save_audio(audio, sample_rate, output_path)
363-
364-
# Check if file was created
365-
if not os.path.exists(output_path):
366-
logger.error(f"Output file not created: {output_path}")
367-
return None, None
368-
369-
# Create URL for accessing the file
370-
relative_path = os.path.relpath(output_path, self.output_dir)
371-
url = f"/voices/{relative_path}"
372-
373-
logger.info(f"Voice generated successfully with standard CSM on CPU: {output_path}")
374-
return output_path, url
375-
else:
376-
logger.error("No model available for voice generation on CPU")
377-
return None, None
378-
379-
except Exception as cpu_e:
380-
logger.error(f"Error generating voice on CPU: {cpu_e}")
381-
return None, None
382-
325+
logger.error(f"Error generating voice: {str(e)}")
383326
return None, None
384327

385328
def list_available_voices(self) -> List[Dict[str, Any]]:

app/core/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@
3939
MODEL_PATH = "/home/tdeshane/.cache/huggingface/hub/models--sesame--csm-1b/snapshots/03ab46ff5cfdcc783cc76fcf9ea6fd0838503093/ckpt.pt"
4040

4141
# Direct CSM settings
42-
USE_DIRECT_CSM = os.environ.get("USE_DIRECT_CSM", "true").lower() == "true"
42+
USE_DIRECT_CSM = os.environ.get("USE_DIRECT_CSM", "true").lower() == "true" # Enable by default
4343
DIRECT_CSM_PATH = os.environ.get("DIRECT_CSM_PATH", "/home/tdeshane/tts_poc/voice_poc/csm")
44+
DIRECT_CSM_FALLBACK_TO_STANDARD = os.environ.get("DIRECT_CSM_FALLBACK_TO_STANDARD", "true").lower() == "true"
4445

4546
# Output settings
4647
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", "/tmp/echoforge/voices")

main.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,26 @@ async def startup_event():
6969
voices_dir.mkdir(parents=True, exist_ok=True)
7070

7171
logger.info(f"Initialized data directories: {data_dir}")
72+
73+
# Initialize voice generator with Direct CSM
74+
from app.api.voice_generator import voice_generator
75+
try:
76+
# Get device from environment or use auto
77+
device = os.environ.get("ECHOFORGE_DEVICE", "auto")
78+
logger.info(f"Initializing voice generator with device: {device}")
79+
80+
# Initialize the voice generator
81+
voice_generator.initialize(device=device)
82+
83+
# Log whether we're using Direct CSM
84+
if voice_generator.direct_csm is not None:
85+
logger.info("Voice generator initialized with Direct CSM")
86+
elif voice_generator.model is not None:
87+
logger.info("Voice generator initialized with standard CSM model")
88+
else:
89+
logger.warning("Voice generator initialized but no model is loaded")
90+
except Exception as e:
91+
logger.error(f"Failed to initialize voice generator: {e}")
7292

7393

7494
def parse_args():
@@ -98,6 +118,18 @@ def parse_args():
98118
default="auto",
99119
help="Device to use for TTS (auto, cuda, cpu)"
100120
)
121+
parser.add_argument(
122+
"--direct-csm",
123+
action="store_true",
124+
default=True,
125+
help="Use Direct CSM implementation (default: True)"
126+
)
127+
parser.add_argument(
128+
"--no-direct-csm",
129+
action="store_false",
130+
dest="direct_csm",
131+
help="Disable Direct CSM implementation"
132+
)
101133
parser.add_argument(
102134
"--debug",
103135
action="store_true",
@@ -115,13 +147,16 @@ def parse_args():
115147
os.environ["ECHOFORGE_MODEL_PATH"] = args.model_path
116148
os.environ["ECHOFORGE_DEVICE"] = args.device
117149

150+
# Set Direct CSM environment variable
151+
os.environ["USE_DIRECT_CSM"] = str(args.direct_csm).lower()
152+
118153
# Configure logging level based on debug mode
119154
if args.debug:
120155
logging.getLogger().setLevel(logging.DEBUG)
121156
logger.debug("Debug mode enabled")
122157

123158
# Start the server
124-
logger.info(f"Starting server on {args.host}:{args.port}")
159+
logger.info(f"Starting server on {args.host}:{args.port} with Direct CSM {'enabled' if args.direct_csm else 'disabled'}")
125160
uvicorn.run(
126161
"main:app",
127162
host=args.host,

run.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,25 @@ def parse_arguments():
101101
help=f"Serve on {config.PUBLIC_HOST} to make the app publicly accessible"
102102
)
103103

104+
# Direct CSM arguments
105+
parser.add_argument(
106+
"--direct-csm",
107+
action="store_true",
108+
help="Enable Direct CSM implementation (default: enabled)"
109+
)
110+
111+
parser.add_argument(
112+
"--no-direct-csm",
113+
action="store_false",
114+
dest="direct_csm",
115+
help="Disable Direct CSM implementation"
116+
)
117+
118+
parser.add_argument(
119+
"--direct-csm-path",
120+
help=f"Path to Direct CSM implementation (default: {config.DIRECT_CSM_PATH})"
121+
)
122+
104123
# Auth arguments - support both styles for compatibility
105124
parser.add_argument(
106125
"--auth",
@@ -158,6 +177,15 @@ def main():
158177
if args.password or args.auth_pass:
159178
os.environ["AUTH_PASSWORD"] = args.auth_pass or args.password
160179

180+
# Direct CSM settings
181+
if hasattr(args, 'direct_csm'):
182+
os.environ["USE_DIRECT_CSM"] = str(args.direct_csm).lower()
183+
logger.info(f"Direct CSM is {'enabled' if args.direct_csm else 'disabled'}")
184+
185+
if args.direct_csm_path:
186+
os.environ["DIRECT_CSM_PATH"] = args.direct_csm_path
187+
logger.info(f"Using Direct CSM path: {args.direct_csm_path}")
188+
161189
# Set appropriate environment variables based on arguments
162190
if public_serving:
163191
os.environ["ALLOW_PUBLIC_SERVING"] = "true"

0 commit comments

Comments
 (0)