diff --git a/docs/yaml.md b/docs/yaml.md index 42181a0848..e09686beae 100644 --- a/docs/yaml.md +++ b/docs/yaml.md @@ -34,6 +34,17 @@ build: cuda: "11.8" ``` +### `cudnn` + +Cog automatically picks the correct version of cuDNN to install, but this lets you override it for whatever reason. + +For example: + +```yaml +build: + cudnn: "9" +``` + ### `gpu` Enable GPUs for this model. When enabled, the [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) base image will be used, and Cog will automatically figure out what versions of CUDA and cuDNN to use based on the version of Python, PyTorch, and Tensorflow that you are using. @@ -47,6 +58,23 @@ build: When you use `cog run` or `cog predict`, Cog will automatically pass the `--gpus=all` flag to Docker. When you run a Docker image built with Cog, you'll need to pass this option to `docker run`. +### `pre_install` + +A list of setup commands to run in the environment before your Python packages are installed. +Do note this command is deprecated, but supported for backwards compatibility. + +For example: + +```yaml +build: + system_packages: + - cowsay + pre_install: + - "cowsay moo" + python_packages: + - torch==2.4 +``` + ### `python_packages` A list of Python packages to install from the PyPi package index, in the format `package==version`. For example: @@ -173,3 +201,30 @@ predict: "predict.py:Predictor" ``` See [the Python API documentation for more information](python.md). + +## `train` + +The pointer to the `train()` function which defines your model is trained. + +For example: + +```yaml +train: "train.py:train" +``` + +See [the train API documentation for more information](training.md). + +## `concurrency` + +The concurrency settings for the model. +- `max` (int): The maximum number of concurrent predictions. +- `default_target`: (int) The default target for number of concurrent predictions. This setting can be used by an autoscaler to determine when to scale a deployment of a model up or down. + +For example: + +```yaml +concurrency: + max: 32 + default_target: 32 +``` +