diff --git a/scripts/olmo2_modal_openai.py b/scripts/olmo2_modal_openai.py index 570d928ed..ab21e31ad 100644 --- a/scripts/olmo2_modal_openai.py +++ b/scripts/olmo2_modal_openai.py @@ -33,6 +33,7 @@ # the weights from HuggingFace directly into a local directory when building the # container image. + def download_model_to_image(model_dir, model_name, model_revision): from huggingface_hub import snapshot_download from transformers.utils import move_cache @@ -46,6 +47,7 @@ def download_model_to_image(model_dir, model_name, model_revision): ) move_cache() + # ## Set up the container image # Our first order of business is to define the environment our server will run in @@ -76,7 +78,7 @@ def download_model_to_image(model_dir, model_name, model_revision): .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) .run_function( download_model_to_image, - timeout=60 * MINUTES, # typically much faster but set high to be conservative + timeout=60 * MINUTES, # typically much faster but set high to be conservative kwargs={ "model_dir": MODEL_DIR, "model_name": MODEL_NAME, @@ -100,14 +102,15 @@ def download_model_to_image(model_dir, model_name, model_revision): # app = modal.App(APP_NAME) + @app.function( image=vllm_image, gpu=GPU_CONFIG, - keep_warm=0, # Spin down entirely when idle + keep_warm=0, # Spin down entirely when idle container_idle_timeout=5 * MINUTES, timeout=24 * HOURS, allow_concurrent_inputs=1000, - secrets=[modal.Secret.from_name("example-secret-token")], # contains MODAL_TOKEN used below + secrets=[modal.Secret.from_name("example-secret-token")], # contains MODAL_TOKEN used below ) @modal.asgi_app() def serve(): @@ -144,7 +147,7 @@ def serve(): # This example uses a token defined in the Modal secret linked above, # as described here: https://modal.com/docs/guide/secrets - async def is_authenticated(api_key = fastapi.Security(http_bearer)): + async def is_authenticated(api_key=fastapi.Security(http_bearer)): if api_key.credentials != os.getenv("MODAL_TOKEN"): raise fastapi.HTTPException( status_code=fastapi.status.HTTP_401_UNAUTHORIZED, @@ -167,17 +170,13 @@ async def is_authenticated(api_key = fastapi.Security(http_bearer)): enforce_eager=False, # capture the graph for faster inference, but slower cold starts (30s > 20s) ) - engine = AsyncLLMEngine.from_engine_args( - engine_args, usage_context=UsageContext.OPENAI_API_SERVER - ) + engine = AsyncLLMEngine.from_engine_args(engine_args, usage_context=UsageContext.OPENAI_API_SERVER) model_config = get_model_config(engine) request_logger = RequestLogger(max_log_len=2048) - base_model_paths = [ - BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME) - ] + base_model_paths = [BaseModelPath(name=MODEL_NAME.split("/")[1], model_path=MODEL_NAME)] api_server.chat = lambda s: OpenAIServingChat( engine,