-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
65 lines (54 loc) · 2.01 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import click
from rich import print
from zenml.integrations.mlflow.mlflow_utils import get_tracking_uri
from configs import configs
from pipelines.test import deployment_inference_pipeline
from pipelines.train import train_and_register_model_pipeline
from steps import produce_results
from utils.utils import set_random_seed
@click.command()
@click.option(
"--config",
"-c",
type=click.Choice(
[configs.CMD_TRAIN_AND_REGISTER, configs.CMD_DEPLOY_AND_TEST, configs.CMD_EXECUTE_ALL]
),
default=configs.CMD_EXECUTE_ALL,
help="Optionally you can choose to only run specific pipelines.",
)
@click.option(
"--toml-config-file",
"-t",
default=configs.TOML_DEFAULT_FILE_NAME,
type=str,
help="Select among possible toml configs located in 'configs/specific/**/*.toml'.",
)
@click.option(
"--results",
"-r",
is_flag=True,
default=False,
help="If set, the results will be produced (metrics calculated and artifacts saved).",
)
def main(config: str, toml_config_file: str, results: bool):
set_random_seed(configs.RANDOM_SEED)
# Set the TOML config file as an environment variable (parsed in the pipelines)
os.environ[configs.TOML_ENV_NAME] = toml_config_file
do_train_and_register = config == configs.CMD_TRAIN_AND_REGISTER or config == configs.CMD_EXECUTE_ALL
do_deploy_and_test = config == configs.CMD_DEPLOY_AND_TEST or config == configs.CMD_EXECUTE_ALL
if do_train_and_register:
train_and_register_model_pipeline()
if do_deploy_and_test:
deployment_inference_pipeline()
if results:
produce_results()
print(
"\nYou can run:\n "
f"[italic green] mlflow ui --backend-store-uri {get_tracking_uri()} [/italic green]\n"
"--> to inspect your experiment runs within the MLflow UI.\n"
f"[italic green] optuna-dashboard sqlite:///{configs.DB_PATH} [/italic green]\n"
"--> to inspect your optuna study run within the Optuna dashboard.\n"
)
if __name__ == "__main__":
main()