Skip to content

Commit

Permalink
Improve TorchRT (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremyfowers authored Feb 26, 2024
1 parent 44a59dc commit e92b760
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 24 deletions.
11 changes: 9 additions & 2 deletions src/turnkeyml/analyze/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -800,14 +800,21 @@ def forward_spy(*args, **kwargs):
and invocation_info.is_target
and (model_info.build_model)
):
# Disable all modifications while we evalute the model
# This is needed in case a tool called during evaluation wants to
# trace the model. There are some scenarios (e.g., ipex.quantization.prepare),
# that raise an exception when they encounter forward_spy()
local_var.forward = old_forward

explore_invocation(
model_inputs=[args, kwargs],
model_info=model_info,
invocation_info=invocation_info,
tracer_args=tracer_args,
)
# Ensure that explore_invocation() doesn't interfere with our execution count
model_info.executed = 1

# Re-enable modifications
local_var.forward = forward_spy

build_name = fs.get_build_name(
tracer_args.script_name,
Expand Down
3 changes: 2 additions & 1 deletion src/turnkeyml/run/basert.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(
f"supports runtimes: {runtimes_supported}"
)

os.makedirs(self.local_output_dir, exist_ok=True)

self._setup()

def posix_path_format(self, path) -> str:
Expand Down Expand Up @@ -140,7 +142,6 @@ def _transfer_files(self, files_to_transfer: List[str]):
files_to_transfer: absolute paths to files
"""

os.makedirs(self.local_output_dir, exist_ok=True)
for file in files_to_transfer:
shutil.copy(
file, os.path.join(self.local_output_dir, os.path.basename(file))
Expand Down
146 changes: 126 additions & 20 deletions src/turnkeyml/run/torchrt/runtime.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from typing import Dict, Any
import logging
from typing import Dict, Any, List, Optional
from statistics import mean
import time
from packaging import version
Expand All @@ -11,6 +12,7 @@
import turnkeyml.build.ignition as ignition
import turnkeyml.common.build as build
import turnkeyml.common.exceptions as exp
import turnkeyml.common.filesystem as fs
from turnkeyml.common.filesystem import Stats


Expand All @@ -26,34 +28,48 @@ def __init__(
model: torch.nn.Module,
inputs: Dict[str, Any],
tensor_type=np.array,
runtimes_supported: Optional[List[str]] = None,
runtime_version: str = str(torch.__version__),
):
# Torch Dynamo is pretty verbose with its warnings,
# so we set the logging level to ERROR
torch._logging.set_logs(dynamo=logging.ERROR)

self.throughput_ips = None
self.mean_latency_ms = None

# Allow children of this class to pass different values than
# the defaults for torch-eager and torch-compiled
if runtimes_supported:
init_runtimes_supported = runtimes_supported
else:
init_runtimes_supported = ["torch-eager", "torch-compiled"]

super().__init__(
cache_dir=cache_dir,
build_name=build_name,
stats=stats,
device_type=device_type,
runtime=runtime,
iterations=iterations,
runtimes_supported=["torch-eager", "torch-compiled"],
runtime_version=str(torch.__version__),
runtimes_supported=init_runtimes_supported,
runtime_version=runtime_version,
base_path=os.path.dirname(__file__),
tensor_type=tensor_type,
model=model,
inputs=inputs,
)

def _setup(self) -> None:
# Ensure we have the correct model type
model_type = ignition.identify_model_type(self.model)
if model_type != build.ModelType.PYTORCH:
raise exp.IntakeError(
f"Only Pytorch models are valid when runtime is {self.runtime}"
)
def _compile(self) -> None:
"""
Perform any requested compilation actions on the PyTorch model.
Note: This method is expected to be overloaded by most children of
this class.
"""

self.model.eval()

# Compile the model
if self.runtime == "torch-compiled":
# First ensure we have the required version of Pytorch
clean_torch_version = self.runtime_version.split("+")[0]
Expand All @@ -67,18 +83,42 @@ def _setup(self) -> None:

self.model = torch.compile(self.model)

def benchmark(self) -> MeasuredPerformance:
per_iteration_latency = [0] * self.iterations
for idx in range(self.iterations):
start_time = time.perf_counter()
self.model(**self.inputs)
end_time = time.perf_counter()
per_iteration_latency[idx] = end_time - start_time
def _setup(self) -> None:
"""
Validate the parameters of this class and invoke compilation.
Note: the implementation of this method is intentionally generic. Any
runtime-specific actions should go into the _compile() method if
possible.
"""

# Ensure we have the correct model type
model_type = ignition.identify_model_type(self.model)
if model_type != build.ModelType.PYTORCH:
raise exp.IntakeError(
f"Only Pytorch models are valid when runtime is {self.runtime}"
)

# Compile the
start_time = time.perf_counter()
with build.Logger("Preparing torch model", self.logfile_path):
self._compile()
end_time = time.perf_counter()
total_time = end_time - start_time

self.stats.save_model_eval_stat("torch_compilation_seconds", total_time)

def _calculate_performance(
self, per_iteration_latency: List[float]
) -> MeasuredPerformance:
"""
Calculate performance statistics from the per-iteration latencies
acquired during execution.
"""

# Calculate perf from per_iteration_latency
self.mean_latency_ms = mean(per_iteration_latency) * 1000
self.throughput_ips = float(
1 / (np.sum(per_iteration_latency) / self.iterations)
1 / (np.sum(per_iteration_latency) / len(per_iteration_latency))
)

return MeasuredPerformance(
Expand All @@ -91,6 +131,72 @@ def benchmark(self) -> MeasuredPerformance:
build_name=self.build_name,
)

def _run_model(self, iterations: int, time_limit: int) -> List[float]:
"""
Run the model repeatedly, collecting the performance of each
iteration. Stop running when the iterations target or time limit
is reached, whichever comes first.
Note: this method is intended to be useful in the following ways:
1. Generic across any child class of TorchRT
2. Useful for both cache warmup and benchmarking
"""

counter = 0
total_time = 0
per_iteration_latency = []

while counter < iterations and total_time < time_limit:
start_time = time.perf_counter()
self.model(**self.inputs)
end_time = time.perf_counter()
total_time = total_time + end_time - start_time
counter = counter + 1
per_iteration_latency.append(end_time - start_time)

return per_iteration_latency

def _benchmark_inner(self) -> MeasuredPerformance:
"""
The logic for benchmarking a torch model to collect performance data.
This method is meant to be called by the benchmark() method, which is
why it is named _benchmark_inner().
Note: this method is intended to be generic across any child class
of TorchRT.
"""

# Cache warmup for 1 minute or 10 iterations, whichever
# comes first
self._run_model(iterations=10, time_limit=60)

# Run the benchmark for the specified amount of iterations,
# or 2 minutes, whichever comes first
per_iteration_latency = self._run_model(
iterations=self.iterations, time_limit=120
)

# Record the number of iterations actually used for the benchmark,
# which will be less than the `iterations` argument if the time
# limit was reached
self.stats.save_model_eval_stat(fs.Keys.ITERATIONS, len(per_iteration_latency))

return self._calculate_performance(per_iteration_latency)

def benchmark(self) -> MeasuredPerformance:
"""
Wrapper function for self._benchmark_inner()
The reason this wrapper exists is to allow plugin developers to apply various
settings to execution on a per-runtime basis. For example, selectively
enabling torch.no_grad().
Note: it is expected that most child classes of TorchRT will overload
this method.
"""
with torch.no_grad():
return self._benchmark_inner()

@property
def mean_latency(self) -> float:
if self.mean_latency_ms is not None:
Expand Down
2 changes: 1 addition & 1 deletion src/turnkeyml/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.1.2"
__version__ = "1.1.3"

0 comments on commit e92b760

Please sign in to comment.