diff --git a/common b/common index 1301af4..da00396 160000 --- a/common +++ b/common @@ -1 +1 @@ -Subproject commit 1301af41cbf429dda8204b22d817c0e17cf8b369 +Subproject commit da003961535613d109a10269d56796918be78d61 diff --git a/riva/client/tts.py b/riva/client/tts.py index 904860c..84bd497 100644 --- a/riva/client/tts.py +++ b/riva/client/tts.py @@ -1,7 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: MIT -from typing import Generator, Optional, Union +from typing import Generator, Optional, Union, Iterable from grpc._channel import _MultiThreadedRendezvous @@ -97,7 +97,7 @@ def synthesize( def synthesize_online( self, - text: str, + text: Union[str, list[str], Iterable[str]], voice_name: Optional[str] = None, language_code: str = 'en-US', encoding: AudioEncoding = AudioEncoding.LINEAR_PCM, @@ -112,7 +112,10 @@ def synthesize_online( becoming available. Args: - text (:obj:`str`): An input text. + text (:obj:`Union[str, list[str], Iterable[str]]`): An input text. + If a string, it will be synthesized as a single text. + If a list of strings, it will be synthesized as a list of texts. + If an iterable of strings, it will be synthesized as an iterable of texts. voice_name (:obj:`str`, `optional`): A name of the voice, e.g. ``"English-US-Female-1"``. You may find available voices in server logs or in server model directory. If this parameter is :obj:`None`, then a server will select the first available model with correct :param:`language_code` value. @@ -132,7 +135,7 @@ def synthesize_online( future object by calling ``result()`` method. """ req = rtts.SynthesizeSpeechRequest( - text=text, + text="", language_code=language_code, sample_rate_hz=sample_rate_hz, encoding=encoding, @@ -147,6 +150,21 @@ def synthesize_online( req.zero_shot_data.encoding = audio_prompt_encoding req.zero_shot_data.quality = zero_shot_quality - add_custom_dictionary_to_config(req, custom_dictionary) - - return self.stub.SynthesizeOnline(req, metadata=self.auth.get_auth_metadata()) + add_custom_dictionary_to_config(req, custom_dictionary) + + def request_generator(text): + if isinstance(text, str): + req.text = text + yield req + elif isinstance(text, list): + for t in text: + req.text = t + yield req + elif isinstance(text, Iterable[str]): + for t in text: + req.text = t + yield req + else: + raise ValueError(f"Invalid text type: {type(text)}") + + return self.stub.SynthesizeOnline(request_generator(text), metadata=self.auth.get_auth_metadata()) \ No newline at end of file diff --git a/scripts/tts/talk.py b/scripts/tts/talk.py index 2df233b..8119d33 100644 --- a/scripts/tts/talk.py +++ b/scripts/tts/talk.py @@ -35,6 +35,7 @@ def parse_args() -> argparse.Namespace: group.add_argument("--text", type=str, help="Text input to synthesize.") group.add_argument("--list-devices", action="store_true", help="List output audio devices indices.") group.add_argument("--list-voices", action="store_true", help="List available voices.") + group.add_argument("--text_file", type=Path, default=None, help="A file path to a list of texts to synthesize.") parser.add_argument( "--voice", help="A voice name to use. If this parameter is missing, then the server will try a first available model " @@ -64,6 +65,7 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--encoding", default="LINEAR_PCM", choices={"LINEAR_PCM", "OGGOPUS"}, help="Output audio encoding.") parser.add_argument("--custom-dictionary", type=str, help="A file path to a user dictionary with key-value pairs separated by double spaces.") + parser.add_argument( "--stream", action="store_true", @@ -138,10 +140,12 @@ def main() -> None: print(json.dumps(tts_models, indent=4)) return - if not args.text: + if not args.text and not args.text_file: print("No input text provided") return - + if args.text_file is not None and not args.stream: + print("Streaming synthesis is required when using a text list") + return try: if args.output_device is not None or args.play_audio: sound_stream = riva.client.audio_io.SoundCallBack( @@ -153,6 +157,12 @@ def main() -> None: out_f.setsampwidth(sampwidth) out_f.setframerate(args.sample_rate_hz) + if args.text_file is not None: + with open(args.text_file, 'r') as file: + text_list = [line.split("|")[1].strip() for line in file.readlines()] + else: + text_list = [args.text] + custom_dictionary_input = {} if args.custom_dictionary is not None: custom_dictionary_input = read_file_to_dict(args.custom_dictionary) @@ -161,7 +171,7 @@ def main() -> None: start = time.time() if args.stream: responses = service.synthesize_online( - args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz, + text_list, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz, encoding=(AudioEncoding.OGGOPUS if args.encoding == "OGGOPUS" else AudioEncoding.LINEAR_PCM), zero_shot_audio_prompt_file=args.zero_shot_audio_prompt_file, zero_shot_quality=(20 if args.zero_shot_quality is None else args.zero_shot_quality), @@ -179,7 +189,7 @@ def main() -> None: out_f.writeframesraw(resp.audio) else: resp = service.synthesize( - args.text, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz, + text_list, args.voice, args.language_code, sample_rate_hz=args.sample_rate_hz, encoding=(AudioEncoding.OGGOPUS if args.encoding == "OGGOPUS" else AudioEncoding.LINEAR_PCM), zero_shot_audio_prompt_file=args.zero_shot_audio_prompt_file, zero_shot_quality=(20 if args.zero_shot_quality is None else args.zero_shot_quality),