diff --git a/docker-compose.gpu.yml b/docker-compose.gpu.yml index f019fb5..6f128d4 100644 --- a/docker-compose.gpu.yml +++ b/docker-compose.gpu.yml @@ -6,6 +6,9 @@ services: service: wyoming-piper build: dockerfile: GPU.Dockerfile + volumes: + - ./piper/__main__.py:/usr/local/lib/python3.10/dist-packages/wyoming_piper/__main__.py + - ./piper/process.py:/usr/local/lib/python3.10/dist-packages/wyoming_piper/process.py deploy: resources: reservations: diff --git a/piper/__main__.py b/piper/__main__.py new file mode 100644 index 0000000..ea81455 --- /dev/null +++ b/piper/__main__.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +import argparse +import asyncio +import json +import logging +from functools import partial +from pathlib import Path +from typing import Any, Dict, Set + +from wyoming.info import Attribution, Info, TtsProgram, TtsVoice +from wyoming.server import AsyncServer + +from .download import find_voice, get_voices +from .handler import PiperEventHandler +from .process import PiperProcessManager + +_LOGGER = logging.getLogger(__name__) + + +async def main() -> None: + """Main entry point.""" + parser = argparse.ArgumentParser() + parser.add_argument( + "--piper", + required=True, + help="Path to piper executable", + ) + parser.add_argument( + "--voice", + required=True, + help="Default Piper voice to use (e.g., en_US-lessac-medium)", + ) + parser.add_argument("--uri", default="stdio://", help="unix:// or tcp://") + parser.add_argument( + "--data-dir", + required=True, + action="append", + help="Data directory to check for downloaded models", + ) + parser.add_argument( + "--download-dir", + help="Directory to download voices into (default: first data dir)", + ) + # + parser.add_argument( + "--speaker", type=str, help="Name or id of speaker for default voice" + ) + parser.add_argument("--noise-scale", type=float, help="Generator noise") + parser.add_argument("--length-scale", type=float, help="Phoneme length") + parser.add_argument("--noise-w", type=float, help="Phoneme width noise") + # + parser.add_argument( + "--auto-punctuation", default=".?!", help="Automatically add punctuation" + ) + parser.add_argument("--samples-per-chunk", type=int, default=1024) + parser.add_argument( + "--max-piper-procs", + type=int, + default=1, + help="Maximum number of piper process to run simultaneously (default: 1)", + ) + # + parser.add_argument( + "--update-voices", + action="store_true", + help="Download latest voices.json during startup", + ) + parser.add_argument( + "--cuda", + action="store_true", + help="Use GPU" + ) + # + parser.add_argument("--debug", action="store_true", help="Log DEBUG messages") + args = parser.parse_args() + + if not args.download_dir: + # Default to first data directory + args.download_dir = args.data_dir[0] + + logging.basicConfig(level=logging.DEBUG if args.debug else logging.INFO) + + # Load voice info + voices_info = get_voices(args.download_dir, update_voices=args.update_voices) + + # Resolve aliases for backwards compatibility with old voice names + aliases_info: Dict[str, Any] = {} + for voice_info in voices_info.values(): + for voice_alias in voice_info.get("aliases", []): + aliases_info[voice_alias] = {"_is_alias": True, **voice_info} + + voices_info.update(aliases_info) + voices = [ + TtsVoice( + name=voice_name, + description=get_description(voice_info), + attribution=Attribution( + name="rhasspy", url="https://github.com/rhasspy/piper" + ), + installed=True, + languages=[ + voice_info.get("language", {}).get( + "code", + voice_info.get("espeak", {}).get("voice", voice_name.split("_")[0]), + ) + ], + # + # Don't send speakers for now because it overflows StreamReader buffers + # speakers=[ + # TtsVoiceSpeaker(name=speaker_name) + # for speaker_name in voice_info["speaker_id_map"] + # ] + # if voice_info.get("speaker_id_map") + # else None, + ) + for voice_name, voice_info in voices_info.items() + if not voice_info.get("_is_alias", False) + ] + + custom_voice_names: Set[str] = set() + if args.voice not in voices_info: + custom_voice_names.add(args.voice) + + for data_dir in args.data_dir: + data_dir = Path(data_dir) + if not data_dir.is_dir(): + continue + + for onnx_path in data_dir.glob("*.onnx"): + custom_voice_name = onnx_path.stem + if custom_voice_name not in voices_info: + custom_voice_names.add(custom_voice_name) + + for custom_voice_name in custom_voice_names: + # Add custom voice info + custom_voice_path, custom_config_path = find_voice( + custom_voice_name, args.data_dir + ) + with open(custom_config_path, "r", encoding="utf-8") as custom_config_file: + custom_config = json.load(custom_config_file) + custom_name = custom_config.get("dataset", custom_voice_path.stem) + custom_quality = custom_config.get("audio", {}).get("quality") + if custom_quality: + description = f"{custom_name} ({custom_quality})" + else: + description = custom_name + + lang_code = custom_config.get("language", {}).get("code") + if not lang_code: + lang_code = custom_config.get("espeak", {}).get("voice") + if not lang_code: + lang_code = custom_voice_path.stem.split("_")[0] + + voices.append( + TtsVoice( + name=custom_name, + description=description, + attribution=Attribution(name="", url=""), + installed=True, + languages=[lang_code], + ) + ) + + wyoming_info = Info( + tts=[ + TtsProgram( + name="piper", + description="A fast, local, neural text to speech engine", + attribution=Attribution( + name="rhasspy", url="https://github.com/rhasspy/piper" + ), + installed=True, + voices=sorted(voices, key=lambda v: v.name), + ) + ], + ) + + process_manager = PiperProcessManager(args, voices_info) + + # Make sure default voice is loaded. + # Other voices will be loaded on-demand. + await process_manager.get_process() + + # Start server + server = AsyncServer.from_uri(args.uri) + + _LOGGER.info("Ready") + await server.run( + partial( + PiperEventHandler, + wyoming_info, + args, + process_manager, + ) + ) + + +# ----------------------------------------------------------------------------- + + +def get_description(voice_info: Dict[str, Any]): + """Get a human readable description for a voice.""" + name = voice_info["name"] + name = " ".join(name.split("_")) + quality = voice_info["quality"] + + return f"{name} ({quality})" + + +# ----------------------------------------------------------------------------- + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + pass diff --git a/piper/process.py b/piper/process.py new file mode 100644 index 0000000..63d7047 --- /dev/null +++ b/piper/process.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python3 +import argparse +import asyncio +import json +import logging +import tempfile +import time +from dataclasses import dataclass +from typing import Any, Dict, Optional + +from .download import ensure_voice_exists, find_voice + +_LOGGER = logging.getLogger(__name__) + + +@dataclass +class PiperProcess: + """Info for a running Piper process (one voice).""" + + name: str + proc: "asyncio.subprocess.Process" + config: Dict[str, Any] + wav_dir: tempfile.TemporaryDirectory + last_used: int = 0 + + def get_speaker_id(self, speaker: str) -> Optional[int]: + """Get speaker by name or id.""" + return _get_speaker_id(self.config, speaker) + + @property + def is_multispeaker(self) -> bool: + """True if model has more than one speaker.""" + return _is_multispeaker(self.config) + + +def _get_speaker_id(config: Dict[str, Any], speaker: str) -> Optional[int]: + """Get speaker by name or id.""" + speaker_id_map = config.get("speaker_id_map", {}) + speaker_id = speaker_id_map.get(speaker) + if speaker_id is None: + try: + # Try to interpret as an id + speaker_id = int(speaker) + except ValueError: + pass + + return speaker_id + + +def _is_multispeaker(config: Dict[str, Any]) -> bool: + """True if model has more than one speaker.""" + return config.get("num_speakers", 1) > 1 + + +# ----------------------------------------------------------------------------- + + +class PiperProcessManager: + """Manager of running Piper processes.""" + + def __init__(self, args: argparse.Namespace, voices_info: Dict[str, Any]): + self.voices_info = voices_info + self.args = args + self.processes: Dict[str, PiperProcess] = {} + self.processes_lock = asyncio.Lock() + + async def get_process(self, voice_name: Optional[str] = None) -> PiperProcess: + """Get a running Piper process or start a new one if necessary.""" + voice_speaker: Optional[str] = None + if voice_name is None: + # Default voice + voice_name = self.args.voice + + if voice_name == self.args.voice: + # Default speaker + voice_speaker = self.args.speaker + + assert voice_name is not None + + # Resolve alias + voice_info = self.voices_info.get(voice_name, {}) + voice_name = voice_info.get("key", voice_name) + assert voice_name is not None + + piper_proc = self.processes.get(voice_name) + if (piper_proc is None) or (piper_proc.proc.returncode is not None): + # Remove if stopped + self.processes.pop(voice_name, None) + + # Start new Piper process + if self.args.max_piper_procs > 0: + # Restrict number of running processes + while len(self.processes) >= self.args.max_piper_procs: + # Stop least recently used process + lru_proc_name, lru_proc = sorted( + self.processes.items(), key=lambda kv: kv[1].last_used + )[0] + _LOGGER.debug("Stopping process for: %s", lru_proc_name) + self.processes.pop(lru_proc_name, None) + if lru_proc.proc.returncode is None: + try: + lru_proc.proc.terminate() + await lru_proc.proc.wait() + except Exception: + _LOGGER.exception("Unexpected error stopping piper process") + + _LOGGER.debug( + "Starting process for: %s (%s/%s)", + voice_name, + len(self.processes) + 1, + self.args.max_piper_procs, + ) + + ensure_voice_exists( + voice_name, + self.args.data_dir, + self.args.download_dir, + self.voices_info, + ) + + onnx_path, config_path = find_voice(voice_name, self.args.data_dir) + with open(config_path, "r", encoding="utf-8") as config_file: + config = json.load(config_file) + + wav_dir = tempfile.TemporaryDirectory() + piper_args = [ + "--model", + str(onnx_path), + "--config", + str(config_path), + "--output_dir", + str(wav_dir.name), + "--json-input", # piper 1.1+ + ] + + if voice_speaker is not None: + if _is_multispeaker(config): + speaker_id = _get_speaker_id(config, voice_speaker) + if speaker_id is not None: + piper_args.extend(["--speaker", str(speaker_id)]) + + if self.args.noise_scale: + piper_args.extend(["--noise-scale", str(self.args.noise_scale)]) + + if self.args.length_scale: + piper_args.extend(["--length-scale", str(self.args.length_scale)]) + + if self.args.noise_w: + piper_args.extend(["--noise-w", str(self.args.noise_w)]) + + if self.args.cuda: + piper_args.extend(["--cuda"]) + + _LOGGER.debug( + "Starting piper process: %s args=%s", self.args.piper, piper_args + ) + piper_proc = PiperProcess( + name=voice_name, + proc=await asyncio.create_subprocess_exec( + self.args.piper, + *piper_args, + stdin=asyncio.subprocess.PIPE, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.DEVNULL, + ), + config=config, + wav_dir=wav_dir, + ) + self.processes[voice_name] = piper_proc + + # Update used + piper_proc.last_used = time.monotonic_ns() + + return piper_proc +