diff --git a/README.md b/README.md index 4eb9be5..88dc35d 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ docker-compose up -d ui && docker-compose ps && docker-compose logs -f Server ```bash +# navigate to /docs for API definitions docker-compose up -d server && docker-compose ps && docker-compose logs -f ``` @@ -102,7 +103,10 @@ tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-s ```bash # You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio. # Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16. + +# navigate to /docs for API definitions poetry run python serving.py + poetry run python app.py ``` diff --git a/app.py b/app.py index 37ee06a..abddc0b 100644 --- a/app.py +++ b/app.py @@ -13,14 +13,14 @@ from fam.llm.utils import check_audio_file #### setup model -TTS_MODEL = tyro.cli(TTS) +TTS_MODEL = tyro.cli(TTS, args=["--telemetry_origin", "webapp"]) #### setup interface RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"] MAX_CHARS = 220 PRESET_VOICES = { # female - "Bria": "https://cdn.themetavoice.xyz/speakers%2Fbria.mp3", + "Bria": "https://cdn.themetavoice.xyz/speakers/bria.mp3", # male "Alex": "https://cdn.themetavoice.xyz/speakers/alex.mp3", "Jacob": "https://cdn.themetavoice.xyz/speakers/jacob.wav", diff --git a/fam/llm/fast_inference.py b/fam/llm/fast_inference.py index 20d593c..f813e0b 100644 --- a/fam/llm/fast_inference.py +++ b/fam/llm/fast_inference.py @@ -46,6 +46,7 @@ def __init__( output_dir: str = "outputs", quantisation_mode: Optional[Literal["int4", "int8"]] = None, first_stage_path: Optional[str] = None, + telemetry_origin: Optional[str] = None, ): """ Initialise the TTS model. @@ -60,6 +61,7 @@ def __init__( - int4 for int4 weight-only quantisation, - int8 for int8 weight-only quantisation. first_stage_path: path to first-stage LLM checkpoint. If provided, this will override the one grabbed from Hugging Face via `model_name`. + telemetry_origin: A string identifier that specifies the origin of the telemetry data sent to PostHog. """ # NOTE: this needs to come first so that we don't change global state when we want to use @@ -104,6 +106,7 @@ def __init__( self._seed = seed self._quantisation_mode = quantisation_mode self._model_name = model_name + self._telemetry_origin = telemetry_origin def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str: """ @@ -183,6 +186,7 @@ def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3. "seed": self._seed, "first_stage_ckpt": self._first_stage_ckpt, "gpu": torch.cuda.get_device_name(0), + "telemetry_origin": self._telemetry_origin, }, ) ) diff --git a/serving.py b/serving.py index 94119b0..636c0ed 100644 --- a/serving.py +++ b/serving.py @@ -1,4 +1,3 @@ -import json import logging import shlex import subprocess @@ -12,7 +11,7 @@ import tyro import uvicorn from attr import dataclass -from fastapi import Request +from fastapi import File, Form, HTTPException, UploadFile, status from fastapi.responses import Response from fam.llm.fast_inference import TTS @@ -50,55 +49,55 @@ class _GlobalState: GlobalState = _GlobalState() -@dataclass(frozen=True) -class TTSRequest: - text: str - speaker_ref_path: Optional[str] = None - guidance: float = 3.0 - top_p: float = 0.95 - top_k: Optional[int] = None - - @app.get("/health") async def health_check(): return {"status": "ok"} @app.post("/tts", response_class=Response) -async def text_to_speech(req: Request): - audiodata = await req.body() - payload = None +async def text_to_speech( + text: str = Form(...), + speaker_ref_path: Optional[str] = Form(None), + guidance: float = Form(3.0), + top_p: float = Form(0.95), + audiodata: Optional[UploadFile] = File(None), +): + # Ensure at least one of speaker_ref_path or audiodata is provided + if not audiodata and not speaker_ref_path: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Either an audio file or a speaker reference path must be provided.", + ) + wav_out_path = None try: - headers = req.headers - payload = headers["X-Payload"] - payload = json.loads(payload) - tts_req = TTSRequest(**payload) with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp: - if tts_req.speaker_ref_path is None: + if speaker_ref_path is None: wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp) check_audio_file(wav_path) else: # TODO: fix - wav_path = tts_req.speaker_ref_path + wav_path = speaker_ref_path if wav_path is None: warnings.warn("Running without speaker reference") - assert tts_req.guidance is None + assert guidance is None wav_out_path = GlobalState.tts.synthesise( - text=tts_req.text, + text=text, spk_ref_path=wav_path, - top_p=tts_req.top_p, - guidance_scale=tts_req.guidance, + top_p=top_p, + guidance_scale=guidance, ) with open(wav_out_path, "rb") as f: return Response(content=f.read(), media_type="audio/wav") except Exception as e: # traceback_str = "".join(traceback.format_tb(e.__traceback__)) - logger.exception(f"Error processing request {payload}") + logger.exception( + f"Error processing request. text: {text}, speaker_ref_path: {speaker_ref_path}, guidance: {guidance}, top_p: {top_p}" + ) return Response( content="Something went wrong. Please try again in a few mins or contact us on Discord", status_code=500, @@ -108,9 +107,9 @@ async def text_to_speech(req: Request): Path(wav_out_path).unlink(missing_ok=True) -def _convert_audiodata_to_wav_path(audiodata, wav_tmp): +def _convert_audiodata_to_wav_path(audiodata: UploadFile, wav_tmp): with tempfile.NamedTemporaryFile() as unknown_format_tmp: - if unknown_format_tmp.write(audiodata) == 0: + if unknown_format_tmp.write(audiodata.read()) == 0: return None unknown_format_tmp.flush() @@ -129,7 +128,11 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp): logging.root.setLevel(logging.INFO) GlobalState.config = tyro.cli(ServingConfig) - GlobalState.tts = TTS(seed=GlobalState.config.seed, quantisation_mode=GlobalState.config.quantisation_mode) + GlobalState.tts = TTS( + seed=GlobalState.config.seed, + quantisation_mode=GlobalState.config.quantisation_mode, + telemetry_origin="api_server", + ) app.add_middleware( fastapi.middleware.cors.CORSMiddleware,