- is a numpy array as would be provided, e.g., to sklearn's predict method.
-
- Outputs:
- - A *-base.onnx file that implements state.model given state.inputs
- """
-
- def __init__(self):
- super().__init__(
- unique_name="hummingbird_conversion",
- monitor_message="Converting model to ONNX with Hummingbird",
- )
-
- def fire(self, state: build.State):
- # TODO: Temporarily inlined to avoid warning message in hummingbird-ml<=0.46.
- import hummingbird.ml # pylint: disable=import-error
- from hummingbird.ml.exceptions import ( # pylint: disable=import-error
- ConstantError,
- MissingConverter,
- MissingBackend,
- )
-
- if not is_supported_model(state.model):
- msg = f"""
- The current stage (ConvertHummingbirdModel) is only compatible with
- certain scikit-learn, xgboost, and lightgbm models, however the stage
- received an unsupported model of type {type(state.model)}.
-
- Support scikit-learn models:
- - sklearn.ensemble.ExtraTreesClassifier
- - sklearn.ensemble.GradientBoostingClassifier
- - sklearn.ensemble.IsolationForest
- - sklearn.ensemble.RandomForestClassifier
- - sklearn.ensemble.RandomForestRegressor
- - sklearn.linear_model.SGDClassifier
- - sklearn.naive_bayes.BernoulliNB
- - sklearn.naive_bayes.GaussianNB
- - sklearn.naive_bayes.MultinomialNB
- - sklearn.neighbors.KNeighborsClassifier
- - sklearn.neural_network.MLPClassifier
- - sklearn.pipeline.Pipeline
- - sklearn.preprocessing.StandardScaler
- - sklearn.svm.LinearSVC
- - sklearn.tree.DecisionTreeClassifier
-
- Supported xgboost models:
- - xgboost.XGBClassifier
- - xgboost.XGBRegressor
-
- Supported lightgbm models:
- - lightgbm.LGBMClassifier
- - lightgbm.LGBMRegressor
- """
- raise exp.StageError(msg)
-
- # TODO: By default the strategy will be chosen wih Hummingbird's logic.
- # Ideally, this would also be a parameter.
- tree_implementation_strategy = "gemm" # or "tree_trav" or "perf_tree_trav"
-
- inputs = state.inputs
- if inputs is None:
- raise exp.StageError(
- "Hummingbird conversion requires inputs to be provided,"
- " however `inputs` is None."
- )
- test_X = inputs["input_0"]
- batch_size = test_X.shape[0]
- if test_X.dtype == np.float64:
- raise exp.StageError(
- "Fitting a model with float64 inputs can cause issues"
- " with conversion and compilation. This can be corrected by changing"
- " code like model.fit(X, y) to model.fit(X.astype(numpy.float32), y)."
- )
-
- extra_config = {
- "onnx_target_opset": state.config.onnx_opset,
- "tree_implementation": tree_implementation_strategy,
- "batch_size": batch_size,
- }
-
- try:
- onnx_model = hummingbird.ml.convert(
- state.model, "onnx", test_X, extra_config=extra_config
- ).model
- except (
- RuntimeError,
- IndexError,
- ValueError,
- ConstantError,
- MissingConverter,
- MissingBackend,
- ) as e:
- raise exp.StageError(f"Hummingbird conversion failed with error: {e}")
-
- input_dims = {
- "input_0": [
- batch_size,
- onnx_model.graph.input[0].type.tensor_type.shape.dim[1].dim_value,
- ]
- }
- if len(onnx_model.graph.output) > 1:
- output_dims = {
- "variable": [batch_size],
- onnx_model.graph.output[1].name: [
- batch_size,
- onnx_model.graph.output[1].type.tensor_type.shape.dim[1].dim_value,
- ],
- }
- else:
- output_dims = {"variable": [batch_size]}
-
- # Concretize symbolic shape parameter
- onnx_model = onnx.tools.update_model_dims.update_inputs_outputs_dims(
- onnx_model, input_dims, output_dims
- )
-
- # Save output node names
- state.expected_output_names = export.get_output_names(onnx_model)
-
- output_path = export.base_onnx_file(state)
- os.makedirs(export.onnx_dir(state))
- onnx.save(onnx_model, output_path)
-
- np.save(state.original_inputs_file, state.inputs)
-
- state.intermediate_results = [output_path]
- stats = fs.Stats(state.cache_dir, state.config.build_name, state.evaluation_id)
- stats.save_model_eval_stat(
- fs.Keys.ONNX_FILE,
- output_path,
- )
-
- return state
diff --git a/src/turnkeyml/build/ignition.py b/src/turnkeyml/build/ignition.py
deleted file mode 100644
index a06d8ddc..00000000
--- a/src/turnkeyml/build/ignition.py
+++ /dev/null
@@ -1,534 +0,0 @@
-from typing import Optional, List, Tuple, Union, Dict, Any, Type, Callable
-import sys
-import os
-import copy
-import torch
-import onnx
-import turnkeyml.common.build as build
-import turnkeyml.common.filesystem as filesystem
-import turnkeyml.common.exceptions as exp
-import turnkeyml.common.printing as printing
-import turnkeyml.common.tf_helpers as tf_helpers
-import turnkeyml.build.onnx_helpers as onnx_helpers
-import turnkeyml.build.tensor_helpers as tensor_helpers
-import turnkeyml.build.export as export
-import turnkeyml.build.stage as stage
-import turnkeyml.build.hummingbird as hummingbird
-import turnkeyml.build.sequences as sequences
-from turnkeyml.version import __version__ as turnkey_version
-
-
-def lock_config(
- model: build.UnionValidModelInstanceTypes,
- build_name: Optional[str] = None,
- sequence: stage.Sequence = None,
- onnx_opset: Optional[int] = None,
- device: Optional[str] = None,
-) -> build.Config:
- """
- Process the user's configuration arguments to build_model():
- 1. Raise exceptions for illegal arguments
- 2. Replace unset arguments with default values
- 3. Lock the configuration into an immutable object
- """
-
- # The default model name is the name of the python file that calls build_model()
- auto_name = False
- if build_name is None:
- build_name = os.path.basename(sys.argv[0])
- auto_name = True
-
- if sequence is None:
- # The value ["default"] indicates that build_model() will be assigning some
- # default sequence later in the program
- stage_names = ["default"]
- else:
- stage_names = sequence.get_names()
-
- # Detect and validate ONNX opset
- if isinstance(model, str) and model.endswith(".onnx"):
- onnx_file_opset = onnx_helpers.get_opset(onnx.load(model))
-
- if onnx_opset is not None and onnx_opset != onnx_file_opset:
- raise ValueError(
- "When using a '.onnx' file as input, the onnx_opset argument must "
- "be None or exactly match the ONNX opset of the '.onnx' file. However, the "
- f"'.onnx' file has opset {onnx_file_opset}, while onnx_opset was set "
- f"to {onnx_opset}"
- )
-
- opset_to_use = onnx_file_opset
- else:
- if onnx_opset is None:
- opset_to_use = build.DEFAULT_ONNX_OPSET
- else:
- opset_to_use = onnx_opset
-
- if device is None:
- device_to_use = build.DEFAULT_DEVICE
- else:
- device_to_use = device
-
- # Store the args that should be immutable
- config = build.Config(
- build_name=build_name,
- auto_name=auto_name,
- sequence=stage_names,
- onnx_opset=opset_to_use,
- device=device_to_use,
- )
-
- return config
-
-
-def decode_version_number(version: str) -> Dict[str, int]:
- numbers = [int(x) for x in version.split(".")]
- return {"major": numbers[0], "minor": numbers[1], "patch": numbers[0]}
-
-
-def validate_cached_model(
- config: build.Config,
- model_type: build.ModelType,
- state: build.State,
- model: build.UnionValidModelInstanceTypes = None,
- inputs: Optional[Dict[str, Any]] = None,
-) -> List[str]:
- """
- Verify whether anything in the call to build_model() changed
- We require the user to resolve the discrepancy when such a
- change occurs, so the purpose of this function is simply to
- detect these conditions and raise an appropriate error.
- If this function returns without raising an exception then
- the cached model is valid to use in the build.
- """
-
- result = []
-
- current_version_decoded = decode_version_number(turnkey_version)
- state_version_decoded = decode_version_number(state.turnkey_version)
-
- out_of_date: Union[str, bool] = False
- if current_version_decoded["major"] > state_version_decoded["major"]:
- out_of_date = "major"
- elif current_version_decoded["minor"] > state_version_decoded["minor"]:
- out_of_date = "minor"
-
- if out_of_date:
- msg = (
- f"Your build {state.config.build_name} was previously built against "
- f"turnkey version {state.turnkey_version}, "
- f"however you are now using turnkey version {turnkey_version}. The previous build is "
- f"incompatible with this version of turnkey, as indicated by the {out_of_date} "
- "version number changing. See **docs/versioning.md** for details."
- )
- result.append(msg)
-
- if model is not None:
- model_changed = state.model_hash != build.hash_model(model, model_type)
- else:
- model_changed = False
-
- if inputs is not None:
- (
- input_shapes_changed,
- input_dtypes_changed,
- ) = tensor_helpers.check_shapes_and_dtypes(
- inputs,
- state.expected_input_shapes,
- state.expected_input_dtypes,
- expect_downcast=state.downcast_applied,
- raise_error=False,
- )
- else:
- input_shapes_changed = False
- input_dtypes_changed = False
-
- changed_args = []
- for key in vars(state.config):
- if vars(config)[key] != vars(state.config)[key]:
- changed_args.append((key, vars(config)[key], vars(state.config)[key]))
-
- # Show an error if the model changed
- build_conditions_changed = (
- model_changed
- or input_shapes_changed
- or input_dtypes_changed
- or len(changed_args) > 0
- )
- if build_conditions_changed:
- # Show an error if build_name is not specified for different models on the same script
- if state.uid == build.unique_id():
- msg = (
- "You are building multiple different models in the same script "
- "without specifying a unique build_model(..., build_name=) for each build."
- )
- result.append(msg)
-
- if model_changed:
- msg = (
- f'Model "{config.build_name}" changed since the last time it was built.'
- )
- result.append(msg)
-
- if input_shapes_changed:
- input_shapes, _ = build.get_shapes_and_dtypes(inputs)
- msg = (
- f'Input shape of model "{config.build_name}" changed from '
- f"{state.expected_input_shapes} to {input_shapes} "
- f"since the last time it was built."
- )
- result.append(msg)
-
- if input_dtypes_changed:
- _, input_dtypes = build.get_shapes_and_dtypes(inputs)
- msg = (
- f'Input data type of model "{config.build_name}" changed from '
- f"{state.expected_input_dtypes} to {input_dtypes} "
- f"since the last time it was built."
- )
- result.append(msg)
-
- if len(changed_args) > 0:
- for key_name, current_arg, previous_arg in changed_args:
- msg = (
- f'build_model() argument "{key_name}" for build '
- f"{config.build_name} changed from "
- f"{previous_arg} to {current_arg} since the last build."
- )
- result.append(msg)
- else:
- if (
- state.build_status == build.FunctionStatus.ERROR
- or state.build_status == build.FunctionStatus.INCOMPLETE
- or state.build_status == build.FunctionStatus.KILLED
- ) and turnkey_version == state.turnkey_version:
- msg = (
- "build_model() has detected that you already attempted building "
- "this model with the exact same model, inputs, options, and version of "
- "turnkey, and that build failed."
- )
- result.append(msg)
-
- return result
-
-
-def _begin_fresh_build(
- state_args: Dict,
- state_type: Type = build.State,
-) -> build.State:
- # Wipe everything in this model's build directory, except for the stats file,
- # start with a fresh State.
- stats = filesystem.Stats(state_args["cache_dir"], state_args["config"].build_name)
-
- build_dir = build.output_dir(
- state_args["cache_dir"], state_args["config"].build_name
- )
-
- filesystem.rmdir(
- build_dir,
- excludes=[
- stats.file,
- os.path.join(build_dir, filesystem.BUILD_MARKER),
- ],
- )
- state = state_type(**state_args)
- state.save()
-
- return state
-
-
-def _rebuild_if_needed(
- problem_report: str, state_args: Dict, state_type: Type = build.State
-):
- build_name = state_args["config"].build_name
- msg = (
- f"build_model() discovered a cached build of {build_name}, but decided to "
- "rebuild for the following reasons: \n\n"
- f"{problem_report} \n\n"
- "build_model() will now rebuild your model to ensure correctness. You can change this "
- "policy by setting the build_model(rebuild=...) argument."
- )
- printing.log_warning(msg)
-
- return _begin_fresh_build(state_args, state_type=state_type)
-
-
-def load_or_make_state(
- config: build.Config,
- evaluation_id: str,
- cache_dir: str,
- rebuild: str,
- model_type: build.ModelType,
- monitor: bool,
- model: build.UnionValidModelInstanceTypes = None,
- inputs: Optional[Dict[str, Any]] = None,
- state_type: Type = build.State,
- cache_validation_func: Callable = validate_cached_model,
- extra_state_args: Optional[Dict] = None,
-) -> build.State:
- """
- Decide whether we can load the model from the model cache
- (return a valid State instance) or whether we need to rebuild it (return
- a new State instance).
- """
-
- # Put all the args for making a new State instance into a dict
- # to help the following code be cleaner
- state_args = {
- "model": model,
- "inputs": inputs,
- "monitor": monitor,
- "rebuild": rebuild,
- "evaluation_id": evaluation_id,
- "cache_dir": cache_dir,
- "config": config,
- "model_type": model_type,
- }
-
- # Ensure that `rebuild` has a valid value
- if rebuild not in build.REBUILD_OPTIONS:
- raise ValueError(
- f"Received `rebuild` argument with value {rebuild}, "
- f"however the only allowed values of `rebuild` are {build.REBUILD_OPTIONS}"
- )
-
- # Allow customizations of turnkey to supply additional args
- if extra_state_args is not None:
- state_args.update(extra_state_args)
-
- if rebuild == "always":
- return _begin_fresh_build(state_args, state_type)
- else:
- # Try to load state and check if model successfully built before
- if os.path.isfile(build.state_file(cache_dir, config.build_name)):
- try:
- state = build.load_state(
- cache_dir,
- config.build_name,
- state_type=state_type,
- )
-
- except exp.StateError as e:
- problem = (
- "- build_model() failed to load "
- f"{build.state_file(cache_dir, config.build_name)}"
- )
-
- if rebuild == "if_needed":
- return _rebuild_if_needed(problem, state_args, state_type)
- else:
- # Give the rebuild="never" users a chance to address the problem
- raise exp.CacheError(e)
-
- if (
- model_type == build.ModelType.UNKNOWN
- and state.build_status == build.FunctionStatus.SUCCESSFUL
- ):
- msg = (
- "Model caching is disabled for successful builds against custom Sequences. "
- "Your model will rebuild whenever you call build_model() on it."
- )
- printing.log_warning(msg)
-
- return _begin_fresh_build(state_args, state_type)
- elif (
- model_type == build.ModelType.UNKNOWN
- and state.build_status == build.FunctionStatus.INCOMPLETE
- ):
- msg = (
- f"Model {config.build_name} was partially built in a previous call to "
- "build_model(). This call to build_model() found that partial build and "
- "is loading it from the build cache."
- )
-
- printing.log_info(msg)
- else:
- cache_problems = cache_validation_func(
- config=config,
- model_type=model_type,
- state=state,
- model=model,
- inputs=inputs,
- )
-
- if len(cache_problems) > 0:
- cache_problems = [f"- {msg}" for msg in cache_problems]
- problem_report = "\n".join(cache_problems)
-
- if rebuild == "if_needed":
- return _rebuild_if_needed(
- problem_report, state_args, state_type
- )
- if rebuild == "never":
- msg = (
- "build_model() discovered a cached build of "
- f"{config.build_name}, and found that it "
- "is likely invalid for the following reasons: \n\n"
- f"{problem_report} \n\n"
- "build_model() will raise a SkipBuild exception because you have "
- "set rebuild=never. "
- )
- printing.log_warning(msg)
-
- raise exp.SkipBuild(
- "Skipping this build, by raising an exception, because it previously "
- "failed and the `rebuild` argument is set to `never`."
- )
-
- # Ensure the model and inputs are part of the state
- # This is useful when loading models that still need to be built
- state.save_when_setting_attribute = False
- if state.model is None:
- state.model = model
- if state.inputs is None:
- state.inputs = inputs
- state.save_when_setting_attribute = True
-
- return state
-
- else:
- # No state file found, so we have to build
- return _begin_fresh_build(state_args, state_type)
-
-
-export_map = {
- build.ModelType.PYTORCH: export.ExportPytorchModel(),
- build.ModelType.KERAS: export.ExportKerasModel(),
- build.ModelType.ONNX_FILE: export.ReceiveOnnxModel(),
- build.ModelType.HUMMINGBIRD: hummingbird.ConvertHummingbirdModel(),
-}
-
-
-def validate_inputs(inputs: Dict):
- """
- Check the model's inputs and make sure they are legal. Raise an exception
- if they are not legal.
- TODO: it may be wise to validate the inputs against the model, or at least
- the type of model, as well.
- """
-
- if inputs is None:
- msg = """
- build_model() requires model inputs. Check your call to build_model() to make sure
- you are passing the inputs argument.
- """
- raise exp.IntakeError(msg)
-
- if not isinstance(inputs, dict):
- msg = f"""
- The "inputs" argument to build_model() is required to be a dictionary, where the
- keys map to the named arguments in the model's forward function. The inputs
- received by build_model() were of type {type(inputs)}, not dict.
- """
- raise exp.IntakeError(msg)
-
-
-def identify_model_type(model) -> build.ModelType:
- # Validate that the model's type is supported by build_model()
- # and assign a ModelType tag
- if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)):
- model_type = build.ModelType.PYTORCH
- elif isinstance(model, str):
- if model.endswith(".onnx"):
- model_type = build.ModelType.ONNX_FILE
- elif tf_helpers.is_keras_model(model):
- model_type = build.ModelType.KERAS
- if not tf_helpers.is_executing_eagerly():
- raise exp.IntakeError(
- "`build_model()` requires Keras models to be run in eager execution mode. "
- "Enable eager execution to continue."
- )
- if not model.built:
- raise exp.IntakeError(
- "Keras model has not been built. Please call "
- "model.build(input_shape) before running build_model()"
- )
- elif hummingbird.is_supported_model(model):
- model_type = build.ModelType.HUMMINGBIRD
- else:
- raise exp.IntakeError(
- "Argument 'model' passed to build_model() is "
- f"of unsupported type {type(model)}"
- )
-
- return model_type
-
-
-def model_intake(
- user_model,
- user_inputs,
- user_sequence: Optional[stage.Sequence],
-) -> Tuple[Any, Any, stage.Sequence, build.ModelType, str]:
- # Model intake structure options:
- # user_model
- # |
- # |------- path to onnx model file
- # |
- # |------- pytorch model object
- # |
- # |------- keras model object
- # |
- # |------- Hummingbird-supported model object
-
- if user_sequence is None or user_sequence.enable_model_validation:
- if user_model is None and user_inputs is None:
- msg = """
- You are running build_model() without any model, inputs, or custom Sequence. The purpose
- of non-customized build_model() is to build a model against some inputs, so you need to
- provide both.
- """
- raise exp.IntakeError(msg)
-
- # Make sure that if the model is a file path, it is valid
- if isinstance(user_model, str):
- if not os.path.isfile(user_model):
- msg = f"""
- build_model() model argument was passed a string (path to a model file),
- however no file was found at {user_model}.
- """
- raise exp.IntakeError(msg)
-
- if not user_model.endswith(".onnx"):
- msg = f"""
- build_model() received a model argument that was a string. However, model string
- arguments are required to be a path to a .onnx file, but the argument was: {user_model}
- """
- raise exp.IntakeError(msg)
-
- # Create dummy inputs based on the ONNX spec, if none were provided by the user
- if user_inputs is None:
- inputs = onnx_helpers.dummy_inputs(user_model)
- else:
- inputs = user_inputs
- else:
- inputs = user_inputs
-
- model_type = identify_model_type(user_model)
-
- sequence = copy.deepcopy(user_sequence)
- if sequence is None:
- sequence = stage.Sequence(
- "top_level_sequence",
- "Top Level Sequence",
- [sequences.onnx_fp32],
- )
-
- # If there is an ExportPlaceholder Stage in the sequence, replace it with
- # a framework-specific export Stage.
- # First, make a deepcopy of any sequence we bring in here. We do not want to modify
- # the original.
- sequence = copy.deepcopy(sequence)
- for index, stage_instance in enumerate(sequence.stages):
- if isinstance(stage_instance, export.ExportPlaceholder):
- sequence.stages[index] = export_map[model_type]
-
- validate_inputs(inputs)
-
- else:
- # We turn off a significant amount of automation and validation
- # to provide custom stages and sequences with maximum flexibility
- inputs = user_inputs
- sequence = user_sequence
- model_type = build.ModelType.UNKNOWN
-
- return (user_model, inputs, sequence, model_type)
diff --git a/src/turnkeyml/build/sequences.py b/src/turnkeyml/build/sequences.py
deleted file mode 100644
index abeb159c..00000000
--- a/src/turnkeyml/build/sequences.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import turnkeyml.build.export as export
-import turnkeyml.build.stage as stage
-import turnkeyml.common.plugins as plugins
-
-optimize_fp16 = stage.Sequence(
- "optimize_fp16",
- "Optimized FP16 ONNX file",
- [
- export.ExportPlaceholder(),
- export.OptimizeOnnxModel(),
- export.ConvertOnnxToFp16(),
- ],
- enable_model_validation=True,
-)
-
-optimize_fp32 = stage.Sequence(
- "optimize_fp32",
- "Optimized FP32 ONNX File",
- [
- export.ExportPlaceholder(),
- export.OptimizeOnnxModel(),
- ],
- enable_model_validation=True,
-)
-
-onnx_fp32 = stage.Sequence(
- "onnx_fp32",
- "Base Sequence",
- [
- export.ExportPlaceholder(),
- ],
- enable_model_validation=True,
-)
-
-# Plugin interface for sequences
-discovered_plugins = plugins.discover()
-
-# Populated supported sequences dict with builtin sequences
-SUPPORTED_SEQUENCES = {
- "optimize-fp16": optimize_fp16,
- "optimize-fp32": optimize_fp32,
- "onnx-fp32": onnx_fp32,
-}
-
-# Add sequences from plugins to supported sequences dict
-for module in discovered_plugins.values():
- if "sequences" in module.implements.keys():
- for seq_name, seq_info in module.implements["sequences"].items():
- if seq_name in SUPPORTED_SEQUENCES:
- raise ValueError(
- f"Your turnkeyml installation has two sequences named '{seq_name}' "
- "installed. You must uninstall one of your plugins that includes "
- f"{seq_name}. Your imported sequence plugins are: {SUPPORTED_SEQUENCES}\n"
- f"This error was thrown while trying to import {module}"
- )
-
- SUPPORTED_SEQUENCES[seq_name] = seq_info["sequence_instance"]
diff --git a/src/turnkeyml/build/stage.py b/src/turnkeyml/build/stage.py
deleted file mode 100644
index 21b48166..00000000
--- a/src/turnkeyml/build/stage.py
+++ /dev/null
@@ -1,391 +0,0 @@
-import abc
-import sys
-import time
-import os
-import copy
-from typing import List, Tuple
-from multiprocessing import Process
-import psutil
-import turnkeyml.common.printing as printing
-import turnkeyml.common.exceptions as exp
-import turnkeyml.common.build as build
-import turnkeyml.common.filesystem as fs
-
-
-def _spinner(message):
- try:
- parent_process = psutil.Process(pid=os.getppid())
- while parent_process.status() == psutil.STATUS_RUNNING:
- for cursor in [" ", ". ", ".. ", "..."]:
- time.sleep(0.5)
- status = f" {message}{cursor}\r"
- sys.stdout.write(status)
- sys.stdout.flush()
- except psutil.NoSuchProcess:
- # If the parent process stopped existing, we can
- # safely assume the spinner no longer needs to spin
- # NOTE: this only seems to be needed on Windows
- pass
-
-
-def _name_is_file_safe(name: str):
- """
- Make sure the name can be used in a filename
- """
-
- allowed_in_unique_name = set(
- "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"
- )
-
- if len(name) == 0:
- msg = """
- Stage __init__() was passed a unique_name with no length. A
- uniquely identifying unique_name is required.
- """
- raise ValueError(msg)
-
- for char in name:
- if char not in allowed_in_unique_name:
- msg = f"""
- Stage __init__() was passed a unique_name:
- {name}
- with illegal characters. The unique_name must be safe to
- use in a filename, meaning it can only use characters: {allowed_in_unique_name}
- """
- raise ValueError(msg)
-
-
-class Stage(abc.ABC):
- def status_line(self, successful, verbosity):
- """
- Print a line of status information for this Stage into the monitor.
- """
- if verbosity:
- # Only use special characters when the terminal encoding supports it
- if sys.stdout.encoding == "utf-8":
- success_tick = "✓"
- fail_tick = "×"
- else:
- success_tick = "+"
- fail_tick = "x"
-
- if successful is None:
- # Initialize the message
- printing.logn(f" {self.monitor_message} ")
- elif successful:
- # Print success message
- printing.log(f" {success_tick} ", c=printing.Colors.OKGREEN)
- printing.logn(self.monitor_message + " ")
- else:
- # successful == False, print failure message
- printing.log(f" {fail_tick} ", c=printing.Colors.FAIL)
- printing.logn(self.monitor_message + " ")
-
- def __init__(
- self,
- unique_name,
- monitor_message,
- ):
- _name_is_file_safe(unique_name)
-
- self.unique_name = unique_name
- self.status_key = f"{fs.Keys.STAGE_STATUS}:{unique_name}"
- self.duration_key = f"{fs.Keys.STAGE_DURATION}:{unique_name}"
- self.monitor_message = monitor_message
- self.progress = None
- self.logfile_path = None
- self.stages = None
-
- @abc.abstractmethod
- def fire(self, state: build.State) -> build.State:
- """
- Developer-defined function to fire the stage.
- In less punny terms, this is the function that
- build_model() will run to implement a model-to-model
- transformation on the flow to producing a Model.
- """
-
- def fire_helper(self, state: build.State) -> Tuple[build.State, int]:
- """
- Wraps the user-defined .fire method with helper functionality.
- Specifically:
- - Provides a path to a log file
- - Redirects the stdout of the stage to that log file
- - Monitors the progress of the stage on the command line,
- including in the event of an exception
- """
-
- # Set the build status to INCOMPLETE to indicate that a Stage
- # started running. This allows us to test whether the Stage exited
- # unexpectedly, before it was able to set ERROR
- state.build_status = build.FunctionStatus.INCOMPLETE
-
- self.logfile_path = os.path.join(
- build.output_dir(state.cache_dir, state.config.build_name),
- f"log_{self.unique_name}.txt",
- )
-
- if state.monitor:
- self.progress = Process(target=_spinner, args=[self.monitor_message])
- self.progress.start()
-
- try:
- # Execute the build stage
- with build.Logger(self.monitor_message, self.logfile_path):
- state = self.fire(state)
-
- except exp.StageError:
- self.status_line(
- successful=False,
- verbosity=state.monitor,
- )
- state.build_status = build.FunctionStatus.ERROR
- raise
-
- else:
- self.status_line(successful=True, verbosity=state.monitor)
-
- # Stages should not set build.FunctionStatus.SUCCESSFUL for the whole build,
- # as that is reserved for Sequence.launch()
- if state.build_status == build.FunctionStatus.SUCCESSFUL:
- raise exp.StageError(
- "TurnkeyML Stages are not allowed to set "
- "`state.build_status == build.FunctionStatus.SUCCESSFUL`, "
- "however that has happened. If you are a plugin developer, "
- "do not do this. If you are a user, please file an issue at "
- "https://github.com/onnx/turnkeyml/issues."
- )
-
- finally:
- if state.monitor:
- self.progress.terminate()
-
- return state
-
- def get_names(self) -> List[str]:
- """
- Sequence uses self.names() to recursively get the names of all
- Stages in the Sequence. An individual Stage just needs to return
- its own name.
- """
- if self.stages is None:
- return [self.unique_name]
- else:
- result = []
- for stage in self.stages:
- result = result + stage.get_names()
-
- return result
-
- def get_depth(self) -> int:
- """
- Sequence needs to know the depth of each Stage within the Sequence in order
- to properly update the terminal UI. An individual Stage just needs to return
- the value 1.
- """
- if self.stages is None:
- return 1
- else:
- count = 0
- for stage in self.stages:
- count = count + stage.get_depth()
- return count
-
-
-def _rewind_stdout(lines: int):
- """
- Helper function for the command line monitor. Moves the cursor up a
- certain number of lines in the terminal, corresponding to the
- status line for a Stage, so that we can update the status of
- that Stage.
- """
- rewind_stdout_one_line = "\033[1A"
- rewind_multiple_lines = rewind_stdout_one_line * lines
- print(rewind_multiple_lines, end="")
-
-
-def unroll_stages(stages):
- """
- Recursively goes through all sequences and returns list of stages
- """
-
- unrolled_stages = []
- for stage in stages:
- if isinstance(stage, Sequence):
- unrolled_stages += unroll_stages(stage.stages)
- else:
- unrolled_stages += [stage]
- return unrolled_stages
-
-
-class Sequence(Stage):
- def __init__(
- self,
- unique_name,
- monitor_message,
- stages: List[Stage],
- enable_model_validation=False,
- ):
- super().__init__(unique_name, monitor_message)
-
- # The `stages` argument can be a nested Sequence (ie, Sequence of Sequence of Stage).
- # Unroll the stages to make the Sequence easier to deal with
- self.stages = unroll_stages(stages)
-
- # Follow default model validation steps in ignition.model_intake()
- self.enable_model_validation = enable_model_validation
-
- # Make sure all the stage names are unique
- stage_names = self.get_names()
-
- if len(stage_names) != len(set(stage_names)):
- msg = f"""
- All Stages in a Sequence must have unique unique_names, however Sequence
- received duplicates in the list of names: {stage_names}
- """
- raise ValueError(msg)
-
- def show_monitor(self, config: build.Config, verbosity: bool):
- """
- Displays the monitor on the terminal. The purpose of the monitor
- is to show the status of each stage (success, failure, not started yet,
- or in-progress).
- """
-
- if verbosity:
- print("\n\n")
-
- printing.logn(
- f'Building "{config.build_name}"',
- c=printing.Colors.BOLD,
- )
-
- for stage in self.stages:
- stage.status_line(successful=None, verbosity=True)
-
- _rewind_stdout(self.get_depth())
-
- def launch(self, state: build.State) -> build.State:
- """
- Executes a launch sequence.
- In less punny terms, this method is called by the top-level
- build_model() function to iterate over all of the Stages required for a build.
- Builds are defined by self.stages in a top-level Sequence, and self.stages
- can include both Stages and Sequences (ie, sequences can be nested).
- """
-
- if state.build_status == build.FunctionStatus.SUCCESSFUL:
- msg = """
- build_model() is running a build on a model that already built successfully, which
- should not happen because the build should have loaded from cache or rebuilt from scratch.
- If you are using custom Stages and Sequences then you have some debugging to do. Otherwise,
- please file an issue at https://github.com/onnx/turnkeyml/issues
- """
- raise exp.Error(msg)
-
- # Collect telemetry for the build
- stats = fs.Stats(state.cache_dir, state.config.build_name, state.evaluation_id)
- stats.save_model_eval_stat(
- fs.Keys.SELECTED_SEQUENCE_OF_STAGES,
- self.get_names(),
- )
-
- # At the beginning of a sequence no stage has started
- for stage in self.stages:
- stats.save_model_eval_stat(
- stage.status_key, build.FunctionStatus.NOT_STARTED.value
- )
- stats.save_model_eval_stat(stage.duration_key, "-")
-
- # Run the build
- for stage in self.stages:
- start_time = time.time()
-
- try:
-
- # Set status as incomplete, since stage just started
- stats.save_model_eval_stat(
- stage.status_key, build.FunctionStatus.INCOMPLETE.value
- )
-
- # Collect telemetry about the stage
- state.current_build_stage = stage.unique_name
-
- # Run the stage
- state = stage.fire_helper(state)
-
- # Broad exception is desirable as we want to capture
- # all exceptions (including those we can't anticipate)
- except Exception as e: # pylint: disable=broad-except
-
- # Update Stage Status
- stats.save_model_eval_stat(
- stage.status_key, build.FunctionStatus.ERROR.value
- )
-
- # Save the log file for the failed stage to stats for easy reference
- stats.save_eval_error_log(stage.logfile_path)
-
- # Advance the cursor below the monitor so
- # we can print an error message
- stage_depth_in_sequence = self.get_depth() - self.get_names().index(
- stage.unique_name # pylint: disable=undefined-loop-variable
- )
- stdout_lines_to_advance = stage_depth_in_sequence - 2
- cursor_down = "\n" * stdout_lines_to_advance
-
- print(cursor_down)
-
- printing.log_error(e)
-
- raise
-
- else:
- # Update Stage Status
- stats.save_model_eval_stat(
- stage.status_key, build.FunctionStatus.SUCCESSFUL.value
- )
-
- finally:
- # Store stage duration
- execution_time = time.time() - start_time
- stats.save_model_eval_stat(stage.duration_key, execution_time)
-
- state.current_build_stage = None
- state.build_status = build.FunctionStatus.SUCCESSFUL
-
- # We use a deepcopy here because the Stage framework supports
- # intermediate_results of any type, including model objects in memory.
- # The deepcopy ensures that we are providing a result that users
- # are free to take any action with.
- state.results = copy.deepcopy(state.intermediate_results)
-
- return state
-
- def status_line(self, successful, verbosity):
- """
- This override of status_line simply propagates status_line()
- to every Stage in the Sequence
- FIXME: A cleaner implementation of Stage/Sequence might not need this
- """
- for stage in self.stages:
- stage.status_line(successful=None, verbosity=verbosity)
-
- def fire(self, state: build.State) -> build.State:
- """
- This override of fire simply propagates fire()
- to every Stage in the Sequence
- FIXME: A cleaner implementation of Stage/Sequence might not need this
- """
- for stage in self.stages:
- state = stage.fire_helper(state)
-
- return state
-
- def fire_helper(self, state: build.State) -> build.State:
- """
- Sequence doesn't need any help calling self.fire(), so it's fire_helper
- is just to call self.fire()
- FIXME: A cleaner implementation of Stage/Sequence might not need this
- """
- return self.fire(state)
diff --git a/src/turnkeyml/build_api.py b/src/turnkeyml/build_api.py
deleted file mode 100644
index 95272661..00000000
--- a/src/turnkeyml/build_api.py
+++ /dev/null
@@ -1,125 +0,0 @@
-import os
-from typing import Optional, List, Dict, Any
-import turnkeyml.build.ignition as ignition
-import turnkeyml.build.stage as stage
-import turnkeyml.common.printing as printing
-import turnkeyml.common.build as build
-import turnkeyml.common.filesystem as filesystem
-
-
-def build_model(
- model: build.UnionValidModelInstanceTypes = None,
- inputs: Optional[Dict[str, Any]] = None,
- build_name: Optional[str] = None,
- evaluation_id: Optional[str] = "build",
- cache_dir: str = filesystem.DEFAULT_CACHE_DIR,
- monitor: Optional[bool] = None,
- rebuild: Optional[str] = None,
- sequence: Optional[List[stage.Stage]] = None,
- onnx_opset: Optional[int] = None,
- device: Optional[str] = None,
-) -> build.State:
- """Use build a model instance into an optimized ONNX file.
-
- Args:
- model: Model to be mapped to an optimized ONNX file, which can be a PyTorch
- model instance, Keras model instance, Hummingbird model instance,
- or a path to an ONNX file.
- inputs: Example inputs to the user's model. The ONNX file will be
- built to handle inputs with the same static shape only.
- build_name: Unique name for the model that will be
- used to store the ONNX file and build state on disk. Defaults to the
- name of the file that calls build_model().
- evaluation_id: Unique name for evaluation statistics that should persist across multiple
- builds of the same model.
- cache_dir: Directory to use as the cache for this build. Output files
- from this build will be stored at cache_dir/build_name/
- Defaults to the current working directory, but we recommend setting it to
- an absolute path of your choosing.
- monitor: Display a monitor on the command line that
- tracks the progress of this function as it builds the ONNX file.
- rebuild: determines whether to rebuild or load a cached build. Options:
- - "if_needed" (default): overwrite invalid cached builds with a rebuild
- - "always": overwrite valid cached builds with a rebuild
- - "never": load cached builds without checking validity, with no guarantee
- of functionality or correctness
- - None: Falls back to default
- sequence: Override the default sequence of build stages. Power
- users only.
- onnx_opset: ONNX opset to use during ONNX export.
- device: Specific device target to take into account during the build sequence.
- Use the format "device_family", "device_family::part", or
- "device_family::part::configuration" to refer to a family of devices,
- part within a family, or configuration of a part model, respectively.
-
- More information is available in the Tools User Guide:
- https://github.com/onnx/turnkeyml/blob/main/docs/tools_user_guide.md
- """
-
- # Allow monitor to be globally disabled by an environment variable
- if monitor is None:
- if os.environ.get("TURNKEY_BUILD_MONITOR") == "False":
- monitor_setting = False
- else:
- monitor_setting = True
- else:
- monitor_setting = monitor
-
- # Support "~" in the cache_dir argument
- parsed_cache_dir = os.path.expanduser(cache_dir)
-
- # Validate and lock in the config (user arguments that
- # configure the build) that will be used by the rest of the toolchain
- config = ignition.lock_config(
- model=model,
- build_name=build_name,
- sequence=sequence,
- onnx_opset=onnx_opset,
- device=device,
- )
-
- # Analyze the user's model argument and lock in the model, inputs,
- # and sequence that will be used by the rest of the toolchain
- (
- model_locked,
- inputs_locked,
- sequence_locked,
- model_type,
- ) = ignition.model_intake(
- model,
- inputs,
- sequence,
- )
-
- # Get the state of the model from the cache if a valid build is available
- state = ignition.load_or_make_state(
- config=config,
- evaluation_id=evaluation_id,
- cache_dir=parsed_cache_dir,
- rebuild=rebuild or build.DEFAULT_REBUILD_POLICY,
- model_type=model_type,
- monitor=monitor_setting,
- model=model_locked,
- inputs=inputs_locked,
- )
-
- # Return a cached build if possible, otherwise prepare the model State for
- # a build
- if state.build_status == build.FunctionStatus.SUCCESSFUL:
- # Successful builds can be loaded from cache and returned with
- # no additional steps
- additional_msg = " (build_name auto-selected)" if config.auto_name else ""
- printing.log_success(
- f' Build "{config.build_name}"{additional_msg} found in cache. Loading it!',
- )
-
- return state
-
- sequence_locked.show_monitor(config, state.monitor)
- state = sequence_locked.launch(state)
-
- printing.log_success(
- f"\n Saved to **{build.output_dir(state.cache_dir, config.build_name)}**"
- )
-
- return state
diff --git a/src/turnkeyml/cli/cli.py b/src/turnkeyml/cli/cli.py
index 0718f6c4..f19a108b 100644
--- a/src/turnkeyml/cli/cli.py
+++ b/src/turnkeyml/cli/cli.py
@@ -1,267 +1,143 @@
import argparse
-import os
import sys
-import copy
+import os
from difflib import get_close_matches
-import turnkeyml.common.build as build
-import turnkeyml.common.exceptions as exceptions
-import turnkeyml.common.filesystem as filesystem
-import turnkeyml.cli.report as report
-import turnkeyml.cli.parser_helpers as parser_helpers
-from turnkeyml.files_api import benchmark_files
-from turnkeyml.version import __version__ as turnkey_version
-from turnkeyml.run.devices import SUPPORTED_DEVICES, SUPPORTED_RUNTIMES
-from turnkeyml.build.sequences import SUPPORTED_SEQUENCES
+from typing import List
+import turnkeyml.common.filesystem as fs
+from turnkeyml.sequence import Sequence
+from turnkeyml.tools import Tool, FirstTool, NiceHelpFormatter
+from turnkeyml.sequence.tool_plugins import SUPPORTED_TOOLS
from turnkeyml.cli.spawn import DEFAULT_TIMEOUT_SECONDS
-from turnkeyml.run.benchmark_build import benchmark_cache_cli
-from turnkeyml.analyze.status import Verbosity
-
-
-class MyParser(argparse.ArgumentParser):
- def error(self, message):
- sys.stderr.write(f"error: {message}\n\n")
- sys.stderr.write(f"Run '{self.prog} --help' for more information\n\n")
- self.print_usage(sys.stderr)
- sys.exit(2)
-
-
-def print_version(_):
- """
- Print the package version number
- """
- print(turnkey_version)
-
-
-def print_stats(args):
- state_path = build.state_file(args.cache_dir, args.build_name)
- filesystem.print_yaml_file(state_path, "build state")
-
- filesystem.print_yaml_file(
- filesystem.Stats(args.cache_dir, args.build_name).file, "stats"
- )
-
+from turnkeyml.files_api import evaluate_files
+import turnkeyml.common.printing as printing
+from turnkeyml.tools.management_tools import ManagementTool
-def benchmark_command(args):
- """
- Map the argparse args into benchmark_files() arguments
- Assumes the following rules:
- - All args passed to a "benchmark" command should be forwarded to the benchmark_files()
- API, except as explicitly handled below.
- - The "dest" names of all CLI args must exactly match the names of the corresponding API arg
- """
+class CustomArgumentParser(argparse.ArgumentParser):
- api_args = copy.deepcopy(vars(args))
-
- # Remove the function ID because it was only used to get us into this method
- api_args.pop("func")
-
- # Decode CLI arguments before calling the API
- api_args["rt_args"] = parser_helpers.decode_args(api_args["rt_args"])
-
- benchmark_files(**api_args)
+ def error(self, message):
+ self.print_usage()
+ printing.log_error(message)
+ self.exit(2)
-def main():
- """
- Parses arguments passed by user and forwards them into a
- command function
- """
-
- parser = MyParser(
- description="TurnkeyML benchmarking command line interface",
- formatter_class=argparse.RawTextHelpFormatter,
- )
+def _tool_list_help(tools: List[Tool], subclass, exclude=None) -> str:
+ help = ""
- # We use sub-parsers to keep the help info neatly organized for each command
- # Sub-parses also allow us to set command-specific help on options like --cache-dir
- # that are used in multiple commands
+ for tool_class in tools:
+ if exclude and issubclass(tool_class, exclude):
+ continue
+ if issubclass(tool_class, subclass):
+ help = (
+ help
+ + f" * {tool_class.unique_name}: {tool_class.parser().short_description}\n"
+ )
- subparsers = parser.add_subparsers(
- title="command",
- help="Choose one of the following commands:",
- metavar="COMMAND",
- required=True,
- )
+ return help
- #######################################
- # Parser for the "benchmark" command
- #######################################
- def check_extension(choices, file_name, error_func):
- _, extension = os.path.splitext(file_name.split("::")[0])
- if extension[1:].lower() not in choices:
+def _check_extension(
+ choices: List[str], file_name: str, error_func: callable, tool_names: List[str]
+):
+ _, extension = os.path.splitext(file_name.split("::")[0])
+ if not extension:
+ close_matches = get_close_matches(file_name, tool_names)
+ if close_matches:
+ # Misspelled tool names can be picked up as input files, so we check
+ # for this case here and try to provide a better suggestion
error_func(
- f"input_files must end with .py, .onnx, or .txt (got '{file_name}')\n"
+ f"unrecognized argument '{file_name}', did you mean '{close_matches[0]}'?"
)
- return file_name
-
- benchmark_parser = subparsers.add_parser(
- "benchmark",
- help="Benchmark the performance of one or more models",
- description="Analyze, build, and then benchmark the model(s) within input file(s).",
- )
- benchmark_parser.set_defaults(func=benchmark_command)
-
- benchmark_parser.add_argument(
- "input_files",
- nargs="+",
- help="One or more script (.py), ONNX (.onnx), or input list (.txt) files to be benchmarked",
- type=lambda file: check_extension(
- ("py", "onnx", "txt"), file, benchmark_parser.error
- ),
- )
-
- toolchain_select_group = benchmark_parser.add_argument_group(
- "Select which phase(s) of the toolchain to run "
- "(default is to run analyze, build, and benchmark)"
- )
+ else:
+ error_func(
+ f"{file_name} was recognized as an argument to `--input-files`, "
+ "however it is not a file name (no file extension). If it was "
+ "meant to be a tool name, please check whether that tool is "
+ "available and correctly spelled in the list of available tools "
+ "when calling `turnkey -h`."
+ )
+ if extension[1:].lower() not in choices:
+ error_func(
+ f"input_files must end with .py, .onnx, or .txt (got '{file_name}')\n"
+ )
+ return file_name
- toolchain_select_group.add_argument(
- "-a",
- "--analyze-only",
- dest="analyze_only",
- help="Stop this command after the analyze phase",
- action="store_true",
- )
- toolchain_select_group.add_argument(
- "-b",
- "--build-only",
- dest="build_only",
- help="Stop this command after the analyze and build phases",
- action="store_true",
- )
+def main():
- analyze_group = benchmark_parser.add_argument_group(
- "Options that specifically apply to the `analyze` phase of the toolflow"
- )
+ tool_parsers = {tool.unique_name: tool.parser() for tool in SUPPORTED_TOOLS}
+ tool_classes = {tool.unique_name: tool for tool in SUPPORTED_TOOLS}
- analyze_group.add_argument(
- "--labels",
- dest="labels",
- help="Only benchmark the scripts that have the provided labels",
- nargs="*",
- default=[],
+ # Define the argument parser
+ parser = CustomArgumentParser(
+ description="This utility runs tools in a sequence. "
+ "To use it, provide a list of tools and "
+ "their arguments. See "
+ "https://github.com/onnx/turnkeyml/blob/main/docs/tools_user_guide.md "
+ "to learn the exact syntax.\n\nExample: turnkey -i my_model.py discover export-pytorch",
+ formatter_class=NiceHelpFormatter,
)
- analyze_group.add_argument(
- "--script-args",
- dest="script_args",
- type=str,
- help="Arguments to pass into the target script(s)",
- )
+ # Sort tools into categories and format for the help menu
+ first_tool_choices = _tool_list_help(SUPPORTED_TOOLS, FirstTool)
+ eval_tool_choices = _tool_list_help(SUPPORTED_TOOLS, Tool, exclude=FirstTool)
+ mgmt_tool_choices = _tool_list_help(SUPPORTED_TOOLS, ManagementTool)
- analyze_group.add_argument(
- "--max-depth",
- dest="max_depth",
- type=int,
- default=0,
- help="Maximum depth to analyze within the model structure of the target script(s)",
- )
+ tools_action = parser.add_argument(
+ "tools",
+ metavar="tool --tool-args [tool --tool-args...]",
+ nargs="?",
+ help=f"""\
+Available tools that can be sequenced together to perform a build.
- both_build_benchmark_group = benchmark_parser.add_argument_group(
- "Options that apply to both the `build` and `benchmark` phases of the toolflow"
- )
+Call `turnkey TOOL -h` to learn more about each tool.
- benchmark_default_device = "x86"
- both_build_benchmark_group.add_argument(
- "--device",
- choices=SUPPORTED_DEVICES,
- dest="device",
- help="Type of hardware device to be used for the benchmark "
- f'(defaults to "{benchmark_default_device}")',
- required=False,
- default=benchmark_default_device,
+Tools that can start a sequence:
+{first_tool_choices}
+Tools that go into a sequence:
+{eval_tool_choices}
+Management tool choices:
+{mgmt_tool_choices}""",
+ choices=tool_parsers.keys(),
)
- both_build_benchmark_group.add_argument(
- "--runtime",
- choices=SUPPORTED_RUNTIMES.keys(),
- dest="runtime",
- help="Software runtime that will be used to collect the benchmark. "
- "Must be compatible with the selected device. "
- "Automatically selects a sequence if `--sequence` is not used. "
- "If this argument is not set, the default runtime of the selected device will be used.",
- required=False,
- default=None,
+ parser.add_argument(
+ "-i",
+ "--input-files",
+ nargs="+",
+ help="One or more inputs that will be evaluated by the tool sequence "
+ "(e.g., script (.py), ONNX (.onnx), turnkey build state (state.yaml), "
+ "input list (.txt) files)",
+ type=lambda file: _check_extension(
+ ("py", "onnx", "txt", "yaml"), file, parser.error, tool_classes
+ ),
)
- both_build_benchmark_group.add_argument(
+ parser.add_argument(
"-d",
"--cache-dir",
- dest="cache_dir",
- help="Build cache directory where the resulting build directories will "
- f"be stored (defaults to {filesystem.DEFAULT_CACHE_DIR})",
+ help="Build cache directory where results will "
+ f"be stored (defaults to {fs.DEFAULT_CACHE_DIR})",
required=False,
- default=filesystem.DEFAULT_CACHE_DIR,
+ default=fs.DEFAULT_CACHE_DIR,
)
- both_build_benchmark_group.add_argument(
+ parser.add_argument(
"--lean-cache",
dest="lean_cache",
- help="Delete all build artifacts except for log files when the command completes",
+ help="Delete all build artifacts (e.g., .onnx files) when the command completes",
action="store_true",
)
- build_group = benchmark_parser.add_argument_group(
- "Options that apply specifically to the `build` phase of the toolflow"
- )
-
- build_group.add_argument(
- "--sequence",
- choices=SUPPORTED_SEQUENCES.keys(),
- dest="sequence",
- help="Name of a build sequence that will define the model-to-model transformations, "
- "used to build the models. Each runtime has a default sequence that it uses.",
- required=False,
- default=None,
- )
-
- build_group.add_argument(
- "--rebuild",
- choices=build.REBUILD_OPTIONS,
- dest="rebuild",
- help=f"Sets the cache rebuild policy (defaults to {build.DEFAULT_REBUILD_POLICY})",
- required=False,
- default=build.DEFAULT_REBUILD_POLICY,
- )
-
- build_group.add_argument(
- "--onnx-opset",
- dest="onnx_opset",
- type=int,
- default=None,
- help=f"ONNX opset used when creating ONNX files (default={build.DEFAULT_ONNX_OPSET}). "
- "Not applicable when input model is already a .onnx file.",
- )
-
- benchmark_group = benchmark_parser.add_argument_group(
- "Options that apply specifically to the `benchmark` phase of the toolflow"
- )
-
- benchmark_group.add_argument(
- "--iterations",
- dest="iterations",
- type=int,
- default=100,
- help="Number of execution iterations of the model to capture\
- the benchmarking performance (e.g., mean latency)",
- )
-
- benchmark_group.add_argument(
- "--rt-args",
- dest="rt_args",
- type=str,
+ parser.add_argument(
+ "--labels",
+ dest="labels",
+ help="Filter the --input-files to only include files that have the provided labels",
nargs="*",
- help="Optional arguments provided to the runtime being used",
- )
-
- all_toolflows_group = benchmark_parser.add_argument_group(
- "Options that apply to all toolflows"
+ default=[],
)
- slurm_or_processes_group = all_toolflows_group.add_mutually_exclusive_group()
+ slurm_or_processes_group = parser.add_mutually_exclusive_group()
slurm_or_processes_group.add_argument(
"--use-slurm",
@@ -277,7 +153,7 @@ def check_extension(choices, file_name, error_func):
action="store_true",
)
- all_toolflows_group.add_argument(
+ parser.add_argument(
"--timeout",
type=int,
default=None,
@@ -286,346 +162,90 @@ def check_extension(choices, file_name, error_func):
"applies when --process-isolation or --use-slurm is also used.",
)
- default_verbosity = Verbosity.AUTO.value
- all_toolflows_group.add_argument(
- "--verbosity",
- choices=[field.value for field in Verbosity],
- default=default_verbosity,
- help="Verbosity of the status updates printed to the command line "
- f"(default={default_verbosity}). '{Verbosity.DYNAMIC.value}': "
- "take over the terminal, updating "
- " it with a summary of all turnkey information. "
- f"'{Verbosity.STATIC.value}': print each evaluation as it takes place and "
- "never clear the terminal.",
- )
-
- #######################################
- # Subparser for the "cache" command
- #######################################
-
- cache_parser = subparsers.add_parser(
- "cache",
- help="Commands for managing the build cache",
- )
-
- cache_subparsers = cache_parser.add_subparsers(
- title="cache",
- help="Commands for managing the build cache",
- required=True,
- dest="cache_cmd",
- )
-
- #######################################
- # Parser for the "cache report" command
- #######################################
-
- report_parser = cache_subparsers.add_parser(
- "report", help="Generate reports in CSV format"
- )
- report_parser.set_defaults(func=report.summary_spreadsheets)
-
- report_parser.add_argument(
- "-d",
- "--cache-dir",
- dest="cache_dirs",
- help=(
- "One or more build cache directories to generate the report "
- f"(defaults to {filesystem.DEFAULT_CACHE_DIR})"
- ),
- default=[filesystem.DEFAULT_CACHE_DIR],
- nargs="*",
- )
-
- report_parser.add_argument(
- "-r",
- "--report-dir",
- dest="report_dir",
- help="Path to folder where report will be saved (defaults to current working directory)",
- required=False,
- default=os.getcwd(),
- )
-
- #######################################
- # Parser for the "cache list" command
- #######################################
-
- list_parser = cache_subparsers.add_parser(
- "list", help="List all builds in a target cache"
- )
- list_parser.set_defaults(func=filesystem.print_available_builds)
-
- list_parser.add_argument(
- "-d",
- "--cache-dir",
- dest="cache_dir",
- help="The builds in this build cache directory will printed to the terminal "
- f" (defaults to {filesystem.DEFAULT_CACHE_DIR})",
- required=False,
- default=filesystem.DEFAULT_CACHE_DIR,
- )
-
- #######################################
- # Parser for the "cache stats" command
- #######################################
-
- stats_parser = cache_subparsers.add_parser(
- "stats", help="Print stats about a build in a target cache"
- )
- stats_parser.set_defaults(func=print_stats)
-
- stats_parser.add_argument(
- "-d",
- "--cache-dir",
- dest="cache_dir",
- help="The stats of a build in this build cache directory will printed to the terminal "
- f" (defaults to {filesystem.DEFAULT_CACHE_DIR})",
- required=False,
- default=filesystem.DEFAULT_CACHE_DIR,
- )
-
- stats_parser.add_argument(
- "build_name",
- help="Name of the specific build whose stats are to be printed, within the cache directory",
- )
-
- #######################################
- # Parser for the "cache delete" command
- #######################################
-
- delete_parser = cache_subparsers.add_parser(
- "delete", help="Delete one or more builds in a build cache"
- )
- delete_parser.set_defaults(func=filesystem.delete_builds)
-
- delete_parser.add_argument(
- "-d",
- "--cache-dir",
- dest="cache_dir",
- help="Search path for builds " f"(defaults to {filesystem.DEFAULT_CACHE_DIR})",
- required=False,
- default=filesystem.DEFAULT_CACHE_DIR,
- )
-
- delete_group = delete_parser.add_mutually_exclusive_group(required=True)
-
- delete_group.add_argument(
- "build_name",
- nargs="?",
- help="Name of the specific build to be deleted, within the cache directory",
- )
-
- delete_group.add_argument(
- "--all",
- dest="delete_all",
- help="Delete all builds in the cache directory",
- action="store_true",
- )
-
- #######################################
- # Parser for the "cache clean" command
- #######################################
-
- clean_parser = cache_subparsers.add_parser(
- "clean",
- help="Remove the build artifacts from one or more builds in a build cache",
- )
- clean_parser.set_defaults(func=filesystem.clean_builds)
-
- clean_parser.add_argument(
- "-d",
- "--cache-dir",
- dest="cache_dir",
- help="Search path for builds " f"(defaults to {filesystem.DEFAULT_CACHE_DIR})",
- required=False,
- default=filesystem.DEFAULT_CACHE_DIR,
- )
-
- clean_group = clean_parser.add_mutually_exclusive_group(required=True)
-
- clean_group.add_argument(
- "build_name",
- nargs="?",
- help="Name of the specific build to be cleaned, within the cache directory",
- )
-
- clean_group.add_argument(
- "--all",
- dest="clean_all",
- help="Clean all builds in the cache directory",
- action="store_true",
- )
-
- #######################################
- # Parser for the "cache location" command
- #######################################
-
- cache_location_parser = cache_subparsers.add_parser(
- "location",
- help="Print the location of the default build cache directory",
- )
- cache_location_parser.set_defaults(func=filesystem.print_cache_dir)
-
- #######################################
- # Parser for the "cache benchmark" command
- #######################################
-
- cache_benchmark_parser = cache_subparsers.add_parser(
- "benchmark",
- help="Benchmark one or more builds in a build cache",
- )
- cache_benchmark_parser.set_defaults(func=benchmark_cache_cli)
-
- cache_benchmark_parser.add_argument(
- "-d",
- "--cache-dir",
- dest="cache_dir",
- help="Search path for builds " f"(defaults to {filesystem.DEFAULT_CACHE_DIR})",
- required=False,
- default=filesystem.DEFAULT_CACHE_DIR,
- )
-
- cache_benchmark_group = cache_benchmark_parser.add_mutually_exclusive_group(
- required=True
- )
-
- cache_benchmark_group.add_argument(
- "build_name",
- nargs="?",
- help="Name of the specific build to be benchmarked, within the cache directory",
- )
-
- cache_benchmark_group.add_argument(
- "--all",
- dest="benchmark_all",
- help="Benchmark all builds in the cache directory",
- action="store_true",
- )
-
- skip_policy_default = "attempted"
- cache_benchmark_parser.add_argument(
- "--skip",
- choices=[skip_policy_default, "failed", "successful", "none"],
- dest="skip_policy",
- help=f"Sets the policy for skipping benchmark attempts (defaults to {skip_policy_default})."
- "`attempted` means to skip any previously-attempted benchmark, "
- "whether it succeeded or failed."
- "`failed` skips benchmarks that have already failed once."
- "`successful` skips benchmarks that have already succeeded."
- "`none` will attempt all benchmarks, regardless of whether they were previously attempted.",
- required=False,
- default=skip_policy_default,
- )
-
- cache_benchmark_parser.add_argument(
- "--timeout",
- type=int,
- default=1800,
- help="Benchmark timeout, in seconds, after which each benchmark will be canceled "
- "(default: 30min).",
- )
-
- cache_benchmark_parser.add_argument(
- "--runtime",
- choices=SUPPORTED_RUNTIMES.keys(),
- dest="runtime",
- help="Software runtime that will be used to collect the benchmark. "
- "Must be compatible with the device chosen for the build. "
- "If this argument is not set, the default runtime of the selected device will be used.",
- required=False,
- default=None,
- )
-
- cache_benchmark_parser.add_argument(
- "--iterations",
- dest="iterations",
- type=int,
- default=100,
- help="Number of execution iterations of the model to capture\
- the benchmarking performance (e.g., mean latency)",
- )
-
- cache_benchmark_parser.add_argument(
- "--rt-args",
- dest="rt_args",
- type=str,
- nargs="*",
- help="Optional arguments provided to the runtime being used",
- )
-
- #######################################
- # Subparser for the "models" command
- #######################################
-
- models_parser = subparsers.add_parser(
- "models",
- help="Commands for managing the models",
- )
-
- models_subparsers = models_parser.add_subparsers(
- title="models",
- help="Commands for managing the models",
- required=True,
- dest="models_cmd",
- )
-
- models_location_parser = models_subparsers.add_parser(
- "location",
- help="Print the location of the models directory",
- )
- models_location_parser.set_defaults(func=filesystem.print_models_dir)
-
- models_location_parser.add_argument(
- "--quiet",
- dest="verbose",
- help="Command output will only include the directory path",
- required=False,
- action="store_false",
- )
-
- #######################################
- # Parser for the "version" command
- #######################################
-
- version_parser = subparsers.add_parser(
- "version",
- help="Print the package version number",
- )
- version_parser.set_defaults(func=print_version)
-
- #######################################
- # Execute the command
- #######################################
-
- # The default behavior of this CLI is to run the build command
- # on a target script. If the user doesn't provide a command,
- # we alter argv to insert the command for them.
-
- # Special characters that indicate a string is a filename, not a command
- file_chars = [".", "/", "\\", "*"]
-
- if len(sys.argv) > 1:
- first_arg = sys.argv[1]
- if first_arg not in subparsers.choices.keys() and "-h" not in first_arg:
- if any(char_to_check in first_arg for char_to_check in file_chars):
- # User has provided a file as the first positional arg
- sys.argv.insert(1, "benchmark")
- else:
- # User has provided a command as the first positional arg
- # Check how close we are from each of the valid options
- # NOTE: if we are not close to a valid option, we will let
- # argparse detect and raise the error
- valid_options = list(subparsers.choices.keys())
- close_matches = get_close_matches(first_arg, valid_options)
-
- if close_matches:
- raise exceptions.ArgError(
- f"Unexpected command `turnkey {first_arg}`. "
- f"Did you mean `turnkey {close_matches[0]}`?"
- )
-
- args = parser.parse_args()
-
- args.func(args)
+ # run as if "-h" was passed if no parameters are passed
+ if len(sys.argv) == 1:
+ sys.argv.append("-h")
+
+ # Break sys.argv into categories based on which tools were invoked
+ # Arguments that are passed prior to invoking a tool are categorized as
+ # global arguments that should be used to initialize the state.
+ current_tool = "globals"
+ tools_invoked = {current_tool: []}
+ cmd = sys.argv[1:]
+ while len(cmd):
+ if cmd[0] in tool_parsers.keys():
+ # Make sure each tool was only called once
+ if cmd[0] in tools_invoked.keys():
+ parser.error(
+ "A single call to turnkey can only invoke each tool once, "
+ f"however this call invokes tool {cmd[0]} multiple times."
+ )
+ current_tool = cmd.pop(0)
+ tools_invoked[current_tool] = []
+ else:
+ tools_invoked[current_tool].append(cmd.pop(0))
+
+ # Trick argparse into thinking tools was not a positional argument
+ # this helps to avoid an error where an incorrect arg/value pair
+ # can be misinterpreted as the tools positional argument
+ tools_action.option_strings = "--tools"
+
+ # Do one pass of parsing to figure out if -h was used
+ global_args = vars(parser.parse_args(tools_invoked["globals"]))
+
+ # Remove "tools" from global args because it was just there
+ # as a placeholder
+ global_args.pop("tools")
+
+ # Remove globals from the list since its already been parsed
+ tools_invoked.pop("globals")
+ evaluation_tools = []
+ management_tools = []
+ for cmd, argv in tools_invoked.items():
+ tool_parsers[cmd].parse_args(argv)
+
+ # Keep track of whether the tools are ManagementTool or not,
+ # since ManagementTools are mutually exclusive with evaluation
+ # tools
+ if issubclass(tool_classes[cmd], ManagementTool):
+ management_tools.append(cmd)
+ else:
+ evaluation_tools.append(cmd)
+
+ if len(management_tools) > 0 and len(evaluation_tools) > 0:
+ parser.error(
+ "This call to turnkey invoked both management and "
+ "evaluation tools, however each call to turnkey "
+ "is only allowed to invoke one or the other. "
+ f"Management tools: {management_tools};"
+ f"Evaluation tools: {evaluation_tools}."
+ )
+
+ if len(management_tools) == 0 and len(evaluation_tools) == 0:
+ parser.error(
+ "Calls to turnkey are required to call at least "
+ "one tool or management tool."
+ )
+
+ # Convert tool names into Tool instances
+ tool_instances = {tool_classes[cmd](): argv for cmd, argv in tools_invoked.items()}
+
+ if len(evaluation_tools) > 0:
+ if not issubclass(tool_classes[evaluation_tools[0]], FirstTool):
+ parser.error(
+ "The first tool in the sequence needs to be one "
+ "of the 'tools that can start a sequence.' Use "
+ "`turnkey -h` to see that list of tools."
+ )
+ # Run the evaluation tools as a build
+ sequence = Sequence(tools=tool_instances)
+ evaluate_files(sequence=sequence, **global_args)
+ else:
+ # Run the management tools
+ for management_tool, argv in tool_instances.items():
+ # Support "~" in the cache_dir argument
+ parsed_cache_dir = os.path.expanduser(global_args[fs.Keys.CACHE_DIR])
+ management_tool.parse_and_run(parsed_cache_dir, argv)
if __name__ == "__main__":
diff --git a/src/turnkeyml/cli/report.py b/src/turnkeyml/cli/report.py
deleted file mode 100644
index be5defc9..00000000
--- a/src/turnkeyml/cli/report.py
+++ /dev/null
@@ -1,201 +0,0 @@
-import os
-import csv
-from datetime import datetime
-from pathlib import Path
-from typing import Dict, List
-import yaml
-import pandas as pd
-import turnkeyml.common.printing as printing
-import turnkeyml.common.filesystem as fs
-import turnkeyml.common.build as bd
-
-
-def get_report_name(prefix: str = "") -> str:
- """
- Returns the name of the .csv report
- """
- day = datetime.now().day
- month = datetime.now().month
- year = datetime.now().year
- date_key = f"{year}-{str(month).zfill(2)}-{str(day).zfill(2)}"
- return f"{prefix}{date_key}.csv"
-
-
-def _good_get(
- dict: Dict, key: str, return_keys: bool = False, return_values: bool = False
-):
- if key in dict:
- if return_keys:
- return list(dict[key].keys())
- elif return_values:
- return list(dict[key].values())
- else:
- return dict[key]
- else:
- return "-"
-
-
-def summary_spreadsheets(args) -> None:
- # Input arguments from CLI
- cache_dirs = [os.path.expanduser(dir) for dir in args.cache_dirs]
- cache_dirs = fs.expand_inputs(cache_dirs)
- report_dir = os.path.expanduser(args.report_dir)
-
- # Name report file
- report_path = os.path.join(report_dir, get_report_name())
-
- # Create report dict
- Path(report_dir).mkdir(parents=True, exist_ok=True)
-
- report: List[Dict] = []
- all_evaluation_stats = []
-
- # Add results from all user-provided cache folders
- for cache_dir in cache_dirs:
- # Check if this is a valid cache directory
- fs.check_cache_dir(cache_dir)
-
- # List all yaml files available
- all_model_stats_yamls = fs.get_all(
- path=cache_dir, file_type="turnkey_stats.yaml"
- )
- all_model_stats_yamls = sorted(all_model_stats_yamls)
-
- # Bring all of the stats for all of the models into memory
- for model_stats_yaml in all_model_stats_yamls:
- with open(model_stats_yaml, "r", encoding="utf8") as stream:
- try:
- # load the yaml into a dict
- model_stats = yaml.load(stream, Loader=yaml.FullLoader)
-
- # create a separate dict for each evaluation
- for evaluation in model_stats[fs.Keys.EVALUATIONS].values():
- evaluation_stats = {}
-
- # Copy all of the stats for the model that are common across evaluation
- for key, value in model_stats.items():
- if key != fs.Keys.EVALUATIONS:
- evaluation_stats[key] = value
-
- # Copy the evaluation-specific stats
- for key, value in evaluation.items():
- # If a build or benchmark is still marked as "incomplete" at
- # reporting time, it must have been killed by a time out,
- # out-of-memory (OOM), or some other uncaught exception
- if (
- (
- key == fs.Keys.BUILD_STATUS
- or fs.Keys.BENCHMARK_STATUS
- )
- or fs.Keys.STAGE_STATUS in key
- ) and value == bd.FunctionStatus.INCOMPLETE.value:
- value = bd.FunctionStatus.KILLED.value
-
- # Add stats ensuring that those are all in lower case
- evaluation_stats[key.lower()] = value
-
- all_evaluation_stats.append(evaluation_stats)
- except yaml.scanner.ScannerError:
- continue
-
- # Scan the build stats to determine the set of columns for the CSV file.
- # The CSV will have one column for every key in any build stats dict.
- column_headers = []
- for evaluation_stats in all_evaluation_stats:
- # Add any key that isn't already in column_headers
- for header in evaluation_stats.keys():
- if header not in column_headers:
- column_headers.append(header)
-
- # Sort all columns alphabetically
- column_headers = sorted(column_headers)
-
- # Add each build to the report
- for evaluation_stats in all_evaluation_stats:
- # Start with a dictionary where all of the values are "-". If a build
- # has a value for each key we will fill it in, and otherwise the "-"
- # will indicate that no value was available
- result = {k: "-" for k in column_headers}
-
- for key in column_headers:
- result[key] = _good_get(evaluation_stats, key)
-
- report.append(result)
-
- # Populate results spreadsheet
- with open(report_path, "w", newline="", encoding="utf8") as spreadsheet:
- writer = csv.writer(spreadsheet)
- writer.writerow(column_headers)
- for build in report:
- writer.writerow([build[col] for col in column_headers])
-
- # Print message with the output file path
- printing.log("Summary spreadsheet saved at ")
- printing.logn(str(report_path), printing.Colors.OKGREEN)
-
- # Save the unique errors and counts to a file
- errors = []
- for evaluation_stats in all_evaluation_stats:
- if (
- "compilation_error" in evaluation_stats.keys()
- and "compilation_error_id" in evaluation_stats.keys()
- ):
- error = evaluation_stats["compilation_error"]
- id = evaluation_stats["compilation_error_id"]
- if id != "":
- unique_error = True
- for reported_error in errors:
- if reported_error["id"] == id:
- unique_error = False
- reported_error["count"] = reported_error["count"] + 1
- reported_error["models_impacted"] = reported_error[
- "models_impacted"
- ] + [evaluation_stats["model_name"]]
-
- if unique_error:
- reported_error = {
- "id": id,
- "count": 1,
- "models_impacted": [evaluation_stats["model_name"]],
- "example": error,
- }
- errors.append(reported_error)
-
- if len(errors) > 0:
- errors_path = os.path.join(report_dir, get_report_name("errors-"))
- with open(errors_path, "w", newline="", encoding="utf8") as spreadsheet:
- writer = csv.writer(spreadsheet)
- error_headers = errors[0].keys()
- writer.writerow(error_headers)
- for unique_error in errors:
- writer.writerow([unique_error[col] for col in error_headers])
-
- printing.log("Compilation errors spreadsheet saved at ")
- printing.logn(str(errors_path), printing.Colors.OKGREEN)
- else:
- printing.logn(
- "No compilation errors in any cached build, skipping errors spreadsheet."
- )
-
-
-def get_dict(report_csv: str, columns: List[str]) -> Dict[str, Dict[str, str]]:
- """
- Returns a dictionary where the keys are model names and the values are dictionaries.
- Each dictionary represents a model with column names as keys and their corresponding values.
- args:
- - report_csv: path to a report.csv file generated by turnkey CLI
- - columns: list of column names in the report.csv file whose values will be used to
- populate the dictionary
- """
-
- # Load the report as a dataframe
- dataframe = pd.read_csv(report_csv)
-
- # Create a nested dictionary with model_name as keys and another
- # dictionary of {column: value} pairs as values
- result = {
- row[0]: row[1].to_dict()
- for row in dataframe.set_index("model_name")[columns].iterrows()
- }
-
- return result
diff --git a/src/turnkeyml/cli/spawn.py b/src/turnkeyml/cli/spawn.py
index 3dc34014..0fb111d7 100644
--- a/src/turnkeyml/cli/spawn.py
+++ b/src/turnkeyml/cli/spawn.py
@@ -12,13 +12,12 @@
from time import monotonic
import getpass
from typing import List, Optional, Dict, Union
-from enum import Enum
import psutil
import turnkeyml.common.filesystem as filesystem
import turnkeyml.common.printing as printing
import turnkeyml.common.build as build
from turnkeyml.cli.parser_helpers import encode_args
-from turnkeyml.analyze.status import Verbosity
+from turnkeyml.sequence import Sequence
class WatchdogTimer(Thread):
@@ -122,11 +121,6 @@ def parse_build_name(line: str, current_value: str) -> Optional[str]:
DEFAULT_TIMEOUT_SECONDS = 3600
-class Target(Enum):
- SLURM = "slurm"
- LOCAL_PROCESS = "local_process"
-
-
def slurm_jobs_in_queue(job_name=None) -> List[str]:
"""Return the set of slurm jobs that are currently pending/running"""
user = getpass.getuser()
@@ -171,10 +165,6 @@ def value_arg(key: str, value: Union[str, int]):
return ""
-def verbosity_arg(key: str, value: Verbosity):
- return f'{key}="{value.value}"'
-
-
def bool_arg(key: str, value: bool):
if value:
return f"{key}"
@@ -189,16 +179,26 @@ def dict_arg(key: str, value: Dict):
return ""
+def sequence_arg(value: Sequence) -> Dict[str, Dict[str, str]]:
+ result = ""
+ for tool, args in value.info.items():
+ result = result + f"{tool} {' '.join(args)}"
+
+ return result
+
+
def run_turnkey(
- op: str,
+ build_name: str,
+ sequence: Sequence,
file_name: str,
- target: Target,
+ process_isolation: bool,
+ use_slurm: bool,
cache_dir: str,
+ lean_cache: bool,
timeout: Optional[int] = DEFAULT_TIMEOUT_SECONDS,
working_dir: str = os.getcwd(),
ml_cache_dir: Optional[str] = os.environ.get("SLURM_ML_CACHE"),
max_jobs: int = 50,
- **kwargs,
):
"""
Run turnkey on a single input file in a separate process (e.g., Slurm, subprocess).
@@ -208,27 +208,33 @@ def run_turnkey(
The key must be the snake_case version of the CLI argument (e.g, build_only for --build-only)
"""
+ if use_slurm and process_isolation:
+ raise ValueError(
+ "use_slurm and process_isolation are mutually exclusive, but both are True"
+ )
+
type_to_formatter = {
str: value_arg,
int: value_arg,
bool: bool_arg,
list: list_arg,
dict: dict_arg,
- Verbosity: verbosity_arg,
}
- invocation_args = f"{op} {file_name}"
+ invocation_args = f"-i {file_name}"
# Add cache_dir to kwargs so that it gets processed
# with the other arguments
- kwargs["cache_dir"] = cache_dir
+ kwargs = {"cache_dir": cache_dir, "lean_cache": lean_cache}
for key, value in kwargs.items():
if value is not None:
arg_str = type_to_formatter[type(value)](arg_format(key), value)
invocation_args = invocation_args + " " + arg_str
- if target == Target.SLURM:
+ invocation_args = invocation_args + " " + sequence_arg(sequence)
+
+ if use_slurm:
# Change args into the format expected by Slurm
slurm_args = " ".join(shlex.split(invocation_args))
@@ -270,7 +276,7 @@ def run_turnkey(
print(f"Submitting job {job_name} to Slurm")
subprocess.check_call(slurm_command)
- elif target == Target.LOCAL_PROCESS:
+ else: # process isolation
command = "turnkey " + invocation_args
printing.log_info(f"Starting process with command: {command}")
@@ -326,74 +332,44 @@ def run_turnkey(
f"turnkey will move on to the next input.\n\n{e}"
)
- # If an evaluation failed, it will be the last build mentioned in the
- # subprocess's stdout. We look for the last instance because sometimes
- # a single input file will contain multiple models, and therefore multiple
- # builds.
- # NOTE: the turnkey status outputs use the term "build" to refer to both
- # builds and benchmarks, we collectively we refer to as evaluations here
- build_name = None
- evaluation_id = None
- for line in process_output:
- evaluation_id = parse_evaluation_id(line, evaluation_id)
- build_name = parse_build_name(line, build_name)
-
- # Perform fault handling if we found a failed evaluation
- if build_name:
- printing.log_info(
- f"Detected failed build {build_name}. "
- "The parent process will attempt to clean up."
- )
-
- # Cleaning the cache is the last step in evaluation
- # If a "lean cache" evaluation was killed, it is safe to assume we still
- # need to clean the cache
- # It is also harmless to run clean_output_dir() again even if the subprocess
- # did have a chance to run it before the subprocess was killed
- if "--lean-cache" in command:
- printing.log_info("Removing build artifacts...")
- filesystem.clean_output_dir(cache_dir, build_name)
-
- # Perform fault handling within the stats file if there is a stats
- # file and we know the evaluation ID of the failed evaluation
- if (
- os.path.isfile(filesystem.stats_file(cache_dir, build_name))
- and evaluation_id
- ):
- try:
- # Amend the stats with a specific function status if possible
- if isinstance(e, subprocess.TimeoutExpired):
- evaluation_status = build.FunctionStatus.TIMEOUT
- else:
- evaluation_status = build.FunctionStatus.KILLED
-
- stats = filesystem.Stats(
- cache_dir,
- build_name,
- evaluation_id,
- )
-
- for key in stats.evaluation_stats.keys():
- if (
- stats.evaluation_stats[key]
- == build.FunctionStatus.INCOMPLETE.value
- ):
- stats.save_model_eval_stat(key, evaluation_status.value)
-
- # Save the exception into the error log stat
- stats.save_model_eval_stat(filesystem.Keys.ERROR_LOG, str(e))
-
- except Exception as stats_exception: # pylint: disable=broad-except
- printing.log_info(
- "Stats file found, but unable to perform cleanup due to "
- f"exception: {stats_exception}"
- )
-
- else:
- printing.log_info(
- "Turnkey subprocess was killed before any "
- "build or benchmark could start."
- )
+ # Perform fault handling
+ printing.log_info(
+ f"Detected failed build {build_name}. "
+ "The parent process will attempt to clean up."
+ )
- else:
- raise ValueError(f"Unsupported value for target: {target}.")
+ # Cleaning the cache is the last step in evaluation
+ # If a "lean cache" evaluation was killed, it is safe to assume we still
+ # need to clean the cache
+ # It is also harmless to run clean_output_dir() again even if the subprocess
+ # did have a chance to run it before the subprocess was killed
+ if "--lean-cache" in command:
+ printing.log_info("Removing build artifacts...")
+ filesystem.clean_output_dir(cache_dir, build_name)
+
+ # Perform fault handling within the stats file if it exists
+ if os.path.isfile(filesystem.stats_file(cache_dir, build_name)):
+ try:
+ # Amend the stats with a specific function status if possible
+ if isinstance(e, subprocess.TimeoutExpired):
+ evaluation_status = build.FunctionStatus.TIMEOUT
+ else:
+ evaluation_status = build.FunctionStatus.KILLED
+
+ stats = filesystem.Stats(
+ cache_dir,
+ build_name,
+ )
+
+ for key in stats.stats.keys():
+ if stats.stats[key] == build.FunctionStatus.INCOMPLETE:
+ stats.save_stat(key, evaluation_status)
+
+ # Save the exception into the error log stat
+ stats.save_stat(filesystem.Keys.ERROR_LOG, str(e))
+
+ except Exception as stats_exception: # pylint: disable=broad-except
+ printing.log_info(
+ "Stats file found, but unable to perform cleanup due to "
+ f"exception: {stats_exception}"
+ )
diff --git a/src/turnkeyml/analyze/model.py b/src/turnkeyml/common/analyze_model.py
similarity index 96%
rename from src/turnkeyml/analyze/model.py
rename to src/turnkeyml/common/analyze_model.py
index 5590f9ca..fd41b121 100644
--- a/src/turnkeyml/analyze/model.py
+++ b/src/turnkeyml/common/analyze_model.py
@@ -3,7 +3,6 @@
import torch
import onnx
from turnkeyml.common import printing
-import turnkeyml.common.build as build
import turnkeyml.common.filesystem as fs
@@ -13,15 +12,13 @@ class AnalysisException(Exception):
"""
-def count_parameters(model: torch.nn.Module, model_type: build.ModelType) -> int:
+def count_parameters(model: torch.nn.Module) -> int:
"""
Returns the number of parameters of a given model
"""
- if model_type == build.ModelType.PYTORCH:
+ if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)):
return sum([parameter.numel() for _, parameter in model.named_parameters()])
- elif model_type == build.ModelType.KERAS:
- return model.count_params()
- elif model_type == build.ModelType.ONNX_FILE:
+ elif isinstance(model, str) and model.endswith(".onnx"):
onnx_model = onnx.load(model)
return int(
sum(
@@ -30,9 +27,11 @@ def count_parameters(model: torch.nn.Module, model_type: build.ModelType) -> int
if tensor.name not in onnx_model.graph.input
)
)
+ elif isinstance(model, str) and model.endswith(".yaml"):
+ return None
# Raise exception if an unsupported model type is provided
- raise AnalysisException(f"model_type {model_type} is not supported")
+ raise AnalysisException(f"model type {type(model)} is not supported")
def get_onnx_ops_list(onnx_model) -> Dict:
@@ -324,19 +323,19 @@ def analyze_onnx(build_name: str, cache_dir: str, stats: fs.Stats):
onnx_model_info = populate_onnx_model_info(final_onnx_file)
input_dimensions = onnx_input_dimensions(final_onnx_file)
- stats.save_model_stat(
+ stats.save_stat(
fs.Keys.ONNX_OPS_COUNTER,
onnx_ops_counter,
)
- stats.save_model_stat(
+ stats.save_stat(
fs.Keys.ONNX_TOTAL_FLOPS,
onnx_total_flops,
)
- stats.save_model_stat(
+ stats.save_stat(
fs.Keys.ONNX_MODEL_INFO,
onnx_model_info,
)
- stats.save_model_stat(
+ stats.save_stat(
fs.Keys.ONNX_INPUT_DIMENSIONS,
input_dimensions,
)
diff --git a/src/turnkeyml/common/build.py b/src/turnkeyml/common/build.py
index fc0342ba..cfdb2d65 100644
--- a/src/turnkeyml/common/build.py
+++ b/src/turnkeyml/common/build.py
@@ -1,24 +1,17 @@
import os
import logging
import sys
-import pathlib
-import copy
import traceback
import platform
import subprocess
-import enum
-from typing import Optional, Any, List, Dict, Union, Type
-import dataclasses
+from typing import Dict, Union
import hashlib
import pkg_resources
import psutil
import yaml
import torch
import numpy as np
-import sklearn.base
import turnkeyml.common.exceptions as exp
-import turnkeyml.common.tf_helpers as tf_helpers
-from turnkeyml.version import __version__ as turnkey_version
UnionValidModelInstanceTypes = Union[
@@ -26,8 +19,6 @@
str,
torch.nn.Module,
torch.jit.ScriptModule,
- "tf.keras.Model",
- sklearn.base.BaseEstimator,
]
if os.environ.get("TURNKEY_ONNX_OPSET"):
@@ -41,20 +32,7 @@
REBUILD_OPTIONS = ["if_needed", "always", "never"]
-class ModelType(enum.Enum):
- PYTORCH = "pytorch"
- PYTORCH_COMPILED = "pytorch_compiled"
- KERAS = "keras"
- ONNX_FILE = "onnx_file"
- HUMMINGBIRD = "hummingbird"
- UNKNOWN = "unknown"
-
-
-# Indicates that the build should take take any specific device into account
-DEFAULT_DEVICE = "default"
-
-
-def load_yaml(file_path):
+def load_yaml(file_path) -> Dict:
with open(file_path, "r", encoding="utf8") as stream:
try:
return yaml.load(stream, Loader=yaml.FullLoader)
@@ -76,24 +54,29 @@ def state_file(cache_dir, build_name):
return path
-def hash_model(model, model_type: ModelType, hash_params: bool = True):
+def hash_model(model, hash_params: bool = True):
# If the model is a path to a file, hash the file
- if model_type == ModelType.ONNX_FILE:
- # TODO: Implement a way of hashing the models but not the parameters
- # of ONNX inputs.
- if not hash_params:
- msg = "hash_params must be True for model_type ONNX_FILE"
- raise ValueError(msg)
- if os.path.isfile(model):
+ if isinstance(model, str):
+ if model.endswith(".onnx"):
+ # TODO: Implement a way of hashing the models but not the parameters
+ # of ONNX inputs.
+ if not hash_params:
+ msg = "hash_params must be True for ONNX files"
+ raise ValueError(msg)
+ if os.path.isfile(model):
+ with open(model, "rb") as f:
+ file_content = f.read()
+ return hashlib.sha256(file_content).hexdigest()
+ else:
+ raise ValueError(
+ "hash_model received str model that doesn't correspond to a file"
+ )
+ else:
with open(model, "rb") as f:
file_content = f.read()
return hashlib.sha256(file_content).hexdigest()
- else:
- raise ValueError(
- "hash_model received str model that doesn't correspond to a file"
- )
- elif model_type in [ModelType.PYTORCH, ModelType.PYTORCH_COMPILED]:
+ if isinstance(model, (torch.nn.Module, torch.jit.ScriptModule)):
# Convert model parameters and topology to string
hashable_params = {}
for name, param in model.named_parameters():
@@ -106,54 +89,30 @@ def hash_model(model, model_type: ModelType, hash_params: bool = True):
# Return hash of topology and parameters
return hashlib.sha256(hashable_model).hexdigest()
- elif model_type == ModelType.KERAS:
- # Convert model parameters and topology to string
- summary_list = [] # type: List[str]
-
- # pylint: disable=unnecessary-lambda
- model.summary(print_fn=lambda x: summary_list.append(x))
-
- summary_str = " ".join(summary_list)
- hashable_params = {}
- for layer in model.layers:
- hashable_params[layer.name] = layer.weights
- if hash_params:
- hashable_model = (summary_str + str(hashable_params)).encode()
- else:
- hashable_model = summary_str.encode()
-
- # Return hash of topology and parameters
- return hashlib.sha256(hashable_model).hexdigest()
-
- elif model_type == ModelType.HUMMINGBIRD:
- import pickle
-
- return hashlib.sha256(pickle.dumps(model)).hexdigest()
-
else:
msg = f"""
- model_type "{model_type}" unsupported by this hash_model function
+ model type "{type(model)}" unsupported by this hash_model function
"""
raise ValueError(msg)
-class FunctionStatus(enum.Enum):
+class FunctionStatus:
"""
- Status values that are assigned to stages, builds, benchmarks, and other
+ Status values that are assigned to tools, builds, benchmarks, and other
functionality to help the user understand whether that function completed
successfully or not.
"""
- # SUCCESSFUL means the stage/build/benchmark completed successfully.
+ # SUCCESSFUL means the tool/build/benchmark completed successfully.
SUCCESSFUL = "successful"
- # ERROR means the stage/build/benchmark failed and threw some error that
+ # ERROR means the tool/build/benchmark failed and threw some error that
# was caught by turnkey. You should proceed by looking at the build
# logs to see what happened.
ERROR = "error"
- # TIMEOUT means the stage/build/benchmark failed because it exceeded the timeout
+ # TIMEOUT means the tool/build/benchmark failed because it exceeded the timeout
# set for the turnkey command.
TIMEOUT = "timeout"
@@ -163,21 +122,21 @@ class FunctionStatus(enum.Enum):
# why it is being killed (e.g., watch the RAM utilization to diagnose an OOM).
KILLED = "killed"
- # The NOT_STARTED status is applied to all stages/builds/benchmarks at startup.
- # It will be replaced by one of the other status values if the stage/build/benchmark
+ # The NOT_STARTED status is applied to all tools/builds/benchmarks at startup.
+ # It will be replaced by one of the other status values if the tool/build/benchmark
# has a chance to start running.
- # A value of NOT_STARTED in the report CSV indicates that the stage/build/benchmark
+ # A value of NOT_STARTED in the report CSV indicates that the tool/build/benchmark
# never had a chance to start because turnkey exited before that functionality had
# a chance to start running.
NOT_STARTED = "not_started"
- # INCOMPLETE indicates that a stage/build/benchmark started running and did not complete.
- # Each stage, build, and benchmark are marked as INCOMPLETE when they start running.
- # If you open the turnkey_stats.yaml file while the stage/build/benchmark
- # is still running, the status will show as INCOMPLETE. If the stage/build/benchmark
+ # INCOMPLETE indicates that a tool/build/benchmark started running and did not complete.
+ # Each tool, build, and benchmark are marked as INCOMPLETE when they start running.
+ # If you open the turnkey_stats.yaml file while the tool/build/benchmark
+ # is still running, the status will show as INCOMPLETE. If the tool/build/benchmark
# is killed without the chance to do any stats cleanup, the status will continue to
# show as INCOMPLETE in turnkey_stats.yaml.
- # When the report CSV is created, any instance of an INCOMPLETE stage/build/benchmark
+ # When the report CSV is created, any instance of an INCOMPLETE tool/build/benchmark
# status will be replaced by KILLED.
INCOMPLETE = "incomplete"
@@ -218,9 +177,6 @@ def get_shapes_and_dtypes(inputs: dict):
elif torch.is_tensor(value):
shapes[key] = np.array(value.detach()).shape
dtypes[key] = np.array(value.detach()).dtype.name
- elif tf_helpers.is_keras_tensor(value):
- shapes[key] = np.array(value).shape
- dtypes[key] = np.array(value).dtype.name
elif isinstance(value, np.ndarray):
shapes[key] = value.shape
dtypes[key] = value.dtype.name
@@ -238,185 +194,6 @@ def get_shapes_and_dtypes(inputs: dict):
return shapes, dtypes
-@dataclasses.dataclass(frozen=True)
-class Config:
- """
- User-provided build configuration. Instances of Config should not be modified
- once they have been instantiated (frozen=True enforces this).
-
- Note: modifying this struct can create a breaking change that
- requires users to rebuild their models. Increment the minor
- version number of the turnkey package if you do make a build-
- breaking change.
- """
-
- build_name: str
- auto_name: bool
- sequence: List[str]
- onnx_opset: int
- device: Optional[str]
-
-
-@dataclasses.dataclass
-class State:
- # User-provided args that influence the generated model
- config: Config
-
- # User-provided args that do not influence the generated model
- monitor: bool = False
- rebuild: str = ""
- cache_dir: str = ""
- evaluation_id: str = ""
-
- # User-provided args that will not be saved as part of state.yaml
- model: UnionValidModelInstanceTypes = None
- inputs: Optional[Dict[str, Any]] = None
-
- # Member variable that helps the code know if State has called
- # __post_init__ yet
- save_when_setting_attribute: bool = False
-
- # All of the following are critical aspects of the build,
- # including properties of the tool and choices made
- # while building the model, which determine the outcome of the build.
- # NOTE: adding or changing a member name in this struct can create
- # a breaking change that requires users to rebuild their models.
- # Increment the minor version number of the turnkey package if you
- # do make a build-breaking change.
-
- turnkey_version: str = turnkey_version
- model_type: ModelType = ModelType.UNKNOWN
- uid: Optional[int] = None
- model_hash: Optional[int] = None
- build_status: FunctionStatus = FunctionStatus.NOT_STARTED
- expected_input_shapes: Optional[Dict[str, list]] = None
- expected_input_dtypes: Optional[Dict[str, list]] = None
- expected_output_names: Optional[List] = None
-
- # Whether or not inputs must be downcasted during inference
- downcast_applied: bool = False
-
- # The results of the most recent stage that was executed
- current_build_stage: str = None
- intermediate_results: Any = None
-
- # Results of a successful build
- results: Any = None
-
- def __post_init__(self):
- if self.uid is None:
- self.uid = unique_id()
- if self.inputs is not None:
- (
- self.expected_input_shapes,
- self.expected_input_dtypes,
- ) = get_shapes_and_dtypes(self.inputs)
- if self.model is not None and self.model_type != ModelType.UNKNOWN:
- self.model_hash = hash_model(self.model, self.model_type)
-
- self.save_when_setting_attribute = True
-
- def __setattr__(self, name, val):
- super().__setattr__(name, val)
-
- # Always automatically save the state.yaml whenever State is modified
- # But don't bother saving until after __post_init__ is done (indicated
- # by the save_when_setting_attribute flag)
- # Note: This only works when elements of the state are set directly.
- if self.save_when_setting_attribute and name != "save_when_setting_attribute":
- self.save()
-
- @property
- def original_inputs_file(self):
- return os.path.join(
- output_dir(self.cache_dir, self.config.build_name), "inputs.npy"
- )
-
- def prepare_file_system(self):
- # Create output folder if it doesn't exist
- os.makedirs(output_dir(self.cache_dir, self.config.build_name), exist_ok=True)
-
- def prepare_state_dict(self) -> Dict:
- state_dict = {
- key: value
- for key, value in vars(self).items()
- if not key == "inputs"
- and not key == "model"
- and not key == "save_when_setting_attribute"
- }
-
- # Special case for saving objects
- state_dict["config"] = copy.deepcopy(vars(self.config))
-
- state_dict["model_type"] = self.model_type.value
- state_dict["build_status"] = self.build_status.value
-
- return state_dict
-
- def save_yaml(self, state_dict: Dict):
- with open(
- state_file(self.cache_dir, self.config.build_name), "w", encoding="utf8"
- ) as outfile:
- yaml.dump(state_dict, outfile)
-
- def save(self):
- self.prepare_file_system()
-
- state_dict = self.prepare_state_dict()
-
- self.save_yaml(state_dict)
-
-
-def load_state(
- cache_dir=None,
- build_name=None,
- state_path=None,
- state_type: Type = State,
-) -> State:
- if state_path is not None:
- file_path = state_path
- elif build_name is not None and cache_dir is not None:
- file_path = state_file(cache_dir, build_name)
- else:
- raise ValueError(
- "This function requires either build_name and cache_dir to be set, "
- "or state_path to be set, not both or neither"
- )
-
- state_dict = load_yaml(file_path)
-
- # Get the type of Config and Info in case they have been overloaded
- field_types = {field.name: field.type for field in dataclasses.fields(state_type)}
- config_type = field_types["config"]
-
- try:
- # Special case for loading enums
- state_dict["model_type"] = ModelType(state_dict["model_type"])
- state_dict["build_status"] = FunctionStatus(state_dict["build_status"])
- state_dict["config"] = config_type(**state_dict["config"])
-
- state = state_type(**state_dict)
-
- except (KeyError, TypeError) as e:
- if state_path is not None:
- path_suggestion = pathlib.Path(state_path).parent
- else:
- path_suggestion = output_dir(cache_dir, build_name)
- msg = f"""
- The cached build of this model was built with an
- incompatible older version of the tool.
-
- Suggested solution: delete the build with
- rm -rf {path_suggestion}
-
- The underlying code raised this exception:
- {e}
- """
- raise exp.StateError(msg)
-
- return state
-
-
class Logger:
"""
Redirects stdout to to file (and console if needed)
diff --git a/src/turnkeyml/common/exceptions.py b/src/turnkeyml/common/exceptions.py
index 95c23c6d..db63f4c2 100644
--- a/src/turnkeyml/common/exceptions.py
+++ b/src/turnkeyml/common/exceptions.py
@@ -32,14 +32,14 @@ class ArgError(Error):
"""
-class StageError(Exception):
+class ToolError(Exception):
"""
Let the user know that something went wrong while
- firing off a Stage.
+ running a tool.
Note: not overloading __init__() so that the
attempt to print to stdout isn't captured into
- the Stage's log file.
+ the Tool's log file.
"""
diff --git a/src/turnkeyml/common/filesystem.py b/src/turnkeyml/common/filesystem.py
index b6b339c6..3ae640d3 100644
--- a/src/turnkeyml/common/filesystem.py
+++ b/src/turnkeyml/common/filesystem.py
@@ -2,7 +2,7 @@
import shutil
import glob
import pathlib
-from typing import Dict, List, Optional
+from typing import Dict, List, Optional, Tuple
import importlib.util
import yaml
import turnkeyml.common.printing as printing
@@ -92,8 +92,15 @@ def get_all(path, exclude_path=False, file_type="state.yaml", recursive=True):
def clean_file_name(script_path: str) -> str:
- # Trim the ".py" / ".onnx"
- return pathlib.Path(script_path).stem
+ """
+ Trim the ".py" / ".onnx" if present.
+
+ If its a state.yaml file, trim the "state.yaml"
+ """
+ if script_path.endswith("_state.yaml"):
+ return pathlib.Path(script_path).stem.replace("_state", "")
+ else:
+ return pathlib.Path(script_path).stem
class CacheError(exp.Error):
@@ -110,7 +117,7 @@ def _load_yaml(file) -> Dict:
return {}
-def _save_yaml(dict: Dict, file):
+def save_yaml(dict: Dict, file):
with open(file, "w", encoding="utf8") as outfile:
yaml.dump(dict, outfile)
@@ -206,6 +213,28 @@ def get_available_scripts(search_dir: str):
return scripts
+def decode_input_arg(input: str) -> Tuple[str, List[str], str]:
+ # Parse the targets out of the file name
+ # Targets use the format:
+ # file_path.ext::target0,target1,...,targetN
+ decoded_input = input.split("::")
+ file_path = os.path.abspath(decoded_input[0])
+
+ if len(decoded_input) == 2:
+ targets = decoded_input[1].split(",")
+ encoded_input = file_path + "::" + decoded_input[1]
+ elif len(decoded_input) == 1:
+ targets = []
+ encoded_input = file_path
+ else:
+ raise ValueError(
+ "Each file input to turnkey should have either 0 or 1 '::' in it."
+ f"However, {file_path} was received."
+ )
+
+ return file_path, targets, encoded_input
+
+
def get_available_builds(cache_dir):
"""
Get all of the build directories within the build cache
@@ -225,52 +254,6 @@ def get_available_builds(cache_dir):
return builds
-def print_available_builds(args):
- printing.log_info(f"Builds available in cache {args.cache_dir}:")
- builds = get_available_builds(args.cache_dir)
- printing.list_table(builds, num_cols=1)
- print()
-
-
-def delete_builds(args):
- check_cache_dir(args.cache_dir)
-
- if args.delete_all:
- builds = get_available_builds(args.cache_dir)
- else:
- builds = [args.build_name]
-
- for build in builds:
- build_path = os.path.join(args.cache_dir, build)
- if is_build_dir(args.cache_dir, build):
- rmdir(build_path)
- printing.log_info(f"Deleted build: {build}")
- else:
- raise CacheError(
- f"No build found with name: {build}. "
- "Try running `turnkey cache list` to see the builds in your build cache."
- )
-
-
-def clean_builds(args):
- check_cache_dir(args.cache_dir)
-
- if args.clean_all:
- builds = get_available_builds(args.cache_dir)
- else:
- builds = [args.build_name]
-
- for build in builds:
- if is_build_dir(args.cache_dir, build):
- clean_output_dir(args.cache_dir, build)
- printing.log_info(f"Removed the build artifacts from: {build}")
- else:
- raise CacheError(
- f"No build found with name: {build}. "
- "Try running `turnkey cache list` to see the builds in your build cache."
- )
-
-
def clean_build_name(build_name: str) -> str:
"""
Remove hash from build name
@@ -323,8 +306,8 @@ class Keys:
ONNX_MODEL_INFO = "onnx_model_information"
# ONNX model input tensor dimensions
ONNX_INPUT_DIMENSIONS = "onnx_input_dimensions"
- # List of all build stages in the Sequence
- SELECTED_SEQUENCE_OF_STAGES = "selected_sequence_of_stages"
+ # List of all build tools in the Sequence
+ SELECTED_SEQUENCE_OF_TOOLS = "selected_sequence_of_tools"
# Location of the most up-to-date ONNX file for this build. If the
# build completed successfully, this is the final ONNX file.
ONNX_FILE = "onnx_file"
@@ -338,8 +321,6 @@ class Keys:
DEVICE = "device"
# Name of the model
MODEL_NAME = "model_name"
- # References the per-evaluation stats section
- EVALUATIONS = "evaluations"
# Catch-all for storing a file's labels
LABELS = "labels"
# Author of the model
@@ -356,25 +337,40 @@ class Keys:
MODEL_SCRIPT = "builtin_model_script"
# Indicates status of the most recent build tool run: FunctionStatus
BUILD_STATUS = "build_status"
- # Indicates status of the most recent benchmark tool run: FunctionStatus
- BENCHMARK_STATUS = "benchmark_status"
# Indicates the match between the TorchScript IR graph and
# the exported onnx model (verified with torch.onnx.verification)
TORCH_ONNX_EXPORT_VALIDITY = "torch_export_validity"
- # Prefix for reporting the execution duration of a stage
- # In the report this will look like stage_duration:STAGE_NAME
- STAGE_DURATION = "stage_duration"
- # Prefix for reporting the execution status of a stage
- # In the report this will look like stage_status:STAGE_NAME
- STAGE_STATUS = "stage_status"
- # Parent key that holds all of the arguments to turnkey's
- # evaluate_file() API
- EVALUATION_ARGS = "turnkey_args"
+ # Prefix for reporting the execution duration of a tool
+ # In the report this will look like tool_duration:TOOL_NAME
+ TOOL_DURATION = "tool_duration"
+ # Prefix for reporting the execution status of a tool
+ # In the report this will look like tool_status:TOOL_NAME
+ TOOL_STATUS = "tool_status"
# Records the date and time of the evaluation after analysis but before
# build and benchmark
TIMESTAMP = "timestamp"
- # Records the logfile of any failed stage/benchmark
+ # Records the logfile of any failed tool/benchmark
ERROR_LOG = "error_log"
+ # Name of the build in the cache
+ BUILD_NAME = "build_name"
+ # Sequence of tools used for this build, along with their args
+ SEQUENCE_INFO = "sequence_info"
+ # Version of TurnkeyML used for the build
+ TURNKEY_VERSION = "turnkey_version"
+ # Unique ID for this build
+ UID = "uid"
+ # Unique hash for this model
+ MODEL_HASH = "model_hash"
+ # Input shapes expected by the model
+ EXPECTED_INPUT_SHAPES = "expected_input_shapes"
+ # Input data types expected by the model
+ EXPECTED_INPUT_DTYPES = "expected_input_dtypes"
+ # Whether or not inputs must be downcasted during inference
+ DOWNCAST_APPLIED = "downcast_applied"
+ # Directory where the turnkey build cache is stored
+ CACHE_DIR = "cache_dir"
+ # Example inputs to the model
+ INPUTS = "inputs"
def _clean_logfile(logfile_lines: List[str]) -> List[str]:
@@ -393,14 +389,13 @@ def stats_file(cache_dir: str, build_name: str):
class Stats:
- def __init__(self, cache_dir: str, build_name: str, evaluation_id: str = None):
+ def __init__(self, cache_dir: str, build_name: str):
self.file = stats_file(cache_dir, build_name)
- self.evaluation_id = evaluation_id
os.makedirs(os.path.dirname(self.file), exist_ok=True)
if not os.path.exists(self.file):
- initial = {Keys.EVALUATIONS: {}}
- _save_yaml(initial, self.file)
+ # Start an empty stats file
+ save_yaml({}, self.file)
@property
def stats(self):
@@ -423,7 +418,7 @@ def _set_key(self, dict, keys: List["str"], value):
self._set_key(dict[keys[0]], keys[1:], value)
- def save_model_stat(self, key: str, value):
+ def save_stat(self, key: str, value):
"""
Save statistics to an yaml file in the build directory
"""
@@ -432,29 +427,20 @@ def save_model_stat(self, key: str, value):
self._set_key(stats_dict, [key], value)
- _save_yaml(stats_dict, self.file)
+ save_yaml(stats_dict, self.file)
- def save_model_eval_stat(self, key: str, value):
+ def save_sub_stat(self, parent_key: str, key: str, value):
stats_dict = self.stats
- self._set_key(stats_dict, [Keys.EVALUATIONS, self.evaluation_id, key], value)
+ self._set_key(stats_dict, [parent_key, key], value)
- _save_yaml(stats_dict, self.file)
-
- def save_model_eval_sub_stat(self, parent_key: str, key: str, value):
- stats_dict = self.stats
-
- self._set_key(
- stats_dict, [Keys.EVALUATIONS, self.evaluation_id, parent_key, key], value
- )
-
- _save_yaml(stats_dict, self.file)
-
- @property
- def evaluation_stats(self):
- return self.stats[Keys.EVALUATIONS][self.evaluation_id]
+ save_yaml(stats_dict, self.file)
def save_eval_error_log(self, logfile_path):
+ if logfile_path is None:
+ # Avoid an error in the situation where we crashed before
+ # initializing the tool (in which case it has no logfile path yet)
+ return
if os.path.exists(logfile_path):
with open(logfile_path, "r", encoding="utf-8") as f:
full_log = f.readlines()
@@ -480,18 +466,7 @@ def save_eval_error_log(self, logfile_path):
else:
stats_log = _clean_logfile(full_log)
- self.save_model_eval_stat(Keys.ERROR_LOG, stats_log)
-
-
-def print_cache_dir(_=None):
- printing.log_info(f"The default cache directory is: {DEFAULT_CACHE_DIR}")
-
-
-def print_models_dir(args=None):
- if args.verbose:
- printing.log_info(f"The models directory is: {MODELS_DIR}")
- else:
- print(MODELS_DIR)
+ self.save_stat(Keys.ERROR_LOG, stats_log)
def expand_inputs(input_paths: List[str]) -> List[str]:
@@ -511,6 +486,16 @@ def expand_inputs(input_paths: List[str]) -> List[str]:
return input_paths_expanded
+def read_labels(file_path: str) -> Dict[str, str]:
+ # Load labels data from python scripts
+ # This is not compatible with ONNX files, so we return
+ # and empty dictionary in that case
+ if file_path.endswith(".py"):
+ return labels.load_from_file(file_path)
+ else:
+ return {}
+
+
def rebase_cache_dir(input_path: str, build_name: str, new_cache_dir: str):
"""
Rebase a turnkey build path onto a new turnkey cache directory.
@@ -528,3 +513,10 @@ def rebase_cache_dir(input_path: str, build_name: str, new_cache_dir: str):
relative_input_path = input_path.split(build_name, 1)[1][1:]
return os.path.join(new_cache_dir, build_name, relative_input_path)
+
+
+def check_extension(choices, file_name, error_func):
+ _, extension = os.path.splitext(file_name.split("::")[0])
+ if extension[1:].lower() not in choices:
+ error_func(f"input_files must end with {choices} (got '{file_name}')\n")
+ return file_name
diff --git a/src/turnkeyml/build/onnx_helpers.py b/src/turnkeyml/common/onnx_helpers.py
similarity index 79%
rename from src/turnkeyml/build/onnx_helpers.py
rename to src/turnkeyml/common/onnx_helpers.py
index cfffd36a..1d02e126 100644
--- a/src/turnkeyml/build/onnx_helpers.py
+++ b/src/turnkeyml/common/onnx_helpers.py
@@ -2,13 +2,49 @@
Helper functions for dealing with ONNX files and ONNX models
"""
-from typing import Tuple
+import os
+from typing import Tuple, Union
import re
import math
import numpy as np
import onnx
import onnxruntime as ort
import turnkeyml.common.exceptions as exp
+from turnkeyml.state import State
+import turnkeyml.common.build as build
+
+
+def check_model(onnx_file, success_message, fail_message) -> bool:
+ if os.path.isfile(onnx_file):
+ print(success_message)
+ else:
+ print(fail_message)
+ return False
+ try:
+ onnx.checker.check_model(onnx_file)
+ print("\tSuccessfully checked onnx file")
+ return True
+ except onnx.checker.ValidationError as e:
+ print("\tError while checking generated ONNX file")
+ print(e)
+ return False
+
+
+def original_inputs_file(cache_dir: str, build_name: str):
+ return os.path.join(build.output_dir(cache_dir, build_name), "inputs.npy")
+
+
+def onnx_dir(state: State):
+ return os.path.join(build.output_dir(state.cache_dir, state.build_name), "onnx")
+
+
+def get_output_names(
+ onnx_model: Union[str, onnx.ModelProto]
+): # pylint: disable=no-member
+ # Get output names of ONNX file/model
+ if not isinstance(onnx_model, onnx.ModelProto): # pylint: disable=no-member
+ onnx_model = onnx.load(onnx_model)
+ return [node.name for node in onnx_model.graph.output] # pylint: disable=no-member
def parameter_count(model):
diff --git a/src/turnkeyml/common/performance.py b/src/turnkeyml/common/performance.py
index d0ad1461..5bc9a4b2 100644
--- a/src/turnkeyml/common/performance.py
+++ b/src/turnkeyml/common/performance.py
@@ -1,17 +1,45 @@
from dataclasses import dataclass
-from typing import Optional, Union, Dict
+from typing import Optional, Union, Dict, List
+import argparse
import turnkeyml.common.printing as printing
import turnkeyml.common.exceptions as exp
+from turnkeyml.state import State
+
+
+def enumerate_supported_devices(rt_supported_devices: set) -> List[str]:
+
+ result = []
+ if isinstance(rt_supported_devices, dict):
+ for family, parts in rt_supported_devices.items():
+ result.append(family)
+
+ if isinstance(parts, dict):
+ for part, configs in parts.items():
+ result.append(f"{family}::{part}")
+
+ for config in configs:
+ result.append(f"{family}::{part}::{config}")
+ elif isinstance(parts, list):
+ for part in parts:
+ result.append(f"{family}::{part}")
+
+ else:
+ for family in rt_supported_devices:
+ result.append(family)
+
+ return result
class Device:
- def __init__(self, selected_device: str, rt_supported_devices: Optional[Dict] = None):
+ def __init__(
+ self, selected_device: str, rt_supported_devices: Optional[Dict] = None
+ ):
self.family: str
self.part: Optional[str] = None
self.config: Optional[str] = None
# Unpack selected_device
- values = selected_device.split("::")
+ values = str(selected_device).split("::")
if len(values) > 3:
raise exp.ArgError(
f"Recieved a device argument that has more than 3 members: {selected_device}. "
@@ -21,11 +49,11 @@ def __init__(self, selected_device: str, rt_supported_devices: Optional[Dict] =
# Set family, part, and config straight away if rt_supported_devices is not provided
if rt_supported_devices is None:
- if len(values)>0:
+ if len(values) > 0:
self.family = values[0]
- if len(values)>1:
+ if len(values) > 1:
self.part = values[1]
- if len(values)>2:
+ if len(values) > 2:
self.config = values[2]
return
@@ -50,13 +78,13 @@ def __init__(self, selected_device: str, rt_supported_devices: Optional[Dict] =
if values[1] in rt_supported_devices[self.family]:
self.part = values[1]
elif len(rt_supported_devices[self.family]) == 0:
- raise exp.ArgError(
- f"Device family {self.family} supports no parts."
- )
+ raise exp.ArgError(f"Device family {self.family} supports no parts.")
else:
error_msg = f"Part {values[1]} is not supported by this device family."
if len(rt_supported_devices[self.family]) > 0:
- error_msg += f" Supported parts are: {rt_supported_devices[self.family]}"
+ error_msg += (
+ f" Supported parts are: {rt_supported_devices[self.family]}"
+ )
raise exp.ArgError(error_msg)
elif rt_supported_devices[self.family]:
self.part = next(iter(rt_supported_devices[self.family]))
@@ -97,17 +125,47 @@ class MeasuredPerformance:
device_type: Union[str, Device]
build_name: str
throughput_units: str = "inferences per second (IPS)"
- latency_units: str = "milliseconds (ms)"
+ mean_latency_units: str = "milliseconds (ms)"
def print(self):
printing.log_info(
f"\nPerformance of build {self.build_name} on {self.device} "
f"({self.runtime} v{self.runtime_version}) is:"
)
- print(f"\tMean Latency: {self.mean_latency:.3f} {self.latency_units}")
+ print(f"\tMean Latency: {self.mean_latency:.3f} {self.mean_latency_units}")
print(f"\tThroughput: {self.throughput:.1f} {self.throughput_units}")
print()
def __post_init__(self):
if isinstance(self.device_type, Device):
self.device_type = str(self.device_type)
+
+
+def parse_device(
+ state: State,
+ parsed_args: argparse.Namespace,
+ default_device: str,
+ tool_name: str,
+ supported_devices=None,
+):
+ # Inherit the device from the state of a prior tool, if available
+ if parsed_args.device is None:
+ if vars(state).get("device") is None:
+ device_to_use = default_device
+ else:
+ device_to_use = state.device
+ else:
+ if vars(state).get("device") is not None and str(state.device) != str(
+ parsed_args.device
+ ):
+ raise exp.ArgError(
+ f"A previous tool set the device to {state.device}, "
+ f"however this tool ({tool_name}) "
+ f"is attempting to set device to {parsed_args.device}. "
+ "We suggest omitting the `--device` argument from "
+ "this tool."
+ )
+
+ device_to_use = parsed_args.device
+
+ parsed_args.device = Device(device_to_use, supported_devices)
diff --git a/src/turnkeyml/analyze/status.py b/src/turnkeyml/common/status.py
similarity index 64%
rename from src/turnkeyml/analyze/status.py
rename to src/turnkeyml/common/status.py
index 2abc9888..f72f9e0b 100644
--- a/src/turnkeyml/analyze/status.py
+++ b/src/turnkeyml/common/status.py
@@ -1,16 +1,13 @@
import os
import sys
import math
-from enum import Enum
import dataclasses
-import platform
from typing import Callable, List, Union, Dict, Optional
import torch
from turnkeyml.common import printing
import turnkeyml.common.build as build
-from turnkeyml.common.performance import MeasuredPerformance
import turnkeyml.common.filesystem as fs
-import turnkeyml.analyze.model as analyze_model
+import turnkeyml.common.analyze_model as analyze_model
def _pretty_print_key(key: str) -> str:
@@ -31,12 +28,6 @@ def parameters_to_size(parameters: int, byte_per_parameter: int = 4) -> str:
return "%s %s" % (s, size_name[i])
-class Verbosity(Enum):
- AUTO = "auto"
- DYNAMIC = "dynamic"
- STATIC = "static"
-
-
@dataclasses.dataclass
class BasicInfo:
name: str
@@ -46,8 +37,6 @@ class BasicInfo:
params: int = 0
depth: int = 0
parent_hash: Union[str, None] = None
- build_model: bool = False
- model_type: build.ModelType = build.ModelType.PYTORCH
model_class: type = None
# This is the "model hash", not to be confused with the
# "invocation hash"
@@ -65,7 +54,6 @@ class SkipFields:
file_name: bool = False
model_name: bool = False
- model_type: bool = False
parameters: bool = False
location: bool = False
input_shape: bool = False
@@ -82,19 +70,20 @@ class UniqueInvocationInfo(BasicInfo):
"""
invocation_hash: Union[str, None] = None
- performance: MeasuredPerformance = None
traceback: List[str] = None
inputs: Union[dict, None] = None
input_shapes: Union[dict, None] = None
executed: int = 0
exec_time: float = 0.0
status_message: str = ""
+ extra_status: Optional[str] = ""
is_target: bool = False
+ auto_selected: bool = False
status_message_color: printing.Colors = printing.Colors.ENDC
traceback_message_color: printing.Colors = printing.Colors.FAIL
- stats_keys: Optional[List[str]] = None
- stats: fs.Stats = None
-
+ stats_keys: List[str] = dataclasses.field(default_factory=list)
+ forward_function_pointer: callable = None
+ original_forward_function: callable = None
# Fields specific to printing status
skip: SkipFields = None
extension: str = None
@@ -117,7 +106,7 @@ def _print_heading(
print(f"{self.script_name}{self.extension}:")
# Print invocation about the model (only applies to scripts, not ONNX files)
- if self.model_type != build.ModelType.ONNX_FILE:
+ if not (self.extension == ".onnx" or self.extension == "_state.yaml"):
if self.depth == 0 and multiple_unique_invocations:
if not model_visited:
printing.logn(f"{self.indent}{self.name}")
@@ -131,30 +120,20 @@ def _print_heading(
self.skip.file_name = True
self.skip.model_name = True
- def _print_model_type(self):
- if self.skip.model_type:
- return
-
- if self.depth == 0:
- if self.model_type == build.ModelType.PYTORCH:
- print(f"{self.indent}\tModel Type:\tPytorch (torch.nn.Module)")
- elif self.model_type == build.ModelType.KERAS:
- print(f"{self.indent}\tModel Type:\tKeras (tf.keras.Model)")
- elif self.model_type == build.ModelType.ONNX_FILE:
- print(f"{self.indent}\tModel Type:\tONNX File (.onnx)")
-
- self.skip.model_type = True
-
def _print_location(self):
- if self.skip.location:
+ if self.skip.location or self.file == "":
return
if self.depth == 0:
- print(f"{self.indent}\tLocation:\t{self.file}, line {self.line}")
+ print(f"{self.indent}\tLocation:\t{self.file}", end="")
+ if self.extension == ".onnx":
+ print()
+ else:
+ print(f", line {self.line}")
self.skip.location = True
def _print_parameters(self):
- if self.skip.parameters:
+ if self.skip.parameters or self.params is None:
return
# Display number of parameters and size
@@ -184,7 +163,7 @@ def _print_unique_input_shape(
self.skip.unique_input_shape = True
def _print_input_shape(self):
- if self.skip.input_shape:
+ if self.skip.input_shape or self.input_shapes is None:
return
# Prepare input shape to be printed
@@ -197,14 +176,15 @@ def _print_input_shape(self):
self.skip.input_shape = True
def _print_build_dir(self, cache_dir: str, build_name: str):
- if self.skip.build_dir:
+ if self.skip.build_dir or not self.is_target:
return
- print(f"{self.indent}\tBuild dir:\t {build.output_dir(cache_dir, build_name)}")
+ print(f"{self.indent}\tBuild dir:\t{build.output_dir(cache_dir, build_name)}")
self.skip.build_dir = True
- def _print_status(self):
+ def _print_status(self, cache_dir: str, build_name: str):
+ stats = fs.Stats(cache_dir, build_name)
if self.skip.previous_status_message:
if self.skip.previous_status_message == self.status_message:
# This is a special case for skipping: we only want to skip
@@ -215,57 +195,45 @@ def _print_status(self):
# Print some whitespace to help the status stand out
print()
- # Print turnkey results if turnkey was run
- if self.performance:
- printing.log(f"{self.indent}\tStatus:\t\t")
- printing.logn(
- f"Successfully benchmarked on {self.performance.device} "
- f"({self.performance.runtime} "
- f"v{self.performance.runtime_version}) ",
- c=self.status_message_color,
- )
- printing.logn(
- f"{self.indent}\t\t\tMean Latency:\t{self.performance.mean_latency:.3f}"
- f"\t{self.performance.latency_units}"
- )
- printing.logn(
- f"{self.indent}\t\t\tThroughput:\t{self.performance.throughput:.1f}"
- f"\t{self.performance.throughput_units}"
- )
-
- if self.stats_keys is not None:
- for key in self.stats_keys:
- nice_key = _pretty_print_key(key)
- try:
- value = self.stats.evaluation_stats[key]
- printing.logn(f"{self.indent}\t\t\t{nice_key}:\t{value}")
- except KeyError:
- # Ignore any keys that are missing because that means the
- # evaluation did not produce them
- pass
- print()
- else:
- if self.is_target and self.build_model:
- printing.log(f"{self.indent}\tStatus:\t\t")
- printing.logn(
- f"{self.status_message}",
- c=self.status_message_color,
- )
+ printing.log(f"{self.indent}\tStatus:\t\t")
+ printing.logn(
+ f"{self.status_message}",
+ c=self.status_message_color,
+ )
+ if self.is_target:
+
+ for key in self.stats_keys:
+ nice_key = _pretty_print_key(key)
+ try:
+ value = stats.stats[key]
+ if isinstance(value, float):
+ value = "{0:.3f}".format(value)
+ # Tools may provide a unit of measurement for their status
+ # stats, whose key name should follow the format
+ # "STATUS_STATS_KEY_units"
+ units_key = key + "_units"
+ units = stats.stats.get(units_key)
+ units = units if units is not None else ""
+ printing.logn(f"{self.indent}\t\t\t{nice_key}:\t{value} {units}")
+ except KeyError:
+ # Ignore any keys that are missing because that means the
+ # evaluation did not produce them
+ pass
+
+ if self.traceback is not None:
+ if os.environ.get("TURNKEY_TRACEBACK") != "False":
+ for line in self.traceback:
+ for subline in line.split("\n")[:-1]:
+ print(f"{self.indent}\t{subline}")
- if self.traceback is not None:
- if os.environ.get("TURNKEY_TRACEBACK") != "False":
- for line in self.traceback:
- for subline in line.split("\n")[:-1]:
- print(f"{self.indent}\t{subline}")
-
- else:
- printing.logn(
- f"{self.indent}\t\t\tTo see the full stack trace, "
- "rerun with `export TURNKEY_TRACEBACK=True`.\n",
- c=self.status_message_color,
- )
else:
- print()
+ printing.logn(
+ f"{self.indent}\t\t\tTo see the full stack trace, "
+ "rerun with `export TURNKEY_TRACEBACK=True`.\n",
+ c=self.status_message_color,
+ )
+ else:
+ print()
self.skip.previous_status_message = self.status_message
@@ -282,14 +250,12 @@ def print(
Print information about a given model or submodel.
"""
- if self.model_type == build.ModelType.ONNX_FILE:
- self.extension = ".onnx"
+ if self.extension == ".onnx":
self.indent = "\t" * (2 * self.depth)
else:
- self.extension = ".py"
self.indent = "\t" * (2 * self.depth + 1)
- if self.exec_time == 0 or self.build_model:
+ if self.exec_time == 0:
exec_time_formatted = ""
else:
exec_time_formatted = f" - {self.exec_time:.2f}s"
@@ -302,7 +268,6 @@ def print(
)
if (self.depth == 0 and not model_visited) or (self.depth != 0):
# Print this information only once per model
- self._print_model_type()
self._print_location()
self._print_parameters()
self._print_unique_input_shape(
@@ -310,7 +275,7 @@ def print(
)
self._print_input_shape()
self._print_build_dir(cache_dir=cache_dir, build_name=build_name)
- self._print_status()
+ self._print_status(cache_dir=cache_dir, build_name=build_name)
print()
@@ -325,55 +290,7 @@ class ModelInfo(BasicInfo):
last_unique_invocation_executed: Union[str, None] = None
def __post_init__(self):
- self.params = analyze_model.count_parameters(self.model, self.model_type)
-
-
-def update(
- models_found: Dict[str, ModelInfo],
- build_name: str,
- cache_dir: str,
- invocation_info: UniqueInvocationInfo,
- verbosity: Verbosity,
-) -> None:
- """
- Prints all models and submodels found
- """
-
- if verbosity == Verbosity.DYNAMIC:
- if platform.system() != "Windows":
- os.system("clear")
- else:
- os.system("cls")
-
- printing.logn(
- "\nModels discovered during profiling:\n",
- c=printing.Colors.BOLD,
- )
- recursive_print(
- models_found=models_found,
- build_name=build_name,
- cache_dir=cache_dir,
- parent_model_hash=None,
- parent_invocation_hash=None,
- script_names_visited=[],
- )
- else: # Verbosity.STATIC
- if invocation_info.model_type == build.ModelType.ONNX_FILE:
- # We don't invoke the ONNX files, so they can't have multiple invocations
- multiple_unique_invocations = False
- else:
- multiple_unique_invocations = (
- len(models_found[invocation_info.hash].unique_invocations) > 1
- )
-
- invocation_info.print(
- build_name=build_name,
- cache_dir=cache_dir,
- print_file_name=True,
- invocation_idx=0,
- model_visited=False,
- multiple_unique_invocations=multiple_unique_invocations,
- )
+ self.params = analyze_model.count_parameters(self.model)
def recursive_print(
diff --git a/src/turnkeyml/build/tensor_helpers.py b/src/turnkeyml/common/tensor_helpers.py
similarity index 91%
rename from src/turnkeyml/build/tensor_helpers.py
rename to src/turnkeyml/common/tensor_helpers.py
index 7aa3df8d..e22a809d 100644
--- a/src/turnkeyml/build/tensor_helpers.py
+++ b/src/turnkeyml/common/tensor_helpers.py
@@ -8,7 +8,7 @@
import numpy as np
import turnkeyml.common.exceptions as exp
import turnkeyml.common.build as build
-import turnkeyml.common.tf_helpers as tf_helpers
+
# Checks whether a given input has the expected shape
def check_shapes_and_dtypes(
@@ -16,7 +16,7 @@ def check_shapes_and_dtypes(
):
current_shapes, current_dtypes = build.get_shapes_and_dtypes(inputs)
- # If we are modifying the data type of inputs on a later stage we
+ # If we are modifying the data type of inputs on a later tool we
# verify input type based on the future data type conversion
if expect_downcast:
for key, value in current_dtypes.items():
@@ -57,8 +57,6 @@ def save_inputs(inputs, inputs_file, input_dtypes=None, downcast=True):
continue
if torch.is_tensor(inputs_converted[i][k]):
inputs_converted[i][k] = inputs_converted[i][k].cpu().detach().numpy()
- if tf_helpers.is_keras_tensor(inputs_converted[i][k]):
- inputs_converted[i][k] = inputs_converted[i][k].numpy()
if downcast:
if input_dtypes is not None and input_dtypes[k] is not None:
inputs_converted[i][k] = inputs_converted[i][k].astype(
diff --git a/test/helpers/common.py b/src/turnkeyml/common/test_helpers.py
similarity index 89%
rename from test/helpers/common.py
rename to src/turnkeyml/common/test_helpers.py
index 4a86782a..a4c1c2b8 100644
--- a/test/helpers/common.py
+++ b/src/turnkeyml/common/test_helpers.py
@@ -1,8 +1,8 @@
import os
import shutil
from typing import Dict
-import turnkeyml.common.filesystem as filesystem
-import turnkeyml.common.build as build
+import turnkeyml.common.filesystem as fs
+from turnkeyml.state import load_state
# We generate a corpus on to the filesystem during the test
@@ -104,7 +104,7 @@ def create_test_dir(key: str, test_scripts: Dict = None):
# Define paths to be used
base_dir = os.path.dirname(os.path.abspath(__file__))
cache_dir = os.path.join(base_dir, "generated", f"{key}_cache_dir")
- corpus_dir = os.path.join(base_dir, "generated", f"test_corpus")
+ corpus_dir = os.path.join(base_dir, "generated", "test_corpus")
# Delete folders if they exist and
if os.path.isdir(cache_dir):
@@ -133,17 +133,16 @@ def get_stats_and_state(
cache_dir: str,
) -> int:
# Figure out the build name by surveying the build cache
- builds = filesystem.get_all(cache_dir)
+ builds = fs.get_all(cache_dir)
test_script_name = strip_dot_py(test_script)
for build_state_file in builds:
if test_script_name in build_state_file:
- build_state = build.load_state(state_path=build_state_file)
- stats = filesystem.Stats(
+ build_state = load_state(state_path=build_state_file)
+ stats = fs.Stats(
build_state.cache_dir,
- build_state.config.build_name,
- build_state.evaluation_id,
+ build_state.build_name,
)
- return stats.evaluation_stats, build_state
+ return stats.stats, build_state
raise Exception(f"Stats not found for {test_script}")
diff --git a/src/turnkeyml/common/tf_helpers.py b/src/turnkeyml/common/tf_helpers.py
deleted file mode 100644
index 359efff4..00000000
--- a/src/turnkeyml/common/tf_helpers.py
+++ /dev/null
@@ -1,67 +0,0 @@
-"""
-Functions that help us avoid importing tensorflow (TF), since that import
-takes a very long time.
-
-The test `if "tensorflow" in sys.modules:` checks to see if TF
-has already been imported. This will always be true if someone is passing
-a TF model since... that TF model had to come from somewhere :)
-
-If TF hasn't already been imported, then there is no change that an object
-is a TF instance, or TF is in any particular mode, or anything else, so
-we can just return False on those checks.
-"""
-
-import sys
-import inspect
-from typing import List
-
-
-def is_keras_model(model) -> bool:
- if "tensorflow" in sys.modules:
- return isinstance(model, sys.modules["tensorflow"].keras.Model)
- else:
- return False
-
-
-def is_keras_tensor(tensor) -> bool:
- if "tensorflow" in sys.modules:
- return sys.modules["tensorflow"].is_tensor(tensor)
- else:
- return False
-
-
-def is_executing_eagerly() -> bool:
- if "tensorflow" in sys.modules:
- return sys.modules["tensorflow"].executing_eagerly()
- else:
- return False
-
-
-def type_is_tf_tensor(object) -> bool:
- if "tensorflow" in sys.modules:
- return object is sys.modules["tensorflow"].Tensor
- else:
- return False
-
-def is_keras_subclass(obj_type) -> bool:
- if "tensorflow" in sys.modules:
- return issubclass(obj_type, sys.modules["tensorflow"].keras.Model)
- else:
- return False
-
-def get_classes(module) -> List[str]:
- """
- Returns all classes within a module
- """
- return [y for x, y in inspect.getmembers(module, inspect.isclass)]
-
-def get_transformers_activations() -> List:
- """
- We need this helper because `import transformers.activations` brings in `tensorflow`
- We can apply this helper because there is 0 chance of encountering a
- transformers activation if user code has not already imported transformers
- """
- if "transformers" in sys.modules:
- return get_classes(sys.modules["transformers"].activations)
- else:
- return []
diff --git a/src/turnkeyml/files_api.py b/src/turnkeyml/files_api.py
index 8c8ae12e..ebc12b42 100644
--- a/src/turnkeyml/files_api.py
+++ b/src/turnkeyml/files_api.py
@@ -1,28 +1,15 @@
import time
import os
-import copy
import glob
-import pathlib
-from typing import Tuple, List, Dict, Optional, Union
+from typing import List, Dict, Optional, Union
+import git
import turnkeyml.common.printing as printing
import turnkeyml.common.exceptions as exceptions
-import turnkeyml.build.stage as stage
+from turnkeyml.sequence import Sequence
import turnkeyml.cli.spawn as spawn
-import turnkeyml.common.filesystem as filesystem
+import turnkeyml.common.filesystem as fs
import turnkeyml.common.labels as labels_library
-import turnkeyml.run.devices as devices
-from turnkeyml.common.performance import Device
-from turnkeyml.run.devices import SUPPORTED_RUNTIMES
-from turnkeyml.analyze.script import (
- evaluate_script,
- TracerArgs,
- Action,
- explore_invocation,
- get_model_hash,
-)
-from turnkeyml.analyze.status import ModelInfo, UniqueInvocationInfo, Verbosity
-import turnkeyml.common.build as build
-import turnkeyml.build.onnx_helpers as onnx_helpers
+from turnkeyml.state import State
# The licensing for tqdm is confusing. Pending a legal scan,
# the following code provides tqdm to users who have installed
@@ -36,93 +23,6 @@ def tqdm(iterable, **kwargs): # pylint: disable=unused-argument
return iterable
-def _select_verbosity(
- verbosity: str, input_files_expanded: List[str], process_isolation: bool
-) -> Tuple[Verbosity, bool]:
- """
- Choose verbosity based on the following policies:
- 1. The explicit verbosity argument takes priority over AUTO and the env var
- 2. The env var takes priority over AUTO
- 3. Use STATIC when there are many inputs, or in process isolation mode,
- and use DYNAMIC otherwise
-
- Returns the selected verbosity.
- """
-
- verbosity_choices = {
- field.value: field for field in Verbosity if field != Verbosity.AUTO
- }
- verbosity_env_var = os.environ.get("TURNKEY_VERBOSITY")
-
- if verbosity != Verbosity.AUTO.value:
- # Specific verbosity argument takes priority over env var
- verbosity_selected = verbosity_choices[verbosity]
- elif verbosity_env_var in verbosity_choices.keys():
- # Env var takes priority over AUTO
- verbosity_selected = verbosity_choices[verbosity_env_var]
- else:
- # Verbosity.AUTO and no env var
- if len(input_files_expanded) > 4 or process_isolation:
- # Automatically select STATIC if:
- # - There are many evaluations (>4), since DYNAMIC mode works
- # best when all results fit on one screen
- # - Process isolation mode is active, since DYNAMIC mode is
- # incompatible with process isolation
- verbosity_selected = Verbosity.STATIC
- else:
- verbosity_selected = Verbosity.DYNAMIC
-
- # Use a progress bar in STATIC mode if there is more than 1 input
- use_progress_bar = (
- verbosity_selected == Verbosity.STATIC and len(input_files_expanded) > 1
- )
-
- return verbosity_selected, use_progress_bar
-
-
-def decode_input_arg(input: str) -> Tuple[str, List[str], str]:
- # Parse the targets out of the file name
- # Targets use the format:
- # file_path.ext::target0,target1,...,targetN
- decoded_input = input.split("::")
- file_path = os.path.abspath(decoded_input[0])
-
- if len(decoded_input) == 2:
- targets = decoded_input[1].split(",")
- encoded_input = file_path + "::" + decoded_input[1]
- elif len(decoded_input) == 1:
- targets = []
- encoded_input = file_path
- else:
- raise ValueError(
- "Each file input to turnkey should have either 0 or 1 '::' in it."
- f"However, {file_path} was received."
- )
-
- return file_path, targets, encoded_input
-
-
-def check_sequence_type(
- sequence: Union[str, stage.Sequence],
- use_slurm: bool,
- process_isolation: bool,
-):
- """
- Check to make sure the user's sequence argument is valid.
- use_slurm or process_isolation: only work with names of installed sequences
- otherwise: sequence instances and sequence names are allowed
- """
-
- if sequence is not None:
- if use_slurm or process_isolation:
- # The spawned process will need to load a sequence file
- if not isinstance(sequence, str):
- raise ValueError(
- "The 'sequence' arg must be a str (name of an installed sequence) "
- "when use_slurm=True or process_isolation=True."
- )
-
-
def unpack_txt_inputs(input_files: List[str]) -> List[str]:
"""
Replace txt inputs with models listed inside those files
@@ -145,38 +45,41 @@ def unpack_txt_inputs(input_files: List[str]) -> List[str]:
return processed_files + [f for f in input_files if not f.endswith(".txt")]
-# pylint: disable=unused-argument
-def benchmark_files(
+def evaluate_files(
input_files: List[str],
- use_slurm: bool = False,
- process_isolation: bool = False,
+ sequence: Union[Dict, Sequence] = None,
+ cache_dir: str = fs.DEFAULT_CACHE_DIR,
lean_cache: bool = False,
- cache_dir: str = filesystem.DEFAULT_CACHE_DIR,
labels: List[str] = None,
- rebuild: Optional[str] = None,
- device: str = "x86",
- runtime: str = None,
- iterations: int = 100,
- analyze_only: bool = False,
- build_only: bool = False,
- script_args: Optional[str] = None,
- max_depth: int = 0,
- onnx_opset: Optional[int] = None,
+ use_slurm: bool = False,
+ process_isolation: bool = False,
timeout: Optional[int] = None,
- sequence: Union[str, stage.Sequence] = None,
- rt_args: Optional[Dict] = None,
- verbosity: str = Verbosity.STATIC.value,
):
-
- # Capture the function arguments so that we can forward them
- # to downstream APIs
- benchmarking_args = copy.deepcopy(locals())
- regular_files = []
+ """
+ Iterate over a list of input files, evaluating each one with the provided sequence.
+
+ Args:
+ input_files: each file in this list will be passed into the first tool in
+ the provided build sequence.
+ sequence: the build tools and their arguments used to act on the inputs.
+ cache_dir: Directory to use as the cache for this build. Output files
+ from this build will be stored at cache_dir/build_name/
+ lean_cache: delete build artifacts from the cache after the build has completed.
+ lables: if provided, only input files that are marked with these labels will be
+ passed into the sequence; the other input files will be skipped.
+ use_slurm: evaluate each input file as its own slurm job (requires slurm to be)
+ set up in advance on your system.
+ process_isolation: evaluate each input file in a subprocess. If one subprocess
+ fails, this function will move on to the next input file.
+ timeout: in slurm or process isolation modes, the evaluation of each input file
+ will be canceled if it exceeds this timeout value (in seconds).
+ """
# Replace .txt files with the models listed inside them
input_files = unpack_txt_inputs(input_files)
# Iterate through each string in the input_files list
+ regular_files = []
for input_string in input_files:
if not any(char in input_string for char in "*?[]"):
regular_files.append(input_string)
@@ -204,40 +107,15 @@ def benchmark_files(
else:
timeout_to_use = spawn.DEFAULT_TIMEOUT_SECONDS
- benchmarking_args["timeout"] = timeout_to_use
-
# Convert regular expressions in input files argument
# into full file paths (e.g., [*.py] -> [a.py, b.py] )
- input_files_expanded = filesystem.expand_inputs(input_files)
-
- # Do not forward arguments to downstream APIs
- # that will be decoded in this function body
- benchmarking_args.pop("input_files")
- benchmarking_args.pop("labels")
- benchmarking_args.pop("use_slurm")
- benchmarking_args.pop("process_isolation")
+ input_files_expanded = fs.expand_inputs(input_files)
# Make sure the cache directory exists
- filesystem.make_cache_dir(cache_dir)
-
- check_sequence_type(sequence, use_slurm, process_isolation)
-
- if device is None:
- device = "x86"
-
- # Replace the runtime with a default value, if needed
- selected_runtime = devices.apply_default_runtime(device, runtime)
- benchmarking_args["runtime"] = selected_runtime
-
- # Get the default part and config by providing the Device class with
- # the supported devices by the runtime
- runtime_supported_devices = SUPPORTED_RUNTIMES[selected_runtime][
- "supported_devices"
- ]
- benchmarking_args["device"] = str(Device(device, runtime_supported_devices))
+ fs.make_cache_dir(cache_dir)
# Force the user to specify a legal cache dir in NFS if they are using slurm
- if cache_dir == filesystem.DEFAULT_CACHE_DIR and use_slurm:
+ if cache_dir == fs.DEFAULT_CACHE_DIR and use_slurm:
printing.log_warning(
"Using the default cache directory when using Slurm will cause your cached "
"files to only be available at the Slurm node. If this is not the behavior "
@@ -247,41 +125,21 @@ def benchmark_files(
# Get list containing only file names
clean_file_names = [
- decode_input_arg(file_name)[0] for file_name in input_files_expanded
+ fs.decode_input_arg(file_name)[0] for file_name in input_files_expanded
]
# Validate that the files have supported file extensions
# Note: We are not checking for .txt files here as those were previously handled
for file_name in clean_file_names:
- if not file_name.endswith(".py") and not file_name.endswith(".onnx"):
+ if (
+ not file_name.endswith(".py")
+ and not file_name.endswith(".onnx")
+ and not file_name.endswith("state.yaml")
+ ):
raise exceptions.ArgError(
f"File extension must be .py, .onnx, or .txt (got {file_name})"
)
- # Decode turnkey args into TracerArgs flags
- if analyze_only:
- actions = [
- Action.ANALYZE,
- ]
- elif build_only:
- actions = [
- Action.ANALYZE,
- Action.BUILD,
- ]
- else:
- actions = [
- Action.ANALYZE,
- Action.BUILD,
- Action.BENCHMARK,
- ]
-
- if Action.BENCHMARK in actions:
- printing.log_warning(
- "The benchmarking functionality of ONNX TurnkeyML has been "
- "deprecated. See https://github.com/onnx/turnkeyml/milestone/3 "
- "for details."
- )
-
if use_slurm:
jobs = spawn.slurm_jobs_in_queue()
if len(jobs) > 0:
@@ -290,137 +148,97 @@ def benchmark_files(
"Suggest quitting turnkey, running 'scancel -u $USER' and trying again."
)
- # Use this data structure to keep a running index of all models
- models_found: Dict[str, ModelInfo] = {}
-
- verbosity_policy, use_progress_bar = _select_verbosity(
- verbosity, input_files_expanded, process_isolation
- )
- benchmarking_args["verbosity"] = verbosity_policy
-
- # Fork the args for analysis since they have differences from the spawn args:
- # build_only and analyze_only are encoded into actions
- analysis_args = copy.deepcopy(benchmarking_args)
- analysis_args.pop("build_only")
- analysis_args.pop("analyze_only")
- analysis_args["actions"] = actions
- analysis_args.pop("timeout")
+ use_progress_bar = len(input_files_expanded) > 1
for file_path_encoded in tqdm(input_files_expanded, disable=not use_progress_bar):
- # Check runtime requirements if needed. All benchmarking will be halted
- # if requirements are not met. This happens regardless of whether
- # process-isolation is used or not.
- runtime_info = SUPPORTED_RUNTIMES[selected_runtime]
- if "requirement_check" in runtime_info and Action.BENCHMARK in actions:
- runtime_info["requirement_check"]()
printing.log_info(f"Running turnkey on {file_path_encoded}")
- file_path_absolute, targets, encoded_input = decode_input_arg(file_path_encoded)
+ file_path_absolute, targets, encoded_input = fs.decode_input_arg(
+ file_path_encoded
+ )
+
+ file_labels = fs.read_labels(file_path_absolute)
+
+ build_name = fs.get_build_name(
+ fs.clean_file_name(file_path_absolute),
+ file_labels,
+ targets[0] if len(targets) > 0 else None,
+ )
# Skip a file if the required_labels are not a subset of the script_labels.
if labels:
- # Labels argument is not supported for ONNX files
- if file_path_absolute.endswith(".onnx"):
+ # Labels argument is not supported for ONNX files or cached builds
+ if file_path_absolute.endswith(".onnx") or file_path_absolute.endswith(
+ ".yaml"
+ ):
raise ValueError(
"The labels argument is not supported for .onnx files, got",
file_path_absolute,
)
required_labels = labels_library.to_dict(labels)
- script_labels = labels_library.load_from_file(encoded_input)
- if not labels_library.is_subset(required_labels, script_labels):
+ if not labels_library.is_subset(required_labels, file_labels):
continue
if use_slurm or process_isolation:
- # Decode args into spawn.Target
- if use_slurm and process_isolation:
- raise ValueError(
- "use_slurm and process_isolation are mutually exclusive, but both are True"
- )
- elif use_slurm:
- process_type = spawn.Target.SLURM
- elif process_isolation:
- process_type = spawn.Target.LOCAL_PROCESS
- else:
- raise ValueError(
- "This code path requires use_slurm or use_process to be True, "
- "but both are False"
- )
-
spawn.run_turnkey(
- op="benchmark",
- target=process_type,
+ build_name=build_name,
+ sequence=sequence,
file_name=encoded_input,
- **benchmarking_args,
+ use_slurm=use_slurm,
+ process_isolation=process_isolation,
+ timeout=timeout_to_use,
+ lean_cache=lean_cache,
+ cache_dir=cache_dir,
)
else:
- # Instantiate an object that holds all of the arguments
- # for analysis, build, and benchmarking
- tracer_args = TracerArgs(
- models_found=models_found,
- targets=targets,
- input=file_path_absolute,
- **analysis_args,
+ # Forward the selected input to the first tool in the sequence
+ first_tool_args = next(iter(sequence.tools.values()))
+ first_tool_args.append("--input")
+ first_tool_args.append(file_path_encoded)
+
+ # Collection of statistics that the sequence instance should save
+ # to the stats file
+ stats_to_save = {}
+
+ # Save lables info
+ if fs.Keys.AUTHOR in file_labels:
+ stats_to_save[fs.Keys.AUTHOR] = file_labels[fs.Keys.AUTHOR][0]
+ if fs.Keys.TASK in file_labels:
+ stats_to_save[fs.Keys.TASK] = file_labels[fs.Keys.TASK][0]
+
+ # Save all of the lables in one place
+ stats_to_save[fs.Keys.LABELS] = file_labels
+
+ # If the input script is a built-in TurnkeyML model, make a note of
+ # which one
+ if os.path.abspath(fs.MODELS_DIR) in os.path.abspath(file_path_absolute):
+ try:
+ # If this turnkey installation is in a git repo, use the
+ # specific git hash
+ git_repo = git.Repo(search_parent_directories=True)
+ git_hash = git_repo.head.object.hexsha
+ except git.exc.InvalidGitRepositoryError:
+ # If we aren't in a git repo (e.g., PyPI package), point the user back to main
+ git_hash = "main"
+
+ relative_path = file_path_absolute.replace(
+ fs.MODELS_DIR,
+ f"https://github.com/onnx/turnkeyml/tree/{git_hash}/models",
+ ).replace("\\", "/")
+ stats_to_save[fs.Keys.MODEL_SCRIPT] = relative_path
+
+ state = State(
+ cache_dir=cache_dir,
+ build_name=build_name,
+ sequence_info=sequence.info,
+ )
+ sequence.launch(
+ state,
+ lean_cache=lean_cache,
+ stats_to_save=stats_to_save,
)
-
- if file_path_absolute.endswith(".py"):
- # Run analysis, build, and benchmarking on every model
- # in the python script
- models_found = evaluate_script(tracer_args)
- elif file_path_absolute.endswith(".onnx"):
- # Skip analysis and go straight to dealing with the model
- # We need to manufacture ModelInfo and UniqueInvocatioInfo instances to do this,
- # since we didn't get them from analysis.
-
- # Gather information about the ONNX model
- onnx_name = pathlib.Path(file_path_absolute).stem
- onnx_hash = get_model_hash(
- file_path_absolute, build.ModelType.ONNX_FILE
- )
- onnx_inputs = onnx_helpers.dummy_inputs(file_path_absolute)
- input_shapes = {key: value.shape for key, value in onnx_inputs.items()}
-
- # Create the UniqueInvocationInfo
- # - execute=1 is required or else the ONNX model will be
- # skipped in later stages of evaluation
- # - is_target=True is required or else traceback wont be printed for
- # in the event of any errors
- # - Most other values can be left as default
- invocation_info = UniqueInvocationInfo(
- name=onnx_name,
- script_name=onnx_name,
- file=file_path_absolute,
- build_model=not build_only,
- model_type=build.ModelType.ONNX_FILE,
- executed=1,
- input_shapes=input_shapes,
- hash=onnx_hash,
- is_target=True,
- )
-
- # Create the ModelInfo
- model_info = ModelInfo(
- model=file_path_absolute,
- name=onnx_name,
- script_name=onnx_name,
- file=file_path_absolute,
- build_model=not build_only,
- model_type=build.ModelType.ONNX_FILE,
- unique_invocations={onnx_hash: invocation_info},
- hash=onnx_hash,
- )
-
- # Begin evaluating the ONNX model
- tracer_args.script_name = onnx_name
- tracer_args.models_found[tracer_args.script_name] = model_info
- explore_invocation(
- model_inputs=onnx_inputs,
- model_info=model_info,
- invocation_info=invocation_info,
- tracer_args=tracer_args,
- )
- models_found = tracer_args.models_found
# Wait until all the Slurm jobs are done
if use_slurm:
@@ -430,5 +248,3 @@ def benchmark_files(
f"jobs left in queue: {spawn.slurm_jobs_in_queue()}"
)
time.sleep(5)
-
- printing.log_success("The 'benchmark' command is complete.")
diff --git a/src/turnkeyml/run/basert.py b/src/turnkeyml/run/basert.py
index b797f99f..938647fb 100644
--- a/src/turnkeyml/run/basert.py
+++ b/src/turnkeyml/run/basert.py
@@ -10,7 +10,8 @@
from turnkeyml.common.performance import MeasuredPerformance, Device
import turnkeyml.common.build as build
import turnkeyml.common.exceptions as exp
-from turnkeyml.common.filesystem import Stats, rebase_cache_dir
+import turnkeyml.common.filesystem as fs
+from turnkeyml.state import load_state
def _check_docker_install():
@@ -41,7 +42,7 @@ def __init__(
self,
cache_dir: str,
build_name: str,
- stats: Stats,
+ stats: fs.Stats,
device_type: Union[str, Device],
runtime: str,
runtimes_supported: List[str],
@@ -177,12 +178,21 @@ def benchmark(self) -> MeasuredPerformance:
os.remove(self.local_outputs_file)
# Transfer input artifacts
- state = build.load_state(self.cache_dir, self.build_name)
+ state = load_state(self.cache_dir, self.build_name)
+
+ # Make sure state.results is an ONNX file
+ if not (isinstance(state.results, str) and state.results.endswith(".onnx")):
+ raise exp.ToolError(
+ "This benchmarking runtime requires the preceeding "
+ "tools to produce an ONNX file, however they did not. "
+ "Please either select different tools, or select a different "
+ "benchmarking runtime that does not require an ONNX result."
+ )
# Just in case the model file was generated on a different machine:
# strip the state's cache dir, then prepend the current cache dir
- model_file = rebase_cache_dir(
- state.results[0], state.config.build_name, self.cache_dir
+ model_file = fs.rebase_cache_dir(
+ state.results, state.build_name, self.cache_dir
)
if not os.path.exists(model_file):
diff --git a/src/turnkeyml/run/benchmark_build.py b/src/turnkeyml/run/benchmark_build.py
deleted file mode 100644
index 27b058b0..00000000
--- a/src/turnkeyml/run/benchmark_build.py
+++ /dev/null
@@ -1,323 +0,0 @@
-from typing import Dict, Optional
-import multiprocessing
-import traceback
-import psutil
-import turnkeyml.common.build as build
-import turnkeyml.common.exceptions as exp
-import turnkeyml.common.filesystem as fs
-import turnkeyml.common.printing as printing
-from turnkeyml.analyze.script import set_status_on_exception
-from turnkeyml.run.devices import SUPPORTED_RUNTIMES, apply_default_runtime
-import turnkeyml.cli.parser_helpers as parser_helpers
-
-# The licensing for tqdm is confusing. Pending a legal scan,
-# the following code provides tqdm to users who have installed
-# it already, while being transparent to users who do not
-# have tqdm installed.
-try:
- from tqdm import tqdm
-except ImportError:
-
- def tqdm(iterable, **kwargs): # pylint: disable=unused-argument
- return iterable
-
-
-class SkippedBenchmark(Exception):
- """
- Indicates that a benchmark was skipped
- """
-
-
-class Process(multiprocessing.Process):
- """
- Standardized way to make it possible to catch exceptions from a
- multiprocessing.Process.
- """
-
- def __init__(self, *args, **kwargs):
- multiprocessing.Process.__init__(self, *args, **kwargs)
- self._pconn, self._cconn = multiprocessing.Pipe()
- self._exception = None
-
- def run(self):
- try:
- multiprocessing.Process.run(self)
- self._cconn.send(None)
- except Exception as e: # pylint: disable=broad-except
- tb = traceback.format_exc()
- self._cconn.send((e, tb))
-
- @property
- def exception(self):
- if self._pconn.poll():
- self._exception = self._pconn.recv()
- return self._exception
-
-
-def benchmark_build(
- first: bool,
- cache_dir: str,
- build_name: str,
- runtime: str,
- iterations: int,
- rt_args: Optional[Dict] = None,
-):
- """
- Benchmark the build artifact from a successful turnkey build.
-
- For example, `turnkey linear.py --build-only` would produce a build whose
- resulting artifact is an optimized ONNX file. This function would benchmark
- that optimized ONNX file.
-
- How it works:
- 1. Attempt to load build state from the cache_dir/build_name specified
- 2. Pass the build state directly into an instance of BaseRT and
- run the benchmark method
- 3. Save stats to the same evaluation entry from the original build
-
- Args:
- first: whether this is the first benchmark in the job
- cache_dir: same as turnkey
- build_name: same as turnkey
- runtime: same as turnkey
- iterations: same as turnkey
- rt_args: same as turnkey
- """
-
- state = build.load_state(cache_dir, build_name)
-
- if state.build_status != build.FunctionStatus.SUCCESSFUL:
- raise SkippedBenchmark(
- "Only successful builds can be benchmarked with this "
- f"function, however selected build at {build_name} "
- f"has state: {state.build_status}"
- )
-
- selected_runtime = apply_default_runtime(state.config.device, runtime)
-
- if rt_args is None:
- rt_args_to_use = {}
- else:
- rt_args_to_use = rt_args
-
- try:
- runtime_info = SUPPORTED_RUNTIMES[selected_runtime]
- except KeyError as e:
- # User should never get this far without hitting an actionable error message,
- # but let's raise an exception just in case.
- raise SkippedBenchmark(
- f"Selected runtime is not supported: {selected_runtime}"
- ) from e
-
- # Check whether the device and runtime are ready for use prior to
- # running the first benchmark in the job
- # NOTE: we perform this check here, instead of in the outer loop,
- # because this is where we know `runtime_info`
- if first and "requirement_check" in runtime_info:
- runtime_info["requirement_check"]()
-
- # Load the stats file using the same evaluation ID used in the original build.
- # This allows us to augment those stats with more data instead of starting a new
- # evaluation entry.
- stats = fs.Stats(cache_dir, build_name, state.evaluation_id)
-
- stats.save_model_eval_stat(
- fs.Keys.BENCHMARK_STATUS, build.FunctionStatus.INCOMPLETE.value
- )
-
- benchmark_logfile_path = ""
- try:
- # Instantiate BaseRT for the selected runtime
- runtime_handle = runtime_info["RuntimeClass"](
- cache_dir=cache_dir,
- build_name=build_name,
- stats=stats,
- iterations=iterations,
- model=state.results[0],
- # The `inputs` argument to BaseRT is only meant for
- # benchmarking runtimes that have to keep their inputs
- # in memory (e.g., `torch-eager`). We provide None here
- # because this function only works with runtimes that
- # keep their model and inputs on disk.
- inputs=None,
- device_type=state.config.device,
- runtime=selected_runtime,
- **rt_args_to_use,
- )
- benchmark_logfile_path = runtime_handle.logfile_path
- perf = runtime_handle.benchmark()
-
- for key, value in vars(perf).items():
- stats.save_model_eval_stat(
- key=key,
- value=value,
- )
-
- # Inform the user of the result
- perf.print()
-
- stats.save_model_eval_stat(
- fs.Keys.BENCHMARK_STATUS, build.FunctionStatus.SUCCESSFUL.value
- )
- except Exception as e:
- set_status_on_exception(
- runtime_info["build_required"], state, stats, benchmark_logfile_path
- )
-
- raise e
-
- # Check whether this benchmark left the device and runtime in a good state
- if "requirement_check" in runtime_info:
- runtime_info["requirement_check"]()
-
-
-def benchmark_cache_cli(args):
- """
- Wrapper function for benchmark_cache() that passes in the CLI arguments
- """
-
- rt_args = parser_helpers.decode_args(args.rt_args)
-
- benchmark_cache(
- cache_dir=args.cache_dir,
- build_name=args.build_name,
- benchmark_all=args.benchmark_all,
- skip_policy=args.skip_policy,
- runtime=args.runtime,
- iterations=args.iterations,
- timeout=args.timeout,
- rt_args=rt_args,
- )
-
-
-def benchmark_cache(
- cache_dir: str,
- build_name: str,
- benchmark_all: bool,
- skip_policy: str,
- runtime: str,
- iterations: int = 100,
- timeout: Optional[int] = None,
- rt_args: Optional[Dict] = None,
-):
- """
- Benchmark one or more builds in a cache using the benchmark_build()
- function.
-
- These benchmarks always run in process isolation mode because the purpose
- of this function is to quickly iterate over many builds.
- """
-
- printing.log_warning(
- "This is an experimental feature. Our plan is to deprecate it "
- "in favor of a new command, `turnkey benchmark cache/*`, ASAP. "
- "Please see https://github.com/onnx/turnkeyml/issues/115 "
- "for more info.\n\n"
- )
-
- if benchmark_all:
- builds = fs.get_available_builds(cache_dir)
- else:
- builds = [build_name]
-
- # Keep track of whether this is the first build we are benchmarking
- first = True
-
- # Iterate over all of the selected builds and benchmark them
- for build_name in tqdm(builds):
- if not fs.is_build_dir(cache_dir, build_name):
- raise exp.CacheError(
- f"No build found with name: {build_name}. "
- "Try running `turnkey cache list` to see the builds in your build cache."
- )
-
- state = build.load_state(cache_dir, build_name)
- stats = fs.Stats(cache_dir, build_name, state.evaluation_id)
-
- # Apply the skip policy by skipping over this iteration of the
- # loop if the evaluation's pre-existing benchmark status doesn't
- # meet certain criteria
- eval_stats = stats.evaluation_stats
- if (
- fs.Keys.BENCHMARK_STATUS in eval_stats
- and eval_stats[fs.Keys.BENCHMARK_STATUS]
- != build.FunctionStatus.NOT_STARTED.value
- ):
- if skip_policy == "attempted":
- printing.log_warning(
- f"Skipping because it was previously attempted: {build_name}"
- )
- continue
- elif (
- skip_policy == "successful"
- and eval_stats[fs.Keys.BENCHMARK_STATUS]
- == build.FunctionStatus.SUCCESSFUL.value
- ):
- printing.log_warning(
- f"Skipping because it was already successfully benchmarked: {build_name}"
- )
- continue
- elif (
- skip_policy == "failed"
- and eval_stats[fs.Keys.BENCHMARK_STATUS]
- != build.FunctionStatus.SUCCESSFUL.value
- ):
- printing.log_warning(
- f"Skipping because it was previously attempted and failed: {build_name}"
- )
- continue
- elif skip_policy == "none":
- # Skip policy of "none" means we should never skip over a build
- pass
-
- printing.log_info(f"Attempting to benchmark: {build_name}")
-
- p = Process(
- target=benchmark_build,
- args=[first, cache_dir, build_name, runtime, iterations, rt_args],
- )
- p.start()
- p.join(timeout=timeout)
-
- if p.is_alive():
- # Handle the timeout, which is needed if the process is still alive after
- # waiting `timeout` seconds
- parent = psutil.Process(p.pid)
- for child in parent.children(recursive=True):
- child.kill()
- parent.kill()
- stats.save_model_eval_stat(
- fs.Keys.BENCHMARK_STATUS, build.FunctionStatus.TIMEOUT.value
- )
-
- printing.log_warning(
- f"Benchmarking {build_name} canceled because it exceeded the {timeout} "
- "seconds timeout"
- )
- elif p.exception:
- # Handle any exception raised by the child process. In most cases, we should
- # move on to the next benchmark. However, if the exception was a
- # HardwareError that means the underlying runtime or device
- # is not able to conduct any more benchmarking. In this case the program
- # should exit and the user should follow the suggestion in the exception
- # message (e.g., restart their computer).
-
- if isinstance(p.exception[0], SkippedBenchmark):
- stats.save_model_eval_stat(
- fs.Keys.BENCHMARK_STATUS, build.FunctionStatus.NOT_STARTED.value
- )
- else:
- stats.save_model_eval_stat(
- fs.Keys.BENCHMARK_STATUS, build.FunctionStatus.ERROR.value
- )
-
- if isinstance(p.exception[0], exp.HardwareError):
- stats.save_model_eval_stat(fs.Keys.ERROR_LOG, p.exception[1])
- raise p.exception[0]
- else:
- printing.log_warning("Benchmarking failed with exception:")
- print(p.exception[1])
- else:
- printing.log_success(f"Done benchmarking: {build_name}")
-
- first = False
diff --git a/src/turnkeyml/run/benchmark_model.py b/src/turnkeyml/run/benchmark_model.py
new file mode 100644
index 00000000..f231d23e
--- /dev/null
+++ b/src/turnkeyml/run/benchmark_model.py
@@ -0,0 +1,179 @@
+import argparse
+from typing import Optional
+import turnkeyml.common.exceptions as exp
+import turnkeyml.common.filesystem as fs
+from turnkeyml.tools import Tool
+from turnkeyml.state import State
+from turnkeyml.run.devices import (
+ SUPPORTED_RUNTIMES,
+ SUPPORTED_DEVICES,
+ apply_default_runtime,
+)
+import turnkeyml.cli.parser_helpers as parser_helpers
+from turnkeyml.common.performance import Device, parse_device
+
+default_iterations = 100
+benchmark_default_device = "x86"
+
+
+class Benchmark(Tool):
+ """
+ Tool that benchmarks a model based on the selected device and runtime.
+
+ Expected inputs:
+ - state.results is a model to be benchmarked
+
+ Outputs: None
+ """
+
+ unique_name = "benchmark"
+
+ def __init__(self):
+ super().__init__(monitor_message="Benchmarking model")
+
+ self.status_stats = ["throughput", "mean_latency"]
+
+ @staticmethod
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
+ parser = __class__.helpful_parser(
+ short_description="Benchmark a model",
+ add_help=add_help,
+ )
+
+ parser.add_argument(
+ "--device",
+ choices=SUPPORTED_DEVICES,
+ dest="device",
+ help="Type of hardware device to be used for the benchmark "
+ f'(defaults to "{benchmark_default_device}")',
+ required=False,
+ )
+
+ parser.add_argument(
+ "--runtime",
+ choices=SUPPORTED_RUNTIMES.keys(),
+ dest="runtime",
+ help="Software runtime that will be used to collect the benchmark. "
+ "Must be compatible with the selected device. "
+ "Automatically selects a sequence if `--sequence` is not used. "
+ "If this argument is not set, the default runtime of the selected device will be used.",
+ required=False,
+ default=None,
+ )
+
+ parser.add_argument(
+ "--iterations",
+ dest="iterations",
+ type=int,
+ default=default_iterations,
+ help="Number of execution iterations of the model to capture\
+ the benchmarking performance (e.g., mean latency)",
+ )
+
+ parser.add_argument(
+ "--rt-args",
+ dest="rt_args",
+ type=str,
+ nargs="*",
+ help="Optional arguments provided to the runtime being used",
+ )
+
+ return parser
+
+ def parse(self, state: State, args, known_only=True) -> argparse.Namespace:
+ parsed_args = super().parse(state, args, known_only)
+
+ # Inherit the device from the tool of a prior tool, if available
+ parse_device(
+ state, parsed_args, benchmark_default_device, self.__class__.__name__
+ )
+
+ parsed_args.rt_args = parser_helpers.decode_args(parsed_args.rt_args)
+
+ return parsed_args
+
+ def run(
+ self,
+ state: State,
+ device: str = benchmark_default_device,
+ runtime: str = None,
+ iterations: int = default_iterations,
+ rt_args: Optional[str] = None,
+ ):
+
+ selected_runtime = apply_default_runtime(device, runtime)
+
+ # Get the default part and config by providing the Device class with
+ # the supported devices by the runtime
+ runtime_supported_devices = SUPPORTED_RUNTIMES[selected_runtime][
+ "supported_devices"
+ ]
+ specific_device = str(Device(device, runtime_supported_devices))
+
+ if rt_args is None:
+ rt_args_to_use = {}
+ else:
+ rt_args_to_use = rt_args
+
+ try:
+ runtime_info = SUPPORTED_RUNTIMES[selected_runtime]
+ except KeyError as e:
+ # User should never get this far without hitting an actionable error message,
+ # but let's raise an exception just in case.
+ raise exp.ToolError(
+ f"Selected runtime is not supported: {selected_runtime}"
+ ) from e
+
+ # Save the device name that will be used for the benchmark
+ state.save_stat(fs.Keys.DEVICE, runtime_info["RuntimeClass"].device_name())
+
+ # Save specific information into its own key for easier access
+ state.save_stat(
+ fs.Keys.DEVICE_TYPE,
+ specific_device,
+ )
+ state.save_stat(
+ fs.Keys.RUNTIME,
+ runtime,
+ )
+
+ state.save_stat(
+ fs.Keys.ITERATIONS,
+ iterations,
+ )
+
+ # Check whether the device and runtime are ready for use prior to
+ # running the benchmark
+ if "requirement_check" in runtime_info:
+ runtime_info["requirement_check"]()
+
+ # Each runtimes can contribute its own status stats
+ if runtime_info.get("status_stats"):
+ self.status_stats += runtime_info.get("status_stats")
+
+ # Instantiate BaseRT for the selected runtime
+ runtime_handle = runtime_info["RuntimeClass"](
+ cache_dir=state.cache_dir,
+ build_name=state.build_name,
+ stats=fs.Stats(state.cache_dir, state.build_name),
+ iterations=iterations,
+ model=state.results,
+ inputs=vars(state).get(fs.Keys.INPUTS),
+ device_type=specific_device,
+ runtime=selected_runtime,
+ **rt_args_to_use,
+ )
+ perf = runtime_handle.benchmark()
+
+ for key, value in vars(perf).items():
+ state.save_stat(
+ key=key,
+ value=value,
+ )
+
+ # Inform the user of the result
+ perf.print()
+
+ state.perf = perf
+
+ return state
diff --git a/src/turnkeyml/run/devices.py b/src/turnkeyml/run/devices.py
index c48d81b5..3ec62b14 100644
--- a/src/turnkeyml/run/devices.py
+++ b/src/turnkeyml/run/devices.py
@@ -4,8 +4,7 @@
import turnkeyml.run.tensorrt as tensorrt
import turnkeyml.run.torchrt as torchrt
import turnkeyml.common.plugins as plugins
-from turnkeyml.build.stage import Sequence
-import turnkeyml.build.sequences as sequences
+from turnkeyml.sequence import Sequence
import turnkeyml.common.exceptions as exp
@@ -72,7 +71,7 @@ def supported_devices_list(data: Dict, parent_key: str = "") -> List:
def apply_default_runtime(device: str, runtime: Optional[str] = None):
if runtime is None:
- return DEVICE_RUNTIME_MAP[device][DEFAULT_RUNTIME]
+ return DEVICE_RUNTIME_MAP[str(device)][DEFAULT_RUNTIME]
else:
return runtime
@@ -84,54 +83,27 @@ def _check_suggestion(value: str):
)
-def select_runtime_and_sequence(
- device: str, runtime: Optional[str], sequence: Optional[Sequence]
-) -> Tuple[str, str, Sequence]:
- selected_runtime = apply_default_runtime(device, runtime)
+def select_runtime(device: str, runtime: Optional[str]) -> Tuple[str, str, Sequence]:
+ # Convert to str in case its an instance of Device
+ device_str = str(device)
+
+ selected_runtime = apply_default_runtime(device_str, runtime)
# Validate device and runtime selections
- if device not in SUPPORTED_DEVICES:
+ if device_str not in SUPPORTED_DEVICES:
raise exp.ArgError(
- f"Device argument '{device}' is not one of the available "
+ f"Device argument '{device_str}' is not one of the available "
f"supported devices {SUPPORTED_DEVICES}\n"
- f"{_check_suggestion(device)}"
+ f"{_check_suggestion(device_str)}"
)
- if selected_runtime not in DEVICE_RUNTIME_MAP[device]:
+ if selected_runtime not in DEVICE_RUNTIME_MAP[device_str]:
raise exp.ArgError(
f"Runtime argument '{selected_runtime}' is not one of the available "
- f"runtimes supported for device '{device}': {DEVICE_RUNTIME_MAP[device]}\n"
+ f"runtimes supported for device '{device_str}': {DEVICE_RUNTIME_MAP[device_str]}\n"
f"{_check_suggestion(selected_runtime)}"
)
# Get the plugin module for the selected runtime
runtime_info = SUPPORTED_RUNTIMES[selected_runtime]
- # Perform a build, if necessary
- if runtime_info["build_required"]:
- # Get the build sequence that will be used for the model
- if sequence is None:
- # Automatically choose a Sequence based on what the runtime expects
- sequence_selected = runtime_info["default_sequence"]
- else:
- # User-specified Sequence
- if isinstance(sequence, str):
- # Sequence is defined by a plugin
- if sequence in sequences.SUPPORTED_SEQUENCES.keys():
- sequence_selected = sequences.SUPPORTED_SEQUENCES[sequence]
- else:
- raise ValueError(
- f"Sequence argument {sequence} is not one of the "
- "available sequences installed: "
- f"{sequences.SUPPORTED_SEQUENCES.keys()} \n"
- f"{_check_suggestion(sequence)}"
- )
-
- elif isinstance(sequence, Sequence):
- # Sequence is a user-defined instance of Sequence
- sequence_selected = sequence
-
- else:
- # Sequence is only needed for builds
- sequence_selected = None
-
- return selected_runtime, runtime_info, sequence_selected
+ return selected_runtime, runtime_info
diff --git a/src/turnkeyml/run/onnxrt/__init__.py b/src/turnkeyml/run/onnxrt/__init__.py
index 6ecc93de..d163fbd2 100644
--- a/src/turnkeyml/run/onnxrt/__init__.py
+++ b/src/turnkeyml/run/onnxrt/__init__.py
@@ -1,13 +1,10 @@
-import turnkeyml.build.sequences as sequences
from .runtime import OnnxRT
implements = {
"runtimes": {
"ort": {
- "build_required": True,
"RuntimeClass": OnnxRT,
"supported_devices": {"x86"},
- "default_sequence": sequences.optimize_fp32,
}
}
}
diff --git a/src/turnkeyml/run/tensorrt/__init__.py b/src/turnkeyml/run/tensorrt/__init__.py
index 9a26d9e7..bbdb4477 100644
--- a/src/turnkeyml/run/tensorrt/__init__.py
+++ b/src/turnkeyml/run/tensorrt/__init__.py
@@ -1,14 +1,11 @@
-import turnkeyml.build.sequences as sequences
from .runtime import TensorRT
implements = {
"runtimes": {
"trt": {
- "build_required": True,
"RuntimeClass": TensorRT,
"supported_devices": {"nvidia"},
- "default_sequence": sequences.optimize_fp16,
}
}
}
diff --git a/src/turnkeyml/run/tensorrt/runtime.py b/src/turnkeyml/run/tensorrt/runtime.py
index e56896af..e99c18d0 100644
--- a/src/turnkeyml/run/tensorrt/runtime.py
+++ b/src/turnkeyml/run/tensorrt/runtime.py
@@ -97,7 +97,7 @@ def _execute(
# Add the GPU driver version to the stats file before execution
gpu_driver_version = _get_nvidia_driver_version()
- self.stats.save_model_eval_stat("gpu_driver_version", gpu_driver_version)
+ self.stats.save_stat("gpu_driver_version", gpu_driver_version)
power_thread.start()
run(
diff --git a/src/turnkeyml/run/torchrt/__init__.py b/src/turnkeyml/run/torchrt/__init__.py
index dadc8d89..c6e746dd 100644
--- a/src/turnkeyml/run/torchrt/__init__.py
+++ b/src/turnkeyml/run/torchrt/__init__.py
@@ -3,12 +3,10 @@
implements = {
"runtimes": {
"torch-eager": {
- "build_required": False,
"RuntimeClass": TorchRT,
"supported_devices": {"x86"},
},
"torch-compiled": {
- "build_required": False,
"RuntimeClass": TorchRT,
"supported_devices": {"x86"},
},
diff --git a/src/turnkeyml/run/torchrt/runtime.py b/src/turnkeyml/run/torchrt/runtime.py
index d604e670..d7e3c87d 100644
--- a/src/turnkeyml/run/torchrt/runtime.py
+++ b/src/turnkeyml/run/torchrt/runtime.py
@@ -9,7 +9,6 @@
from turnkeyml.run.basert import BaseRT
from turnkeyml.common.performance import MeasuredPerformance
from turnkeyml.run.onnxrt.execute import get_cpu_specs
-import turnkeyml.build.ignition as ignition
import turnkeyml.common.build as build
import turnkeyml.common.exceptions as exp
import turnkeyml.common.filesystem as fs
@@ -93,8 +92,7 @@ def _setup(self) -> None:
"""
# Ensure we have the correct model type
- model_type = ignition.identify_model_type(self.model)
- if model_type != build.ModelType.PYTORCH:
+ if not isinstance(self.model, (torch.nn.Module, torch.jit.ScriptModule)):
raise exp.IntakeError(
f"Only Pytorch models are valid when runtime is {self.runtime}"
)
@@ -106,7 +104,7 @@ def _setup(self) -> None:
end_time = time.perf_counter()
total_time = end_time - start_time
- self.stats.save_model_eval_stat("torch_compilation_seconds", total_time)
+ self.stats.save_stat("torch_compilation_seconds", total_time)
def _calculate_performance(
self, per_iteration_latency: List[float]
@@ -180,9 +178,7 @@ def _benchmark_inner(self) -> MeasuredPerformance:
# Record the number of iterations actually used for the benchmark,
# which will be less than the `iterations` argument if the time
# limit was reached
- self.stats.save_model_eval_stat(
- fs.Keys.ITERATIONS, len(per_iteration_latency)
- )
+ self.stats.save_stat(fs.Keys.ITERATIONS, len(per_iteration_latency))
return self._calculate_performance(per_iteration_latency)
diff --git a/src/turnkeyml/sequence/__init__.py b/src/turnkeyml/sequence/__init__.py
new file mode 100644
index 00000000..52f02f11
--- /dev/null
+++ b/src/turnkeyml/sequence/__init__.py
@@ -0,0 +1 @@
+from .sequence import Sequence
diff --git a/src/turnkeyml/sequence/sequence.py b/src/turnkeyml/sequence/sequence.py
new file mode 100644
index 00000000..657fd883
--- /dev/null
+++ b/src/turnkeyml/sequence/sequence.py
@@ -0,0 +1,282 @@
+import sys
+import time
+import os
+import copy
+from datetime import datetime
+from typing import List, Dict, Optional
+import turnkeyml.common.printing as printing
+import turnkeyml.common.exceptions as exp
+import turnkeyml.common.build as build
+import turnkeyml.common.filesystem as fs
+import turnkeyml.common.status as status
+from turnkeyml.tools.tool import Tool
+from turnkeyml.state import State
+
+
+def _rewind_stdout(lines: int = 1):
+ """
+ Helper function for the command line monitor. Moves the cursor up a
+ certain number of lines in the terminal, corresponding to the
+ status line for a Tool, so that we can update the status of
+ that Tool.
+ """
+ rewind_stdout_one_line = "\033[1A"
+ rewind_multiple_lines = rewind_stdout_one_line * lines
+ print(rewind_multiple_lines, end="")
+ sys.stdout.flush()
+
+
+class Sequence:
+ """
+ Helper class to launch and manage build tools.
+ """
+
+ def __init__(
+ self,
+ tools: Dict[Tool, List[str]],
+ ):
+
+ self.tools = tools
+
+ # Make sure all the tool names are unique
+ self.tool_names = [tool.__class__.unique_name for tool in self.tools.keys()]
+
+ if len(self.tool_names) != len(set(self.tool_names)):
+ msg = f"""
+ All tools in a Sequence must have unique unique_names, however Sequence
+ received duplicates in the list of names: {self.tool_names}
+ """
+ raise ValueError(msg)
+
+ def show_monitor(self, state: State, verbosity: bool):
+ """
+ Displays the monitor on the terminal. The purpose of the monitor
+ is to show the status of each tool (success, failure, not started yet,
+ or in-progress).
+ """
+
+ if verbosity:
+ print()
+
+ printing.logn(
+ f'Building "{state.build_name}"',
+ c=printing.Colors.BOLD,
+ )
+
+ for tool in self.tools:
+ tool.status_line(successful=None, verbosity=True)
+
+ _rewind_stdout(len(self.tools))
+
+ def _advance_cursor(self, current_tool_name: str):
+ # Advance the cursor below the monitor so
+ # we can print a message
+ tool_depth_in_sequence = len(self.tool_names) - self.tool_names.index(
+ current_tool_name
+ )
+ stdout_lines_to_advance = tool_depth_in_sequence - 2
+ cursor_down = "\n" * stdout_lines_to_advance
+
+ print(cursor_down)
+
+ def launch(
+ self,
+ state: State,
+ lean_cache: bool = False,
+ stats_to_save: Optional[Dict] = None,
+ ) -> State:
+ """
+ Executes the sequence of tools.
+ """
+
+ # Create a build directory in the cache
+ fs.make_build_dir(state.cache_dir, state.build_name)
+
+ self.show_monitor(state, state.monitor)
+
+ if state.build_status == build.FunctionStatus.SUCCESSFUL:
+ msg = """
+ build_model() is running a build on a model that already built successfully, which
+ should not happen because the build should have loaded from cache or rebuilt from scratch.
+ If you are using custom tools and Sequences then you have some debugging to do. Otherwise,
+ please file an issue at https://github.com/onnx/turnkeyml/issues
+ """
+ raise exp.Error(msg)
+
+ # Keep a copy of any stats we loaded from disk, in case we need to
+ # restore them later
+ saved_stats = copy.deepcopy(fs.Stats(state.cache_dir, state.build_name).stats)
+
+ # Indicate that the build is running. If the build fails for any reason,
+ # we will try to catch the exception and note it in the stats.
+ # If a concluded build still has a status of "running", this means
+ # there was an uncaught exception.
+ state.save_stat(fs.Keys.BUILD_STATUS, build.FunctionStatus.INCOMPLETE)
+
+ # Save a timestamp so that we know the order of builds within a cache
+ state.save_stat(
+ fs.Keys.TIMESTAMP,
+ datetime.now(),
+ )
+
+ # Save the system information used for this build
+ system_info = build.get_system_info()
+ state.save_stat(
+ fs.Keys.SYSTEM_INFO,
+ system_info,
+ )
+
+ # Collect telemetry for the build
+ state.save_stat(
+ fs.Keys.SELECTED_SEQUENCE_OF_TOOLS,
+ self.tool_names,
+ )
+
+ # At the beginning of a sequence no tool has started
+ for tool in self.tools:
+ state.save_stat(tool.status_key, build.FunctionStatus.NOT_STARTED)
+ state.save_stat(tool.duration_key, "-")
+
+ # Save any additional stats passed in via arguments
+ if stats_to_save:
+ for stat_key, stat_value in stats_to_save.items():
+ state.save_stat(stat_key, stat_value)
+
+ # Run the build
+ saved_exception = None
+ for tool, argv in self.tools.items():
+ start_time = time.time()
+
+ try:
+
+ # Set status as incomplete, since tool just started
+ state.save_stat(tool.status_key, build.FunctionStatus.INCOMPLETE)
+
+ # Collect telemetry about the tool
+ state.current_build_tool = tool.unique_name
+
+ # Run the tool
+ state = tool.parse_and_run(state, argv)
+
+ # Save the state so that it can be assessed for a cache hit
+ state.save()
+
+ except exp.SkipBuild as e:
+ # SkipBuild is a special exception, which means that a build
+ # was loaded from disk, then we realized we want to skip it.
+ # In order to preserve the original stats and state of the build,
+ # we need to restore the stats file to what it was at the beginning
+ # of this function call. We also need to avoid calling state.save().
+
+ # Restore the prior stats
+ fs.save_yaml(
+ saved_stats, fs.Stats(state.cache_dir, state.build_name).file
+ )
+
+ # Advance the cursor below the monitor so
+ # we can print a message
+ self._advance_cursor(tool.unique_name)
+ printing.log_warning(str(e))
+ return
+
+ # Broad exception is desirable as we want to capture
+ # all exceptions (including those we can't anticipate)
+ except Exception as e: # pylint: disable=broad-except
+
+ if os.environ.get("TURNKEY_DEBUG", "").lower() == "true":
+ # It may be useful to raise the exception here, since
+ # if any of the subsequent lines of code raise another
+ # exception it will be very hard to root cause e.
+ raise e
+
+ # Update tool and build status
+ state.save_stat(tool.status_key, build.FunctionStatus.ERROR)
+ state.save_stat(fs.Keys.BUILD_STATUS, build.FunctionStatus.ERROR)
+
+ # Save the log file for the failed tool to stats for easy reference
+ stats = fs.Stats(state.cache_dir, state.build_name)
+ stats.save_eval_error_log(tool.logfile_path)
+
+ # Advance the cursor below the monitor so
+ # we can print a message
+ self._advance_cursor(tool.unique_name)
+
+ if vars(state).get("invocation_info"):
+ state.invocation_info.status_message = f"Error: {e}"
+ state.invocation_info.status_message_color = printing.Colors.WARNING
+ else:
+ printing.log_error(e)
+
+ # We will raise this exception after we capture as many statistics
+ # about the build as possible
+ saved_exception = e
+
+ # Don't run any more tools
+ break
+
+ else:
+ # Update tool Status
+ state.save_stat(tool.status_key, build.FunctionStatus.SUCCESSFUL)
+ state.current_build_tool = None
+
+ finally:
+ # Store tool duration
+ execution_time = time.time() - start_time
+ state.save_stat(tool.duration_key, execution_time)
+
+ if not saved_exception:
+ state.build_status = build.FunctionStatus.SUCCESSFUL
+ state.save_stat(fs.Keys.BUILD_STATUS, build.FunctionStatus.SUCCESSFUL)
+
+ if vars(state).get("invocation_info"):
+ state.invocation_info.status_message = (
+ f"Successful build! {state.invocation_info.extra_status}"
+ )
+ state.invocation_info.status_message_color = printing.Colors.OKGREEN
+
+ if vars(state).get("models_found") and vars(state).get("invocation_info"):
+
+ # Present status statistics from the tools
+ for tool in self.tools:
+ state.invocation_info.stats_keys += tool.status_stats
+
+ print()
+
+ status.recursive_print(
+ models_found=state.models_found,
+ build_name=state.build_name,
+ cache_dir=state.cache_dir,
+ parent_model_hash=None,
+ parent_invocation_hash=None,
+ script_names_visited=[],
+ )
+
+ if lean_cache:
+ printing.log_info("Removing build artifacts...")
+ fs.clean_output_dir(state.cache_dir, state.build_name)
+
+ state.save()
+
+ if saved_exception:
+ raise saved_exception
+
+ printing.log_success(
+ f"\n Saved to **{build.output_dir(state.cache_dir, state.build_name)}**"
+ )
+
+ return state
+
+ def status_line(self, verbosity):
+ """
+ Print a status line in the monitor for every tool in the sequence
+ """
+ for tool in self.tools:
+ tool.status_line(successful=None, verbosity=verbosity)
+
+ @property
+ def info(self) -> Dict[str, Dict]:
+ """
+ Return a dictionary of tool_name:argv for the sequence
+ """
+
+ return {tool.__class__.unique_name: argv for tool, argv in self.tools.items()}
diff --git a/src/turnkeyml/sequence/tool_plugins.py b/src/turnkeyml/sequence/tool_plugins.py
new file mode 100644
index 00000000..a9aa5e8f
--- /dev/null
+++ b/src/turnkeyml/sequence/tool_plugins.py
@@ -0,0 +1,42 @@
+import turnkeyml.tools.export as export
+import turnkeyml.tools.onnx as onnx_tools
+import turnkeyml.common.plugins as plugins
+import turnkeyml.tools.management_tools as mgmt
+from turnkeyml.run.benchmark_model import Benchmark
+from turnkeyml.tools.discovery import Discover
+import turnkeyml.tools.report as report
+from turnkeyml.tools.load_build import LoadBuild
+
+# Plugin interface for sequences
+discovered_plugins = plugins.discover()
+
+# Populated supported sequences dict with builtin sequences
+SUPPORTED_TOOLS = [
+ mgmt.Version,
+ mgmt.Cache,
+ mgmt.ModelsLocation,
+ report.Report,
+ Benchmark,
+ Discover,
+ export.ExportPytorchModel,
+ onnx_tools.OptimizeOnnxModel,
+ onnx_tools.LoadOnnx,
+ onnx_tools.ConvertOnnxToFp16,
+ export.VerifyOnnxExporter,
+ LoadBuild,
+]
+
+# Add sequences from plugins to supported sequences dict
+for module in discovered_plugins.values():
+ if "tools" in module.implements.keys():
+ for tool_class in module.implements["tools"]:
+ if tool_class in SUPPORTED_TOOLS:
+ name = tool_class.__class__.unique_name
+ raise ValueError(
+ f"Your turnkeyml installation has two tools named '{name}' "
+ "installed. You must uninstall one of your plugins that includes "
+ f"{name}. Your imported sequence plugins are: {SUPPORTED_TOOLS}\n"
+ f"This error was thrown while trying to import {module}"
+ )
+
+ SUPPORTED_TOOLS.append(tool_class)
diff --git a/src/turnkeyml/state.py b/src/turnkeyml/state.py
new file mode 100644
index 00000000..2463f5ce
--- /dev/null
+++ b/src/turnkeyml/state.py
@@ -0,0 +1,166 @@
+import os
+import sys
+from typing import Dict, Optional, Any
+import yaml
+import turnkeyml.common.build as build
+import turnkeyml.common.filesystem as fs
+from turnkeyml.version import __version__ as turnkey_version
+
+
+def _is_nice_to_write(value):
+ """
+ Checks whether a value is nice to write to YAML.
+ Returns True if the value is a string, int, float, bool, list, dict, or tuple.
+ Returns False otherwise.
+ """
+ if isinstance(value, (str, int, float, bool)):
+ return True
+ elif isinstance(value, list) or isinstance(value, tuple):
+ # Check if all elements in the list are nice to write
+ return all(_is_nice_to_write(item) for item in value)
+ elif isinstance(value, dict):
+ # Check if all values in the dictionary are nice to write
+ return all(_is_nice_to_write(item) for item in value.values())
+ return False
+
+
+def _sanitize_for_yaml(input_dict: Dict) -> Dict:
+ """
+ Creates a new dictionary containing only nice-to-write values
+ from the original dictionary.
+ """
+ result = {}
+ for key, value in input_dict.items():
+ if _is_nice_to_write(value):
+ result[key] = value
+ return result
+
+
+class State:
+ """
+ The State class is meant to carry build state, starting with the user's
+ initial arguments, through each build Tool in the Sequence, and finally
+ to the disk, where it is used to assess cache hits.
+
+ State is initialized with the key members that are shared by every build,
+ and reasonable default values are assigned as appropriate.
+
+ Tool developers can also add any members they wish. To get or set an
+ attribute, reference it as an attribute:
+ 1. get: `my_variable = state.attribute_name`
+ 2. set: `state.attribute_name = my_variable`
+
+ Build State can be saved and loaded from disk in the form of a state.yaml file
+ via State.save() and load_state(), respectively. Note that while State can
+ contain members of any type, only YAML-safe members (str, int, bool, float,
+ list, dict, tuple) will be saved and loaded.
+ """
+
+ def __init__(
+ self,
+ cache_dir: str,
+ monitor: Optional[bool] = None,
+ build_name: Optional[str] = None,
+ sequence_info: Dict[str, Dict] = None,
+ **kwargs,
+ ):
+
+ # Allow monitor to be globally disabled by an environment variable
+ if monitor is None:
+ if os.environ.get("TURNKEY_BUILD_MONITOR") == "False":
+ monitor_setting = False
+ else:
+ monitor_setting = True
+ else:
+ monitor_setting = monitor
+
+ # The default model name is the name of the python file that calls build_model()
+ if build_name is None:
+ build_name = os.path.basename(sys.argv[0])
+
+ # Support "~" in the cache_dir argument
+ parsed_cache_dir = os.path.expanduser(cache_dir)
+
+ # Save settings as State members
+ self.monitor = monitor_setting
+ self.cache_dir = parsed_cache_dir
+ self.build_name = build_name
+ self.sequence_info = sequence_info
+ self.turnkey_version = turnkey_version
+ self.build_status = build.FunctionStatus.NOT_STARTED
+ self.downcast_applied = False
+ self.uid = build.unique_id()
+ self.results = None
+
+ # Store any additional kwargs as members
+ for key, value in kwargs.items():
+ self.__dict__[key] = value
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ """
+ Tool developers can add a new member to State by simply
+ assigning it as an attribute, i.e., `state.new_member = value`.
+ """
+ return super().__setattr__(name, value)
+
+ def save_stat(self, key: str, value):
+ """
+ Save statistics to an yaml file in the build directory
+ """
+
+ stats = fs.Stats(self.cache_dir, self.build_name)
+ stats.save_stat(key, value)
+
+ def save_sub_stat(self, parent_key: str, key: str, value):
+ """
+ Save statistics to an yaml file in the build directory
+ """
+
+ stats = fs.Stats(self.cache_dir, self.build_name)
+ stats.save_sub_stat(parent_key, key, value)
+
+ def save(self):
+ """
+ Save all YAML-friendly members to disk as a state.yaml file.
+
+ Note that `model` and `inputs` will typically not be saved since
+ they are typically in non-YAML-friendly types such as `torch.nn.Module`
+ and `torch.tensor`.
+ """
+
+ state_to_save = _sanitize_for_yaml(vars(self))
+
+ # Create a build directory in the cache
+ fs.make_build_dir(self.cache_dir, self.build_name)
+
+ with open(
+ build.state_file(self.cache_dir, self.build_name),
+ "w",
+ encoding="utf8",
+ ) as outfile:
+ yaml.dump(state_to_save, outfile)
+
+
+def load_state(
+ cache_dir=None,
+ build_name=None,
+ state_path=None,
+) -> State:
+ """
+ Read a state.yaml file corresponding to a specific build in a specific
+ cache, and use its contents to initialize a State instance.
+ """
+
+ if state_path is not None:
+ file_path = state_path
+ elif build_name is not None and cache_dir is not None:
+ file_path = build.state_file(cache_dir, build_name)
+ else:
+ raise ValueError(
+ "This function requires either build_name and cache_dir to be set, "
+ "or state_path to be set, not both or neither"
+ )
+
+ state_dict = build.load_yaml(file_path)
+
+ return State(**state_dict)
diff --git a/src/turnkeyml/tools/__init__.py b/src/turnkeyml/tools/__init__.py
new file mode 100644
index 00000000..107f87e5
--- /dev/null
+++ b/src/turnkeyml/tools/__init__.py
@@ -0,0 +1 @@
+from .tool import Tool, FirstTool, NiceHelpFormatter
diff --git a/src/turnkeyml/tools/discovery/__init__.py b/src/turnkeyml/tools/discovery/__init__.py
new file mode 100644
index 00000000..945a14a2
--- /dev/null
+++ b/src/turnkeyml/tools/discovery/__init__.py
@@ -0,0 +1 @@
+from .discover import Discover
diff --git a/src/turnkeyml/tools/discovery/discover.py b/src/turnkeyml/tools/discovery/discover.py
new file mode 100644
index 00000000..45e5b30d
--- /dev/null
+++ b/src/turnkeyml/tools/discovery/discover.py
@@ -0,0 +1,252 @@
+import argparse
+import copy
+import os
+import inspect
+from typing import Optional, List
+import torch
+from turnkeyml.tools import FirstTool
+import turnkeyml.common.build as build
+import turnkeyml.common.exceptions as exp
+import turnkeyml.common.filesystem as fs
+from turnkeyml.tools.discovery.script import (
+ evaluate_script,
+ TracerArgs,
+)
+import turnkeyml.common.printing as printing
+from turnkeyml.state import State
+
+
+default_max_depth = 0
+
+
+class Discover(FirstTool):
+ """
+ Discover the PyTorch models and their corresponding inputs in a python script (.py)
+ and send one model/inputs pair onwards into the sequence.
+
+ Expected inputs:
+ - Input file is a python script (.py file) that invokes at least one PyTorch model
+
+ Outputs:
+ - state.results: a PyTorch model instance (torch.nn.Module)
+ - state.inputs: a dictionary of example inputs to the model's forward function,
+ e.g., model(**inputs)
+
+ You can learn more about how discovery and its arguments work at
+ https://github.com/onnx/turnkeyml/blob/main/docs/tools_user_guide.md
+ """
+
+ unique_name = "discover"
+
+ def __init__(self):
+ super().__init__(monitor_message="Discovering PyTorch models")
+
+ @staticmethod
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
+ parser = __class__.helpful_parser(
+ short_description="Discover the PyTorch models in a python script",
+ add_help=add_help,
+ )
+
+ parser.add_argument(
+ "--script-args",
+ dest="script_args",
+ type=str,
+ help="Arguments to pass into the target script(s)",
+ )
+
+ parser.add_argument(
+ "--max-depth",
+ dest="max_depth",
+ type=int,
+ default=default_max_depth,
+ help="Maximum depth to analyze within the model structure of the target script(s)",
+ )
+
+ return parser
+
+ def parse(self, state: State, args, known_only=True) -> argparse.Namespace:
+ parsed_args = super().parse(state, args, known_only)
+
+ file_path, targets, encoded_input = fs.decode_input_arg(parsed_args.input)
+
+ parsed_args.input = file_path
+
+ if len(targets) > 1:
+ raise exp.ArgError(
+ "Only one target (number after the ::) is allowed, "
+ f"but received {encoded_input}"
+ )
+ elif len(targets) == 1:
+ parsed_args.target = targets[0]
+ else: # len(targets)==0
+ parsed_args.target = None
+
+ return parsed_args
+
+ def run(
+ self,
+ state: State,
+ input: str = "",
+ target: Optional[List[str]] = None,
+ script_args: str = "",
+ max_depth: int = default_max_depth,
+ ):
+ if not input.endswith(".py"):
+ raise exp.ArgError(
+ "Inputs to the `discover` tool must by python scripts "
+ f"(.py files), got {input}",
+ )
+
+ if target is None:
+ target_to_use = []
+ else:
+ target_to_use = [target]
+
+ tracer_args = TracerArgs(
+ input=input,
+ script_args=script_args,
+ targets=target_to_use,
+ max_depth=max_depth,
+ )
+
+ # Discover the models in the python script by executing it with
+ # a tracer enabled
+ state.models_found = evaluate_script(tracer_args)
+
+ # Count the amount of build-able model invocations discovered
+ # If there is only 1, pass it to the next build tool. Otherwise,
+ # print all the invocations and suggest that the user select one.
+ count = 0
+ for model_info in state.models_found.values():
+ for (
+ invocation_hash,
+ invocation_info,
+ ) in model_info.unique_invocations.items():
+ count += 1
+
+ # Set the same status for all invocations at first
+ # The next code block will be responsible for the selected
+ # invocation.
+
+ invocation_info.status_message = (
+ "Discovered; select with `-i "
+ f"{os.path.basename(input)}::{invocation_hash}"
+ )
+ invocation_info.status_message_color = printing.Colors.OKCYAN
+
+ # The potential outcomes of target selection are:
+ # Case 1. Input file has only one model, so we select it and don't
+ # bother the user about target selection
+ # Case 2. Input file has more than one model, and...
+ # a. user didn't select a target, so we auto-select the
+ # least-deep (last discovered) model and let the user
+ # know about target selection
+ # b. user selected a target, so we run with it
+ # Case 3. Exception: Input file contained no models
+ # Case 4. Exception: input file has one or more model, but user
+ # selected an invalid target
+ #
+ # The purpose of this loop is to identify which of those cases is
+ # active.
+
+ if count == 0:
+ # Case 3
+ raise exp.ToolError(f"No models discovered in input file {input}")
+
+ model_selected = None
+ invocation_selected = None
+ valid_hashes = []
+ case_1 = target is None and count == 1
+ case_2a = target is None and count > 1
+ for model_info in state.models_found.values():
+ for invocation_info in model_info.unique_invocations.values():
+ valid_hashes.append(invocation_info.invocation_hash)
+
+ case_2b = (
+ target is not None and invocation_info.invocation_hash == target
+ )
+
+ if any([case_1, case_2b]):
+ model_selected = model_info
+ state.invocation_info = invocation_info
+ break
+ if case_2a:
+ # Point to the most recent model and invocation identified
+ # We do this so that we can auto-select the last model and invocation
+ # that was discovered, which is typically the least-deep model
+ # because discovery is recursive.
+ model_selected = model_info
+ invocation_selected = invocation_info
+
+ if vars(state).get("invocation_info") is not None:
+ # If we have already selected then there is no need to keep iterating
+ break
+
+ if model_selected is None:
+ # Case 4
+ raise exp.ToolError(
+ f"Hash {target} was selected, but the only "
+ f"valid hashes are {valid_hashes}"
+ )
+
+ if case_2a:
+ state.invocation_info = invocation_selected
+ state.invocation_info.extra_status = (
+ "(auto-selected; select manually with "
+ f"`-i {os.path.basename(input)}"
+ f"::{state.invocation_info.invocation_hash})"
+ )
+
+ # Save stats about the model
+ state.save_stat(
+ fs.Keys.HASH,
+ model_selected.hash,
+ )
+ state.save_stat(
+ "selected_invocation_hash",
+ state.invocation_info.invocation_hash,
+ )
+ state.save_stat(
+ fs.Keys.MODEL_NAME,
+ tracer_args.script_name,
+ )
+ state.save_stat(
+ fs.Keys.PARAMETERS,
+ model_selected.params,
+ )
+
+ state.save_stat(
+ fs.Keys.CLASS,
+ type(model_selected.model).__name__,
+ )
+
+ # Organize the inputs to python model instances
+ args, kwargs = state.invocation_info.inputs
+ inputs = {}
+ for k in kwargs.keys():
+ if torch.is_tensor(kwargs[k]):
+ inputs[k] = torch.tensor(kwargs[k].detach().numpy())
+ else:
+ inputs[k] = copy.deepcopy(kwargs[k])
+
+ # Convert all positional arguments into keyword arguments
+ if args != ():
+
+ forward_function = model_info.model.forward
+ all_args = list(inspect.signature(forward_function).parameters.keys())
+ for i in range(len(args)):
+ if torch.is_tensor(args[i]):
+ inputs[all_args[i]] = torch.tensor(args[i].detach().numpy())
+ else:
+ inputs[all_args[i]] = args[i]
+
+ # Pass the model and inputs to the next tool
+ state.results = model_selected.model
+ state.model_hash = build.hash_model(model_selected.model)
+ state.expected_input_shapes, state.expected_input_dtypes = (
+ build.get_shapes_and_dtypes(inputs)
+ )
+ state.inputs = inputs
+
+ return state
diff --git a/src/turnkeyml/tools/discovery/script.py b/src/turnkeyml/tools/discovery/script.py
new file mode 100644
index 00000000..378220e5
--- /dev/null
+++ b/src/turnkeyml/tools/discovery/script.py
@@ -0,0 +1,465 @@
+import sys
+import os
+import inspect
+import importlib.util
+import time
+import shlex
+import functools
+import dataclasses
+import traceback
+import hashlib
+from typing import Union, List, Dict, Tuple, Optional
+from types import FrameType, TracebackType
+import torch
+import turnkeyml.common.build as build
+import turnkeyml.common.status as status
+import turnkeyml.common.analyze_model as analyze_model
+import turnkeyml.common.filesystem as fs
+
+
+def _get_classes(module) -> List[str]:
+ """
+ Returns all classes within a module.
+ """
+ return [y for x, y in inspect.getmembers(module, inspect.isclass)]
+
+
+def _get_transformers_activations() -> List:
+ """
+ We need this helper because transformers is not a required depenence for
+ this project, however if we are analyzing a transformers model then we need
+ to inspect its activations.
+ """
+ if "transformers" in sys.modules:
+ return _get_classes(sys.modules["transformers"].activations)
+ else:
+ return []
+
+
+@dataclasses.dataclass
+class TracerArgs:
+ input: str
+ script_args: str
+ targets: List[str]
+ max_depth: int
+ models_found: Dict[str, status.ModelInfo] = dataclasses.field(default_factory=dict)
+ script_name: Optional[str] = None
+
+ @functools.cached_property
+ def torch_activations(self) -> List[str]:
+ act = _get_classes(torch.nn.modules.activation)
+ act += _get_transformers_activations()
+ return act
+
+ @property
+ def hash(self) -> str:
+ """
+ Returns a unique hash representing the arguments. Useful for distinguishing
+ between evaluations of the same model that have different arguments.
+ """
+
+ return hashlib.sha256(str(self).encode()).hexdigest()[:8]
+
+
+def get_model_hash(model: Union[torch.nn.Module, str]):
+ if isinstance(model, str) and model.endswith(".onnx"):
+ hash_params = True
+ else:
+ hash_params = False
+ return build.hash_model(model, hash_params=hash_params)[:8]
+
+
+def get_invocation_hash(
+ model_hash: str, parent_invocation_hash: str, args: Tuple, kwargs: Dict
+) -> str:
+ """
+ Combines the model hash and the input shapes to create the invocation hash
+ We also ensure that invocations that come from different parents have different hashes
+ """
+
+ # Merge positional and keyword args
+ args = {"Positional Arg {}".format(i + 1): arg for i, arg in enumerate(args)}
+ kwargs = {**kwargs, **args}
+
+ # Get input shapes and types
+ input_shapes, input_dtypes = build.get_shapes_and_dtypes(kwargs)
+
+ hashable_content = (
+ f"{model_hash}{parent_invocation_hash}{input_shapes}{input_dtypes}"
+ )
+ return hashlib.sha256(hashable_content.encode()).hexdigest()[:8], input_shapes
+
+
+def store_model_info(
+ model: torch.nn.Module,
+ model_name: str,
+ frame: FrameType,
+ event: str,
+ tracer_args: TracerArgs,
+ depth: int,
+ parent_hash: str,
+):
+ model_hash = get_model_hash(model)
+
+ # File where the model was found
+ file = str(frame)[str(frame).find("file ") + 6 : str(frame).find("',")]
+
+ # Line where the model was found
+ line = frame.f_lineno if event == "return" else frame.f_lineno - 1
+
+ # Keep track of all models details
+
+ # If we have already found a model, don't add it to models_found again
+ # We have to use both the model hash and the script name, since we don't
+ # want to ignore a model if it was explicitly called in two different scripts
+ identifier = f"{model_hash}_{tracer_args.script_name}"
+ model_already_found = False
+ for model_info in tracer_args.models_found.values():
+ if identifier == f"{model_info.hash}_{model_info.script_name}":
+ model_already_found = True
+
+ if not model_already_found:
+ tracer_args.models_found[model_hash] = status.ModelInfo(
+ model=model,
+ name=model_name,
+ file=file,
+ line=line,
+ depth=depth,
+ hash=model_hash,
+ parent_hash=parent_hash,
+ script_name=tracer_args.script_name,
+ )
+
+
+def explore_frame(
+ frame,
+ event,
+ local_var_name,
+ local_var,
+ tracer_args: TracerArgs,
+ depth: int = 0,
+ parent_hash: Union[str, None] = None,
+):
+ """
+ This function checks whether local_var is a torch model.
+ If it is, we will modify its forward function to know when it
+ is called.
+ """
+
+ # Exit frame exploration if Python is shutting down
+ if not bool(sys.modules):
+ return
+
+ # Skip all variables that are not a subclass of torch.nn.Module
+ # Note: try block used since dead weakreferences fail when checking subclass
+ try:
+ if issubclass(type(local_var), torch.nn.Module):
+ if type(local_var) in tracer_args.torch_activations:
+ return
+ else:
+ return
+ except AttributeError:
+ return
+
+ # Skip self variable and variable names commonly used by child models
+ if (
+ local_var_name == "self"
+ or local_var_name == "instance"
+ or local_var_name == "child"
+ or local_var_name == "layer"
+ or local_var_name == "module"
+ ):
+ return
+
+ # Check if we are inside of a subclass of torch.nn.Module
+ inside_class = False
+ inside_nn_subclass = False
+ if "self" in frame.f_locals:
+ self_var = frame.f_locals["self"]
+ inside_class = type(self_var)
+ inside_nn_subclass = issubclass(inside_class, torch.nn.Module)
+
+ if not inside_nn_subclass:
+ if hasattr(local_var, "forward_instrumented"):
+
+ # Starting in version 2.2.0, torch dynamo added wrappers to callbacks
+ # while tracing frames, which conflicts with TurnkeML's analysis. Here,
+ # we suppress errors caused by those callback wrappers and only raise an
+ # error if the compiled model actually tries to execute within TurnkeyML.
+ td = torch._dynamo # pylint: disable=protected-access
+ td.config.suppress_errors = True
+ if hasattr(td.eval_frame, "guarded_backend_cache"):
+ td.eval_frame.guarded_backend_cache.skip_backend_check_for_run_only_mode = (
+ True
+ )
+
+ return
+
+ # Avoid instrumenting models before they have been fully loaded
+ if analyze_model.count_parameters(local_var) == 0:
+ return
+
+ # Mark this model as instrumented
+ local_var.forward_instrumented = True
+
+ # Create a copy of the old forward function
+ old_forward = local_var.forward
+
+ # Recursively look for sub-models within the found model
+ # This is only possible on Pytorch, since each layer of a torch.nn.module
+ # is also a torch.nn.module.
+ model_hash = get_model_hash(local_var)
+ local_var.turnkey_hash = model_hash
+ if depth < tracer_args.max_depth:
+ recursive_search(frame, event, local_var, depth, model_hash, tracer_args)
+
+ # We can keep track of Pytorch models even before they are executed
+ store_model_info(
+ local_var,
+ local_var_name,
+ frame,
+ event,
+ tracer_args,
+ depth,
+ parent_hash,
+ )
+
+ local_var.old_forward = old_forward
+
+ def forward_spy(*args, **kwargs):
+ tracer = sys.getprofile()
+ if tracer is not None:
+ # Turn tracing off while the model is being executed for speed
+ sys.setprofile(None)
+ elif depth == 0:
+ # If this model is being executed and the tracing is already off
+ # we are calling a module within a parent module. We only run
+ # on child models if the user has explicitly asked us to
+ # do so by setting the max_depth flag.
+ return old_forward(*args, **kwargs)
+
+ # Get parent invocation hash
+ parent_invocation_hash = None
+ if parent_hash:
+ parent_invocation_hash = tracer_args.models_found[
+ parent_hash
+ ].last_unique_invocation_executed
+
+ model_hash = get_model_hash(local_var)
+ invocation_hash, input_shapes = get_invocation_hash(
+ model_hash, parent_invocation_hash, args, kwargs
+ )
+ model_info = tracer_args.models_found[model_hash]
+
+ if invocation_hash not in model_info.unique_invocations:
+ model_info.unique_invocations[invocation_hash] = (
+ status.UniqueInvocationInfo(
+ name=model_info.name,
+ script_name=model_info.script_name,
+ file=model_info.file,
+ line=model_info.line,
+ params=model_info.params,
+ depth=model_info.depth,
+ model_class=type(model_info.model),
+ invocation_hash=invocation_hash,
+ hash=model_info.hash,
+ is_target=invocation_hash in tracer_args.targets
+ or len(tracer_args.targets) == 0,
+ input_shapes=input_shapes,
+ parent_hash=parent_invocation_hash,
+ inputs=[args, kwargs],
+ extension=f".{tracer_args.input.split('.')[-1]}",
+ forward_function_pointer=local_var.forward,
+ original_forward_function=old_forward,
+ )
+ )
+ model_info.last_unique_invocation_executed = invocation_hash
+
+ # Keep track of execution time
+ start_time = time.time()
+ outputs = old_forward(*args, **kwargs)
+ end_time = time.time()
+
+ invocation_info = model_info.unique_invocations[invocation_hash]
+ invocation_info.exec_time = (
+ invocation_info.exec_time + end_time - start_time
+ )
+ invocation_info.executed = invocation_info.executed + 1
+
+ # Turn tracing on again after computing the outputs
+ sys.setprofile(tracer)
+
+ return outputs
+
+ # The inspect module offers the ability to actually copy the signature of the wrapped
+ # function. This allows other functions to see the correct parameters instead of the
+ # enigmatic *args, **kwargs.
+ forward_spy.__signature__ = inspect.signature(old_forward)
+
+ # Use modified forward/call function
+ local_var.forward = forward_spy
+
+
+def tracefunc(
+ frame: FrameType, event: str, _, tracer_args: TracerArgs
+) -> TracebackType:
+ """
+ This function is used to trace the program as it runs in order
+ to keep track of all all instantiated models.
+ This function is passed to sys.setprofile() as a callback function.
+ It receives three arguments:
+ frame (the stack frame from the code being run),
+ event (a string naming the type of notification), and
+ arg (an event-specific value)
+
+ """
+
+ # Create a copy of f_locals.keys() to avoid errors due to dict changing
+ local_names = list(frame.f_locals.keys())
+
+ # Loop over all local variables to check if new models can be found
+ for local_var_name in local_names:
+ explore_frame(
+ frame,
+ event,
+ local_var_name,
+ frame.f_locals[local_var_name],
+ tracer_args=tracer_args,
+ depth=0,
+ )
+
+ return tracefunc
+
+
+def recursive_search(
+ frame: FrameType,
+ event: str,
+ model: torch.nn.Module,
+ depth: int,
+ parent_hash: Union[str, None],
+ tracer_args: TracerArgs,
+):
+ """
+ Recursively check for submodels within found models
+ """
+ element_names = list(dict(model.named_modules()).keys())[1:]
+ for element_name in element_names:
+ if hasattr(model, element_name):
+ element = getattr(model, element_name)
+ if issubclass(type(element), torch.nn.Module):
+ explore_frame(
+ frame,
+ event,
+ element_name,
+ element,
+ tracer_args,
+ depth=depth + 1,
+ parent_hash=parent_hash,
+ )
+
+
+@dataclasses.dataclass
+class HelpfulHandler:
+ # Type of exception to handle
+ exc_type: Exception
+ # Do not print any traceback after this message is encountered
+ traceback_stop_msg: str
+ # Message to print that gives context to the traceback
+ helpful_msg: str
+
+
+class AnalysisException(Exception):
+ pass
+
+
+class HelpfulExceptions:
+ """
+ Catch certain exceptions, defined by `HelpfulHandler`s, and print a more helpful
+ error message and traceback than what would ordinarily be printed out. This is
+ useful to avoid showing the user a giant traceback that goes all the way through
+ our profiling code.
+ """
+
+ def __init__(self, exceptions_to_handle: List[HelpfulHandler]):
+ self.excs = exceptions_to_handle
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, _exc_value, exc_tb):
+ for exc_handler in self.excs:
+ if exc_type == exc_handler.exc_type:
+ # Search the full traceback for the traceback_stop_msg
+ tb = traceback.format_tb(exc_tb)
+
+ # This default value of offending_line makes it so we will print
+ # the entire traceback if we can't find the traceback_stop_msg
+ offending_line = -2
+ for i, line in enumerate(tb):
+ if exc_handler.traceback_stop_msg in line:
+ offending_line = i
+
+ # Eliminate the lines of traceback before and after the traceback_stop_msg
+ # Typically, the lines that follow will be related to our profiling
+ # code and not helpful to the user
+
+ # Start the helpful_traceback after line 3, since the first 3 lines are related
+ # to our profiler
+ start_line = 3
+ helpful_traceback = "\n".join(tb[start_line : offending_line + 1])
+
+ # sys.tracebacklimit = 0 prevents the unwanted traceback from printing
+ # when we raise our AnalysisException
+ sys.tracebacklimit = 0
+ raise AnalysisException(
+ f"{exc_handler.helpful_msg}\n\nTraceback: \n\n: {helpful_traceback}"
+ )
+
+
+def evaluate_script(tracer_args: TracerArgs) -> Dict[str, status.ModelInfo]:
+ tracer_args.script_name = fs.clean_file_name(tracer_args.input)
+
+ # Get a pointer to the script's python module
+ spec = importlib.util.spec_from_file_location("__main__", tracer_args.input)
+ module = importlib.util.module_from_spec(spec)
+
+ # Overwriting argv to import input script using "input-args"
+ if tracer_args.script_args is None:
+ tracer_args.script_args = []
+ else:
+ tracer_args.script_args = shlex.split(tracer_args.script_args)
+ sys.argv = [tracer_args.input] + tracer_args.script_args
+ sys.path.append(os.getcwd())
+
+ # Create a tracer object that bundles a callback function with some args
+ tracer = functools.partial(tracefunc, tracer_args=tracer_args)
+
+ # Enabling analysis via setprofile
+ sys.setprofile(tracer)
+
+ # Import input script. Each executed frame of the input script will
+ # trigger the tracefunc() callback function (defined above)
+ with HelpfulExceptions(
+ [
+ HelpfulHandler(
+ torch.jit.frontend.NotSupportedError,
+ "torch.jit.script(",
+ "torch.jit.script() is not supported by turnkey CLI and benchmark_files() API, "
+ "however torch.jit.script() is being called in your script."
+ "You can try passing your model instance into the build_model() API instead. ",
+ )
+ ]
+ ):
+ spec.loader.exec_module(module)
+
+ # Stop profiling when we're done executing the module
+ sys.setprofile(None)
+
+ # Restore the original forward function for all models
+ for model_info in tracer_args.models_found.values():
+ for invocation_info in model_info.unique_invocations.values():
+ invocation_info.forward_function_pointer = (
+ invocation_info.original_forward_function
+ )
+
+ return tracer_args.models_found
diff --git a/src/turnkeyml/tools/export.py b/src/turnkeyml/tools/export.py
new file mode 100644
index 00000000..e3da96a9
--- /dev/null
+++ b/src/turnkeyml/tools/export.py
@@ -0,0 +1,252 @@
+import os
+import inspect
+import warnings
+import sys
+import copy
+import argparse
+import torch
+import torch.onnx.verification
+from turnkeyml.tools import Tool
+import turnkeyml.common.exceptions as exp
+import turnkeyml.common.build as build
+import turnkeyml.common.tensor_helpers as tensor_helpers
+import turnkeyml.common.onnx_helpers as onnx_helpers
+import turnkeyml.common.filesystem as fs
+from turnkeyml.state import State
+
+
+def _warn_to_stdout(message, category, filename, line_number, _, line):
+ sys.stdout.write(
+ warnings.formatwarning(message, category, filename, line_number, line)
+ )
+
+
+def base_onnx_file(state: State):
+ return os.path.join(
+ onnx_helpers.onnx_dir(state),
+ f"{state.build_name}-op{state.onnx_opset}-base.onnx",
+ )
+
+
+class ExportPytorchModel(Tool):
+ """
+ Tool that takes a PyTorch model instance, from the state of the previous
+ tool in the sequence, and exports it to an ONNX file.
+
+ Expected inputs:
+ - state.results: torch.nn.Module or torch.jit.ScriptModule
+ - state.inputs: dict that represents valid kwargs to the forward
+ function of state.results
+
+ Outputs:
+ - state.results: a *-base.onnx file that implements state.results
+ given state.inputs
+ """
+
+ unique_name = "export-pytorch"
+
+ def __init__(self):
+ super().__init__(monitor_message="Exporting PyTorch to ONNX")
+
+ @staticmethod
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
+ parser = __class__.helpful_parser(
+ short_description="Export a PyTorch model to ONNX",
+ add_help=add_help,
+ )
+
+ parser.add_argument(
+ "--opset",
+ type=int,
+ default=build.DEFAULT_ONNX_OPSET,
+ help=f"ONNX opset to export into (default: {build.DEFAULT_ONNX_OPSET})",
+ )
+
+ return parser
+
+ def run(self, state: State, opset: int = build.DEFAULT_ONNX_OPSET):
+ if not isinstance(state.results, (torch.nn.Module, torch.jit.ScriptModule)):
+ msg = f"""
+ The current tool (ExportPytorchModel) is only compatible with
+ models of type torch.nn.Module or torch.jit.ScriptModule, however
+ the tool received a model of type {type(state.results)}.
+ """
+ raise exp.ToolError(msg)
+
+ state.onnx_opset = opset
+
+ # The `torch.onnx.export()` function accepts a tuple of positional inputs
+ # followed by a dictionary with all keyword inputs.
+ # The dictionary must be last item in tuple.
+ user_provided_args = list(state.inputs.keys())
+
+ if isinstance(state.results, torch.nn.Module):
+ # Validate user provided args
+ all_args = list(inspect.signature(state.results.forward).parameters.keys())
+
+ for inp in user_provided_args:
+ if inp not in all_args:
+ msg = f"""
+ Input name {inp} not found in the model's forward method. Available
+ input names are: {all_args}"
+ """
+ raise ValueError(msg)
+
+ # Most pytorch models have args that are kind = positional_or_keyword.
+ # The `torch.onnx.export()` function accepts model args as
+ # (all_positional_args_value,{keyword_arg:value}).
+ # To map the input_args correctly and to build an accurate model
+ # the order of the input_names must reflect the order of the model args.
+
+ # Collect order of pytorch model args.
+ all_args_order_mapping = {arg: idx for idx, arg in enumerate(all_args)}
+
+ # Sort the user provided inputs with respect to model args and store as tuple.
+ sorted_user_inputs = sorted(
+ user_provided_args, key=lambda x: all_args_order_mapping[x]
+ )
+ dummy_input_names = tuple(sorted_user_inputs)
+
+ # If a single input is provided torch.onnx.export will
+ # not accept a dictionary, so pop the first arg
+ user_args = copy.deepcopy(state.inputs)
+ first_input = user_args.pop(dummy_input_names[0])
+
+ # Create tuple: (first input, {rest of user_args dict as keyword args})
+ dummy_inputs = (first_input, user_args)
+
+ else: # state.results is a torch.jit.ScriptModule
+ dummy_inputs = tuple(state.inputs.values())
+
+ # Collect input names
+ dummy_input_names = tuple(state.inputs.keys())
+
+ # Send torch export warnings to stdout (and therefore the log file)
+ # so that they don't fill up the command line
+ default_warnings = warnings.showwarning
+ warnings.showwarning = _warn_to_stdout
+
+ # Export the model to ONNX
+ output_path = base_onnx_file(state)
+ os.makedirs(onnx_helpers.onnx_dir(state), exist_ok=True)
+
+ torch.onnx.export(
+ state.results,
+ dummy_inputs,
+ output_path,
+ input_names=dummy_input_names,
+ do_constant_folding=True,
+ opset_version=opset,
+ verbose=False,
+ )
+
+ # Save output names to ensure we are preserving the order of the outputs
+ state.expected_output_names = onnx_helpers.get_output_names(output_path)
+
+ # Restore default warnings behavior
+ warnings.showwarning = default_warnings
+
+ tensor_helpers.save_inputs(
+ [state.inputs],
+ onnx_helpers.original_inputs_file(state.cache_dir, state.build_name),
+ downcast=False,
+ )
+
+ # Check the if the base mode has been exported successfully
+ success_msg = "\tSuccess exporting model to ONNX"
+ fail_msg = "\tFailed exporting model to ONNX"
+
+ if onnx_helpers.check_model(output_path, success_msg, fail_msg):
+ state.results = output_path
+
+ state.save_stat(
+ fs.Keys.ONNX_FILE,
+ output_path,
+ )
+ else:
+ msg = f"""
+ Unable to export model to ONNX using Torch's ONNX exporter.
+ We recommend that you modify your model until it is
+ compatible with this third party software, then re-run.
+ More information may be available in the log file at **{self.logfile_path}**
+ """
+ raise exp.ToolError(msg)
+
+ return state
+
+
+class VerifyOnnxExporter(Tool):
+ """
+ Tool that runs a parity test on an input PyTorch model and an ONNX
+ file derived from that model.
+
+ Note that the derived ONNX file is discarded by the verification API,
+ so we can't use it in downstream Tools. To use this tool in the same sequence
+ as other build tools, we recommend:
+ discover -> verify-exporter -> export-pytorch -> other tools
+
+ Expected inputs:
+ - state.results: torch.nn.Module or torch.jit.ScriptModule
+ - state.inputs: dict that represents valid kwargs to the forward
+ function of state.results
+
+ Outputs: No change to state
+ """
+
+ unique_name = "verify-exporter"
+
+ def __init__(self):
+ super().__init__(monitor_message="Verifying ONNX exporter")
+
+ @staticmethod
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
+ parser = __class__.helpful_parser(
+ short_description="Verify if model can be exported to ONNX without major "
+ "numerical discrepancies",
+ add_help=add_help,
+ )
+
+ return parser
+
+ def run(self, state: State):
+
+ # Verify if the exported model matches the input torch model
+ try:
+ # Tolerance levels for the torch export are recommended by Pytorch here:
+ # https://pytorch.org/docs/stable/testing.html#module-torch.testing
+ fp32_tolerance = torch.onnx.verification.VerificationOptions(
+ rtol=1.3e-6, atol=1e-5
+ )
+
+ # The `torch.onnx.verification.find_mismatch()` takes input arguments to the
+ # model as `input_args (Tuple[Any, ...])`
+ export_verification = torch.onnx.verification.find_mismatch(
+ state.results,
+ tuple(state.inputs.values()),
+ opset_version=state.onnx_opset,
+ options=fp32_tolerance,
+ )
+
+ # `export_verification.has_mismatch()` returns True if a mismatch is found and
+ # False otherwise. If no mismatch is found,# `is_export_valid` is set to "Valid",
+ # indicating successful verification.
+ # If a mismatch is found, `is_export_valid` is set to "Invalid", indicating
+ # the verification failed.
+ if not export_verification.has_mismatch():
+ is_export_valid = "valid"
+ else:
+ is_export_valid = "invalid"
+
+ # The except block catches any type of exception that might occur during the
+ # verification process. If any exception occurs,`is_export_valid` is set to
+ # "Unverified", indicating that the verification process could not be completed,
+ # and therefore the model's export status is unverified.
+ except Exception: # pylint: disable=broad-except
+ is_export_valid = "unverified"
+
+ state.save_stat(
+ fs.Keys.TORCH_ONNX_EXPORT_VALIDITY,
+ is_export_valid,
+ )
+
+ return state
diff --git a/src/turnkeyml/tools/load_build.py b/src/turnkeyml/tools/load_build.py
new file mode 100644
index 00000000..b3566f51
--- /dev/null
+++ b/src/turnkeyml/tools/load_build.py
@@ -0,0 +1,199 @@
+import pathlib
+import copy
+import argparse
+from typing import Union, Dict
+from turnkeyml.tools import FirstTool
+import turnkeyml.common.exceptions as exp
+import turnkeyml.common.build as build
+import turnkeyml.common.filesystem as fs
+from turnkeyml.common.status import ModelInfo, UniqueInvocationInfo
+from turnkeyml.state import State, load_state
+import turnkeyml.common.printing as printing
+from turnkeyml.version import __version__ as turnkey_version
+
+skip_policy_default = "attempted"
+
+
+def _decode_version_number(version: str) -> Dict[str, int]:
+ numbers = [int(x) for x in version.split(".")]
+ return {"major": numbers[0], "minor": numbers[1], "patch": numbers[0]}
+
+
+class LoadBuild(FirstTool):
+ """
+ Tool that loads a build from a previous usage of TurnkeyML and passes
+ its saved State on to the next tool in the sequence.
+
+ Works best with build State that is complete on disk.
+
+ For example:
+ - State that references an ONNX file is a good target, because the ONNX file can
+ be loaded from disk.
+ - State that references a PyTorch model in memory is a poor target, because
+ that PyTorch model will not be available when the State file is loaded
+ from disk.
+
+ Expected inputs:
+ - Input file is a *_state.yaml file in a turnkey cache build directory
+
+ Outputs:
+ - State has the contents of the state.yaml file of the target build.
+ """
+
+ unique_name = "load-build"
+
+ def __init__(self):
+ super().__init__(monitor_message="Loading cached build")
+
+ @staticmethod
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
+ parser = __class__.helpful_parser(
+ short_description="Load build state from the cache",
+ add_help=add_help,
+ )
+
+ parser.add_argument(
+ "--skip-policy",
+ choices=[skip_policy_default, "failed", "successful", "none"],
+ help="Sets the policy for skipping evaluation attempts "
+ f"(defaults to {skip_policy_default})."
+ "`attempted` means to skip any previously-attempted evaluation, "
+ "whether it succeeded or failed."
+ "`failed` skips evaluations that have already failed once."
+ "`successful` skips evaluations that have already succeeded."
+ "`none` will attempt all evaluations, regardless of whether "
+ "they were previously attempted.",
+ required=False,
+ default=skip_policy_default,
+ )
+
+ return parser
+
+ def run(self, state: State, input: str = "", skip_policy=skip_policy_default):
+
+ # Extract the cache directory, build directory, and build name from the input
+ source_build_dir = pathlib.Path(input).parent
+ source_build_dir_name = source_build_dir.name
+ source_cache_dir = source_build_dir.parent
+
+ # Make sure that the target yaml file is actually the state of a turnkey build
+ if not fs.is_build_dir(source_cache_dir, source_build_dir_name):
+ raise exp.CacheError(
+ f"No build found at path: {input}. "
+ "Try running `turnkey cache --list --all` to see the builds in your build cache."
+ )
+
+ # Record the new sequence's information so that we can append it to the
+ # loaded build's sequence information later
+ new_sequence_info = state.sequence_info
+
+ # Load the cached build
+ printing.log_info(f"Attempting to load: {input}")
+ state = load_state(state_path=input)
+
+ # Record the sequence used for the loaded build so that we examine it later
+ prior_selected_sequence = list(state.sequence_info.keys())
+
+ # Raise an exception if there is a version mismatch between the installed
+ # version of turnkey and the version of turnkey used to create the loaded
+ # build
+ current_version_decoded = _decode_version_number(turnkey_version)
+ state_version_decoded = _decode_version_number(state.turnkey_version)
+ out_of_date: Union[str, bool] = False
+ if current_version_decoded["major"] > state_version_decoded["major"]:
+ out_of_date = "major"
+ elif current_version_decoded["minor"] > state_version_decoded["minor"]:
+ out_of_date = "minor"
+
+ if out_of_date:
+ raise exp.SkipBuild(
+ f"Your build {state.build_name} was previously built against "
+ f"turnkey version {state.turnkey_version}, "
+ f"however you are now using turnkey version {turnkey_version}. "
+ "The previous build is "
+ f"incompatible with this version of turnkey, as indicated by the {out_of_date} "
+ "version number changing. See **docs/versioning.md** for details."
+ )
+
+ # Append the sequence of this build to the sequence of the loaded build.
+ # so that the stats file reflects the complete set of Tools that have been
+ # attempted on this build
+ stats = fs.Stats(state.cache_dir, state.build_name)
+ combined_selected_sequence = copy.deepcopy(prior_selected_sequence)
+ for new_tool, new_tool_args in new_sequence_info.items():
+ combined_selected_sequence.append(new_tool)
+ state.sequence_info[new_tool] = new_tool_args
+ stats.save_stat(fs.Keys.SELECTED_SEQUENCE_OF_TOOLS, combined_selected_sequence)
+
+ # Apply the skip policy by raising a SkipBuild exception
+ # if the pre-existing build status doesn't meet certain criteria
+ if self.__class__.unique_name not in prior_selected_sequence:
+ if state.build_status != build.FunctionStatus.SUCCESSFUL:
+ if skip_policy == "attempted" or skip_policy == "failed":
+ raise exp.SkipBuild(
+ f"Skipping {state.build_name} because it has a "
+ f"status of {state.build_status} and the skip policy "
+ f"is set to {skip_policy}."
+ )
+ else:
+ # Issue a warning to users if they loaded an unsuccessful build
+ # This is a warning, instead of an exception, to allow for the case
+ # where a Tool is being re-attempted under different conditions (e.g.,
+ # re-attempting a benchmark after a system restart).
+ if state.build_status != build.FunctionStatus.SUCCESSFUL:
+ print(f"Warning: loaded build status is {state.build_status}")
+ else:
+ if skip_policy == "attempted":
+ raise exp.SkipBuild(
+ f"Skipping {state.build_name} because it was previously attempted "
+ f"and the skip policy is set to {skip_policy}"
+ )
+ elif (
+ skip_policy == "successful"
+ and state.build_status == build.FunctionStatus.SUCCESSFUL
+ ):
+ raise exp.SkipBuild(
+ f"Skipping {state.build_name} because it was previously successfully "
+ f"attempted and the skip policy is set to {skip_policy}"
+ )
+ elif (
+ skip_policy == "failed"
+ and state.build_status != build.FunctionStatus.SUCCESSFUL
+ ):
+ raise exp.SkipBuild(
+ f"Skipping {state.build_name} because it was previously "
+ f"unsuccessfully attempted and the skip policy is set to {skip_policy}"
+ )
+ elif skip_policy == "none":
+ # Skip policy of "none" means we should never skip over a build
+ pass
+ else:
+ # The skip condition is not met, so we will continue
+ pass
+
+ # Mark the build status as incomplete now that we have re-opened it
+ state.build_status = build.FunctionStatus.INCOMPLETE
+
+ # Create a UniqueInvocationInfo and ModelInfo so that we can display status
+ # at the end of the sequence
+ state.invocation_info = UniqueInvocationInfo(
+ name=input,
+ script_name=fs.clean_file_name(input),
+ hash=0,
+ is_target=True,
+ extension="_state.yaml",
+ executed=1,
+ )
+ state.models_found = {
+ "state_file": ModelInfo(
+ model=input,
+ name=input,
+ script_name=input,
+ file=input,
+ unique_invocations={0: state.invocation_info},
+ hash=0,
+ )
+ }
+ state.invocation_info.params = state.models_found["state_file"].params
+
+ return state
diff --git a/src/turnkeyml/tools/management_tools.py b/src/turnkeyml/tools/management_tools.py
new file mode 100644
index 00000000..5a0d2905
--- /dev/null
+++ b/src/turnkeyml/tools/management_tools.py
@@ -0,0 +1,267 @@
+import argparse
+import abc
+import os
+from typing import List
+import turnkeyml.common.filesystem as fs
+import turnkeyml.common.exceptions as exp
+import turnkeyml.common.printing as printing
+from turnkeyml.tools.tool import ToolParser
+from turnkeyml.version import __version__ as turnkey_version
+
+
+class ManagementTool(abc.ABC):
+ """
+ Intended for management functions, such as managing the cache
+ or printing the version number.
+ """
+
+ unique_name: str
+
+ @classmethod
+ def helpful_parser(cls, short_description: str, **kwargs):
+ epilog = (
+ f"`{cls.unique_name}` is a Management Tool. It is intended to be invoked by itself "
+ "(i.e., not as part of a sequence), to accomplish a utility function. "
+ )
+
+ return ToolParser(
+ prog=f"turnkey {cls.unique_name}",
+ short_description=short_description,
+ description=cls.__doc__,
+ epilog=epilog,
+ **kwargs,
+ )
+
+ @staticmethod
+ @abc.abstractmethod
+ def parser() -> argparse.ArgumentParser:
+ """
+ Static method that returns an ArgumentParser that defines the command
+ line interface for this Tool.
+ """
+
+ # pylint: disable=unused-argument
+ def parse(self, args, known_only=True) -> argparse.Namespace:
+ """
+ Run the parser and return a Namespace of keyword arguments that the user
+ passed to the Tool via the command line.
+
+ Tools should extend this function only if they require specific parsing
+ logic.
+
+ Args:
+ args: command line arguments passed from the CLI.
+ known_only: this argument allows the CLI framework to
+ incrementally parse complex commands.
+ """
+
+ if known_only:
+ parsed_args = self.__class__.parser().parse_args(args)
+ else:
+ parsed_args, _ = self.__class__.parser().parse_known_args(args)
+
+ return parsed_args
+
+ @abc.abstractmethod
+ def run(self, cache_dir: str):
+ """
+ Execute the functionality of the Tool.
+ """
+
+ def parse_and_run(self, cache_dir: str, args, known_only=True):
+ """
+ Helper function to parse CLI arguments into the args expected
+ by run(), and then forward them into the run() method.
+ """
+
+ parsed_args = self.parse(args, known_only)
+ self.run(cache_dir, **parsed_args.__dict__)
+
+
+class Version(ManagementTool):
+ """
+ Simply prints the version number of the turnkeyml installation.
+ """
+
+ unique_name = "version"
+
+ @staticmethod
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
+ parser = __class__.helpful_parser(
+ short_description="Print the turnkeyml version number",
+ add_help=add_help,
+ )
+
+ return parser
+
+ def run(self, _):
+ print(turnkey_version)
+
+
+class Cache(ManagementTool):
+ # pylint: disable=pointless-statement,f-string-without-interpolation
+ f"""
+ A set of functions for managing the turnkey build cache. The default
+ cache location is {fs.DEFAULT_CACHE_DIR}, and can also be selected with
+ the global --cache-dir option or the TURNKEY_CACHE_DIR environment variable.
+
+ Users must set either "--all" or "--build-names" to let the tool
+ know what builds to operate on.
+
+ Users must also set one of the available actions (e.g., list, stats, etc.).
+
+ That action will be applied to all selected builds.
+ """
+
+ unique_name = "cache"
+
+ @staticmethod
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
+ # NOTE: `--cache-dir` is set as a global input to the turnkey CLI and
+ # passed directly to the `run()` method
+
+ parser = __class__.helpful_parser(
+ short_description="Manage the turnkey build cache " f"",
+ add_help=add_help,
+ )
+
+ build_selection_group = parser.add_mutually_exclusive_group(required=True)
+
+ build_selection_group.add_argument(
+ "-b",
+ "--build-names",
+ nargs="+",
+ help="Name of the specific build(s) to be operated upon, within the cache directory",
+ )
+
+ build_selection_group.add_argument(
+ "-a",
+ "--all",
+ dest="all_builds",
+ help="Operate on all the builds in the cache",
+ action="store_true",
+ )
+
+ action_group = parser.add_mutually_exclusive_group(required=True)
+
+ action_group.add_argument(
+ "-l",
+ "--list",
+ dest="list_builds",
+ action="store_true",
+ help="List all of the builds in the cache",
+ )
+
+ action_group.add_argument(
+ "-s",
+ "--stats",
+ action="store_true",
+ help="Print the collected stats for the selected build(s)",
+ )
+
+ action_group.add_argument(
+ "--delete",
+ action="store_true",
+ help="Permanently delete the selected build(s)",
+ )
+
+ action_group.add_argument(
+ "--clean",
+ action="store_true",
+ help="Remove the build artifacts from the selected build(s)",
+ )
+
+ return parser
+
+ def run(
+ self,
+ cache_dir: str,
+ all_builds: bool = False,
+ build_names: List[str] = None,
+ list_builds: bool = False,
+ stats: bool = False,
+ delete: bool = False,
+ clean: bool = False,
+ ):
+ fs.check_cache_dir(cache_dir)
+
+ if all_builds and build_names:
+ raise ValueError(
+ "all_builds and build_names are mutually exclusive, "
+ "but both are used in this call."
+ )
+ elif all_builds:
+ builds = fs.get_available_builds(cache_dir)
+ elif build_names:
+ builds = build_names
+ else:
+ raise ValueError(
+ "Either all_builds or build_names must be set, "
+ "but this call sets neither."
+ )
+
+ # Print a nice heading
+ printing.log_info(f"Operating on cache directory {cache_dir}")
+
+ if not builds:
+ printing.log_warning("No builds found.")
+
+ for build in builds:
+ build_path = os.path.join(cache_dir, build)
+ if fs.is_build_dir(cache_dir, build):
+ # Run actions on the build
+ # These actions are intended to be mutually exclusive, so we
+ # use an if-elif block in order from least to most destructive
+ if list_builds:
+ print(build)
+ elif stats:
+ fs.print_yaml_file(fs.Stats(cache_dir, build).file, "stats")
+ elif clean:
+ fs.clean_output_dir(cache_dir, build)
+ printing.log_info(f"Removed the build artifacts from: {build}")
+
+ elif delete:
+ fs.rmdir(build_path)
+ printing.log_info(f"Deleted build: {build}")
+ else:
+ raise exp.CacheError(
+ f"No build found with name: {build}. "
+ "Try running `turnkey cache list` to see the builds in your build cache."
+ )
+
+ print()
+
+
+class ModelsLocation(ManagementTool):
+ """
+ Prints the location of the turnkeyml built in models corpora.
+
+ This is especially useful for when turnkey was installed from PyPI
+ with `pip install turnkeyml`. Example usage in this context:
+ models=$(turnkey models-location --quiet)
+ turnkey -i $models/selftest/linear.py discover export-pytorch
+ """
+
+ unique_name = "models-location"
+
+ @staticmethod
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
+ parser = __class__.helpful_parser(
+ short_description="Print the location of the built-in turnkeyml models",
+ add_help=add_help,
+ )
+
+ parser.add_argument(
+ "-q",
+ "--quiet",
+ action="store_true",
+ help="Print only the file path, with no other text",
+ )
+
+ return parser
+
+ def run(self, _, quiet: bool = False):
+ if quiet:
+ print(fs.MODELS_DIR)
+ else:
+ printing.log_info(f"The models directory is: {fs.MODELS_DIR}")
diff --git a/src/turnkeyml/tools/onnx.py b/src/turnkeyml/tools/onnx.py
new file mode 100644
index 00000000..9e2d2e73
--- /dev/null
+++ b/src/turnkeyml/tools/onnx.py
@@ -0,0 +1,364 @@
+import os
+import shutil
+import warnings
+import sys
+import argparse
+import numpy as np
+import onnxruntime
+import onnxmltools
+import onnx
+from turnkeyml.tools import Tool, FirstTool
+import turnkeyml.common.exceptions as exp
+import turnkeyml.common.build as build
+import turnkeyml.common.tensor_helpers as tensor_helpers
+import turnkeyml.common.onnx_helpers as onnx_helpers
+import turnkeyml.common.filesystem as fs
+from turnkeyml.common.status import ModelInfo, UniqueInvocationInfo
+from turnkeyml.state import State
+
+
+def _warn_to_stdout(message, category, filename, line_number, _, line):
+ sys.stdout.write(
+ warnings.formatwarning(message, category, filename, line_number, line)
+ )
+
+
+def loaded_onnx_file(state: State):
+ return os.path.join(
+ onnx_helpers.onnx_dir(state),
+ f"{state.build_name}-op{state.onnx_opset}-loaded.onnx",
+ )
+
+
+def opt_onnx_file(state: State):
+ return os.path.join(
+ onnx_helpers.onnx_dir(state),
+ f"{state.build_name}-op{state.onnx_opset}-opt.onnx",
+ )
+
+
+def converted_onnx_file(state: State):
+ return os.path.join(
+ onnx_helpers.onnx_dir(state),
+ f"{state.build_name}-op{state.onnx_opset}-opt-f16.onnx",
+ )
+
+
+class LoadOnnx(FirstTool):
+ """
+ Tool that takes an ONNX model as input and passes it to the following
+ tools.
+
+ Expected inputs:
+ - Input: a .onnx file
+
+ Outputs:
+ - state.result: a .onnx file that has been copied to the turnkey cache
+ - state.inputs: valid inputs to that .onnx file
+ """
+
+ unique_name = "load-onnx"
+
+ def __init__(self):
+ super().__init__(monitor_message="Loading ONNX Model")
+
+ @staticmethod
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
+ parser = __class__.helpful_parser(
+ short_description="Load an ONNX model",
+ add_help=add_help,
+ )
+
+ return parser
+
+ def run(self, state: State, input: str = ""):
+
+ onnx_file = input
+ state.model_hash = build.hash_model(onnx_file)
+
+ if not onnx_file.endswith(".onnx"):
+ msg = f"""
+ The current tool (ReceiveOnnxModel) expects a path to ONNX
+ model, however the tool received {onnx_file}.
+ """
+ raise exp.ToolError(msg)
+
+ state.inputs = onnx_helpers.dummy_inputs(onnx_file)
+ dummy_inputs = tuple(state.inputs.values())
+ dummy_input_names = tuple(state.inputs.keys())
+ state.inputs = dict(zip(dummy_input_names, dummy_inputs))
+
+ model = onnx.load(onnx_file)
+ opset = onnx_helpers.get_opset(model)
+ state.onnx_opset = opset
+ input_shapes = [
+ [d.dim_value for d in _input.type.tensor_type.shape.dim]
+ for _input in model.graph.input # pylint: disable=no-member
+ ]
+
+ # Save output node names
+ state.expected_output_names = onnx_helpers.get_output_names(model)
+
+ # Check for Dynamic shapes in the model. They can be represented as 0, -1, "unk__".
+ for input in input_shapes:
+ for dimension in input:
+ if dimension < 1 or not isinstance(dimension, int):
+ msg = f"""
+ The received model has dynamic input dimensions. Please freeze the model with static
+ input dimensions.
+ More information may be available in the log file at **{self.logfile_path}**
+ """
+ raise exp.ToolError(msg)
+
+ if opset < build.DEFAULT_ONNX_OPSET and opset >= build.MINIMUM_ONNX_OPSET:
+ print(
+ f" \n The received model has an opset {opset}. Though this opset is supported \
+ we recommend upgrading the model to opset {build.MINIMUM_ONNX_OPSET}"
+ )
+ elif opset < build.MINIMUM_ONNX_OPSET:
+ msg = f"""
+ The received model has an opset {opset}. Opset < {build.MINIMUM_ONNX_OPSET}
+ is not supported. Please try upgrading the model to opset {build.MINIMUM_ONNX_OPSET}.
+ More information may be available in the log file at **{self.logfile_path}**
+ """
+ raise exp.ToolError(msg)
+
+ output_path = loaded_onnx_file(state)
+ os.makedirs(onnx_helpers.onnx_dir(state), exist_ok=True)
+ shutil.copy(onnx_file, output_path)
+
+ tensor_helpers.save_inputs(
+ [state.inputs],
+ onnx_helpers.original_inputs_file(state.cache_dir, state.build_name),
+ downcast=False,
+ )
+
+ # Check the if the base mode has been exported successfully
+ success_msg = "\tSuccess receiving ONNX Model"
+ fail_msg = "\tFailed receiving ONNX Model"
+
+ if onnx_helpers.check_model(output_path, success_msg, fail_msg):
+ state.results = output_path
+
+ state.save_stat(
+ fs.Keys.ONNX_FILE,
+ output_path,
+ )
+ else:
+ msg = f"""
+ Unable to process ONNX Model. We recommend that you verify the source of the model.
+ Any optimizations performed on the model could result in an error.
+ More information may be available in the log file at **{self.logfile_path}**
+ """
+ raise exp.ToolError(msg)
+
+ # Create a UniqueInvocationInfo and ModelInfo so that we can display status
+ # at the end of the sequence
+ state.invocation_info = UniqueInvocationInfo(
+ name=onnx_file,
+ script_name=fs.clean_file_name(onnx_file),
+ file=onnx_file,
+ input_shapes={key: value.shape for key, value in state.inputs.items()},
+ hash=state.model_hash,
+ is_target=True,
+ extension=".onnx",
+ executed=1,
+ )
+ state.models_found = {
+ "onnx_file": ModelInfo(
+ model=onnx_file,
+ name=onnx_file,
+ script_name=onnx_file,
+ file=onnx_file,
+ unique_invocations={state.model_hash: state.invocation_info},
+ hash=state.model_hash,
+ )
+ }
+ state.invocation_info.params = state.models_found["onnx_file"].params
+
+ return state
+
+
+class OptimizeOnnxModel(Tool):
+ """
+ Tool that takes a .onnx file and uses ONNX Runtime to optimize it by
+ performing constant folding, redundant node eliminations,
+ semantics-preserving node fusions, etc.
+
+ Expected inputs:
+ - state.results: a .onnx file
+
+ Outputs:
+ - state.results: a *-opt.onnx file
+ """
+
+ unique_name = "optimize-ort"
+
+ def __init__(self):
+ super().__init__(monitor_message="Optimizing ONNX file")
+
+ @staticmethod
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
+ parser = __class__.helpful_parser(
+ short_description="Use OnnxRuntime to optimize an ONNX model",
+ add_help=add_help,
+ )
+
+ return parser
+
+ def run(self, state: State):
+ input_onnx = state.results
+ output_path = opt_onnx_file(state)
+
+ # Perform some basic optimizations on the model to remove shape related
+ # information inserted for dynamic shape inference.
+ # Given that we're compiling against a fixed sequence length the dynamic
+ # shape information is not necessary
+ session_options = onnxruntime.SessionOptions()
+
+ # Set graph optimization level
+ session_options.graph_optimization_level = (
+ onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
+ )
+
+ # To enable model serialization after graph optimization set this
+ session_options.optimized_model_filepath = output_path
+
+ # Optimize graph
+ onnxruntime.InferenceSession(input_onnx, session_options)
+
+ # Check that the converted model is still valid
+ success_msg = "\tSuccess optimizing ONNX model"
+ fail_msg = "\tFailed optimizing ONNX model"
+
+ if onnx_helpers.check_model(output_path, success_msg, fail_msg):
+ state.results = output_path
+
+ state.save_stat(
+ fs.Keys.ONNX_FILE,
+ output_path,
+ )
+ else:
+ msg = f"""
+ Unable to optimize ONNX file using ONNX runtime.
+ We recommend that you modify your model until it is
+ compatible with this third party software, then re-run.
+ More information may be available in the log file at **{self.logfile_path}**
+ """
+ raise exp.ToolError(msg)
+
+ return state
+
+
+class ConvertOnnxToFp16(Tool):
+ """
+ Tool that takes an ONNX file and converts its trained parameters
+ to fp16.
+
+ Expected inputs:
+ - state.results: a .onnx file
+
+ Outputs:
+ - state.results: a *-f16.onnx file with FP16 trained parameters
+ """
+
+ unique_name = "convert-fp16"
+
+ def __init__(self):
+ super().__init__(
+ monitor_message="Converting to FP16",
+ )
+
+ @staticmethod
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
+ parser = __class__.helpful_parser(
+ short_description="Use OnnxMLTools to convert an ONNX model to fp16",
+ add_help=add_help,
+ )
+
+ return parser
+
+ def run(self, state: State):
+ input_onnx = state.results
+
+ # Convert the model to FP16
+ # Some ops will not be converted to fp16 because they are in a block list
+ # The latest list can be found here. It is not necessarily the list that
+ # our version of onnxmltools sees
+ # https://github.com/microsoft/onnxconverter-common/blob/master/onnxconverter_common/float16.py#L82
+
+ # Send onnxmltools warnings to stdout (and therefore the log file)
+ # so that they don't fill up the command line
+ default_warnings = warnings.showwarning
+ warnings.showwarning = _warn_to_stdout
+
+ # Legalize ops are ops that have been or are currently in the block list
+ # that we explicitly want removed
+ legalize_ops = ["InstanceNormalization", "Resize", "Max"]
+ op_block_list = onnxmltools.utils.float16_converter.DEFAULT_OP_BLOCK_LIST.copy()
+ for op in legalize_ops:
+ # Check to see that they are not in the block list before we remove them
+ # Necessary because the block list may be updated, and not in the state we expect
+ if op in op_block_list:
+ op_block_list.remove(op)
+
+ # Infer shapes before converting to FP16 to enable models with >2GB
+ onnx.shape_inference.infer_shapes_path(input_onnx)
+
+ fp32_model = onnx.load_model(input_onnx)
+ fp16_model = onnxmltools.utils.float16_converter.convert_float_to_float16(
+ fp32_model, op_block_list=op_block_list, disable_shape_infer=True
+ )
+
+ # Load inputs and convert to fp16
+ inputs_file = onnx_helpers.original_inputs_file(
+ state.cache_dir, state.build_name
+ )
+ if os.path.isfile(inputs_file):
+ inputs = np.load(inputs_file, allow_pickle=True)
+ inputs_converted = tensor_helpers.save_inputs(
+ inputs, inputs_file, downcast=True
+ )
+ else:
+ raise exp.ToolError(
+ "Attempted to convert inputs to FP16, however inputs file was not found."
+ )
+
+ # Overwrite expected dtypes
+ _, state.expected_input_dtypes = build.get_shapes_and_dtypes(
+ inputs_converted[0]
+ )
+
+ # Indicate that inputs must be downcasted during inference
+ state.downcast_applied = True
+
+ # Save FP16 model (use external data format if needed)
+ output_path = converted_onnx_file(state)
+ try:
+ onnxmltools.utils.save_model(fp16_model, output_path)
+ except ValueError:
+ onnx.save_model(fp16_model, output_path, save_as_external_data=True)
+
+ # Restore default warnings behavior
+ warnings.showwarning = default_warnings
+
+ # Check that the converted model is still valid
+ success_msg = "\tSuccess converting ONNX model to fp16"
+ fail_msg = "\tFailed converting ONNX model to fp16"
+
+ if onnx_helpers.check_model(output_path, success_msg, fail_msg):
+ state.results = output_path
+
+ state.save_stat(
+ fs.Keys.ONNX_FILE,
+ output_path,
+ )
+ else:
+ msg = f"""
+ Attempted to use onnxmltools, a third party library, to convert your
+ model to the float16 datatype, however this operation was not successful.
+ More information may be available in the log file at **{self.logfile_path}**
+ """
+ raise exp.ToolError(msg)
+
+ return state
diff --git a/src/turnkeyml/tools/report.py b/src/turnkeyml/tools/report.py
new file mode 100644
index 00000000..6e0bdeed
--- /dev/null
+++ b/src/turnkeyml/tools/report.py
@@ -0,0 +1,236 @@
+import os
+import argparse
+import csv
+from datetime import datetime
+from pathlib import Path
+from typing import Dict, List
+import yaml
+import pandas as pd
+import turnkeyml.common.printing as printing
+import turnkeyml.common.filesystem as fs
+import turnkeyml.common.build as build
+from turnkeyml.tools.management_tools import ManagementTool
+
+
+def get_report_name(prefix: str = "") -> str:
+ """
+ Returns the name of the .csv report
+ """
+ day = datetime.now().day
+ month = datetime.now().month
+ year = datetime.now().year
+ date_key = f"{year}-{str(month).zfill(2)}-{str(day).zfill(2)}"
+ return f"{prefix}{date_key}.csv"
+
+
+def _good_get(
+ dict: Dict, key: str, return_keys: bool = False, return_values: bool = False
+):
+ if key in dict:
+ if return_keys:
+ return list(dict[key].keys())
+ elif return_values:
+ return list(dict[key].values())
+ else:
+ return dict[key]
+ else:
+ return "-"
+
+
+class Report(ManagementTool):
+ """
+ Analyzes the input turnkeyml cache(s) and produces an aggregated report
+ in csv format that contains the build stats for all builds in all cache(s).
+ """
+
+ unique_name = "report"
+
+ @staticmethod
+ def parser(add_help: bool = True) -> argparse.ArgumentParser:
+ parser = __class__.helpful_parser(
+ short_description="Export statistics from each turnkey run to a CSV file",
+ add_help=add_help,
+ )
+
+ parser.add_argument(
+ "-i",
+ "--input-caches",
+ nargs="*",
+ default=[fs.DEFAULT_CACHE_DIR],
+ help=(
+ "One or more turnkey cache directories to use to generate the report "
+ f"(defaults to {fs.DEFAULT_CACHE_DIR})"
+ ),
+ )
+
+ parser.add_argument(
+ "-o",
+ "--output-dir",
+ help="Path to folder where report will be saved "
+ "(defaults to current working directory)",
+ required=False,
+ default=os.getcwd(),
+ )
+
+ return parser
+
+ def run(
+ self,
+ _,
+ input_caches: List[str] = None,
+ output_dir: str = os.getcwd(),
+ ):
+ # Input arguments from CLI
+ cache_dirs = [os.path.expanduser(dir) for dir in input_caches]
+ cache_dirs = fs.expand_inputs(cache_dirs)
+ report_dir = os.path.expanduser(output_dir)
+
+ # Name report file
+ report_path = os.path.join(report_dir, get_report_name())
+
+ # Create report dict
+ Path(report_dir).mkdir(parents=True, exist_ok=True)
+
+ report: List[Dict] = []
+ all_evaluation_stats = []
+
+ # Add results from all user-provided cache folders
+ for cache_dir in cache_dirs:
+ # Check if this is a valid cache directory
+ fs.check_cache_dir(cache_dir)
+
+ # List all yaml files available
+ all_model_stats_yamls = fs.get_all(
+ path=cache_dir, file_type="turnkey_stats.yaml"
+ )
+ all_model_stats_yamls = sorted(all_model_stats_yamls)
+
+ # Bring all of the stats for all of the models into memory
+ for model_stats_yaml in all_model_stats_yamls:
+ with open(model_stats_yaml, "r", encoding="utf8") as stream:
+ try:
+ # load the yaml into a dict
+ model_stats = yaml.load(stream, Loader=yaml.FullLoader)
+
+ # Copy the stats to a new dictionary, making any necessary modifications
+ # along the way
+ evaluation_stats = {}
+
+ for key, value in model_stats.items():
+ # If a build or benchmark is still marked as "incomplete" at
+ # reporting time, it must have been killed by a time out,
+ # out-of-memory (OOM), or some other uncaught exception
+ if (
+ key == fs.Keys.BUILD_STATUS
+ or fs.Keys.TOOL_STATUS in key
+ ) and value == build.FunctionStatus.INCOMPLETE:
+ value = build.FunctionStatus.KILLED
+
+ # Add stats ensuring that those are all in lower case
+ evaluation_stats[key.lower()] = value
+
+ all_evaluation_stats.append(evaluation_stats)
+ except yaml.scanner.ScannerError:
+ continue
+
+ # Scan the build stats to determine the set of columns for the CSV file.
+ # The CSV will have one column for every key in any build stats dict.
+ column_headers = []
+ for evaluation_stats in all_evaluation_stats:
+ # Add any key that isn't already in column_headers
+ for header in evaluation_stats.keys():
+ if header not in column_headers:
+ column_headers.append(header)
+
+ # Sort all columns alphabetically
+ column_headers = sorted(column_headers)
+
+ # Add each build to the report
+ for evaluation_stats in all_evaluation_stats:
+ # Start with a dictionary where all of the values are "-". If a build
+ # has a value for each key we will fill it in, and otherwise the "-"
+ # will indicate that no value was available
+ result = {k: "-" for k in column_headers}
+
+ for key in column_headers:
+ result[key] = _good_get(evaluation_stats, key)
+
+ report.append(result)
+
+ # Populate results spreadsheet
+ with open(report_path, "w", newline="", encoding="utf8") as spreadsheet:
+ writer = csv.writer(spreadsheet)
+ writer.writerow(column_headers)
+ for entry in report:
+ writer.writerow([entry[col] for col in column_headers])
+
+ # Print message with the output file path
+ printing.log("Summary spreadsheet saved at ")
+ printing.logn(str(report_path), printing.Colors.OKGREEN)
+
+ # Save the unique errors and counts to a file
+ errors = []
+ for evaluation_stats in all_evaluation_stats:
+ if (
+ "compilation_error" in evaluation_stats.keys()
+ and "compilation_error_id" in evaluation_stats.keys()
+ ):
+ error = evaluation_stats["compilation_error"]
+ id = evaluation_stats["compilation_error_id"]
+ if id != "":
+ unique_error = True
+ for reported_error in errors:
+ if reported_error["id"] == id:
+ unique_error = False
+ reported_error["count"] = reported_error["count"] + 1
+ reported_error["models_impacted"] = reported_error[
+ "models_impacted"
+ ] + [evaluation_stats["model_name"]]
+
+ if unique_error:
+ reported_error = {
+ "id": id,
+ "count": 1,
+ "models_impacted": [evaluation_stats["model_name"]],
+ "example": error,
+ }
+ errors.append(reported_error)
+
+ if len(errors) > 0:
+ errors_path = os.path.join(report_dir, get_report_name("errors-"))
+ with open(errors_path, "w", newline="", encoding="utf8") as spreadsheet:
+ writer = csv.writer(spreadsheet)
+ error_headers = errors[0].keys()
+ writer.writerow(error_headers)
+ for unique_error in errors:
+ writer.writerow([unique_error[col] for col in error_headers])
+
+ printing.log("Compilation errors spreadsheet saved at ")
+ printing.logn(str(errors_path), printing.Colors.OKGREEN)
+ else:
+ printing.logn(
+ "No compilation errors in any cached build, skipping errors spreadsheet."
+ )
+
+
+def get_dict(report_csv: str, columns: List[str]) -> Dict[str, Dict[str, str]]:
+ """
+ Returns a dictionary where the keys are model names and the values are dictionaries.
+ Each dictionary represents a model with column names as keys and their corresponding values.
+ args:
+ - report_csv: path to a report.csv file generated by turnkey CLI
+ - columns: list of column names in the report.csv file whose values will be used to
+ populate the dictionary
+ """
+
+ # Load the report as a dataframe
+ dataframe = pd.read_csv(report_csv)
+
+ # Create a nested dictionary with model_name as keys and another
+ # dictionary of {column: value} pairs as values
+ result = {
+ row[0]: row[1].to_dict()
+ for row in dataframe.set_index("model_name")[columns].iterrows()
+ }
+
+ return result
diff --git a/src/turnkeyml/tools/tool.py b/src/turnkeyml/tools/tool.py
new file mode 100644
index 00000000..511771f6
--- /dev/null
+++ b/src/turnkeyml/tools/tool.py
@@ -0,0 +1,307 @@
+import abc
+import sys
+import time
+import os
+import argparse
+import textwrap as _textwrap
+import re
+from typing import Tuple, Dict
+from multiprocessing import Process
+import psutil
+import turnkeyml.common.printing as printing
+import turnkeyml.common.exceptions as exp
+import turnkeyml.common.build as build
+import turnkeyml.common.filesystem as fs
+from turnkeyml.state import State
+
+
+def _spinner(message):
+ try:
+ parent_process = psutil.Process(pid=os.getppid())
+ while parent_process.status() == psutil.STATUS_RUNNING:
+ for cursor in [" ", ". ", ".. ", "..."]:
+ time.sleep(0.5)
+ status = f" {message}{cursor}\r"
+ sys.stdout.write(status)
+ sys.stdout.flush()
+ except psutil.NoSuchProcess:
+ # If the parent process stopped existing, we can
+ # safely assume the spinner no longer needs to spin
+ # NOTE: this only seems to be needed on Windows
+ pass
+
+
+def _name_is_file_safe(name: str):
+ """
+ Make sure the name can be used in a filename
+ """
+
+ allowed_in_unique_name = set(
+ "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-"
+ )
+
+ if len(name) == 0:
+ msg = """
+ Tool __init__() was passed a unique_name with no length. A
+ uniquely identifying unique_name is required.
+ """
+ raise ValueError(msg)
+
+ for char in name:
+ if char not in allowed_in_unique_name:
+ msg = f"""
+ Tool __init__() was passed a unique_name:
+ {name}
+ with illegal characters. The unique_name must be safe to
+ use in a filename, meaning it can only use characters: {allowed_in_unique_name}
+ """
+ raise ValueError(msg)
+
+
+class NiceHelpFormatter(argparse.RawDescriptionHelpFormatter):
+ def __add_whitespace(self, idx, amount, text):
+ if idx == 0:
+ return text
+ return (" " * amount) + text
+
+ def _split_lines(self, text, width):
+ textRows = text.splitlines()
+ for idx, line in enumerate(textRows):
+ search = re.search(r"\s*[0-9\-]{0,}\.?\s*", line)
+ if line.strip() == "":
+ textRows[idx] = " "
+ elif search:
+ whitespace_needed = search.end()
+ lines = [
+ self.__add_whitespace(i, whitespace_needed, x)
+ for i, x in enumerate(_textwrap.wrap(line, width))
+ ]
+ textRows[idx] = lines
+
+ return [item for sublist in textRows for item in sublist]
+
+
+class ToolParser(argparse.ArgumentParser):
+
+ def error(self, message):
+ if message.startswith("unrecognized arguments"):
+ unrecognized = message.split(": ")[1]
+ if not unrecognized.startswith("-"):
+ # This was probably a misspelled tool name
+ message = message + (
+ f". If `{unrecognized}` was intended to invoke "
+ "a tool, please run `turnkey -h` and check the spelling and "
+ "availability of that tool."
+ )
+ self.print_usage()
+ printing.log_error(message)
+ self.exit(2)
+
+ def __init__(
+ self, short_description: str, description: str, prog: str, epilog: str, **kwargs
+ ):
+ super().__init__(
+ description=description,
+ prog=prog,
+ epilog=epilog,
+ formatter_class=NiceHelpFormatter,
+ **kwargs,
+ )
+
+ self.short_description = short_description
+
+
+class Tool(abc.ABC):
+
+ unique_name: str
+
+ @classmethod
+ def helpful_parser(cls, short_description: str, **kwargs):
+ epilog = (
+ f"`{cls.unique_name}` is a Tool. It is intended to be invoked as "
+ "part of a sequence of Tools, for example: `turnkey -i INPUTS tool-one "
+ "tool-two tool-three`. Tools communicate data to each other via State. "
+ "You can learn more at "
+ "https://github.com/onnx/turnkeyml/blob/main/docs/tools_user_guide.md"
+ )
+
+ return ToolParser(
+ prog=f"turnkey {cls.unique_name}",
+ short_description=short_description,
+ description=cls.__doc__,
+ epilog=epilog,
+ **kwargs,
+ )
+
+ def status_line(self, successful, verbosity):
+ """
+ Print a line of status information for this Tool into the monitor.
+ """
+ if verbosity:
+ # Only use special characters when the terminal encoding supports it
+ if sys.stdout.encoding == "utf-8":
+ success_tick = "✓"
+ fail_tick = "×"
+ else:
+ success_tick = "+"
+ fail_tick = "x"
+
+ if successful is None:
+ # Initialize the message
+ printing.logn(f" {self.monitor_message} ")
+ elif successful:
+ # Print success message
+ printing.log(f" {success_tick} ", c=printing.Colors.OKGREEN)
+ printing.logn(self.monitor_message + " ")
+ else:
+ # successful == False, print failure message
+ printing.log(f" {fail_tick} ", c=printing.Colors.FAIL)
+ printing.logn(self.monitor_message + " ")
+
+ def __init__(
+ self,
+ monitor_message,
+ ):
+ _name_is_file_safe(self.__class__.unique_name)
+
+ self.status_key = f"{fs.Keys.TOOL_STATUS}:{self.__class__.unique_name}"
+ self.duration_key = f"{fs.Keys.TOOL_DURATION}:{self.__class__.unique_name}"
+ self.monitor_message = monitor_message
+ self.progress = None
+ self.logfile_path = None
+ # Tools can provide a list of keys that can be found in
+ # evaluation stats. Those key:value pairs will be presented
+ # in the status at the end of the build.
+ self.status_stats = []
+
+ @abc.abstractmethod
+ def run(self, state: State) -> State:
+ """
+ Execute the functionality of the Tool by acting on the state.
+ """
+
+ @staticmethod
+ @abc.abstractmethod
+ def parser() -> argparse.ArgumentParser:
+ """
+ Static method that returns an ArgumentParser that defines the command
+ line interface for this Tool.
+ """
+
+ # pylint: disable=unused-argument
+ def parse(self, state: State, args, known_only=True) -> argparse.Namespace:
+ """
+ Run the parser and return a Namespace of keyword arguments that the user
+ passed to the Tool via the command line.
+
+ Tools should extend this function only if they require specific parsing
+ logic, for example decoding the name of a data type into a data type class.
+
+ Args:
+ state: the same state passed into the run method of the Tool, useful if
+ the parse decoding logic needs to take the state into account.
+ args: command line arguments passed from the CLI.
+ known_only: this argument allows the CLI framework to
+ incrementally parse complex commands.
+ """
+
+ if known_only:
+ parsed_args = self.__class__.parser().parse_args(args)
+ else:
+ parsed_args, _ = self.__class__.parser().parse_known_args(args)
+
+ return parsed_args
+
+ def parse_and_run(self, state: State, args, known_only=True) -> Dict:
+ """
+ Helper function to parse CLI arguments into the args expected
+ by run(), and then forward them into the run() method.
+ """
+
+ parsed_args = self.parse(state, args, known_only)
+ return self.run_helper(state, **parsed_args.__dict__)
+
+ def run_helper(self, state: State, **kwargs) -> Tuple[State, int]:
+ """
+ Wraps the developer-defined .run() method with helper functionality.
+ Specifically:
+ - Provides a path to a log file
+ - Redirects the stdout of the tool to that log file
+ - Monitors the progress of the tool on the command line,
+ including in the event of an exception
+ """
+
+ # Set the build status to INCOMPLETE to indicate that a Tool
+ # started running. This allows us to test whether the Tool exited
+ # unexpectedly, before it was able to set ERROR
+ state.build_status = build.FunctionStatus.INCOMPLETE
+
+ self.logfile_path = os.path.join(
+ build.output_dir(state.cache_dir, state.build_name),
+ f"log_{self.unique_name}.txt",
+ )
+
+ if state.monitor:
+ self.progress = Process(target=_spinner, args=[self.monitor_message])
+ self.progress.start()
+
+ try:
+ # Execute the build tool
+ with build.Logger(self.monitor_message, self.logfile_path):
+ state = self.run(state, **kwargs)
+
+ except Exception: # pylint: disable=broad-except
+ self.status_line(
+ successful=False,
+ verbosity=state.monitor,
+ )
+ state.build_status = build.FunctionStatus.ERROR
+ raise
+
+ else:
+ self.status_line(successful=True, verbosity=state.monitor)
+
+ # Tools should not set build.FunctionStatus.SUCCESSFUL for the whole build,
+ # as that is reserved for Sequence.launch()
+ if state.build_status == build.FunctionStatus.SUCCESSFUL:
+ raise exp.ToolError(
+ "TurnkeyML Tools are not allowed to set "
+ "`state.build_status == build.FunctionStatus.SUCCESSFUL`, "
+ "however that has happened. If you are a plugin developer, "
+ "do not do this. If you are a user, please file an issue at "
+ "https://github.com/onnx/turnkeyml/issues."
+ )
+
+ finally:
+ if state.monitor:
+ self.progress.terminate()
+
+ return state
+
+
+class FirstTool(Tool):
+ """
+ Provides extra features for Tools that are meant to be the first Tool
+ in the sequence.
+
+ Specifically:
+ - FirstTools should not have any expectations of State.result, since
+ they populate State with an initial result.
+ - All FirstTools implicitly take an `input` argument that points to
+ the input to that Tool, for example an ONNX file or PyTorch script.
+ """
+
+ @classmethod
+ def helpful_parser(cls, short_description: str, **kwargs):
+ parser = super().helpful_parser(short_description, **kwargs)
+
+ # Argument required by TurnkeyML for any tool that starts a sequence
+ parser.add_argument("--input", help=argparse.SUPPRESS)
+
+ return parser
+
+ @abc.abstractmethod
+ def run(self, state: State, input=None) -> State:
+ """
+ The run() method of any FirstTool must accept the `input` argument
+ """
diff --git a/src/turnkeyml/version.py b/src/turnkeyml/version.py
index 5fa9130a..528787cf 100644
--- a/src/turnkeyml/version.py
+++ b/src/turnkeyml/version.py
@@ -1 +1 @@
-__version__ = "2.0.3"
+__version__ = "3.0.0"
diff --git a/test/analysis.py b/test/analysis.py
index 70a96852..45ba7520 100644
--- a/test/analysis.py
+++ b/test/analysis.py
@@ -4,8 +4,6 @@
import os
import unittest
-from pathlib import Path
-import shutil
import glob
import subprocess
import numpy as np
@@ -13,13 +11,11 @@
from unittest.mock import patch
import io
import sys
-import platform
from turnkeyml.cli.cli import main as turnkeycli
-import turnkeyml.common.labels as labels
from turnkeyml.parser import parse
import turnkeyml.common.filesystem as filesystem
-from helpers import common
-from turnkeyml.analyze.status import Verbosity
+import turnkeyml.common.test_helpers as common
+import turnkeyml.common.exceptions as exp
try:
# pylint: disable=unused-import
@@ -110,7 +106,6 @@ def __init__(self, **kwargs):
"two_executions.py": """
import torch
import timm
-from turnkeyml.parser import parse
# Creating model and set it to evaluation mode
model = timm.create_model("mobilenetv2_035", pretrained=False)
@@ -165,14 +160,22 @@ def run_cli(args):
def run_analysis(args):
output = run_cli(args)
+ print(output)
# Process outputs
- output = output[output.rfind("Models discovered") :]
+ output = output[output.rfind("Discovering PyTorch models") :]
models_executed = output.count("(executed")
- models_built = output.count("Model successfully built!")
+ models_built = output.count("Exporting PyTorch to ONNX")
return models_executed, 0, models_built
+def check_discover_log(build_name: str, expected_content: str):
+ log_path = os.path.join(cache_dir, build_name, "log_discover.txt")
+ with open(log_path, "r", encoding="utf-8") as log_file:
+ log_content = log_file.read()
+ assert expected_content in log_content, log_content
+
+
class Testing(unittest.TestCase):
def setUp(self) -> None:
filesystem.rmdir(cache_dir)
@@ -182,10 +185,11 @@ def test_01_basic(self):
pytorch_output = run_analysis(
[
"turnkey",
+ "-i",
os.path.join(corpus_dir, "linear_pytorch.py"),
- "--analyze-only",
- "--verbosity",
- Verbosity.DYNAMIC.value,
+ "--cache-dir",
+ cache_dir,
+ "discover",
]
)
assert np.array_equal(pytorch_output, (1, 0, 0))
@@ -194,12 +198,13 @@ def test_03_depth(self):
output = run_analysis(
[
"turnkey",
+ "-i",
os.path.join(corpus_dir, "linear_pytorch.py"),
+ "--cache-dir",
+ cache_dir,
+ "discover",
"--max-depth",
"1",
- "--analyze-only",
- "--verbosity",
- Verbosity.DYNAMIC.value,
]
)
assert np.array_equal(output, (2, 0, 0))
@@ -208,14 +213,14 @@ def test_04_build(self):
output = run_analysis(
[
"turnkey",
+ "-i",
os.path.join(corpus_dir, "linear_pytorch.py::76af2f62"),
- "--max-depth",
- "1",
- "--build-only",
"--cache-dir",
cache_dir,
- "--verbosity",
- Verbosity.DYNAMIC.value,
+ "discover",
+ "--max-depth",
+ "1",
+ "export-pytorch",
]
)
assert np.array_equal(output, (2, 0, 1))
@@ -225,15 +230,15 @@ def test_05_cache(self):
run_analysis(
[
"turnkey",
+ "-i",
os.path.join(corpus_dir, f"linear_pytorch.py::{model_hash}"),
- "--max-depth",
- "1",
"--cache-dir",
cache_dir,
"--lean-cache",
- "--build-only",
- "--verbosity",
- Verbosity.DYNAMIC.value,
+ "discover",
+ "--max-depth",
+ "1",
+ "export-pytorch",
]
)
build_name = f"linear_pytorch_{model_hash}"
@@ -243,88 +248,89 @@ def test_05_cache(self):
assert cache_is_lean(cache_dir, build_name) and labels_found != {}, labels_found
def test_06_generic_args(self):
- output = run_cli(
+ test_arg = "test_arg"
+ run_cli(
[
"turnkey",
+ "-i",
os.path.join(corpus_dir, "linear_pytorch.py"),
+ "--cache-dir",
+ cache_dir,
+ "discover",
"--max-depth",
"1",
"--script-args",
- "--my-arg test_arg",
- "--analyze-only",
- "--verbosity",
- Verbosity.DYNAMIC.value,
+ f"--my-arg {test_arg}",
]
)
- assert "Received arg test_arg" in output
+ check_discover_log("linear_pytorch", f"Received arg {test_arg}")
- # TODO: Investigate why this test is only failing on Windows
- @unittest.skipIf(
- platform.system() == "Windows",
- "Potential turnkeyml windows bug"
- "The ouputs do match, but fails due to misinterpretation",
- )
def test_07_valid_turnkey_args(self):
height, width, num_channels = parse(["height", "width", "num_channels"])
cmd = [
"turnkey",
+ "-i",
os.path.join(corpus_dir, "turnkey_parser.py"),
+ "--cache-dir",
+ cache_dir,
+ "discover",
"--script-args",
f"--num_channels {num_channels+1}",
- "--verbosity",
- Verbosity.DYNAMIC.value,
]
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- stdout, _ = process.communicate()
- output = stdout.decode("utf-8")
+ subprocess.run(cmd)
expected_output = str([height, width, num_channels + 1])
- assert expected_output in output, f"Got {output} but expected {expected_output}"
+ check_discover_log("turnkey_parser", expected_output)
def test_08_invalid_turnkey_args(self):
cmd = [
"turnkey",
+ "-i",
os.path.join(corpus_dir, "turnkey_parser.py"),
+ "--cache-dir",
+ cache_dir,
+ "discover",
"--script-args",
"--invalid_arg 123",
- "--verbosity",
- Verbosity.DYNAMIC.value,
]
- process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
- _, stderr = process.communicate()
- assert "error: unrecognized argument" in stderr.decode("utf-8")
+
+ subprocess.run(cmd)
+ check_discover_log("turnkey_parser", "error: unrecognized argument")
def test_09_pipeline(self):
output = run_analysis(
[
"turnkey",
+ "-i",
os.path.join(corpus_dir, "pipeline.py"),
- "--analyze-only",
- "--verbosity",
- Verbosity.DYNAMIC.value,
+ "--cache-dir",
+ cache_dir,
+ "discover",
]
)
assert np.array_equal(output, (1, 0, 0))
def test_10_activation(self):
- output = run_analysis(
- [
- "turnkey",
- os.path.join(corpus_dir, "activation.py"),
- "--analyze-only",
- "--verbosity",
- Verbosity.DYNAMIC.value,
- ]
- )
- assert np.array_equal(output, (0, 0, 0))
+ with self.assertRaises(exp.ToolError):
+ run_analysis(
+ [
+ "turnkey",
+ "-i",
+ os.path.join(corpus_dir, "activation.py"),
+ "--cache-dir",
+ cache_dir,
+ "discover",
+ ]
+ )
def test_11_analyze_only(self):
output = run_analysis(
[
"turnkey",
+ "-i",
os.path.join(corpus_dir, "linear_pytorch.py"),
- "--analyze-only",
- "--verbosity",
- Verbosity.DYNAMIC.value,
+ "--cache-dir",
+ cache_dir,
+ "discover",
]
)
assert np.array_equal(output, (1, 0, 0))
@@ -333,14 +339,14 @@ def test_12_turnkey_hashes(self):
output = run_analysis(
[
"turnkey",
+ "-i",
os.path.join(corpus_dir, "linear_pytorch.py::76af2f62"),
- "--build-only",
- "--max-depth",
- "1",
"--cache-dir",
cache_dir,
- "--verbosity",
- Verbosity.DYNAMIC.value,
+ "discover",
+ "--max-depth",
+ "1",
+ "export-pytorch",
]
)
assert np.array_equal(output, (2, 0, 1))
@@ -350,25 +356,26 @@ def test_13_clean_cache(self):
run_analysis(
[
"turnkey",
+ "-i",
os.path.join(corpus_dir, f"linear_pytorch.py::{model_hash}"),
- "--max-depth",
- "1",
"--cache-dir",
cache_dir,
- "--build-only",
- "--verbosity",
- Verbosity.DYNAMIC.value,
+ "discover",
+ "--max-depth",
+ "1",
+ "export-pytorch",
]
)
build_name = f"linear_pytorch_{model_hash}"
cmd = [
"turnkey",
- "cache",
- "clean",
- build_name,
"--cache-dir",
cache_dir,
+ "cache",
+ "--clean",
+ "--build-names",
+ build_name,
]
subprocess.run(cmd, check=True)
@@ -378,10 +385,11 @@ def test_14_same_model_different_input_shapes(self):
output = run_analysis(
[
"turnkey",
+ "-i",
os.path.join(corpus_dir, "two_executions.py"),
- "--analyze-only",
- "--verbosity",
- Verbosity.DYNAMIC.value,
+ "--cache-dir",
+ cache_dir,
+ "discover",
]
)
assert np.array_equal(output, (2, 0, 0))
@@ -390,12 +398,13 @@ def test_15_same_model_different_input_shapes_maxdepth(self):
output = run_analysis(
[
"turnkey",
+ "-i",
os.path.join(corpus_dir, "two_executions.py"),
- "--analyze-only",
+ "--cache-dir",
+ cache_dir,
+ "discover",
"--max-depth",
"1",
- "--verbosity",
- Verbosity.DYNAMIC.value,
]
)
assert np.array_equal(output, (6, 0, 0))
diff --git a/test/build_model.py b/test/build_model.py
deleted file mode 100644
index 7171820f..00000000
--- a/test/build_model.py
+++ /dev/null
@@ -1,710 +0,0 @@
-import os
-import unittest
-import torch
-import onnx
-import tensorflow as tf
-import numpy as np
-import sklearn.ensemble
-import sklearn.neighbors
-import xgboost # pylint: disable=import-error
-import lightgbm # pylint: disable=import-error
-from onnxmltools.utils.float16_converter import convert_float_to_float16
-from onnxmltools.utils import save_model
-from onnxmltools.utils import load_model
-from turnkeyml import build_model
-import turnkeyml.build.export as export
-import turnkeyml.build.stage as stage
-import turnkeyml.common.filesystem as filesystem
-import turnkeyml.common.exceptions as exp
-import turnkeyml.common.build as build
-import turnkeyml.build.sequences as sequences
-
-
-class SmallPytorchModel(torch.nn.Module):
- def __init__(self):
- super(SmallPytorchModel, self).__init__()
- self.fc = torch.nn.Linear(10, 5)
-
- def forward(self, x):
- output = self.fc(x)
- return output
-
-
-class AnotherSimplePytorchModel(torch.nn.Module):
- def __init__(self):
- super(AnotherSimplePytorchModel, self).__init__()
- self.relu = torch.nn.ReLU()
-
- def forward(self, x):
- output = self.relu(x)
- return output
-
-
-class SmallKerasModel(tf.keras.Model): # pylint: disable=abstract-method
- def __init__(self):
- super(SmallKerasModel, self).__init__()
- self.dense = tf.keras.layers.Dense(10)
-
- def call(self, x): # pylint: disable=arguments-differ
- return self.dense(x)
-
-
-base_dir = os.path.dirname(os.path.abspath(__file__))
-cache_location = os.path.join(base_dir, "generated", "build_model_cache")
-
-# Define pytorch model and inputs
-pytorch_model = SmallPytorchModel()
-tiny_pytorch_model = AnotherSimplePytorchModel()
-inputs = {"x": torch.rand(10)}
-inputs_2 = {"x": torch.rand(5)}
-input_tensor = torch.rand(10)
-
-# Define keras models and inputs
-batch_keras_inputs = {"x": tf.random.uniform((1, 10), dtype=tf.float32)}
-keras_subclass_model = SmallKerasModel()
-keras_subclass_model.build(input_shape=(1, 10))
-keras_sequential_model = tf.keras.Sequential()
-keras_sequential_model.add(
- tf.keras.layers.InputLayer(
- batch_size=1,
- input_shape=(10),
- name="x",
- )
-)
-keras_sequential_model.add(tf.keras.layers.Dense(10))
-keras_sequential_model.compile(
- loss="binary_crossentropy",
- optimizer="adam",
- metrics=["accuracy"],
-)
-
-# Define sklearn model and inputs
-np.random.seed(0)
-rf_batch_size = 320
-
-rf_inputs = np.random.rand(rf_batch_size, 10).astype(np.float32)
-
-rf_model = sklearn.ensemble.RandomForestClassifier(
- n_estimators=10, max_depth=5, random_state=0
-)
-xgb_model = xgboost.XGBClassifier(
- n_estimators=10, max_depth=5, random_state=0, objective="binary:logistic"
-)
-lgbm_model = lightgbm.LGBMClassifier(n_estimators=10, max_depth=5, random_state=0)
-kn_model = sklearn.neighbors.KNeighborsClassifier(n_neighbors=10)
-
-
-# Run build_model() and get results
-def full_compilation_pytorch_model():
- build_name = "full_compilation_pytorch_model"
- state = build_model(
- pytorch_model,
- inputs,
- build_name=build_name,
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- )
- return state.build_status == build.FunctionStatus.SUCCESSFUL
-
-
-def full_compilation_keras_subclass_model():
- build_name = "full_compilation_keras_subclass_model"
- state = build_model(
- keras_subclass_model,
- batch_keras_inputs,
- build_name=build_name,
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- )
- return state.build_status == build.FunctionStatus.SUCCESSFUL
-
-
-def full_compilation_keras_sequential_model():
- build_name = "full_compilation_keras_sequential_model"
- state = build_model(
- keras_sequential_model,
- batch_keras_inputs,
- build_name=build_name,
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- )
- return state.build_status == build.FunctionStatus.SUCCESSFUL
-
-
-def full_compilation_onnx_model():
- build_name = "full_compilation_onnx_model"
- torch.onnx.export(
- pytorch_model,
- input_tensor,
- "small_onnx_model.onnx",
- opset_version=build.DEFAULT_ONNX_OPSET,
- input_names=["input"],
- output_names=["output"],
- )
- state = build_model(
- "small_onnx_model.onnx",
- inputs,
- build_name=build_name,
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- )
- return state.build_status == build.FunctionStatus.SUCCESSFUL
-
-
-def full_compilation_hummingbird_rf():
- rf_model.fit(rf_inputs, np.random.randint(2, size=rf_batch_size))
-
- build_name = "full_compilation_hummingbird_rf"
- state = build_model(
- rf_model,
- {"input_0": rf_inputs},
- build_name=build_name,
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- )
- return state.build_status == build.FunctionStatus.SUCCESSFUL
-
-
-def full_compilation_hummingbird_xgb():
- xgb_model.fit(rf_inputs, np.random.randint(2, size=rf_batch_size))
-
- build_name = "full_compilation_hummingbird_xgb"
- state = build_model(
- xgb_model,
- {"input_0": rf_inputs},
- build_name=build_name,
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- )
- return state.build_status == build.FunctionStatus.SUCCESSFUL
-
-
-def full_compilation_hummingbird_lgbm():
- lgbm_model.fit(rf_inputs, np.random.randint(2, size=rf_batch_size))
-
- build_name = "full_compilation_hummingbird_lgbm"
- state = build_model(
- lgbm_model,
- {"input_0": rf_inputs},
- build_name=build_name,
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- )
- return state.build_status == build.FunctionStatus.SUCCESSFUL
-
-
-def full_compilation_hummingbird_kn():
- kn_model.fit(rf_inputs, np.random.randint(2, size=rf_batch_size))
-
- build_name = "full_compilation_hummingbird_kn"
- state = build_model(
- kn_model,
- {"input_0": rf_inputs},
- build_name=build_name,
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- )
- return state.build_status == build.FunctionStatus.SUCCESSFUL
-
-
-def scriptmodule_functional_check():
- build_name = "scriptmodule_functional_check"
- x = torch.rand(10)
- forward_input = x
- input_dict = {"forward": forward_input}
- pytorch_module = torch.jit.trace_module(pytorch_model, input_dict)
- state = build_model(
- pytorch_module,
- inputs,
- build_name=build_name,
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- )
- return state.build_status == build.FunctionStatus.SUCCESSFUL
-
-
-def custom_stage():
- build_name = "custom_stage"
-
- class MyCustomStage(stage.Stage):
- def __init__(self, funny_saying):
- super().__init__(
- unique_name="funny_fp16_convert",
- monitor_message="Funny FP16 conversion",
- )
-
- self.funny_saying = funny_saying
-
- def fire(self, state):
- input_onnx = state.intermediate_results[0]
- output_onnx = os.path.join(export.onnx_dir(state), "custom.onnx")
- fp32_model = load_model(input_onnx)
- fp16_model = convert_float_to_float16(fp32_model)
- save_model(fp16_model, output_onnx)
-
- print(f"funny message: {self.funny_saying}")
-
- state.intermediate_results = [output_onnx]
-
- return state
-
- my_custom_stage = MyCustomStage(
- funny_saying="Is a fail whale a fail at all if it makes you smile?"
- )
- my_sequence = stage.Sequence(
- unique_name="my_sequence",
- monitor_message="Running My Sequence",
- stages=[
- export.ExportPytorchModel(),
- export.OptimizeOnnxModel(),
- my_custom_stage,
- ],
- )
-
- state = build_model(
- pytorch_model,
- inputs,
- build_name=build_name,
- rebuild="always",
- monitor=False,
- sequence=my_sequence,
- cache_dir=cache_location,
- )
-
- return state.build_status == build.FunctionStatus.SUCCESSFUL
-
-
-class FullyCustomStage(stage.Stage):
- def __init__(self, saying, name):
- super().__init__(
- unique_name=name,
- monitor_message=f"Running {name}",
- )
-
- self.saying = saying
-
- def fire(self, state):
- print(self.saying)
-
- return state
-
-
-def custom_sequence():
- build_name = "custom_sequence"
- stage_1_name = "Stage1"
- stage_2_name = "Stage2"
- stage_3_name = "Stage3"
- stage_1_msg = "Developer Velocity is"
- stage_2_msg = "Innovating"
- stage_3_msg = "Faster than ever"
-
- stage_1 = FullyCustomStage(stage_1_msg, stage_1_name)
- stage_2 = FullyCustomStage(stage_2_msg, stage_2_name)
- stage_3 = FullyCustomStage(stage_3_msg, stage_3_name)
-
- my_sequence = stage.Sequence(
- "my_stage", "Running my Sequence", stages=[stage_1, stage_2, stage_3]
- )
-
- build_model(
- build_name=build_name,
- monitor=False,
- rebuild="always",
- sequence=my_sequence,
- cache_dir=cache_location,
- )
-
- log_1_path = os.path.join(cache_location, build_name, f"log_{stage_1_name}.txt")
- log_2_path = os.path.join(cache_location, build_name, f"log_{stage_2_name}.txt")
- log_3_path = os.path.join(cache_location, build_name, f"log_{stage_3_name}.txt")
-
- with open(log_1_path, "r", encoding="utf8") as f:
- log_1 = f.readlines()[1]
-
- with open(log_2_path, "r", encoding="utf8") as f:
- log_2 = f.readlines()[1]
-
- with open(log_3_path, "r", encoding="utf8") as f:
- log_3 = f.readlines()[1]
-
- return stage_1_msg in log_1 and stage_2_msg in log_2 and stage_3_msg in log_3
-
-
-def rebuild_always():
- """
- This function checks to see if the build_name.yaml file has been modified.
- If rebuild="always" the build_name_state.yaml file will have been modified along with
- the rest of the files in model/build_name due to a forced rebuild.
- If rebuild="never" the build_name_state.yaml file should *not* have been modified and
- the rest of the files in model/build_name will remain untouched and the
- model will be loaded from cache.
- To pass this test:
- between build 1 and build 2 the build_name_state.yaml file will be modified and
- therefor have different file modification timestamps
- between build 2 and build 3 the build_name_state.yaml file will *not* be modified
- resulting in identical modification timestamps.
- """
- build_name = "rebuild"
- build_timestamps = {}
- build_purpose_to_rebuild_setting = {
- "initial": "always",
- "rebuild": "always",
- "load": "never",
- }
-
- # Build Initial model, rebuild, and load from cache
- for build_purpose, rebuild_setting in build_purpose_to_rebuild_setting.items():
- build_model(
- pytorch_model,
- inputs,
- build_name=build_name,
- rebuild=rebuild_setting,
- monitor=False,
- cache_dir=cache_location,
- )
-
- yaml_file_path = build.state_file(cache_location, build_name)
-
- # Read the the file modification timestamp
- if os.path.isfile(yaml_file_path):
- build_timestamps[build_purpose] = os.path.getmtime(yaml_file_path)
- else:
- msg = f"""
- The rebuild_always test attempted to load a state.yaml file
- at {yaml_file_path} but couldn't find one.
- """
- raise ValueError(msg)
-
- # Did the second build Rebuild?
- if build_timestamps["initial"] != build_timestamps["rebuild"]:
- rebuild = True
- else:
- rebuild = False
-
- # Was the third build skipped and the model loaded from cache?
- if build_timestamps["rebuild"] == build_timestamps["load"]:
- load = True
- else:
- load = False
-
- return rebuild and load
-
-
-def rebuild_if_needed():
- """
- This function checks to see if the build_name.yaml file has been modified.
- If rebuild="always" the build_name_state.yaml file will have been modified along with
- the rest of the files in model/build_name due to a forced rebuild.
- If rebuild="if_needed" the build_name_state.yaml file should *not* have been modified and
- the rest of the files in model/build_name will remain untouched and the
- model will be loaded from cache.
- To pass this test:
- between build 1 and build 2 the build_name_state.yaml file will *not* be modified
- resulting in identical modification timestamps.
- We also toss in a state.save() call to make sure that doesn't break the cache.
- """
- build_name = "rebuild"
- build_timestamps = {}
- build_purpose_to_rebuild_setting = {
- "initial": "always",
- "load": "if_needed",
- }
-
- # Build Initial model, rebuild, and load from cache
- for build_purpose, rebuild_setting in build_purpose_to_rebuild_setting.items():
- state = build_model(
- pytorch_model,
- inputs,
- build_name=build_name,
- rebuild=rebuild_setting,
- monitor=False,
- cache_dir=cache_location,
- )
-
- if build_purpose == "initial":
- state.save()
-
- yaml_file_path = build.state_file(cache_location, build_name)
-
- # Read the the file modification timestamp
- if os.path.isfile(yaml_file_path):
- build_timestamps[build_purpose] = os.path.getmtime(yaml_file_path)
- else:
- msg = f"""
- The rebuild_always test attempted to load a state.yaml file
- at {yaml_file_path} but couldn't find one.
- """
- raise ValueError(msg)
-
- # Was the third build skipped and the model loaded from cache?
- if build_timestamps["initial"] == build_timestamps["load"]:
- load = True
- else:
- load = False
-
- return load
-
-
-def illegal_onnx_opset():
- build_name = "illegal_onnx_opset"
- torch.onnx.export(
- pytorch_model,
- input_tensor,
- "illegal_onnx_opset.onnx",
- opset_version=(build.MINIMUM_ONNX_OPSET - 1),
- input_names=["input"],
- output_names=["output"],
- )
- build_model(
- "illegal_onnx_opset.onnx",
- inputs,
- build_name=build_name,
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- )
-
-
-class Testing(unittest.TestCase):
- def setUp(self) -> None:
- filesystem.rmdir(cache_location)
-
- return super().setUp()
-
- def test_000_rebuild_always(self):
- assert rebuild_always()
-
- def test_001_rebuild_if_needed(self):
- assert rebuild_if_needed()
-
- def test_002_full_compilation_pytorch_model(self):
- assert full_compilation_pytorch_model()
-
- def test_003_full_compilation_keras_sequential_model(self):
- assert full_compilation_keras_sequential_model()
-
- def test_004_full_compilation_keras_subclass_model(self):
- assert full_compilation_keras_subclass_model()
-
- def test_005_full_compilation_onnx_model(self):
- assert full_compilation_onnx_model()
-
- def test_006_full_compilation_hummingbird_rf(self):
- assert full_compilation_hummingbird_rf()
-
- def test_007_full_compilation_hummingbird_xgb(self):
- assert full_compilation_hummingbird_xgb()
-
- def test_009_custom_stage(self):
- assert custom_stage()
-
- def test_010_nested_sequence(self):
- build_name = "nested_sequence"
- stage_1_name = "Stage1"
- stage_2_name = "Stage2"
- stage_3_name = "Stage3"
- stage_1_msg = "Did you know"
- stage_2_msg = "sequences can go in sequences?"
- stage_3_msg = "Indeed they can!"
-
- stage_1 = FullyCustomStage(stage_1_msg, stage_1_name)
- stage_2 = FullyCustomStage(stage_2_msg, stage_2_name)
- stage_3 = FullyCustomStage(stage_3_msg, stage_3_name)
-
- inner_sequence = stage.Sequence(
- "inner_sequence", "Running my Inner Sequence", stages=[stage_1, stage_2]
- )
-
- outer_sequence = stage.Sequence(
- "outer_sequence",
- "Running my Outer Sequence",
- stages=[inner_sequence, stage_3],
- )
-
- build_model(
- build_name=build_name,
- monitor=False,
- rebuild="always",
- sequence=outer_sequence,
- cache_dir=cache_location,
- )
-
- log_1_path = os.path.join(cache_location, build_name, f"log_{stage_1_name}.txt")
- log_2_path = os.path.join(cache_location, build_name, f"log_{stage_2_name}.txt")
- log_3_path = os.path.join(cache_location, build_name, f"log_{stage_3_name}.txt")
-
- with open(log_1_path, "r", encoding="utf8") as f:
- log_1 = f.readlines()[1]
-
- with open(log_2_path, "r", encoding="utf8") as f:
- log_2 = f.readlines()[1]
-
- with open(log_3_path, "r", encoding="utf8") as f:
- log_3 = f.readlines()[1]
-
- assert stage_1_msg in log_1, f"{stage_1_msg} not in {log_1}"
- assert stage_2_msg in log_2, f"{stage_2_msg} not in {log_2}"
- assert stage_3_msg in log_3, f"{stage_3_msg} not in {log_3}"
-
- def test_011_custom_sequence(self):
- assert custom_sequence()
-
- def test_012_illegal_onnx_opset(self):
- self.assertRaises(exp.StageError, illegal_onnx_opset)
- if os.path.exists("illegal_onnx_opset.onnx"):
- os.remove("illegal_onnx_opset.onnx")
-
- def test_013_set_onnx_opset(self):
- build_name = "full_compilation_pytorch_model"
-
- user_opset = 15
- assert user_opset != build.DEFAULT_ONNX_OPSET
-
- state = build_model(
- pytorch_model,
- inputs,
- build_name=build_name,
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- onnx_opset=user_opset,
- sequence=sequences.optimize_fp16,
- )
-
- assert state.build_status == build.FunctionStatus.SUCCESSFUL
-
- onnx_model = onnx.load(state.results[0])
- model_opset = getattr(onnx_model.opset_import[0], "version", None)
- assert user_opset == model_opset
-
- def test_014_export_only(self):
- build_name = "export_only"
-
- state = build_model(
- pytorch_model,
- inputs,
- build_name=build_name,
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- sequence=sequences.onnx_fp32,
- )
-
- assert state.build_status == build.FunctionStatus.SUCCESSFUL
- assert os.path.exists(export.base_onnx_file(state))
- assert not os.path.exists(export.opt_onnx_file(state))
-
- def test_015_receive_onnx(self):
- """
- Manually export an ONNX file with an opset other than the default
- Then make sure that the state file correctly reflects that opset
- """
- build_name = "receive_onnx"
- onnx_file = f"{build_name} + .onnx"
- user_opset = build.MINIMUM_ONNX_OPSET
-
- # Make sure we are using an non-default ONNX opset
- assert user_opset != build.DEFAULT_ONNX_OPSET
-
- # Create ONNX file
- torch.onnx.export(
- pytorch_model,
- input_tensor,
- onnx_file,
- opset_version=user_opset,
- input_names=["input"],
- output_names=["output"],
- )
-
- # Build the ONNX file
- state = build_model(
- onnx_file,
- inputs,
- build_name=build_name,
- rebuild="always",
- monitor=False,
- )
-
- # Make sure the build was successful
- assert state.build_status == build.FunctionStatus.SUCCESSFUL
-
- # Get ONNX file's opset
- onnx_model = onnx.load(onnx_file)
- model_opset = getattr(onnx_model.opset_import[0], "version", None)
-
- # Make sure the ONNX file matches the opset we asked for
- assert user_opset == model_opset
-
- # Make sure the ONNX file matches the state file
- assert model_opset == state.config.onnx_opset
-
- def test_016_full_compilation_hummingbird_lgbm(self):
- assert full_compilation_hummingbird_lgbm()
-
- def test_017_inputs_conversion(self):
- custom_sequence_fp32 = stage.Sequence(
- "custom_sequence_fp32",
- "Building Pytorch Model without fp16 conversion",
- [
- export.ExportPytorchModel(),
- export.OptimizeOnnxModel(),
- ],
- enable_model_validation=True,
- )
-
- custom_sequence_fp16 = stage.Sequence(
- "custom_sequence_fp16",
- "Building Pytorch Model with fp16 conversion",
- [
- export.ExportPytorchModel(),
- export.OptimizeOnnxModel(),
- export.ConvertOnnxToFp16(),
- ],
- enable_model_validation=True,
- )
-
- # Build model using fp32 inputs
- build_name = "custom_sequence_fp32"
- build_model(
- pytorch_model,
- inputs,
- build_name=build_name,
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- sequence=custom_sequence_fp32,
- )
-
- inputs_path = os.path.join(cache_location, build_name, "inputs.npy")
- assert np.load(inputs_path, allow_pickle=True)[0]["x"].dtype == np.float32
-
- # Build model using fp16 inputs
- build_name = "custom_sequence_fp16"
- build_model(
- pytorch_model,
- inputs,
- build_name="custom_sequence_fp16",
- rebuild="always",
- monitor=False,
- cache_dir=cache_location,
- sequence=custom_sequence_fp16,
- )
-
- inputs_path = os.path.join(cache_location, build_name, "inputs.npy")
- assert np.load(inputs_path, allow_pickle=True)[0]["x"].dtype == np.float16
-
- def test_018_full_compilation_hummingbird_kn(self):
- assert full_compilation_hummingbird_kn()
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/test/cli.py b/test/cli.py
index e2ea7e2d..bddd1de5 100644
--- a/test/cli.py
+++ b/test/cli.py
@@ -17,17 +17,14 @@
import platform
import torch
from turnkeyml.cli.cli import main as turnkeycli
-import turnkeyml.cli.report as report
-import turnkeyml.common.filesystem as filesystem
-from turnkeyml.run.onnxrt.runtime import OnnxRT
-from turnkeyml.run.tensorrt.runtime import TensorRT
+import turnkeyml.tools.report as report
+import turnkeyml.common.filesystem as fs
import turnkeyml.common.build as build
-import turnkeyml.common.filesystem as filesystem
import turnkeyml.common.exceptions as exceptions
-import turnkeyml.build.export as export
-import turnkeyml.cli.spawn as spawn
+import turnkeyml.common.onnx_helpers as onnx_helpers
from turnkeyml.cli.parser_helpers import decode_args, encode_args
-from helpers import common
+import turnkeyml.common.test_helpers as common
+from turnkeyml.state import load_state
def bash(cmd: str) -> List[str]:
@@ -61,7 +58,7 @@ def assert_success_of_builds(
) -> int:
# Figure out the build name by surveying the build cache
# for a build that includes test_script_name in the name
- builds = filesystem.get_all(cache_dir)
+ builds = fs.get_all(cache_dir)
builds_found = 0
for test_script in test_script_files:
@@ -70,11 +67,10 @@ def assert_success_of_builds(
for build_state_file in builds:
if test_script_name in build_state_file:
- build_state = build.load_state(state_path=build_state_file)
- stats = filesystem.Stats(
+ build_state = load_state(state_path=build_state_file)
+ stats = fs.Stats(
build_state.cache_dir,
- build_state.config.build_name,
- build_state.evaluation_id,
+ build_state.build_name,
)
assert build_state.build_status == build.FunctionStatus.SUCCESSFUL
script_build_found = True
@@ -86,20 +82,20 @@ def assert_success_of_builds(
), f"{build_state.info.__dict__[info_property[0]]} == {info_property[1]}"
if check_perf:
- assert stats.evaluation_stats["mean_latency"] > 0
- assert stats.evaluation_stats["throughput"] > 0
+ assert stats.stats["mean_latency"] > 0
+ assert stats.stats["throughput"] > 0
if check_iteration_count:
- iterations = stats.evaluation_stats["iterations"]
+ iterations = stats.stats["iterations"]
assert iterations == check_iteration_count
if check_opset:
- onnx_model = onnx.load(build_state.results[0])
+ onnx_model = onnx.load(build_state.results)
model_opset = getattr(onnx_model.opset_import[0], "version", None)
assert model_opset == check_opset
if check_onnx_file_count:
- onnx_dir = export.onnx_dir(build_state)
+ onnx_dir = onnx_helpers.onnx_dir(build_state)
assert len(os.listdir(onnx_dir)) == check_onnx_file_count
assert script_build_found
@@ -127,7 +123,7 @@ def forward(self, x):
class Testing(unittest.TestCase):
def setUp(self) -> None:
- filesystem.rmdir(cache_dir)
+ fs.rmdir(cache_dir)
return super().setUp()
@@ -137,11 +133,13 @@ def test_001_cli_single(self):
testargs = [
"turnkey",
- "benchmark",
+ "-i",
os.path.join(corpus_dir, test_script),
- "--build-only",
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -154,12 +152,14 @@ def test_002_search_multiple(self):
testargs = [
"turnkey",
- "benchmark",
+ "-i",
os.path.join(corpus_dir, test_scripts[0]),
os.path.join(corpus_dir, test_scripts[1]),
- "--build-only",
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -175,11 +175,13 @@ def test_003_cli_build_dir(self):
testargs = [
"turnkey",
- "benchmark",
+ "-i",
bash(f"{corpus_dir}/*.py"),
- "--build-only",
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
@@ -194,11 +196,13 @@ def test_004_cli_list(self):
# Build the test corpus so we have builds to list
testargs = [
"turnkey",
- "benchmark",
+ "-i",
bash(f"{corpus_dir}/*.py"),
- "--build-only",
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
@@ -207,10 +211,11 @@ def test_004_cli_list(self):
with redirect_stdout(io.StringIO()) as f:
testargs = [
"turnkey",
- "cache",
- "list",
"--cache-dir",
cache_dir,
+ "cache",
+ "--list",
+ "--all",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -228,11 +233,13 @@ def test_005_cli_delete(self):
# Build the test corpus so we have builds to delete
testargs = [
"turnkey",
- "benchmark",
+ "-i",
bash(f"{corpus_dir}/*.py"),
- "--build-only",
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
@@ -241,10 +248,11 @@ def test_005_cli_delete(self):
with redirect_stdout(io.StringIO()) as f:
testargs = [
"turnkey",
- "cache",
- "list",
"--cache-dir",
cache_dir,
+ "cache",
+ "--list",
+ "--all",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -256,11 +264,11 @@ def test_005_cli_delete(self):
# Delete the builds
testargs = [
"turnkey",
- "cache",
- "delete",
- "--all",
"--cache-dir",
cache_dir,
+ "cache",
+ "--delete",
+ "--all",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -269,10 +277,11 @@ def test_005_cli_delete(self):
with redirect_stdout(io.StringIO()) as f:
testargs = [
"turnkey",
- "cache",
- "list",
"--cache-dir",
cache_dir,
+ "cache",
+ "--list",
+ "--all",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -289,11 +298,13 @@ def test_006_cli_stats(self):
# Build the test corpus so we have builds to print
testargs = [
"turnkey",
- "benchmark",
+ "-i",
bash(f"{corpus_dir}/*.py"),
- "--build-only",
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
@@ -301,20 +312,19 @@ def test_006_cli_stats(self):
# Make sure we can print the builds in the cache
for test_script in common.test_scripts_dot_py.keys():
test_script_path = os.path.join(corpus_dir, test_script)
- builds, script_name = filesystem.get_builds_from_file(
- cache_dir, test_script_path
- )
+ builds, script_name = fs.get_builds_from_file(cache_dir, test_script_path)
for build_name in builds:
# Make sure each build can be accessed with `turnkey cache stats`
with redirect_stdout(io.StringIO()) as f:
testargs = [
"turnkey",
- "cache",
- "stats",
- build_name,
"--cache-dir",
cache_dir,
+ "cache",
+ "--stats",
+ "--build-names",
+ build_name,
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -383,13 +393,13 @@ def test_008_cli_turnkey_args(self):
# Set as many turnkey args as possible
testargs = [
"turnkey",
- "benchmark",
+ "-i",
os.path.join(corpus_dir, test_script),
- "--rebuild",
- "always",
- "--build-only",
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -404,10 +414,14 @@ def test_009_cli_benchmark(self):
testargs = [
"turnkey",
- "benchmark",
+ "-i",
os.path.join(corpus_dir, test_script),
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -420,49 +434,55 @@ def test_010_cli_labels(self):
# Only build models labels with test_group::a
testargs = [
"turnkey",
- "benchmark",
+ "-i",
bash(f"{corpus_dir}/*.py"),
"--labels",
"test_group::a",
"--build-only",
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
- state_files = [Path(p).stem for p in filesystem.get_all(cache_dir)]
+ state_files = [Path(p).stem for p in fs.get_all(cache_dir)]
assert state_files == ["linear_d5b1df11_state"]
# Delete the builds
testargs = [
"turnkey",
- "cache",
- "delete",
- "--all",
"--cache-dir",
cache_dir,
+ "cache",
+ "--delete",
+ "--all",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
- assert filesystem.get_all(cache_dir) == []
+ assert fs.get_all(cache_dir) == []
# Only build models labels with test_group::a and test_group::b
testargs = [
"turnkey",
- "benchmark",
+ "-i",
bash(f"{corpus_dir}/*.py"),
"--labels",
"test_group::a,b",
- "--build-only",
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
- state_files = [Path(p).stem for p in filesystem.get_all(cache_dir)]
+ state_files = [Path(p).stem for p in fs.get_all(cache_dir)]
assert state_files == ["linear_d5b1df11_state", "linear2_80b93950_state"]
@unittest.skip("Needs re-implementation")
@@ -470,24 +490,27 @@ def test_011_report_on_failed_build(self):
testargs = [
"turnkey",
bash(f"{corpus_dir}/linear.py"),
- "--device",
- "reimplement_me",
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
+ "--device",
+ "reimplement_me",
]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
# Ensure test failed
- build_state = build.load_state(state_path=filesystem.get_all(cache_dir)[0])
+ build_state = load_state(state_path=fs.get_all(cache_dir)[0])
assert build_state.build_status != build.FunctionStatus.SUCCESSFUL
# Generate report
testargs = [
"turnkey",
- "cache",
"report",
- "--cache-dir",
+ "--input-caches",
cache_dir,
]
with patch.object(sys, "argv", testargs):
@@ -513,10 +536,14 @@ def test_012_runtimes(self):
with self.assertRaises(exceptions.ArgError):
testargs = [
"turnkey",
- "benchmark",
+ "-i",
bash(f"{corpus_dir}/linear.py"),
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
"--device",
"x86",
"--runtime",
@@ -528,10 +555,12 @@ def test_012_runtimes(self):
# Benchmark with Pytorch
testargs = [
"turnkey",
- "benchmark",
+ "-i",
bash(f"{corpus_dir}/linear.py"),
"--cache-dir",
cache_dir,
+ "discover",
+ "benchmark",
"--device",
"x86",
"--runtime",
@@ -543,10 +572,14 @@ def test_012_runtimes(self):
# Benchmark with Onnx Runtime
testargs = [
"turnkey",
- "benchmark",
+ "-i",
bash(f"{corpus_dir}/linear.py"),
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
"--device",
"x86",
"--runtime",
@@ -566,12 +599,16 @@ def test_013_cli_onnx_opset(self):
testargs = [
"turnkey",
- "benchmark",
+ "-i",
os.path.join(corpus_dir, test_script),
"--cache-dir",
cache_dir,
- "--onnx-opset",
+ "discover",
+ "export-pytorch",
+ "--opset",
str(user_opset),
+ "optimize-ort",
+ "benchmark",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -587,10 +624,14 @@ def test_014_cli_iteration_count(self):
test_iterations = 123
testargs = [
"turnkey",
- "benchmark",
+ "-i",
os.path.join(corpus_dir, test_script),
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
"--iterations",
str(test_iterations),
]
@@ -612,30 +653,21 @@ def test_015_cli_process_isolation(self):
with redirect_stdout(io.StringIO()) as f:
testargs = [
"turnkey",
- "benchmark",
+ "-i",
os.path.join(corpus_dir, test_script),
"--cache-dir",
cache_dir,
"--process-isolation",
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
assert_success_of_builds([test_script], cache_dir, None, check_perf=True)
- # We use certain key phrases in stdout to perform cleanup in the event
- # that a turnkey subprocess does not complete.
- # These checks make sure that those key phrases are not removed
- output = f.getvalue().split("\n")
- evaluation_id = None
- build_name = None
- for line in output:
- evaluation_id = spawn.parse_evaluation_id(line, evaluation_id)
- build_name = spawn.parse_build_name(line, build_name)
-
- assert evaluation_id is not None
- assert build_name is not None
-
@unittest.skipIf(
platform.system() == "Windows",
"Skipping, as torch.compile is not supported on Windows"
@@ -645,10 +677,14 @@ def test_016_skip_compiled(self):
test_script = "compiled.py"
testargs = [
"turnkey",
- "benchmark",
+ "-i",
os.path.join(extras_dir, test_script),
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -662,7 +698,15 @@ def test_016_skip_compiled(self):
def test_017_invalid_file_type(self):
# Ensure that we get an error when running turnkey with invalid input_files
with self.assertRaises(SystemExit):
- testargs = ["turnkey", "gobbledegook"]
+ testargs = [
+ "turnkey",
+ "-i",
+ "gobbledegook",
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
+ ]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
@@ -672,12 +716,13 @@ def test_018_cli_export_only(self):
testargs = [
"turnkey",
- "benchmark",
+ "-i",
os.path.join(corpus_dir, test_script),
- "--sequence",
- "onnx-fp32",
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "benchmark",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -703,10 +748,12 @@ def test_019_cli_onnx_model(self):
testargs = [
"turnkey",
- "benchmark",
+ "-i",
onnx_file,
"--cache-dir",
cache_dir,
+ "load-onnx",
+ "benchmark",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -736,10 +783,12 @@ def test_020_cli_onnx_model_opset(self):
testargs = [
"turnkey",
- "benchmark",
+ "-i",
onnx_file,
"--cache-dir",
cache_dir,
+ "load-onnx",
+ "benchmark",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -763,7 +812,15 @@ def test_022_benchmark_non_existent_file(self):
with self.assertRaises(exceptions.ArgError):
filename = "thou_shall_not_exist.py"
with redirect_stdout(io.StringIO()) as f:
- testargs = ["turnkey", "benchmark", filename]
+ testargs = [
+ "turnkey",
+ "-i",
+ filename,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
+ ]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -772,7 +829,15 @@ def test_023_benchmark_non_existent_file_prefix(self):
with self.assertRaises(exceptions.ArgError):
file_prefix = "non_existent_prefix_*.py"
with redirect_stdout(io.StringIO()) as f:
- testargs = ["turnkey", "benchmark", file_prefix]
+ testargs = [
+ "turnkey",
+ "-i",
+ file_prefix,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
+ ]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -783,11 +848,13 @@ def test_024_input_text_file(self):
testargs = [
"turnkey",
- "benchmark",
+ "-i",
os.path.join(extras_dir, "selected_models.txt"),
"--cache-dir",
cache_dir,
- "--build-only",
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
@@ -811,23 +878,25 @@ def test_025_cli_timeout(self):
testargs = [
"turnkey",
- "benchmark",
+ "-i",
os.path.join(extras_dir, "timeout.py"),
"--cache-dir",
cache_dir,
"--process-isolation",
"--timeout",
"10",
- "--build-only",
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
testargs = [
"turnkey",
- "cache",
"report",
- "--cache-dir",
+ "--input-caches",
cache_dir,
]
with patch.object(sys, "argv", testargs):
@@ -850,6 +919,9 @@ def test_025_cli_timeout(self):
# Edge case where the CSV is empty because the build timed out before
# the stats.yaml was created, which in turn means the CSV is empty
pass
+ except KeyError:
+ # Edge case where the CSV only contains a key for "error_log"
+ assert "timeout" in timeout_summary["error_log"]
def test_026_cli_report(self):
# NOTE: this is not a unit test, it relies on other command
@@ -861,19 +933,22 @@ def test_026_cli_report(self):
# Benchmark the test corpus so we have builds to report
testargs = [
"turnkey",
- "benchmark",
+ "-i",
bash(f"{corpus_dir}/*.py"),
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
testargs = [
"turnkey",
- "cache",
"report",
- "--cache-dir",
+ "--input-caches",
cache_dir,
]
with patch.object(sys, "argv", testargs):
@@ -896,7 +971,7 @@ def test_026_cli_report(self):
"device",
"mean_latency",
"throughput",
- "selected_sequence_of_stages",
+ "selected_sequence_of_tools",
]
linear_summary = summary[1]
assert len(summary) == len(test_scripts)
@@ -935,29 +1010,31 @@ def test_026_cli_report(self):
result_dict = report.get_dict(
summary_csv_path,
[
- "selected_sequence_of_stages",
- "stage_duration:export_pytorch",
- "stage_duration:optimize_onnx",
- "stage_status:export_pytorch",
- "stage_status:optimize_onnx",
+ "selected_sequence_of_tools",
+ "tool_duration:discover",
+ "tool_duration:export-pytorch",
+ "tool_duration:optimize-ort",
+ "tool_status:discover",
+ "tool_status:export-pytorch",
+ "tool_status:optimize-ort",
],
)
for result in result_dict.values():
# All of the models should have exported to ONNX and optimized the ONNX model
- for stage in ["export_pytorch", "optimize_onnx"]:
- assert stage in result["selected_sequence_of_stages"]
- duration = result[f"stage_duration:{stage}"]
- status = result[f"stage_status:{stage}"]
+ for tool in ["export-pytorch", "optimize-ort"]:
+ assert tool in result["selected_sequence_of_tools"]
+ duration = result[f"tool_duration:{tool}"]
+ status = result[f"tool_status:{tool}"]
assert (
status == "successful"
- ), f"Unexpected status {status} for stage '{stage}'"
+ ), f"Unexpected status {status} for tool '{tool}'"
try:
assert (
float(duration) > 0
- ), f"Stage {stage} has invalid duration '{duration}'"
+ ), f"Tool {tool} has invalid duration '{duration}'"
except ValueError:
# Catch the case where the value is not numeric
- assert False, f"Stage {stage} has invalid duration {duration}"
+ assert False, f"Tool {tool} has invalid duration {duration}"
def test_027_cli_cache_benchmark(self):
@@ -966,40 +1043,48 @@ def test_027_cli_cache_benchmark(self):
# Build the test corpus so we have builds to benchmark
testargs = [
"turnkey",
- "benchmark",
+ "-i",
bash(f"{corpus_dir}/*.py"),
"--cache-dir",
cache_dir,
- "--build-only",
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
# Benchmark the single model from cache directory
- selected_build = filesystem.get_available_builds(cache_dir)[-1]
+ selected_build = fs.get_available_builds(cache_dir)[-1]
+ state_file_path = os.path.join(
+ cache_dir, selected_build, f"{selected_build}_state.yaml"
+ )
+
testargs = [
"turnkey",
- "cache",
- "benchmark",
- selected_build,
"--cache-dir",
cache_dir,
+ "-i",
+ state_file_path,
+ "load-build",
+ "benchmark",
]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
# Make sure the benchmark happened
- test_script = "_".join(selected_build.split("_")[:-1]) + ".py"
+ test_script = selected_build + ".py"
assert_success_of_builds([test_script], cache_dir, check_perf=True)
# Benchmark the cache directory
testargs = [
"turnkey",
- "cache",
- "benchmark",
- "--all",
"--cache-dir",
cache_dir,
+ "-i",
+ os.path.join(cache_dir, "*", "*_state.yaml"),
+ "load-build",
+ "benchmark",
]
with patch.object(sys, "argv", flatten(testargs)):
turnkeycli()
@@ -1007,6 +1092,46 @@ def test_027_cli_cache_benchmark(self):
# Make sure the benchmarks happened
assert_success_of_builds(test_scripts, cache_dir, check_perf=True)
+ def test_028_cli_onnx_verify(self):
+ # Test the first model in the corpus
+ test_script = list(common.test_scripts_dot_py.keys())[0]
+
+ testargs = [
+ "turnkey",
+ "-i",
+ os.path.join(corpus_dir, test_script),
+ "--cache-dir",
+ cache_dir,
+ "discover",
+ "verify-exporter",
+ "export-pytorch",
+ "optimize-ort",
+ ]
+ with patch.object(sys, "argv", testargs):
+ turnkeycli()
+
+ assert_success_of_builds([test_script], cache_dir)
+
+ def test_029_cli_fp16_convert(self):
+ # Test the first model in the corpus
+ test_script = list(common.test_scripts_dot_py.keys())[0]
+
+ testargs = [
+ "turnkey",
+ "-i",
+ os.path.join(corpus_dir, test_script),
+ "--cache-dir",
+ cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "convert-fp16",
+ ]
+ with patch.object(sys, "argv", testargs):
+ turnkeycli()
+
+ assert_success_of_builds([test_script], cache_dir)
+
if __name__ == "__main__":
# Create a cache directory a directory with test models
diff --git a/test/helpers/check_slurm_output.sh b/test/helpers/check_slurm_output.sh
index f2431e46..3685b17a 100644
--- a/test/helpers/check_slurm_output.sh
+++ b/test/helpers/check_slurm_output.sh
@@ -1,6 +1,6 @@
# Checks whether a slurm output contains any errors
SLURM_OUTPUT="$1"
-if ! grep -q "Model successfully built!" $SLURM_OUTPUT
+if ! grep -q "Successful build!" $SLURM_OUTPUT
then
cat $SLURM_OUTPUT
echo "Model has not been successfully built"
diff --git a/test/plugins.py b/test/plugins.py
index 5bb0e005..ce63a3c8 100644
--- a/test/plugins.py
+++ b/test/plugins.py
@@ -9,7 +9,7 @@
from turnkeyml.cli.cli import main as turnkeycli
import turnkeyml.common.filesystem as filesystem
import turnkeyml.common.build as build
-from helpers import common
+import turnkeyml.common.test_helpers as common
class Testing(unittest.TestCase):
@@ -25,27 +25,31 @@ def test_001_device_naming(self):
test_script = "linear.py"
testargs = [
"turnkey",
- "benchmark",
+ "-i",
os.path.join(corpus_dir, test_script),
- "--device",
- "example_family",
- "--build-only",
"--cache-dir",
cache_dir,
+ "discover",
+ "export-pytorch",
+ "optimize-ort",
+ "benchmark",
+ "--device",
+ "example_family",
]
with patch.object(sys, "argv", testargs):
turnkeycli()
- _, build_state = common.get_stats_and_state(test_script, cache_dir)
+ build_stats, build_state = common.get_stats_and_state(test_script, cache_dir)
# Check if build was successful
assert build_state.build_status == build.FunctionStatus.SUCCESSFUL
# Check if default part and config were assigned
expected_device = "example_family::part1::config1"
+ actual_device = build_stats["device_type"]
assert (
- build_state.config.device == expected_device
- ), f"Got {build_state.config.device}, expected {expected_device}"
+ actual_device == expected_device
+ ), f"Got {actual_device}, expected {expected_device}"
if __name__ == "__main__":
diff --git a/trackers/huggingface/.streamlit/config.toml b/trackers/huggingface/.streamlit/config.toml
deleted file mode 100644
index c9b8c45f..00000000
--- a/trackers/huggingface/.streamlit/config.toml
+++ /dev/null
@@ -1,2 +0,0 @@
-[theme]
-base="dark"
\ No newline at end of file
diff --git a/trackers/huggingface/app.py b/trackers/huggingface/app.py
deleted file mode 100644
index c9b62ac7..00000000
--- a/trackers/huggingface/app.py
+++ /dev/null
@@ -1,216 +0,0 @@
-from os import listdir
-from os.path import isfile, join
-import pandas as pd
-import streamlit as st # pylint: disable=import-error
-import graphs
-from streamlit_helpers import add_filter, slider_filter, Collapsable
-
-st.set_page_config(
- page_title="TurnkeyML Tracker",
- page_icon="⚡",
- layout="wide",
-)
-
-# dashboard title
-st.title("TurnkeyML Tracker ⚡")
-
-st.warning(
- (
- "TurnkeyML is under active development and we are currently working on a list of critical data "
- "validation tasks. We are sharing this "
- "dashboard and the data within for the sole purpose of gathering early feedback. See our FAQ below "
- "for more details about license and liability."
- ),
- icon="⚠️",
-)
-
-
-def add_faq() -> None:
- """
- Displays FAQ using Collapsable sections
- """
- faq = Collapsable()
- faq.add_section(
- "How is TurnkeyML different from MLPerf?",
- (
- "Deep learning pioneers have been judging their progress with the Machine Learning "
- "Performance (MLPerf) inference benchmark, but have found that the corpus of models "
- "is small enough that it allows vendors to primarily compete by hand-optimizing "
- "kernels. TurnkeyML offers a complementary approach to MLPerf by examining the "
- "capability of vendors to provide turnkey solutions to a larger corpus of "
- "off-the-shelf models. By providing a workflow that is representative of the "
- "mass adoption customer on a variety of ML accelerators and effectively disallowing "
- "hand-crafted kernels, TurnkeyML bridges the gap between MLPerf and the mass adoption "
- "of hardware acceleration."
- ),
- )
- faq.add_section(
- "Why now for TurnkeyML?",
- (
- "Deep learning algorithms and their associated DL hardware accelerators are "
- "transitioning from early adoption into mass adoption. Production DL is now "
- "becoming available to the masses, with a desire to customize models to tackle "
- "their specific problems, and then take the path of least resistance into "
- "production. A market for turnkey solutions, starting with a model as input and "
- "provision a cost- and latency-effective acceleration solution, often in the cloud, "
- "as output, has emerged."
- ),
- )
- faq.add_section(
- "Which tool was used to generate those results?",
- (
- "All TurnkeyML results have been generated using the turnkey tool v1.0.0, which is part "
- "of the TurnkeyML Github Repository. You can learn more about it "
- 'here.'
- ),
- )
- faq.add_section(
- "What is the experimental setup for each of the devices?",
- [
- "x86: Intel(R) Xeon(R) X40 CPU @ 2.00GHz on Google Cloud (custom: n2, 80 vCPU, 64.00 GiB) and OnnxRuntime version 1.14.0.",
- "nvidia: NVIDIA A100 40GB on Google Cloud (a2-highgpu-1g) and TensorRT version 22.12-py3.",
- (
- "You can find more details about the methodology "
- 'here.'
- ),
- ],
- )
- faq.add_section(
- "What are the current key limitations of those results?",
- [
- (
- "Results currently only represent batch 1 performance on a limited number of models, "
- "devices, vendors, and runtimes. You can learn more about future directions by reading "
- 'the "What are the future directions of TurnkeyML?" FAQ section.'
- ),
- (
- "Results are currently being validated. You can have a look at our current validation "
- "tasks and other limitations "
- 'here.'
- ),
- ],
- )
- faq.add_section(
- "What are the future directions of TurnkeyML?",
- [
- "Include additional classes of models (e.g. LLMs, GNNs, DLRMs).",
- "Perform experiments that include sweeps over batch and input sizes.",
- "Increase the number of devices from existing vendors (e.g. T4, A10, and H100).",
- "Include devices from additional vendors (e.g. ARM, and AMD)."
- "Include the number of runtimes supported (e.g. ORT and PyTorch for CUDA, PyTorch for x86).",
- ],
- )
- faq.add_section(
- "Who runs TurnkeyML?",
- (
- "TurnkeyML is currently maintained by the following individuals (in alphabetical order): "
- "Daniel Holanda Noronha, Jeremy Fowers, Kalin Ovtcharov, and Ramakrishnan Sivakumar. We are actively seeking collaborators from across the industry."
- ),
- )
- faq.add_section(
- "License and Liability",
- (
- 'THE TURNKEY BENCHMARK IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR '
- "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, "
- "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE "
- "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER "
- "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, "
- "OUT OF OR IN CONNECTION WITH THE BENCHMARK OR THE USE OR OTHER DEALINGS IN THE "
- "BENCHMARK. Read more about it "
- 'here.'
- ),
- )
-
- faq.deploy()
-
-
-# Add all filters to sidebar
-with st.sidebar:
- st.markdown("# Filters")
-
- # Get all reports of a given test type
- REPORT_FOLDER = "reports"
- reports = sorted(
- [f for f in listdir(REPORT_FOLDER) if isfile(join(REPORT_FOLDER, f))]
- )
-
- # Select and read a report
- selected_report = st.selectbox("Test date", reports, index=len(reports) - 1)
- selected_report_idx = reports.index(selected_report)
- report = pd.read_csv(f"{REPORT_FOLDER}/{selected_report}")
-
- # Convert int parameters to int/float
- for p in ["params"]:
- report[p] = report[p].replace("-", 0).astype("int64")
-
- # Add parameter filter
- st.markdown("#### Parameters")
-
- report = slider_filter(
- [report], "Select a range parameters (in millions)", filter_by="params"
- )[0]
-
- # Add author filter
- report = add_filter(
- [report],
- "Origin",
- label="author",
- num_cols=2,
- )[0]
-
- # Add task filter
- report = add_filter([report], "Tasks", label="task", options=None)[0]
-
-
-st.markdown("## Summary Results")
-
-graphs.device_funnel(report)
-
-st.markdown("""#### Benchmark results""")
-baseline = st.selectbox("Baseline", ("x86", "nvidia"))
-graphs.speedup_text_summary(report, baseline)
-graphs.speedup_bar_chart(report, baseline)
-
-cols = st.columns(2)
-with cols[0]:
- st.markdown("""#### Workload origin""")
- graphs.workload_origin(report)
-
-with cols[1]:
- st.markdown("""#### Parameter Size Distribution""")
- graphs.parameter_histogram(report, show_assembled=False)
-
-# FAQ Block
-st.markdown("""## About this workload analysis (FAQ)""")
-add_faq()
-
-# Detailed data view (table)
-st.markdown("## Detailed Data View")
-
-# Add columns that do not exist yet
-report["gpu_chips_used"] = 1
-report["cpu_chips_used"] = 1
-
-
-# Using 3 significant digits
-report["nvidia_latency"] = [
- "-" if x == "-" else "{:.3f}".format(float(x)) for x in report["nvidia_latency"]
-]
-report["x86_latency"] = [
- "-" if x == "-" else "{:.3f}".format(float(x)) for x in report["x86_latency"]
-]
-
-renamed_cols = {
- "model_name": "Model Name",
- "author": "Source",
- "params": "Parameters",
- "nvidia_latency": "NVIDIA A100-PCIE-40GB: Latency (ms)",
- "x86_latency": "Intel(R) Xeon(R) x40 CPU: Latency (ms)",
- "gpu_chips_used": "NVIDIA A100-PCIE-40GB: Chips Used",
- "cpu_chips_used": "Intel(R) Xeon(R) x40 CPU: Chips Used",
-}
-
-report.rename(columns=renamed_cols, inplace=True)
-selected_cols = list(renamed_cols.values())
-
-graphs.results_table(report[selected_cols]) # pylint: disable=unsubscriptable-object
diff --git a/trackers/huggingface/graphs.py b/trackers/huggingface/graphs.py
deleted file mode 100644
index d0ba2761..00000000
--- a/trackers/huggingface/graphs.py
+++ /dev/null
@@ -1,646 +0,0 @@
-from collections import Counter
-from streamlit_echarts import st_echarts # pylint: disable=import-error
-import numpy as np
-import pandas as pd
-import streamlit as st # pylint: disable=import-error
-import plotly.figure_factory as ff
-from plotly import graph_objs as go
-import plotly.express as px
-from statistics import median
-
-colors = {
- "blue": "#5470c6",
- "orange": "#FF7F0E",
- "green": "#94cc74",
- "saffron_mango": "#fac858",
- "red": "#ee6666",
- "light_blue": "#73c0de",
- "ocean_green": "#3ba272",
-}
-device_colors = {
- "x86": colors["blue"],
- "nvidia": colors["green"],
-}
-
-
-class StageCount:
- def __init__(self, df: pd.DataFrame) -> None:
- self.all_models = len(df)
- self.base_onnx = int(np.sum(df["base_onnx"]))
- self.optimize_fp32 = int(np.sum(df["optimize_fp32"]))
- self.all_ops_supported = int(np.sum(df["all_ops_supported"]))
- self.fp16_onnx = int(np.sum(df["fp16_onnx"]))
- self.compiles = int(np.sum(df["compiles"]))
- self.assembles = int(np.sum(df["assembles"]))
-
-
-class DeviceStageCount:
- def __init__(self, df: pd.DataFrame) -> None:
- self.all_models = len(df)
- self.base_onnx = int(np.sum(df["onnx_exported"]))
- self.optimize_fp32 = int(np.sum(df["onnx_optimized"]))
- self.fp16_onnx = int(np.sum(df["onnx_converted"]))
- self.x86 = df.loc[df.x86_latency != "-", "x86_latency"].count()
- self.nvidia = df.loc[df.nvidia_latency != "-", "nvidia_latency"].count()
-
-
-def stages_count_summary(current_df: pd.DataFrame, prev_df: pd.DataFrame) -> None:
- """
- Show count of how many models compile, assemble, etc
- """
- current = StageCount(current_df)
- prev = StageCount(prev_df)
-
- kpi = st.columns(7)
-
- kpi[0].metric(
- label="All models",
- value=current.all_models,
- delta=current.all_models - prev.all_models,
- )
-
- kpi[1].metric(
- label="Converts to ONNX",
- value=current.base_onnx,
- delta=current.base_onnx - prev.base_onnx,
- )
-
- kpi[2].metric(
- label="Optimizes ONNX file",
- value=current.optimize_fp32,
- delta=current.optimize_fp32 - prev.optimize_fp32,
- )
-
- kpi[3].metric(
- label="Supports all ops",
- value=current.all_ops_supported,
- delta=current.all_ops_supported - prev.all_ops_supported,
- )
-
- kpi[4].metric(
- label="Converts to FP16",
- value=current.fp16_onnx,
- delta=current.fp16_onnx - prev.fp16_onnx,
- )
-
- kpi[5].metric(
- label="Compiles",
- value=current.compiles,
- delta=current.compiles - prev.compiles,
- )
-
- kpi[6].metric(
- label="Assembles",
- value=current.assembles,
- delta=current.assembles - prev.assembles,
- )
-
- # Show Sankey graph with percentages
- sk_val = {
- "All models": "100%",
- "Converts to ONNX": str(int(100 * current.base_onnx / current.all_models))
- + "%",
- "Optimizes ONNX file": str(
- int(100 * current.optimize_fp32 / current.all_models)
- )
- + "%",
- "Supports all ops": str(
- int(100 * current.all_ops_supported / current.all_models)
- )
- + "%",
- "Converts to FP16": str(int(100 * current.fp16_onnx / current.all_models))
- + "%",
- "Compiles": str(int(100 * current.compiles / current.all_models)) + "%",
- "Assembles": str(int(100 * current.assembles / current.all_models)) + "%",
- }
- option = {
- "series": {
- "type": "sankey",
- "animationDuration": 1,
- "top": "0%",
- "bottom": "20%",
- "left": "0%",
- "right": "13.5%",
- "darkMode": "true",
- "nodeWidth": 2,
- "textStyle": {"fontSize": 16},
- "lineStyle": {"curveness": 0},
- "layoutIterations": 0,
- "layout": "none",
- "emphasis": {"focus": "adjacency"},
- "data": [
- {
- "name": "All models",
- "value": sk_val["All models"],
- "itemStyle": {"color": "white", "borderColor": "white"},
- },
- {
- "name": "Converts to ONNX",
- "value": sk_val["Converts to ONNX"],
- "itemStyle": {"color": "white", "borderColor": "white"},
- },
- {
- "name": "Optimizes ONNX file",
- "value": sk_val["Optimizes ONNX file"],
- "itemStyle": {"color": "white", "borderColor": "white"},
- },
- {
- "name": "Supports all ops",
- "value": sk_val["Supports all ops"],
- "itemStyle": {"color": "white", "borderColor": "white"},
- },
- {
- "name": "Converts to FP16",
- "value": sk_val["Converts to FP16"],
- "itemStyle": {"color": "white", "borderColor": "white"},
- },
- {
- "name": "Compiles",
- "value": sk_val["Compiles"],
- "itemStyle": {"color": "white", "borderColor": "white"},
- },
- {
- "name": "Assembles",
- "value": sk_val["Assembles"],
- "itemStyle": {"color": "white", "borderColor": "white"},
- },
- ],
- "label": {
- "position": "insideTopLeft",
- "borderWidth": 0,
- "fontSize": 16,
- "color": "white",
- "textBorderWidth": 0,
- "formatter": "{c}",
- },
- "links": [
- {
- "source": "All models",
- "target": "Converts to ONNX",
- "value": current.base_onnx,
- },
- {
- "source": "Converts to ONNX",
- "target": "Optimizes ONNX file",
- "value": current.optimize_fp32,
- },
- {
- "source": "Optimizes ONNX file",
- "target": "Supports all ops",
- "value": current.all_ops_supported,
- },
- {
- "source": "Supports all ops",
- "target": "Converts to FP16",
- "value": current.fp16_onnx,
- },
- {
- "source": "Converts to FP16",
- "target": "Compiles",
- "value": current.compiles,
- },
- {
- "source": "Compiles",
- "target": "Assembles",
- "value": current.assembles,
- },
- ],
- }
- }
- st_echarts(
- options=option,
- height="50px",
- )
-
-
-def workload_origin(df: pd.DataFrame) -> None:
- """
- Show pie chart that groups models by author
- """
- all_authors = list(df.loc[:, "author"])
- author_count = {i: all_authors.count(i) for i in all_authors}
- all_models = len(df)
-
- options = {
- "darkMode": "true",
- "textStyle": {"fontSize": 16},
- "tooltip": {"trigger": "item"},
- "series": [
- { # "Invisible" chart, used to show author labels
- "name": "Name of corpus:",
- "type": "pie",
- "radius": ["70%", "70%"],
- "data": [
- {"value": author_count[k], "name": k} for k in author_count.keys()
- ],
- "label": {
- "formatter": "{b}\n{d}%",
- },
- },
- {
- # Actual graph where data is shown
- "name": "Name of corpus:",
- "type": "pie",
- "radius": ["50%", "70%"],
- "data": [
- {"value": author_count[k], "name": k} for k in author_count.keys()
- ],
- "emphasis": {
- "itemStyle": {
- "shadowBlur": 10,
- "shadowOffsetX": 0,
- "shadowColor": "rgba(0, 0, 0, 0.5)",
- }
- },
- "label": {
- "position": "inner",
- "formatter": "{c}",
- "color": "black",
- "textBorderWidth": 0,
- },
- },
- {
- # Show total number of models inside
- "name": "Total number of models:",
- "type": "pie",
- "radius": ["0%", "0%"],
- "data": [{"value": all_models, "name": "Total"}],
- "silent": "true",
- "label": {
- "position": "inner",
- "formatter": "{c}",
- "color": "white",
- "fontSize": 30,
- "textBorderWidth": 0,
- },
- },
- ],
- }
- st_echarts(
- options=options,
- height="400px",
- )
-
-
-def parameter_histogram(df: pd.DataFrame, show_assembled=True) -> None:
- # Add parameters histogram
- all_models = [float(x) / 1000000 for x in df["params"] if x != "-"]
-
- hist_data = []
- group_labels = []
-
- if all_models != []:
- hist_data.append(all_models)
- if show_assembled:
- group_labels.append("Models we tried compiling")
- else:
- group_labels.append("All models")
-
- if show_assembled:
- assembled_models = df[
- df["assembles"] == True # pylint: disable=singleton-comparison
- ]
- assembled_models = [
- float(x) / 1000000 for x in assembled_models["params"] if x != "-"
- ]
- if assembled_models != []:
- hist_data.append(assembled_models)
- group_labels.append("Assembled models")
-
- if hist_data:
- fig = ff.create_distplot(
- hist_data,
- group_labels,
- bin_size=25,
- histnorm="",
- colors=list(colors.values()),
- curve_type="normal",
- )
- fig.layout.update(xaxis_title="Parameters in millions")
- fig.layout.update(yaxis_title="count")
- fig.update_xaxes(range=[1, 1000])
-
- st.plotly_chart(fig, use_container_width=True)
-
- else:
- st.markdown(
- """At least one model needs to reach the compiler to show this graph 😅"""
- )
-
-
-def process_latency_data(df, baseline):
- df = df[["model_name", "nvidia_latency", "x86_latency"]]
- df = df.sort_values(by=["model_name"])
-
- df.x86_latency.replace(["-"], [float("inf")], inplace=True)
- df.nvidia_latency.replace(["-"], [float("inf")], inplace=True)
-
- df["nvidia_latency"] = df["nvidia_latency"].astype(float)
- df["x86_latency"] = df["x86_latency"].astype(float)
-
- df["nvidia_compute_ratio"] = df[f"{baseline}_latency"] / df["nvidia_latency"]
- df["x86_compute_ratio"] = df[f"{baseline}_latency"] / df["x86_latency"]
-
- return df
-
-
-def speedup_bar_chart(df: pd.DataFrame, baseline) -> None:
-
- if len(df) == 0:
- st.markdown(
- ("Nothing to show here since no models have been successfully benchmarked.")
- )
- else:
- df = process_latency_data(df, baseline)
- bar_chart = {}
- bar_chart["nvidia"] = go.Bar(
- x=df["model_name"],
- y=df["nvidia_compute_ratio"],
- name="NVIDIA A100",
- )
- bar_chart["x86"] = go.Bar(
- x=df["model_name"],
- y=df["x86_compute_ratio"],
- name="Intel(R) Xeon(R)",
- )
-
- # Move baseline to the back of the plot
- plot_sequence = list(bar_chart.keys())
- plot_sequence.insert(0, plot_sequence.pop(plot_sequence.index(baseline)))
-
- # Ensure that the baseline is the last bar
- data = [bar_chart[device_type] for device_type in plot_sequence]
- color_sequence = [device_colors[device_type] for device_type in plot_sequence]
-
- layout = go.Layout(
- barmode="overlay", # group
- legend={
- "orientation": "h",
- "xanchor": "center",
- "x": 0.5,
- "y": 1.2,
- },
- yaxis_title="Latency Speedup",
- colorway=color_sequence,
- height=500,
- )
-
- fig = dict(data=data, layout=layout)
- st.plotly_chart(fig, use_container_width=True)
-
-
-def kpi_to_markdown(
- compute_ratio, device, num_baseline_models, is_baseline=False, color="blue"
-):
-
- if is_baseline:
- title = f"""
- Median {device} Acceleration ({len(compute_ratio)} models):
"""
- return (
- title
- + f""" {1}x (Baseline)
"""
- )
-
- title = f"""
- Median {device} Acceleration ({len(compute_ratio)}/{num_baseline_models} models):
"""
-
- if len(compute_ratio) > 0:
- kpi_min, kpi_median, kpi_max = (
- round(compute_ratio.min(), 2),
- round(median(compute_ratio), 2),
- round(compute_ratio.max(), 2),
- )
- else:
- kpi_min, kpi_median, kpi_max = 0, 0, 0
-
- return (
- title
- + f""" {kpi_median}x
- min {kpi_min}x; max {kpi_max}x
- """
- )
-
-
-def speedup_text_summary(df: pd.DataFrame, baseline) -> None:
-
- df = process_latency_data(df, baseline)
-
- # Some latencies are "infinite" because they could not be calculated
- # To calculate statistics, we remove all elements of df where the baseline latency is inf
- df = df[(df[baseline + "_latency"] != float("inf"))]
-
- # Setting latencies that could not be calculated to infinity also causes some compute ratios to be zero
- # We remove those to avoid doing any calculations with infinite latencies
- x86_compute_ratio = df["x86_compute_ratio"].to_numpy()
- nvidia_compute_ratio = df["nvidia_compute_ratio"].to_numpy()
- x86_compute_ratio = x86_compute_ratio[x86_compute_ratio != 0]
- nvidia_compute_ratio = nvidia_compute_ratio[nvidia_compute_ratio != 0]
-
- num_baseline_models = len(df[f"{baseline}_compute_ratio"])
- x86_text = kpi_to_markdown(
- x86_compute_ratio,
- device="Intel(R) Xeon(R) X40 CPU @ 2.00GHz",
- num_baseline_models=num_baseline_models,
- color="blue",
- is_baseline=baseline == "x86",
- )
- nvidia_text = kpi_to_markdown(
- nvidia_compute_ratio,
- device="NVIDIA A100-PCIE-40GB",
- num_baseline_models=num_baseline_models,
- color="green",
- is_baseline=baseline == "nvidia",
- )
-
- cols = st.columns(3)
- with cols[0]:
- st.markdown(f"""{x86_text}""", unsafe_allow_html=True)
- with cols[1]:
- st.markdown(f"""{nvidia_text}""", unsafe_allow_html=True)
-
-
-def compiler_errors(df: pd.DataFrame) -> None:
- compiler_errors = df[df["compiler_error"] != "-"]["compiler_error"]
- compiler_errors = Counter(compiler_errors)
- if len(compiler_errors) > 0:
- compiler_errors = pd.DataFrame.from_dict(
- compiler_errors, orient="index"
- ).reset_index()
- compiler_errors = compiler_errors.set_axis(
- ["error", "count"], axis=1, inplace=False
- )
- compiler_errors["error"] = [ce[:80] for ce in compiler_errors["error"]]
- fig = px.bar(
- compiler_errors,
- x="count",
- y="error",
- orientation="h",
- height=400,
- )
- fig.update_traces(marker_color=colors["blue"])
-
- st.plotly_chart(fig, use_container_width=True)
- else:
- st.markdown("""No compiler errors found :tada:""")
-
-
-def results_table(df: pd.DataFrame):
- model_name = st.text_input("", placeholder="Filter model by name")
- if model_name != "":
- df = df[[model_name in x for x in df["Model Name"]]]
-
- st.dataframe(df, height=min((len(df) + 1) * 35, 35 * 21))
-
-
-def device_funnel_metrics(num_models: int, num_total_models: int) -> str:
- """
- Calculates the percentage between models and total_models
- Avoids ZeroDivisionError when dividend is zero
- """
- models_message = f"{num_models} model"
- models_message = models_message + "s" if num_models != 1 else models_message
- percentage_message = ""
- if num_total_models > 0:
- model_ratio = num_models / num_total_models
- if model_ratio < 0.01 and model_ratio != 0:
- percentage_message = " - < 1%"
- else:
- percentage_message = f" - {int(100*num_models / num_total_models)}%"
- return f"{models_message}{percentage_message}"
-
-
-def device_funnel(df: pd.DataFrame) -> None:
- """
- Show count of how many models compile, assemble, etc
- """
- summ = DeviceStageCount(df)
-
- stages = [
- "All models",
- "Export to ONNX",
- "Optimize ONNX file",
- "Convert to FP16",
- "Acquire Performance",
- ]
- cols = st.columns(len(stages))
-
- for idx, stage in enumerate(stages):
- with cols[idx]:
- st.markdown(stage)
-
- # Show Sankey graph with percentages
- sk_val = {
- "All models": device_funnel_metrics(summ.all_models, summ.all_models),
- "Converts to ONNX": device_funnel_metrics(summ.base_onnx, summ.all_models),
- "Optimizes ONNX file": device_funnel_metrics(
- summ.optimize_fp32, summ.all_models
- ),
- "Converts to FP16": device_funnel_metrics(summ.fp16_onnx, summ.all_models),
- "Acquires Nvidia Perf": device_funnel_metrics(summ.nvidia, summ.all_models)
- + " (Nvidia)",
- "Acquires x86 Perf": device_funnel_metrics(summ.x86, summ.all_models)
- + " (x86)",
- }
-
- # Calculate bar heights for each of the devices
- # Bar height is proportional to the number of models benchmarked by each device
- default_bar_size = 1
- target_combined_height = max(default_bar_size, summ.fp16_onnx)
- device_bar_size = target_combined_height / 3
-
- option = {
- "series": {
- "type": "sankey",
- "animationDuration": 1,
- "top": "0%",
- "bottom": "20%",
- "left": "0%",
- "right": "19%",
- "darkMode": "true",
- "nodeWidth": 2,
- "textStyle": {"fontSize": 16},
- "nodeAlign": "left",
- "lineStyle": {"curveness": 0},
- "layoutIterations": 0,
- "nodeGap": 12,
- "layout": "none",
- "emphasis": {"focus": "adjacency"},
- "data": [
- {
- "name": "All models",
- "value": sk_val["All models"],
- "itemStyle": {"color": "white", "borderColor": "white"},
- },
- {
- "name": "Converts to ONNX",
- "value": sk_val["Converts to ONNX"],
- "itemStyle": {"color": "white", "borderColor": "white"},
- },
- {
- "name": "Optimizes ONNX file",
- "value": sk_val["Optimizes ONNX file"],
- "itemStyle": {"color": "white", "borderColor": "white"},
- },
- {
- "name": "Converts to FP16",
- "value": sk_val["Converts to FP16"],
- "itemStyle": {"color": "white", "borderColor": "white"},
- },
- {
- "name": "Acquires Nvidia Perf",
- "value": sk_val["Acquires Nvidia Perf"],
- "itemStyle": {
- "color": device_colors["nvidia"],
- "borderColor": device_colors["nvidia"],
- },
- },
- {
- "name": "Acquires x86 Perf",
- "value": sk_val["Acquires x86 Perf"],
- "itemStyle": {
- "color": device_colors["x86"],
- "borderColor": device_colors["x86"],
- },
- },
- ],
- "label": {
- "position": "insideTopLeft",
- "borderWidth": 0,
- "fontSize": 16,
- "color": "white",
- "textBorderWidth": 0,
- "formatter": "{c}",
- },
- "links": [
- {
- "source": "All models",
- "target": "Converts to ONNX",
- "value": max(default_bar_size, summ.all_models),
- },
- {
- "source": "Converts to ONNX",
- "target": "Optimizes ONNX file",
- "value": max(default_bar_size, summ.optimize_fp32),
- },
- {
- "source": "Optimizes ONNX file",
- "target": "Converts to FP16",
- "value": max(default_bar_size, summ.fp16_onnx),
- },
- {
- "source": "Converts to FP16",
- "target": "Acquires Nvidia Perf",
- "value": device_bar_size,
- },
- {
- "source": "Converts to FP16",
- "target": "Acquires x86 Perf",
- "value": device_bar_size,
- },
- ],
- }
- }
- st_echarts(
- options=option,
- height="70px",
- )
diff --git a/trackers/huggingface/reports/randomized_data.csv b/trackers/huggingface/reports/randomized_data.csv
deleted file mode 100644
index 65136c70..00000000
--- a/trackers/huggingface/reports/randomized_data.csv
+++ /dev/null
@@ -1,169 +0,0 @@
-model_name,author,class,downloads,assembles,params,hash,license,task,model_type,cycles,nvidia_compute_latency,nvidia_latency,x86_latency,onnx_exported,onnx_optimized,onnx_converted
-alexnet,torch hub,AlexNet,0,TRUE,61100840,2891f54c,Unknown,-,pytorch,,-,-,-,TRUE,TRUE,TRUE
-bart,huggingface pytorch,BartModel,0,FALSE,404079238,cb0751ce,Unknown,-,pytorch,,-,0.3594731854,0.9526702776,TRUE,TRUE,TRUE
-beit,huggingface pytorch,BeitModel,0,FALSE,85530815,6b5d54c6,Unknown,-,pytorch,,-,0.9212441442,0.8234458168,TRUE,TRUE,TRUE
-bert,huggingface pytorch,BertModel,0,TRUE,109166782,d59172a2,Unknown,-,pytorch,,-,0.530061979,0.7047938409,TRUE,TRUE,TRUE
-bert for question answering,huggingface pytorch,BertForQuestionAnswering,0,FALSE,333701493,64bce7df,Unknown,-,pytorch,,-,0.6683411346,0.348481092,TRUE,TRUE,TRUE
-bert generation,huggingface pytorch,EncoderDecoderModel,0,FALSE,465654648,42b8fae4,Unknown,-,pytorch,,-,0.1077982091,0.3473692903,TRUE,TRUE,TRUE
-bert tiny for sequence classification,huggingface pytorch,BertForSequenceClassification,0,TRUE,4353194,ca662a9e,Unknown,-,pytorch,,-,0.5748178435,0.2362540791,TRUE,TRUE,TRUE
-blenderbot small,huggingface pytorch,BlenderbotSmallModel,0,FALSE,84607202,d65dd9e3,Unknown,-,pytorch,,-,-,0.8786105288,TRUE,TRUE,TRUE
-camembert,huggingface pytorch,CamembertModel,0,TRUE,109461696,a2ac5985,Unknown,-,pytorch,,-,0.1664160905,0.8746881469,TRUE,TRUE,TRUE
-clip text encoder,diffusers,CLIPTextModel,0,TRUE,123066839,d312ecd1,Unknown,-,pytorch,,-,0.2097790728,0.9720412135,TRUE,TRUE,TRUE
-convbert,huggingface pytorch,ConvBertModel,0,FALSE,105389032,b39013e9,Unknown,-,pytorch,,-,0.3710670566,0.9451956605,TRUE,TRUE,TRUE
-convnext,huggingface pytorch,ConvNextModel,0,FALSE,27766278,80414def,Unknown,-,pytorch,,-,0.7483780632,-,TRUE,TRUE,TRUE
-convnext base,torch hub,ConvNeXt,0,FALSE,88438765,c68282ce,Unknown,-,pytorch,,-,0.8708838228,0.6443308707,TRUE,TRUE,TRUE
-convnext large,torch hub,ConvNeXt,0,FALSE,197538285,af479213,Unknown,-,pytorch,,-,0.06627834059,0.3834119618,TRUE,TRUE,TRUE
-convnext small,torch hub,ConvNeXt,0,FALSE,50109165,32bd6900,Unknown,-,pytorch,,-,0.1387087279,0.1858895092,TRUE,TRUE,TRUE
-convnext tiny,torch hub,ConvNeXt,0,FALSE,28536813,4f884eed,Unknown,-,pytorch,,-,0.8040838302,0.7140804057,TRUE,TRUE,TRUE
-deberta,huggingface pytorch,DebertaModel,0,TRUE,123641957,f4e4f0d1,Unknown,-,pytorch,,-,0.07356748404,0.8320810276,TRUE,TRUE,TRUE
-deit,huggingface pytorch,DeiTModel,0,FALSE,86272702,4519cd75,Unknown,-,pytorch,,-,-,0.4257759263,TRUE,TRUE,TRUE
-deit base for image classification,huggingface pytorch,ViTForImageClassification,0,FALSE,86567846,8fa842d1,Unknown,-,pytorch,,-,0.07511358418,0.1587937427,TRUE,TRUE,TRUE
-deit tiny for image classification,huggingface pytorch,ViTForImageClassification,0,TRUE,5717606,4f7bba18,Unknown,-,pytorch,,-,0.5986953321,0.1589380348,TRUE,TRUE,TRUE
-densenet121,torch hub,DenseNet,0,FALSE,7928960,d5f7254d,Unknown,-,pytorch,,-,0.223103432,0.3582776808,TRUE,TRUE,TRUE
-densenet161,torch hub,DenseNet,0,FALSE,28564768,6c360ce5,Unknown,-,pytorch,,-,0.5760521071,-,TRUE,TRUE,TRUE
-densenet169,torch hub,DenseNet,0,FALSE,14079232,ccd997cb,Unknown,-,pytorch,,-,0.925285356,0.2066641834,TRUE,TRUE,TRUE
-densenet201,torch hub,DenseNet,0,FALSE,19901952,e355a66c,Unknown,-,pytorch,,-,0.5952328131,0.7150526363,TRUE,TRUE,FALSE
-detr,huggingface pytorch,DetrModel,0,FALSE,-,c328f5b8,Unknown,-,pytorch,,-,0.5943997277,0.3579429127,TRUE,TRUE,FALSE
-detr for object detection,huggingface pytorch,DetrForObjectDetection,0,FALSE,-,a2481ba5,Unknown,-,pytorch,,-,0.3708182675,0.8912486329,TRUE,TRUE,TRUE
-distil wav2vec2 for audio classification,huggingface pytorch,Wav2Vec2ForSequenceClassification,0,FALSE,37866370,cd811c97,Unknown,-,pytorch,,-,0.3044858549,0.4356521705,TRUE,TRUE,TRUE
-distilbert,huggingface pytorch,DistilBertModel,0,FALSE,66068114,38518005,Unknown,-,pytorch,,-,0.753096493,0.5359567334,TRUE,TRUE,TRUE
-distilbert for question answering,huggingface pytorch,DistilBertForQuestionAnswering,0,FALSE,66069655,65b3ff1b,Unknown,-,pytorch,,-,-,0.7001622561,TRUE,TRUE,TRUE
-distilhubert for audio classification,huggingface pytorch,HubertForSequenceClassification,0,FALSE,23700597,4170140a,Unknown,-,pytorch,,-,0.5306673544,0.5248636079,TRUE,TRUE,TRUE
-efficientnet b0,torch hub,EfficientNet,0,TRUE,5242196,94890704,Unknown,-,pytorch,,-,0.8726445517,0.9883209454,TRUE,TRUE,TRUE
-efficientnet b1,torch hub,EfficientNet,0,TRUE,7724900,8e53a932,Unknown,-,pytorch,,-,-,0.2983805803,TRUE,TRUE,TRUE
-efficientnet b2,torch hub,EfficientNet,0,TRUE,9034582,204800dc,Unknown,-,pytorch,,-,0.8041922217,0.5590392182,TRUE,TRUE,TRUE
-efficientnet b3,torch hub,EfficientNet,0,FALSE,12134224,2950ca5b,Unknown,-,pytorch,,-,0.6351977538,0.7910649523,TRUE,TRUE,TRUE
-efficientnet b4,torch hub,EfficientNet,0,FALSE,19197120,7d75dda2,Unknown,-,pytorch,,-,0.6566002185,-,TRUE,TRUE,TRUE
-efficientnet b5,torch hub,EfficientNet,0,FALSE,30187756,204c9208,Unknown,-,pytorch,,-,-,0.7448938632,TRUE,TRUE,TRUE
-efficientnet b6,torch hub,EfficientNet,0,FALSE,42776110,d5bd9458,Unknown,-,pytorch,,-,0.08317055388,0.5736338115,TRUE,TRUE,TRUE
-efficientnet b7,torch hub,EfficientNet,0,FALSE,65977888,6973429a,Unknown,-,pytorch,,-,0.923207231,0.1441279948,TRUE,TRUE,TRUE
-efficientnet v2 l,torch hub,EfficientNet,0,FALSE,117896136,f5ddf7f0,Unknown,-,pytorch,,-,0.04204524828,0.4086286848,TRUE,TRUE,TRUE
-efficientnet v2 m,torch hub,EfficientNet,0,FALSE,53790556,a041aef8,Unknown,-,pytorch,,-,0.6019568858,0.5245239163,TRUE,TRUE,TRUE
-efficientnet v2 s,torch hub,EfficientNet,0,FALSE,21275536,ae743058,Unknown,-,pytorch,,-,0.5937655039,0.5348455458,TRUE,TRUE,TRUE
-electra,huggingface pytorch,ElectraModel,0,TRUE,13411517,8da49ae6,Unknown,-,pytorch,,-,0.6692486205,0.4009829786,TRUE,TRUE,TRUE
-electra for sequence classification,huggingface pytorch,ElectraForSequenceClassification,0,TRUE,109285824,5ccb19c4,Unknown,-,pytorch,,-,0.3971938312,0.5677821826,TRUE,TRUE,TRUE
-encoder decoder,huggingface pytorch,EncoderDecoderModel,0,FALSE,269541724,051eeb05,Unknown,-,pytorch,,-,0.9346869012,0.9847387641,TRUE,TRUE,TRUE
-fasterrcnn mobilenet v3 large 320 fpn,torchvision,FasterRCNN,0,FALSE,-,59bcc1a5,Unknown,-,pytorch,,-,-,0.6083279321,TRUE,TRUE,TRUE
-fasterrcnn mobilenet v3 large fpn,torchvision,FasterRCNN,0,FALSE,-,e32c9090,Unknown,-,pytorch,,-,0.9809107221,0.8068068871,TRUE,TRUE,TRUE
-fasterrcnn resnet50 fpn,torchvision,FasterRCNN,0,FALSE,-,d8b3f65a,Unknown,-,pytorch,,-,0.7137446244,0.3051665576,TRUE,TRUE,TRUE
-fasterrcnn resnet50 fpn v2,torchvision,FasterRCNN,0,FALSE,-,7147702b,Unknown,-,pytorch,,-,0.2270062371,0.6995605109,TRUE,TRUE,TRUE
-fcos resnet50 fpn,torchvision,FCOS,0,FALSE,-,78b52a80,Unknown,-,pytorch,,-,0.3569217638,0.9152885619,TRUE,TRUE,TRUE
-flaubert,huggingface pytorch,FlaubertModel,0,FALSE,665991453,6202b0cf,Unknown,-,pytorch,,-,0.1632512774,-,TRUE,TRUE,TRUE
-funnel,huggingface pytorch,FunnelModel,0,FALSE,126026366,ab8f5fd3,Unknown,-,pytorch,,-,0.7044616903,-,TRUE,TRUE,TRUE
-funnel base,huggingface pytorch,FunnelBaseModel,0,FALSE,111060055,37ecc84c,Unknown,-,pytorch,,-,0.2091450516,-,TRUE,TRUE,TRUE
-googlenet,torch hub,GoogLeNet,0,TRUE,6613040,6e59c54b,Unknown,-,pytorch,,-,0.7150814403,-,TRUE,TRUE,TRUE
-gpt1,huggingface pytorch,OpenAIGPTModel,0,TRUE,116160216,0342a9fe,Unknown,-,pytorch,,-,0.5142592189,-,TRUE,TRUE,TRUE
-gpt2,huggingface pytorch,GPT2Model,0,FALSE,123654106,af143a10,Unknown,-,pytorch,,-,0.158932048,-,TRUE,TRUE,TRUE
-gpt2 doublehead,huggingface pytorch,GPT2DoubleHeadsModel,0,FALSE,162253019,7befd733,Unknown,-,pytorch,,-,0.3122759696,0.02552053518,TRUE,TRUE,TRUE
-hardnet39ds,torch hub,HarDNet,0,TRUE,3475386,47ba431c,Unknown,-,pytorch,,-,0.7285997594,0.3964458821,TRUE,TRUE,TRUE
-hardnet68,torch hub,HarDNet,0,TRUE,17557570,9d6d24cf,Unknown,-,pytorch,,-,0.6983739791,0.9057760945,TRUE,TRUE,TRUE
-hardnet68ds,torch hub,HarDNet,0,TRUE,4162836,85f34cd3,Unknown,-,pytorch,,-,0.4661837048,0.5113252359,TRUE,TRUE,TRUE
-hardnet85,torch hub,HarDNet,0,FALSE,36657020,acb062f3,Unknown,-,pytorch,,-,0.5940045827,0.8055358081,TRUE,TRUE,TRUE
-imagegpt,huggingface pytorch,ImageGPTModel,0,FALSE,75872158,3b5850cc,Unknown,-,pytorch,,-,0.1787505865,0.3447673324,TRUE,TRUE,TRUE
-inception v3,torch hub,Inception3,0,TRUE,23802160,46db3db5,Unknown,-,pytorch,,-,0.08009607107,0.6512160638,TRUE,TRUE,TRUE
-keypointrcnn resnet50 fpn,torchvision,KeypointRCNN,0,FALSE,-,2f5908b4,Unknown,-,pytorch,,-,0.8088128085,0.4302517376,FALSE,FALSE,FALSE
-layoutlm,huggingface pytorch,LayoutLMModel,0,TRUE,112312513,33ec397d,Unknown,-,pytorch,,-,0.7114299737,0.2745420037,TRUE,TRUE,TRUE
-luke,huggingface pytorch,LukeModel,0,TRUE,124625858,431c265c,Unknown,-,pytorch,,-,0.2215668713,0.3966205355,FALSE,FALSE,FALSE
-m2m 100,huggingface pytorch,M2M100Model,0,FALSE,484582485,533285d2,Unknown,-,pytorch,,-,0.8410742278,0.5874390004,TRUE,TRUE,TRUE
-marian,huggingface pytorch,MarianModel,0,FALSE,73968682,ea99ab2b,Unknown,-,pytorch,,-,0.8877533581,0.6225305022,TRUE,TRUE,TRUE
-marianmt,huggingface pytorch,MarianMTModel,0,FALSE,105222820,f4dcd1cc,Unknown,-,pytorch,,-,0.6196875754,0.3745179239,TRUE,TRUE,TRUE
-maskrcnn resnet50 fpn,torchvision,MaskRCNN,0,FALSE,-,a5f78569,Unknown,-,pytorch,,-,0.6920710863,0.3714439234,TRUE,TRUE,TRUE
-maskrcnn resnet50 fpn v2,torchvision,MaskRCNN,0,FALSE,-,f4f1de9a,Unknown,-,pytorch,,-,0.9412028203,0.06505430483,TRUE,TRUE,TRUE
-megatron bert,huggingface pytorch,MegatronBertModel,0,FALSE,333060466,2fa53f3f,Unknown,-,pytorch,,-,0.9346378407,0.3721718198,TRUE,TRUE,TRUE
-minilmv2,huggingface pytorch,BertModel,0,TRUE,22565860,f969d36d,Unknown,-,pytorch,,-,0.7004551396,0.7460853493,TRUE,TRUE,TRUE
-mnasnet0 5,torch hub,MNASNet,0,TRUE,2200880,9.13E+07,Unknown,-,pytorch,,-,0.2601984045,0.5719197358,TRUE,TRUE,TRUE
-mnasnet0 75,torch hub,MNASNet,0,TRUE,3144288,4a915154,Unknown,-,pytorch,,-,0.8419733408,0.2381885875,TRUE,TRUE,TRUE
-mnasnet1 0,torch hub,MNASNet,0,TRUE,4350160,041e693a,Unknown,-,pytorch,,-,0.1304315094,0.02978677531,TRUE,TRUE,TRUE
-mnasnet1 3,torch hub,MNASNet,0,TRUE,6239320,87ea0deb,Unknown,-,pytorch,,-,0.6102472479,0.323515366,TRUE,TRUE,TRUE
-mobilebert,huggingface pytorch,MobileBertModel,0,TRUE,24552318,72442a94,Unknown,-,pytorch,,-,0.8232095313,0.1631495713,TRUE,TRUE,TRUE
-mobilebert for sequence classification,huggingface pytorch,MobileBertForSequenceClassification,0,TRUE,21063042,c6599ac3,Unknown,-,pytorch,,-,0.4737991483,0.7552069153,TRUE,TRUE,TRUE
-mobilenet v2,torch hub,MobileNetV2,0,TRUE,3475010,a81033ae,Unknown,-,pytorch,,-,-,0.472360152,TRUE,TRUE,TRUE
-mobilenet v3 large,torch hub,MobileNetV3,0,TRUE,5457176,777649,Unknown,-,pytorch,,-,0.4776212777,0.8175223709,TRUE,TRUE,TRUE
-mobilenet v3 small,torch hub,MobileNetV3,0,TRUE,2529712,e7fae853,Unknown,-,pytorch,,-,0.1028122177,0.5063675839,TRUE,TRUE,TRUE
-mobilevit,huggingface pytorch,MobileViTModel,0,FALSE,4913337,47b02614,Unknown,-,pytorch,,-,0.3613737719,0.2915829332,TRUE,TRUE,TRUE
-mobilevit small for semantic segmentation,huggingface pytorch,MobileViTForSemanticSegmentation,0,FALSE,6351130,5621d1d8,Unknown,-,pytorch,,-,0.211294006,0.8895246707,TRUE,TRUE,TRUE
-mobilevit x small for semantic segmentation,huggingface pytorch,MobileViTForSemanticSegmentation,0,TRUE,2938906,f9f29c8e,Unknown,-,pytorch,,-,0.5634407352,0.3229371116,TRUE,TRUE,TRUE
-mobilevit xx small for semantic segmentation,huggingface pytorch,MobileViTForSemanticSegmentation,0,FALSE,1851794,535af098,Unknown,-,pytorch,,-,0.7550517067,0.2948776764,TRUE,TRUE,TRUE
-mpnet,huggingface pytorch,MPNetModel,0,TRUE,109563840,747bb620,Unknown,-,pytorch,,-,0.8518452593,0.01015326091,TRUE,TRUE,TRUE
-mt5 base,huggingface pytorch,MT5Model,0,FALSE,393067559,6a56180f,Unknown,-,pytorch,,-,0.9836297157,0.04919817726,TRUE,TRUE,TRUE
-mt5 encoder,huggingface pytorch,MT5EncoderModel,0,FALSE,147030657,760f744b,Unknown,-,pytorch,,-,0.4275603609,0.4299510572,TRUE,TRUE,TRUE
-mt5 small,huggingface pytorch,MT5Model,0,FALSE,173102451,9625f18b,Unknown,-,pytorch,,-,0.7055376793,0.7080656531,TRUE,TRUE,TRUE
-openai doublehead,huggingface pytorch,OpenAIGPTDoubleHeadsModel,0,FALSE,147248857,a4df98ec,Unknown,-,pytorch,,-,0.05237705151,0.1912185114,TRUE,TRUE,TRUE
-pegasus,huggingface pytorch,PegasusModel,0,FALSE,403947598,b92cca23,Unknown,-,pytorch,,-,0.6718773318,0.2042362347,TRUE,TRUE,TRUE
-perceiver,huggingface pytorch,PerceiverModel,0,FALSE,259427480,a4732115,Unknown,-,pytorch,,-,0.01378931341,0.6404767106,TRUE,TRUE,TRUE
-poolformer,huggingface pytorch,PoolFormerModel,0,TRUE,11371373,a8cfe755,Unknown,-,pytorch,,-,0.953904252,0.6987515128,TRUE,TRUE,TRUE
-rag,huggingface pytorch,RagModel,0,FALSE,455992031,7e502070,Unknown,-,pytorch,,-,0.02073222123,0.2134296424,TRUE,TRUE,TRUE
-realm,huggingface pytorch,RealmEmbedder,0,TRUE,109265344,d9107239,Unknown,-,pytorch,,-,0.8876490931,0.8417903689,TRUE,TRUE,TRUE
-regnet x 16gf,torch hub,RegNet,0,FALSE,54171112,90fe350f,Unknown,-,pytorch,,-,0.9060867619,0.5952690409,TRUE,TRUE,TRUE
-regnet x 1 6gf,torch hub,RegNet,0,FALSE,9148224,9b6af29e,Unknown,-,pytorch,,-,0.396030724,0.9598102533,TRUE,TRUE,TRUE
-regnet x 32gf,torch hub,RegNet,0,FALSE,107654448,2.49E+08,Unknown,-,pytorch,,-,0.4390239856,0.7861245276,TRUE,TRUE,TRUE
-regnet x 3 2gf,torch hub,RegNet,0,FALSE,15235752,731da922,Unknown,-,pytorch,,-,0.8759118734,0.5483516564,TRUE,TRUE,TRUE
-regnet x 400mf,torch hub,RegNet,0,FALSE,5458776,08b8712e,Unknown,-,pytorch,,-,0.3031631823,0.6028060577,TRUE,TRUE,TRUE
-regnet x 800mf,torch hub,RegNet,0,FALSE,7223528,1e12c62e,Unknown,-,pytorch,,-,-,0.07790472777,TRUE,TRUE,TRUE
-regnet x 8gf,torch hub,RegNet,0,FALSE,39485176,26bfacd7,Unknown,-,pytorch,,-,0.3240022857,0.8347574034,TRUE,TRUE,TRUE
-regnet y 128gf,torch hub,RegNet,0,FALSE,644409734,a2a92eba,Unknown,-,pytorch,,-,0.6964564787,0.3659986558,TRUE,TRUE,TRUE
-regnet y 16gf,torch hub,RegNet,0,FALSE,83472284,a44f744c,Unknown,-,pytorch,,-,-,0.2275266886,TRUE,TRUE,TRUE
-regnet y 1 6gf,torch hub,RegNet,0,FALSE,11151182,993181bc,Unknown,-,pytorch,,-,0.2067551592,0.8753373872,TRUE,TRUE,TRUE
-regnet y 32gf,torch hub,RegNet,0,FALSE,144894546,16e3920e,Unknown,-,pytorch,,-,0.2675957934,0.7232885837,TRUE,TRUE,TRUE
-regnet y 3 2gf,torch hub,RegNet,0,FALSE,19372586,a06a50b4,Unknown,-,pytorch,,-,0.5064092553,0.1491575794,TRUE,TRUE,TRUE
-regnet y 400mf,torch hub,RegNet,0,FALSE,4317824,74d9ef17,Unknown,-,pytorch,,-,0.5741097604,0.2428770285,TRUE,TRUE,TRUE
-regnet y 800mf,torch hub,RegNet,0,FALSE,6403424,efe4b887,Unknown,-,pytorch,,-,0.9147002002,0.9654421187,TRUE,TRUE,TRUE
-regnet y 8gf,torch hub,RegNet,0,FALSE,39298560,0c98c39d,Unknown,-,pytorch,,-,0.6723404719,0.560880337,TRUE,TRUE,TRUE
-rembert,huggingface pytorch,RemBertModel,0,FALSE,575380202,1a69d8de,Unknown,-,pytorch,,-,0.8297280138,0.5962212296,TRUE,TRUE,TRUE
-resnet101,torch hub,ResNet,0,TRUE,44447848,285cd579,Unknown,-,pytorch,,-,0.2129516123,0.08371641283,TRUE,TRUE,TRUE
-resnet152,torch hub,ResNet,0,TRUE,60045416,c732f780,Unknown,-,pytorch,,-,0.4378760167,0.1026012754,FALSE,FALSE,FALSE
-resnet18,torch hub,ResNet,0,TRUE,11680872,11f0e9e3,Unknown,-,pytorch,,-,0.791935827,0.1849698907,FALSE,FALSE,FALSE
-resnet34,torch hub,ResNet,0,TRUE,21781608,85df0c4a,Unknown,-,pytorch,,-,0.04500402699,0.8266632657,TRUE,TRUE,TRUE
-resnet50,torch hub,ResNet,0,TRUE,25507944,3ba0a685,Unknown,-,pytorch,,-,0.7112983347,0.04298736488,TRUE,TRUE,TRUE
-resnext101 32x8d,torch hub,ResNet,0,FALSE,88592360,0b88b3d8,Unknown,-,pytorch,,-,0.03848879881,0.994887442,TRUE,TRUE,TRUE
-resnext50 32x4d,torch hub,ResNet,0,FALSE,24964712,ce6f3fb8,Unknown,-,pytorch,,-,0.5413600234,0.0387678123,TRUE,TRUE,TRUE
-retinanet resnet50 fpn,torchvision,RetinaNet,0,FALSE,-,7cc11439,Unknown,-,pytorch,,-,0.5087831001,0.6018223323,TRUE,TRUE,TRUE
-retinanet resnet50 fpn v2,torchvision,RetinaNet,0,FALSE,-,20403119,Unknown,-,pytorch,,-,0.3942354788,0.8491541692,TRUE,TRUE,TRUE
-retribert,huggingface pytorch,RetriBertModel,0,FALSE,81150221,4c3ee101,Unknown,-,pytorch,,-,0.9963637796,0.8654928427,TRUE,TRUE,TRUE
-roberta,huggingface pytorch,RobertaModel,0,TRUE,109461696,f75bf095,Unknown,-,pytorch,,-,0.9524146197,0.5684564905,TRUE,TRUE,TRUE
-roformer,huggingface pytorch,RoFormerModel,0,FALSE,123454241,a48eefbd,Unknown,-,pytorch,,-,0.9332141151,0.5463205479,TRUE,TRUE,TRUE
-safety clipvision,diffusers,CLIPVisionModel,0,FALSE,303180456,bd5ab0a3,Unknown,-,pytorch,,-,0.5250482893,0.9337669034,TRUE,TRUE,TRUE
-segformer,huggingface pytorch,SegformerModel,0,TRUE,3301468,28a23805,Unknown,-,pytorch,,-,0.7589393868,0.1985242188,TRUE,TRUE,TRUE
-shufflenet v2 x0 5,torch hub,ShuffleNetV2,0,TRUE,1360182,15046a84,Unknown,-,pytorch,,-,0.5956733578,0.9538399132,TRUE,TRUE,TRUE
-shufflenet v2 x1 0,torch hub,ShuffleNetV2,0,TRUE,2264028,81185b92,Unknown,-,pytorch,,-,0.06747308314,0.01711918436,TRUE,TRUE,TRUE
-shufflenet v2 x1 5,torch hub,ShuffleNetV2,0,TRUE,3481998,51805568,Unknown,-,pytorch,,-,0.9875014394,0.6526094216,TRUE,TRUE,TRUE
-shufflenet v2 x2 0,torch hub,ShuffleNetV2,0,TRUE,7363356,670c36ac,Unknown,-,pytorch,,-,0.4039454795,0.7854043332,TRUE,TRUE,TRUE
-speech to text,huggingface pytorch,Speech2TextModel,0,TRUE,29738198,fc9ef5d8,Unknown,-,pytorch,,-,-,0.9559223994,TRUE,TRUE,TRUE
-splinter,huggingface pytorch,SplinterModel,0,TRUE,108576957,d8703a6e,Unknown,-,pytorch,,-,0.3054487993,0.5885042017,TRUE,TRUE,TRUE
-squeezebert,huggingface pytorch,SqueezeBertModel,0,FALSE,50775742,c54b2d76,Unknown,-,pytorch,,-,0.921134318,0.9592404522,TRUE,TRUE,TRUE
-squeezenet1 0,torch hub,SqueezeNet,0,TRUE,1246280,8b319b5b,Unknown,-,pytorch,,-,0.2578457753,0.4000357773,TRUE,TRUE,TRUE
-squeezenet1 1,torch hub,SqueezeNet,0,TRUE,1233288,db09563d,Unknown,-,pytorch,,-,0.2250988563,0.4158708533,TRUE,TRUE,TRUE
-ssd300 vgg16,torchvision,SSDFeatureExtractorVGG,0,FALSE,22941893,ba239042,Unknown,-,pytorch,,-,0.7458929344,0.4738323319,TRUE,TRUE,TRUE
-ssdlite320 mobilenet v3 large,torchvision,SSDLiteFeatureExtractorMobileNet,0,TRUE,3531146,0b96e723,Unknown,-,pytorch,,-,0.6098419908,0.5463400678,TRUE,TRUE,TRUE
-ssdlite320 mobilenet v3 large,torchvision,SSD,0,TRUE,3531146,cb077411,Unknown,-,pytorch,,-,0.8067197133,0.1740465113,TRUE,TRUE,TRUE
-swin b,torch hub,SwinTransformer,0,FALSE,88738859,f0e93177,Unknown,-,pytorch,,-,0.9865009173,0.472041202,TRUE,TRUE,TRUE
-swin s,torch hub,SwinTransformer,0,FALSE,50404109,cc85d49e,Unknown,-,pytorch,,-,0.068262417,0.5886588024,TRUE,TRUE,TRUE
-swin t,torch hub,SwinTransformer,0,FALSE,28766603,89de9245,Unknown,-,pytorch,,-,0.1546109902,0.2681236252,TRUE,TRUE,TRUE
-t5 base,huggingface pytorch,T5ForConditionalGeneration,0,FALSE,250330147,ba7c8360,Unknown,-,pytorch,,-,0.9236510639,0.01785099731,TRUE,TRUE,TRUE
-t5 encoder,huggingface pytorch,T5EncoderModel,0,TRUE,35455582,0559914f,Unknown,-,pytorch,,-,-,0.8752267279,TRUE,TRUE,TRUE
-t5 large,huggingface pytorch,T5ForConditionalGeneration,0,FALSE,777382975,47d226ef,Unknown,-,pytorch,,-,0.4080238654,0.3749054603,TRUE,TRUE,TRUE
-t5 small,huggingface pytorch,T5ForConditionalGeneration,0,FALSE,78004501,6f1dd5bb,Unknown,-,pytorch,,-,0.1525429461,0.7361882706,TRUE,TRUE,TRUE
-unet 2d condition,diffusers,UNet2DConditionModel,0,FALSE,2324093279,b6cc8b9c,Unknown,-,pytorch,,-,0.5983132622,0.2389244031,TRUE,TRUE,TRUE
-unet,torch hub,UNet,0,FALSE,7760097,a76ab7f4,Unknown,-,pytorch,,-,0.8406553374,0.8770121489,TRUE,TRUE,TRUE
-vae decoder,diffusers,Decoder,0,FALSE,66269573,d2afe38b,Unknown,-,pytorch,,-,0.7418232932,0.9520182421,TRUE,TRUE,TRUE
-vgg11,torch hub,VGG,0,FALSE,132857448,b38617af,Unknown,-,pytorch,,-,-,0.3415078108,TRUE,TRUE,TRUE
-vgg11 bn,torch hub,VGG,0,FALSE,132857448,8550040,Unknown,-,pytorch,,-,0.7010413771,0.8456017941,TRUE,TRUE,TRUE
-vgg13,torch hub,VGG,0,FALSE,133041768,20ce33fd,Unknown,-,pytorch,,-,0.6163820998,0.04209412966,TRUE,TRUE,TRUE
-vgg13 bn,torch hub,VGG,0,FALSE,133041768,20dffe7e,Unknown,-,pytorch,,-,0.9464643567,0.638601289,TRUE,TRUE,TRUE
-vgg16,torch hub,VGG,0,FALSE,138350184,b628f277,Unknown,-,pytorch,,-,-,0.5968935779,TRUE,TRUE,TRUE
-vgg16 bn,torch hub,VGG,0,FALSE,138350184,8e2b426b,Unknown,-,pytorch,,-,0.8792166923,0.8928780708,TRUE,TRUE,TRUE
-vgg19 bn,torch hub,VGG,0,FALSE,143658600,bc2392e4,Unknown,-,pytorch,,-,0.7572185912,0.9926552952,TRUE,TRUE,TRUE
-vgg19,torch hub,VGG,0,FALSE,143658600,d889f054,Unknown,-,pytorch,,-,0.3934931812,0.3968043966,TRUE,TRUE,TRUE
-vit,huggingface pytorch,ViTModel,0,FALSE,86271166,993623dd,Unknown,-,pytorch,,-,-,0.6170905797,TRUE,TRUE,TRUE
-vit b 16,torch hub,VisionTransformer,0,FALSE,86497183,dd47dfd6,Unknown,-,pytorch,,-,0.8292088496,0.6260986518,TRUE,TRUE,TRUE
-vit b 32,torch hub,VisionTransformer,0,TRUE,88153759,48d88bc1,Unknown,-,pytorch,,-,0.2875190457,0.5313556644,TRUE,TRUE,TRUE
-vit h 14,torch hub,VisionTransformer,0,FALSE,631723703,c682724f,Unknown,-,pytorch,,-,0.05994937867,0.2917794972,TRUE,TRUE,TRUE
-vit l 16,torch hub,VisionTransformer,0,FALSE,304134471,44b6c5a5,Unknown,-,pytorch,,-,0.913958232,0.7577796257,TRUE,TRUE,FALSE
-vit l 32,torch hub,VisionTransformer,0,TRUE,306343239,f137eddc,Unknown,-,pytorch,,-,0.86291783,0.1090945909,TRUE,TRUE,FALSE
-wide resnet101 2,torch hub,ResNet,0,FALSE,126752872,0eb07645,Unknown,-,pytorch,,-,0.2180280201,0.809231502,TRUE,TRUE,FALSE
-wide resnet50 2,torch hub,ResNet,0,TRUE,68819048,fd743f94,Unknown,-,pytorch,,-,0.01156058265,0.6298108722,TRUE,TRUE,FALSE
-xglm,huggingface pytorch,XGLMModel,0,FALSE,566264670,41f01198,Unknown,-,pytorch,,-,0.970793283,0.9201076994,TRUE,TRUE,FALSE
-xlm,huggingface pytorch,XLMModel,0,FALSE,665991453,6918ed2c,Unknown,-,pytorch,,-,0.0619859172,0.4852305783,TRUE,TRUE,FALSE
-xlm roberta,huggingface pytorch,XLMRobertaModel,0,TRUE,109461696,a0532c05,Unknown,-,pytorch,,-,0.04920299384,0.2057686784,TRUE,TRUE,FALSE
-xlnet,huggingface pytorch,XLNetModel,0,FALSE,341121042,5cfcb429,Unknown,-,pytorch,,-,0.937412215,0.04126296064,TRUE,TRUE,FALSE
-yolos tiny for object detection,huggingface pytorch,YolosForObjectDetection,0,FALSE,6488935,8f6a6a55,Unknown,-,pytorch,,-,0.136900365,0.7939705911,TRUE,TRUE,TRUE
\ No newline at end of file
diff --git a/trackers/huggingface/requirements.txt b/trackers/huggingface/requirements.txt
deleted file mode 100644
index fe1fe75b..00000000
--- a/trackers/huggingface/requirements.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-plotly>=5.10.0
-pandas>=1.4.3
-scipy>=1.9.1
-streamlit_echarts
-streamlit_toggle_switch
\ No newline at end of file
diff --git a/trackers/huggingface/streamlit_helpers.py b/trackers/huggingface/streamlit_helpers.py
deleted file mode 100644
index 4f6605e4..00000000
--- a/trackers/huggingface/streamlit_helpers.py
+++ /dev/null
@@ -1,150 +0,0 @@
-from collections import Counter
-from typing import List
-import numpy as np
-import streamlit as st # pylint: disable=import-error
-import pandas as pd
-
-
-class Collapsable:
- """
- Creates a collapsable text composed of a preamble (clickable section of text)
- and epilogue (collapsable text).
- """
-
- def __init__(self, preamble="", epilogue=""):
- self.preamble = preamble
- self.epilogue = epilogue
- self.small_font = 18
- self.large_font = 18
- self.sections = []
-
- def add_section(self, heading, text):
- # Convert text to bullet points if it is a list
- if isinstance(text, list):
- text = (
- ""
- + "".join(
- [
- f'- {x}
'
- for x in text
- ]
- )
- + "
"
- )
-
- # Append section
- self.sections.append((heading, text))
-
- def deploy(self):
-
- secs = "".join(
- [
- (
- ""
- f"{heading}
"
- f"{text}
"
- )
- for heading, text in self.sections
- ]
- )
- collapsable_sec = f"""
-
- {self.preamble}
- {secs}
- {self.epilogue}
-
- """
- st.markdown(collapsable_sec, unsafe_allow_html=True)
-
-
-def add_filter(
- data_frame_list: List[pd.DataFrame],
- name: str,
- label: str,
- options: List[str] = None,
- num_cols: int = 1,
- last_is_others: bool = True,
-):
- """
- Creates a filter on the side bar using checkboxes
- """
-
- # Get list of all options and return if no options are available
- all_options = set(data_frame_list[-1][label])
- if "-" in all_options:
- all_options.remove("-")
- if len(all_options) == 0:
- return data_frame_list
-
- st.markdown(f"#### {name}")
-
- # Create list of options if selectable options are not provided
- if options is None:
- options_dict = Counter(data_frame_list[-1][label])
- sorted_options = sorted(options_dict, key=options_dict.get, reverse=True)
- if "-" in sorted_options:
- sorted_options.remove("-")
- if len(sorted_options) > 8:
- options = list(sorted_options[:7]) + ["others"]
- last_is_others = True
- else:
- options = list(sorted_options)
- last_is_others = False
-
- cols = st.columns(num_cols)
- instantiated_checkbox = []
- for idx in range(len(options)):
- with cols[idx % num_cols]:
- instantiated_checkbox.append(
- st.checkbox(options[idx], False, key=f"{label}_{options[idx]}")
- )
-
- selected_options = [
- options[idx] for idx, checked in enumerate(instantiated_checkbox) if checked
- ]
-
- # The last checkbox will always correspond to "other"
- if instantiated_checkbox[-1] and last_is_others:
- selected_options = selected_options[:-1]
- other_options = [x for x in all_options if x not in options]
- selected_options = set(selected_options + other_options)
-
- if len(selected_options) > 0:
- for idx, _ in enumerate(data_frame_list):
- data_frame_list[idx] = data_frame_list[idx][
- [
- any([x == model_entry for x in selected_options])
- for model_entry in data_frame_list[idx][label]
- ]
- ]
- return data_frame_list
-
-
-def slider_filter(
- data_frame_list: List[pd.DataFrame],
- title: str,
- filter_by: str,
- max_val: int = 1000,
-):
- """
- Creates slider to filter dataframes according to a given label.
- label must be numeric. Values are in millions.
- """
-
- start_val, end_val = st.select_slider(
- title,
- options=[str(x) for x in np.arange(0, max_val + 1, 10, dtype=int)],
- value=("0", str(max_val)),
- )
-
- for idx in range(len(data_frame_list)):
- data_frame_list[idx] = data_frame_list[idx][
- [
- int(model_entry) >= int(start_val) * 1000000
- and int(model_entry) <= int(end_val) * 1000000
- for model_entry in data_frame_list[idx][filter_by]
- ]
- ]
-
- return data_frame_list
diff --git a/trackers/report_plots.py b/trackers/report_plots.py
deleted file mode 100644
index 5d64b7a5..00000000
--- a/trackers/report_plots.py
+++ /dev/null
@@ -1,168 +0,0 @@
-import plotly.graph_objects as go
-import pandas as pd
-import plotly.figure_factory as ff
-import plotly.express as px
-import numpy as np
-
-df = pd.read_csv(r"C:\Users\danie\turnkeyml\models\timm\2023-08-30.csv")
-
-colors = {
- "blue": "#5470c6",
- "orange": "#FF7F0E",
- "green": "#94cc74",
- "saffron_mango": "#fac858",
- "red": "#ee6666",
- "light_blue": "#73c0de",
- "ocean_green": "#3ba272",
-}
-
-
-def throughput_acceleration(df):
- vitisep_results = df[df["runtime"] == "vitisep"]
- ort_results = df[df["runtime"] == "ort"]
- assert len(vitisep_results) == len(ort_results)
- on_ipu = vitisep_results.ipu_compilation_successful.to_numpy()
- ratio = vitisep_results.throughput.to_numpy() / ort_results.throughput.to_numpy()
-
- y0 = [ratio[idx] for idx in range(len(ratio)) if on_ipu[idx] == "True"]
- y1 = [ratio[idx] for idx in range(len(ratio)) if on_ipu[idx] == "False"]
- y2 = np.concatenate([y0, y1])
-
- y0_label = ["Yes"] * len(y0)
- y1_label = ["No"] * len(y1)
- y2_label = y0_label + y1_label
-
- df = pd.DataFrame(
- {
- "graph_name": ["Running on IPU"] * len(y0)
- + ["Fallback to CPU"] * len(y1)
- + ["All models"] * len(y2),
- "value": np.concatenate([y0, y1, y2], 0),
- "Actually running on the IPU?": y0_label + y1_label + y2_label,
- }
- )
-
- fig = px.strip(
- df,
- x="graph_name",
- y="value",
- color="Actually running on the IPU?",
- stripmode="overlay",
- )
-
- fig.add_trace(
- go.Box(
- y=df.query('graph_name == "Running on IPU"')["value"],
- name="Running on IPU",
- marker=dict(opacity=0.1),
- )
- )
- fig.add_trace(
- go.Box(
- y=df.query('graph_name == "Fallback to CPU"')["value"],
- name="Fallback to CPU",
- )
- )
- fig.add_trace(
- go.Box(y=df.query('graph_name == "All models"')["value"], name="All models")
- )
-
- fig.update_layout(
- autosize=False,
- legend={"traceorder": "normal"},
- )
- fig.update_yaxes(title_text="Acceleration compared to OnnxRuntime CPU EP")
- fig.update_xaxes(title_text="")
- fig.show()
-
-
-def parameter_histogram(df: pd.DataFrame) -> None:
- # Add parameters histogram
- all_models = [
- float(x) / 1000000
- for x in df[df["runtime"] == "vitisep"]["parameters"]
- if x != "-"
- ]
-
- hist_data = []
- group_labels = []
-
- if all_models != []:
- hist_data.append(all_models)
- group_labels.append("All models")
-
- if hist_data:
- fig = ff.create_distplot(
- hist_data,
- group_labels,
- bin_size=5,
- histnorm="",
- colors=list(colors.values()),
- curve_type="normal",
- )
- fig.update_layout(showlegend=False)
- fig.layout.update(xaxis_title="Parameters in millions")
- fig.layout.update(yaxis_title="Models inside bin")
- fig.update_xaxes(range=[1, 200])
-
- fig.show()
-
-
-def throughput_plot(df):
- vitisep_results = df[df["runtime"] == "vitisep"]
- ort_results = df[df["runtime"] == "ort"]
-
- fig = go.Figure(
- data=[
- go.Bar(
- name="VitisEP",
- x=vitisep_results.model_name,
- y=vitisep_results.throughput,
- ),
- go.Bar(
- name="OnnxRuntime CPU EP",
- x=ort_results.model_name,
- y=ort_results.throughput,
- ),
- ]
- )
-
- # Set x and y axis labels
- fig.update_layout(barmode="group", xaxis_title="", yaxis_title="Throughput")
- fig.show()
-
-
-def compilation_time(df):
- # Add compilation time histogram
- all_models = [
- float(x)
- for x in df[df["runtime"] == "vitisep"]["ipu_compilation_seconds"]
- if x != "-"
- ]
-
- hist_data = []
- group_labels = []
-
- hist_data.append(all_models)
- group_labels.append("All models")
-
- if hist_data:
- fig = ff.create_distplot(
- hist_data,
- group_labels,
- bin_size=5,
- histnorm="",
- colors=list(colors.values()),
- curve_type="normal",
- )
- fig.update_layout(showlegend=False)
- fig.layout.update(xaxis_title="Compilation time in seconds")
- fig.layout.update(yaxis_title="Models inside bin")
-
- fig.show()
-
-
-parameter_histogram(df)
-throughput_plot(df)
-throughput_acceleration(df)
-compilation_time(df)