From e92b760396f9bf385d10d31c59b4955d393b9118 Mon Sep 17 00:00:00 2001 From: Jeremy Fowers <80718789+jeremyfowers@users.noreply.github.com> Date: Mon, 26 Feb 2024 14:14:55 -0500 Subject: [PATCH] Improve TorchRT (#118) --- src/turnkeyml/analyze/script.py | 11 +- src/turnkeyml/run/basert.py | 3 +- src/turnkeyml/run/torchrt/runtime.py | 146 +++++++++++++++++++++++---- src/turnkeyml/version.py | 2 +- 4 files changed, 138 insertions(+), 24 deletions(-) diff --git a/src/turnkeyml/analyze/script.py b/src/turnkeyml/analyze/script.py index f8af2cbc..f75962f5 100644 --- a/src/turnkeyml/analyze/script.py +++ b/src/turnkeyml/analyze/script.py @@ -800,14 +800,21 @@ def forward_spy(*args, **kwargs): and invocation_info.is_target and (model_info.build_model) ): + # Disable all modifications while we evalute the model + # This is needed in case a tool called during evaluation wants to + # trace the model. There are some scenarios (e.g., ipex.quantization.prepare), + # that raise an exception when they encounter forward_spy() + local_var.forward = old_forward + explore_invocation( model_inputs=[args, kwargs], model_info=model_info, invocation_info=invocation_info, tracer_args=tracer_args, ) - # Ensure that explore_invocation() doesn't interfere with our execution count - model_info.executed = 1 + + # Re-enable modifications + local_var.forward = forward_spy build_name = fs.get_build_name( tracer_args.script_name, diff --git a/src/turnkeyml/run/basert.py b/src/turnkeyml/run/basert.py index 8c1dc6e6..e3229a1d 100644 --- a/src/turnkeyml/run/basert.py +++ b/src/turnkeyml/run/basert.py @@ -87,6 +87,8 @@ def __init__( f"supports runtimes: {runtimes_supported}" ) + os.makedirs(self.local_output_dir, exist_ok=True) + self._setup() def posix_path_format(self, path) -> str: @@ -140,7 +142,6 @@ def _transfer_files(self, files_to_transfer: List[str]): files_to_transfer: absolute paths to files """ - os.makedirs(self.local_output_dir, exist_ok=True) for file in files_to_transfer: shutil.copy( file, os.path.join(self.local_output_dir, os.path.basename(file)) diff --git a/src/turnkeyml/run/torchrt/runtime.py b/src/turnkeyml/run/torchrt/runtime.py index d9212861..8324c135 100644 --- a/src/turnkeyml/run/torchrt/runtime.py +++ b/src/turnkeyml/run/torchrt/runtime.py @@ -1,5 +1,6 @@ import os -from typing import Dict, Any +import logging +from typing import Dict, Any, List, Optional from statistics import mean import time from packaging import version @@ -11,6 +12,7 @@ import turnkeyml.build.ignition as ignition import turnkeyml.common.build as build import turnkeyml.common.exceptions as exp +import turnkeyml.common.filesystem as fs from turnkeyml.common.filesystem import Stats @@ -26,10 +28,23 @@ def __init__( model: torch.nn.Module, inputs: Dict[str, Any], tensor_type=np.array, + runtimes_supported: Optional[List[str]] = None, + runtime_version: str = str(torch.__version__), ): + # Torch Dynamo is pretty verbose with its warnings, + # so we set the logging level to ERROR + torch._logging.set_logs(dynamo=logging.ERROR) + self.throughput_ips = None self.mean_latency_ms = None + # Allow children of this class to pass different values than + # the defaults for torch-eager and torch-compiled + if runtimes_supported: + init_runtimes_supported = runtimes_supported + else: + init_runtimes_supported = ["torch-eager", "torch-compiled"] + super().__init__( cache_dir=cache_dir, build_name=build_name, @@ -37,23 +52,24 @@ def __init__( device_type=device_type, runtime=runtime, iterations=iterations, - runtimes_supported=["torch-eager", "torch-compiled"], - runtime_version=str(torch.__version__), + runtimes_supported=init_runtimes_supported, + runtime_version=runtime_version, base_path=os.path.dirname(__file__), tensor_type=tensor_type, model=model, inputs=inputs, ) - def _setup(self) -> None: - # Ensure we have the correct model type - model_type = ignition.identify_model_type(self.model) - if model_type != build.ModelType.PYTORCH: - raise exp.IntakeError( - f"Only Pytorch models are valid when runtime is {self.runtime}" - ) + def _compile(self) -> None: + """ + Perform any requested compilation actions on the PyTorch model. + + Note: This method is expected to be overloaded by most children of + this class. + """ + + self.model.eval() - # Compile the model if self.runtime == "torch-compiled": # First ensure we have the required version of Pytorch clean_torch_version = self.runtime_version.split("+")[0] @@ -67,18 +83,42 @@ def _setup(self) -> None: self.model = torch.compile(self.model) - def benchmark(self) -> MeasuredPerformance: - per_iteration_latency = [0] * self.iterations - for idx in range(self.iterations): - start_time = time.perf_counter() - self.model(**self.inputs) - end_time = time.perf_counter() - per_iteration_latency[idx] = end_time - start_time + def _setup(self) -> None: + """ + Validate the parameters of this class and invoke compilation. + + Note: the implementation of this method is intentionally generic. Any + runtime-specific actions should go into the _compile() method if + possible. + """ + + # Ensure we have the correct model type + model_type = ignition.identify_model_type(self.model) + if model_type != build.ModelType.PYTORCH: + raise exp.IntakeError( + f"Only Pytorch models are valid when runtime is {self.runtime}" + ) + + # Compile the + start_time = time.perf_counter() + with build.Logger("Preparing torch model", self.logfile_path): + self._compile() + end_time = time.perf_counter() + total_time = end_time - start_time + + self.stats.save_model_eval_stat("torch_compilation_seconds", total_time) + + def _calculate_performance( + self, per_iteration_latency: List[float] + ) -> MeasuredPerformance: + """ + Calculate performance statistics from the per-iteration latencies + acquired during execution. + """ - # Calculate perf from per_iteration_latency self.mean_latency_ms = mean(per_iteration_latency) * 1000 self.throughput_ips = float( - 1 / (np.sum(per_iteration_latency) / self.iterations) + 1 / (np.sum(per_iteration_latency) / len(per_iteration_latency)) ) return MeasuredPerformance( @@ -91,6 +131,72 @@ def benchmark(self) -> MeasuredPerformance: build_name=self.build_name, ) + def _run_model(self, iterations: int, time_limit: int) -> List[float]: + """ + Run the model repeatedly, collecting the performance of each + iteration. Stop running when the iterations target or time limit + is reached, whichever comes first. + + Note: this method is intended to be useful in the following ways: + 1. Generic across any child class of TorchRT + 2. Useful for both cache warmup and benchmarking + """ + + counter = 0 + total_time = 0 + per_iteration_latency = [] + + while counter < iterations and total_time < time_limit: + start_time = time.perf_counter() + self.model(**self.inputs) + end_time = time.perf_counter() + total_time = total_time + end_time - start_time + counter = counter + 1 + per_iteration_latency.append(end_time - start_time) + + return per_iteration_latency + + def _benchmark_inner(self) -> MeasuredPerformance: + """ + The logic for benchmarking a torch model to collect performance data. + This method is meant to be called by the benchmark() method, which is + why it is named _benchmark_inner(). + + Note: this method is intended to be generic across any child class + of TorchRT. + """ + + # Cache warmup for 1 minute or 10 iterations, whichever + # comes first + self._run_model(iterations=10, time_limit=60) + + # Run the benchmark for the specified amount of iterations, + # or 2 minutes, whichever comes first + per_iteration_latency = self._run_model( + iterations=self.iterations, time_limit=120 + ) + + # Record the number of iterations actually used for the benchmark, + # which will be less than the `iterations` argument if the time + # limit was reached + self.stats.save_model_eval_stat(fs.Keys.ITERATIONS, len(per_iteration_latency)) + + return self._calculate_performance(per_iteration_latency) + + def benchmark(self) -> MeasuredPerformance: + """ + Wrapper function for self._benchmark_inner() + + The reason this wrapper exists is to allow plugin developers to apply various + settings to execution on a per-runtime basis. For example, selectively + enabling torch.no_grad(). + + Note: it is expected that most child classes of TorchRT will overload + this method. + """ + with torch.no_grad(): + return self._benchmark_inner() + @property def mean_latency(self) -> float: if self.mean_latency_ms is not None: diff --git a/src/turnkeyml/version.py b/src/turnkeyml/version.py index 72f26f59..0b2f79db 100644 --- a/src/turnkeyml/version.py +++ b/src/turnkeyml/version.py @@ -1 +1 @@ -__version__ = "1.1.2" +__version__ = "1.1.3"