Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add telemetry origin #141

Merged
merged 5 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ docker-compose up -d ui && docker-compose ps && docker-compose logs -f

Server
```bash
# navigate to <URL>/docs for API definitions
docker-compose up -d server && docker-compose ps && docker-compose logs -f
```

Expand Down Expand Up @@ -102,7 +103,10 @@ tts.synthesise(text="This is a demo of text to speech by MetaVoice-1B, an open-s
```bash
# You can use `--quantisation_mode int4` or `--quantisation_mode int8` for experimental faster inference. This will degrade the quality of the audio.
# Note: int8 is slower than bf16/fp16 for undebugged reasons. If you want fast, try int4 which is roughly 2x faster than bf16/fp16.

# navigate to <URL>/docs for API definitions
poetry run python serving.py

poetry run python app.py
```

Expand Down
4 changes: 2 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from fam.llm.utils import check_audio_file

#### setup model
TTS_MODEL = tyro.cli(TTS)
TTS_MODEL = tyro.cli(TTS, args=["--telemetry_origin", "webapp"])

#### setup interface
RADIO_CHOICES = ["Preset voices", "Upload target voice (atleast 30s)"]
MAX_CHARS = 220
PRESET_VOICES = {
# female
"Bria": "https://cdn.themetavoice.xyz/speakers%2Fbria.mp3",
"Bria": "https://cdn.themetavoice.xyz/speakers/bria.mp3",
# male
"Alex": "https://cdn.themetavoice.xyz/speakers/alex.mp3",
"Jacob": "https://cdn.themetavoice.xyz/speakers/jacob.wav",
Expand Down
4 changes: 4 additions & 0 deletions fam/llm/fast_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
output_dir: str = "outputs",
quantisation_mode: Optional[Literal["int4", "int8"]] = None,
first_stage_path: Optional[str] = None,
telemetry_origin: Optional[str] = None,
):
"""
Initialise the TTS model.
Expand All @@ -60,6 +61,7 @@ def __init__(
- int4 for int4 weight-only quantisation,
- int8 for int8 weight-only quantisation.
first_stage_path: path to first-stage LLM checkpoint. If provided, this will override the one grabbed from Hugging Face via `model_name`.
telemetry_origin: A string identifier that specifies the origin of the telemetry data sent to PostHog.
"""

# NOTE: this needs to come first so that we don't change global state when we want to use
Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(
self._seed = seed
self._quantisation_mode = quantisation_mode
self._model_name = model_name
self._telemetry_origin = telemetry_origin

def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.0, temperature=1.0) -> str:
"""
Expand Down Expand Up @@ -183,6 +186,7 @@ def synthesise(self, text: str, spk_ref_path: str, top_p=0.95, guidance_scale=3.
"seed": self._seed,
"first_stage_ckpt": self._first_stage_ckpt,
"gpu": torch.cuda.get_device_name(0),
"telemetry_origin": self._telemetry_origin,
},
)
)
Expand Down
59 changes: 31 additions & 28 deletions serving.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
import shlex
import subprocess
Expand All @@ -12,7 +11,7 @@
import tyro
import uvicorn
from attr import dataclass
from fastapi import Request
from fastapi import File, Form, HTTPException, UploadFile, status
from fastapi.responses import Response

from fam.llm.fast_inference import TTS
Expand Down Expand Up @@ -50,55 +49,55 @@ class _GlobalState:
GlobalState = _GlobalState()


@dataclass(frozen=True)
class TTSRequest:
text: str
speaker_ref_path: Optional[str] = None
guidance: float = 3.0
top_p: float = 0.95
top_k: Optional[int] = None


@app.get("/health")
async def health_check():
return {"status": "ok"}


@app.post("/tts", response_class=Response)
async def text_to_speech(req: Request):
audiodata = await req.body()
payload = None
async def text_to_speech(
text: str = Form(...),
speaker_ref_path: Optional[str] = Form(None),
guidance: float = Form(3.0),
top_p: float = Form(0.95),
audiodata: Optional[UploadFile] = File(None),
):
# Ensure at least one of speaker_ref_path or audiodata is provided
if not audiodata and not speaker_ref_path:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Either an audio file or a speaker reference path must be provided.",
)

wav_out_path = None

try:
headers = req.headers
payload = headers["X-Payload"]
payload = json.loads(payload)
tts_req = TTSRequest(**payload)
with tempfile.NamedTemporaryFile(suffix=".wav") as wav_tmp:
if tts_req.speaker_ref_path is None:
if speaker_ref_path is None:
wav_path = _convert_audiodata_to_wav_path(audiodata, wav_tmp)
check_audio_file(wav_path)
else:
# TODO: fix
wav_path = tts_req.speaker_ref_path
wav_path = speaker_ref_path

if wav_path is None:
warnings.warn("Running without speaker reference")
assert tts_req.guidance is None
assert guidance is None

wav_out_path = GlobalState.tts.synthesise(
text=tts_req.text,
text=text,
spk_ref_path=wav_path,
top_p=tts_req.top_p,
guidance_scale=tts_req.guidance,
top_p=top_p,
guidance_scale=guidance,
)

with open(wav_out_path, "rb") as f:
return Response(content=f.read(), media_type="audio/wav")
except Exception as e:
# traceback_str = "".join(traceback.format_tb(e.__traceback__))
logger.exception(f"Error processing request {payload}")
logger.exception(
f"Error processing request. text: {text}, speaker_ref_path: {speaker_ref_path}, guidance: {guidance}, top_p: {top_p}"
)
return Response(
content="Something went wrong. Please try again in a few mins or contact us on Discord",
status_code=500,
Expand All @@ -108,9 +107,9 @@ async def text_to_speech(req: Request):
Path(wav_out_path).unlink(missing_ok=True)


def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
def _convert_audiodata_to_wav_path(audiodata: UploadFile, wav_tmp):
with tempfile.NamedTemporaryFile() as unknown_format_tmp:
if unknown_format_tmp.write(audiodata) == 0:
if unknown_format_tmp.write(audiodata.read()) == 0:
return None
unknown_format_tmp.flush()

Expand All @@ -129,7 +128,11 @@ def _convert_audiodata_to_wav_path(audiodata, wav_tmp):
logging.root.setLevel(logging.INFO)

GlobalState.config = tyro.cli(ServingConfig)
GlobalState.tts = TTS(seed=GlobalState.config.seed, quantisation_mode=GlobalState.config.quantisation_mode)
GlobalState.tts = TTS(
seed=GlobalState.config.seed,
quantisation_mode=GlobalState.config.quantisation_mode,
telemetry_origin="api_server",
)

app.add_middleware(
fastapi.middleware.cors.CORSMiddleware,
Expand Down
Loading