diff --git a/reginald/cli.py b/reginald/cli.py index 0b9c79d..c2f4ab5 100644 --- a/reginald/cli.py +++ b/reginald/cli.py @@ -32,6 +32,8 @@ "streaming": "Whether to use streaming for the chat interaction.", "slack_app_token": "Slack app token for the bot.", "slack_bot_token": "Slack bot token for the bot.", + "host": "Host to listen on.", + "port": "Port to listen on.", } cli = typer.Typer(context_settings={"help_option_names": ["-h", "--help"]}) @@ -261,6 +263,12 @@ def app( device: Annotated[ str, typer.Option(envvar="LLAMA_INDEX_DEVICE", help=HELP_TEXT["device"]) ] = DEFAULT_ARGS["device"], + host: Annotated[ + str, typer.Option(envvar="REGINALD_HOST", help=HELP_TEXT["host"]) + ] = DEFAULT_ARGS["host"], + port: Annotated[ + int, typer.Option(envvar="REGINALD_PORT", help=HELP_TEXT["port"]) + ] = DEFAULT_ARGS["port"], ) -> None: """ Sets up the response model and then creates a @@ -273,6 +281,8 @@ def app( set_up_logging_config(level=20) main( cli="app", + host=host, + port=port, model=model, model_name=model_name, mode=mode, diff --git a/reginald/defaults.py b/reginald/defaults.py index 7aa705a..413bcd5 100644 --- a/reginald/defaults.py +++ b/reginald/defaults.py @@ -19,4 +19,6 @@ "is_path": False, "n_gpu_layers": 0, "device": "auto", + "host": "0.0.0.0", + "port": 8000, } diff --git a/reginald/models/app.py b/reginald/models/app.py index beedae0..5d1c1a2 100644 --- a/reginald/models/app.py +++ b/reginald/models/app.py @@ -2,6 +2,7 @@ from fastapi import FastAPI from pydantic import BaseModel +from reginald.defaults import DEFAULT_ARGS from reginald.models.setup_llm import setup_llm @@ -43,8 +44,17 @@ async def channel_mention(query: Query): return app -def run_reginald_app(**kwargs) -> None: +def run_reginald_app( + host: str | None = DEFAULT_ARGS["host"], + port: int | None = DEFAULT_ARGS["port"], + **kwargs +) -> None: + if host is None: + host = DEFAULT_ARGS["host"] + if port is None: + port = DEFAULT_ARGS["port"] + # set up response model response_model = setup_llm(**kwargs) app: FastAPI = create_reginald_app(response_model) - uvicorn.run(app, host="0.0.0.0", port=8000) + uvicorn.run(app, host=host, port=port) diff --git a/reginald/run.py b/reginald/run.py index 75a3cd6..26728c2 100644 --- a/reginald/run.py +++ b/reginald/run.py @@ -11,6 +11,8 @@ def main( which_index: str | None = None, slack_app_token: str | None = None, slack_bot_token: str | None = None, + host: str | None = None, + port: int | None = None, **kwargs, ): # initialise logging @@ -40,7 +42,9 @@ def main( elif cli == "app": from reginald.models.app import run_reginald_app - run_reginald_app(data_dir=data_dir, which_index=which_index, **kwargs) + run_reginald_app( + host=host, port=port, data_dir=data_dir, which_index=which_index, **kwargs + ) elif cli == "chat": import warnings