From e5245369e9ba585aecbb57c9aafa03e9efb5f589 Mon Sep 17 00:00:00 2001 From: Jeremy Fowers Date: Thu, 29 Feb 2024 14:27:17 -0500 Subject: [PATCH 1/5] draft of solution --- src/turnkeyml/analyze/script.py | 23 +++++------------------ src/turnkeyml/analyze/status.py | 20 ++++++++++++++++++++ src/turnkeyml/files_api.py | 28 ++++++++++++++++------------ 3 files changed, 41 insertions(+), 30 deletions(-) diff --git a/src/turnkeyml/analyze/script.py b/src/turnkeyml/analyze/script.py index f75962f5..ddca7d39 100644 --- a/src/turnkeyml/analyze/script.py +++ b/src/turnkeyml/analyze/script.py @@ -761,24 +761,11 @@ def forward_spy(*args, **kwargs): model_info = tracer_args.models_found[model_hash] if invocation_hash not in model_info.unique_invocations: - model_info.unique_invocations[invocation_hash] = ( - status.UniqueInvocationInfo( - name=model_info.name, - script_name=model_info.script_name, - file=model_info.file, - line=model_info.line, - params=model_info.params, - depth=model_info.depth, - build_model=model_info.build_model, - model_type=model_info.model_type, - model_class=type(model_info.model), - invocation_hash=invocation_hash, - hash=model_info.hash, - is_target=invocation_hash in tracer_args.targets - or len(tracer_args.targets) == 0, - input_shapes=input_shapes, - parent_hash=parent_invocation_hash, - ) + model_info.add_unique_invocation( + invocation_hash, + tracer_args.targets, + input_shapes, + parent_invocation_hash, ) model_info.last_unique_invocation_executed = invocation_hash diff --git a/src/turnkeyml/analyze/status.py b/src/turnkeyml/analyze/status.py index 2abc9888..b1d3dd9a 100644 --- a/src/turnkeyml/analyze/status.py +++ b/src/turnkeyml/analyze/status.py @@ -327,6 +327,26 @@ class ModelInfo(BasicInfo): def __post_init__(self): self.params = analyze_model.count_parameters(self.model, self.model_type) + def add_unique_invocation( + self, invocation_hash: int, targets, input_shapes, parent_invocation_hash + ): + unique_invocations[invocation_hash] = UniqueInvocationInfo( + name=self.name, + script_name=self.script_name, + file=self.file, + line=self.line, + params=self.params, + depth=self.depth, + build_model=self.build_model, + model_type=self.model_type, + model_class=type(self.model), + invocation_hash=invocation_hash, + hash=self.hash, + is_target=invocation_hash in targets or len(targets) == 0, + input_shapes=input_shapes, + parent_hash=parent_invocation_hash, + ) + def update( models_found: Dict[str, ModelInfo], diff --git a/src/turnkeyml/files_api.py b/src/turnkeyml/files_api.py index 4ee56d91..784db25a 100644 --- a/src/turnkeyml/files_api.py +++ b/src/turnkeyml/files_api.py @@ -380,17 +380,17 @@ def benchmark_files( # - is_target=True is required or else traceback wont be printed for # in the event of any errors # - Most other values can be left as default - invocation_info = UniqueInvocationInfo( - name=onnx_name, - script_name=onnx_name, - file=file_path_absolute, - build_model=not build_only, - model_type=build.ModelType.ONNX_FILE, - executed=1, - input_shapes=input_shapes, - hash=onnx_hash, - is_target=True, - ) + # invocation_info = UniqueInvocationInfo( + # name=onnx_name, + # script_name=onnx_name, + # file=file_path_absolute, + # build_model=not build_only, + # model_type=build.ModelType.ONNX_FILE, + # executed=1, + # input_shapes=input_shapes, + # hash=onnx_hash, + # is_target=True, + # ) # Create the ModelInfo model_info = ModelInfo( @@ -400,9 +400,13 @@ def benchmark_files( file=file_path_absolute, build_model=not build_only, model_type=build.ModelType.ONNX_FILE, - unique_invocations={onnx_hash: invocation_info}, + # unique_invocations={onnx_hash: invocation_info}, hash=onnx_hash, ) + model_info.add_unique_invocation() + + # THIS IS BAD + # invocation_info.params = model_info.params # Begin evaluating the ONNX model tracer_args.script_name = onnx_name From 4498358d6f94d73a09ca30f89f9717b846d6ad9a Mon Sep 17 00:00:00 2001 From: Holanda Noronha Date: Thu, 29 Feb 2024 16:08:48 -0800 Subject: [PATCH 2/5] Complete implementation - Failing status print --- src/turnkeyml/analyze/script.py | 9 +++++---- src/turnkeyml/analyze/status.py | 12 ++++++++---- src/turnkeyml/files_api.py | 11 +++++++---- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/src/turnkeyml/analyze/script.py b/src/turnkeyml/analyze/script.py index ddca7d39..891e0c86 100644 --- a/src/turnkeyml/analyze/script.py +++ b/src/turnkeyml/analyze/script.py @@ -762,10 +762,11 @@ def forward_spy(*args, **kwargs): if invocation_hash not in model_info.unique_invocations: model_info.add_unique_invocation( - invocation_hash, - tracer_args.targets, - input_shapes, - parent_invocation_hash, + invocation_hash=invocation_hash, + is_target=invocation_hash in tracer_args.targets + or len(tracer_args.targets) == 0, + input_shapes=input_shapes, + parent_hash=parent_invocation_hash, ) model_info.last_unique_invocation_executed = invocation_hash diff --git a/src/turnkeyml/analyze/status.py b/src/turnkeyml/analyze/status.py index b1d3dd9a..19557196 100644 --- a/src/turnkeyml/analyze/status.py +++ b/src/turnkeyml/analyze/status.py @@ -328,9 +328,13 @@ def __post_init__(self): self.params = analyze_model.count_parameters(self.model, self.model_type) def add_unique_invocation( - self, invocation_hash: int, targets, input_shapes, parent_invocation_hash + self, + invocation_hash: int, + is_target: bool, + input_shapes: Dict, + parent_hash: Union[str, None] = None, ): - unique_invocations[invocation_hash] = UniqueInvocationInfo( + self.unique_invocations[invocation_hash] = UniqueInvocationInfo( name=self.name, script_name=self.script_name, file=self.file, @@ -342,9 +346,9 @@ def add_unique_invocation( model_class=type(self.model), invocation_hash=invocation_hash, hash=self.hash, - is_target=invocation_hash in targets or len(targets) == 0, + is_target=is_target, input_shapes=input_shapes, - parent_hash=parent_invocation_hash, + parent_hash=parent_hash, ) diff --git a/src/turnkeyml/files_api.py b/src/turnkeyml/files_api.py index 784db25a..91f4c4be 100644 --- a/src/turnkeyml/files_api.py +++ b/src/turnkeyml/files_api.py @@ -380,6 +380,7 @@ def benchmark_files( # - is_target=True is required or else traceback wont be printed for # in the event of any errors # - Most other values can be left as default + # invocation_info = UniqueInvocationInfo( # name=onnx_name, # script_name=onnx_name, @@ -400,10 +401,12 @@ def benchmark_files( file=file_path_absolute, build_model=not build_only, model_type=build.ModelType.ONNX_FILE, - # unique_invocations={onnx_hash: invocation_info}, - hash=onnx_hash, ) - model_info.add_unique_invocation() + model_info.add_unique_invocation( + invocation_hash=onnx_hash, + is_target=True, + input_shapes=input_shapes, + ) # THIS IS BAD # invocation_info.params = model_info.params @@ -414,7 +417,7 @@ def benchmark_files( explore_invocation( model_inputs=onnx_inputs, model_info=model_info, - invocation_info=invocation_info, + invocation_info=model_info.unique_invocations[onnx_hash], tracer_args=tracer_args, ) models_found = tracer_args.models_found From 34ea5052cffd2be6f9165ebdbaf9dcaf871ea83f Mon Sep 17 00:00:00 2001 From: Holanda Noronha Date: Thu, 29 Feb 2024 16:38:35 -0800 Subject: [PATCH 3/5] Solved issue --- src/turnkeyml/analyze/status.py | 2 ++ src/turnkeyml/files_api.py | 30 ++++++++---------------------- 2 files changed, 10 insertions(+), 22 deletions(-) diff --git a/src/turnkeyml/analyze/status.py b/src/turnkeyml/analyze/status.py index 19557196..3b91e5c7 100644 --- a/src/turnkeyml/analyze/status.py +++ b/src/turnkeyml/analyze/status.py @@ -333,6 +333,7 @@ def add_unique_invocation( is_target: bool, input_shapes: Dict, parent_hash: Union[str, None] = None, + executed: int = 0, ): self.unique_invocations[invocation_hash] = UniqueInvocationInfo( name=self.name, @@ -349,6 +350,7 @@ def add_unique_invocation( is_target=is_target, input_shapes=input_shapes, parent_hash=parent_hash, + executed=executed, ) diff --git a/src/turnkeyml/files_api.py b/src/turnkeyml/files_api.py index 91f4c4be..9222fb26 100644 --- a/src/turnkeyml/files_api.py +++ b/src/turnkeyml/files_api.py @@ -374,25 +374,6 @@ def benchmark_files( onnx_inputs = onnx_helpers.dummy_inputs(file_path_absolute) input_shapes = {key: value.shape for key, value in onnx_inputs.items()} - # Create the UniqueInvocationInfo - # - execute=1 is required or else the ONNX model will be - # skipped in later stages of evaluation - # - is_target=True is required or else traceback wont be printed for - # in the event of any errors - # - Most other values can be left as default - - # invocation_info = UniqueInvocationInfo( - # name=onnx_name, - # script_name=onnx_name, - # file=file_path_absolute, - # build_model=not build_only, - # model_type=build.ModelType.ONNX_FILE, - # executed=1, - # input_shapes=input_shapes, - # hash=onnx_hash, - # is_target=True, - # ) - # Create the ModelInfo model_info = ModelInfo( model=file_path_absolute, @@ -401,16 +382,21 @@ def benchmark_files( file=file_path_absolute, build_model=not build_only, model_type=build.ModelType.ONNX_FILE, + hash=onnx_hash, ) + + # Add UniqueInvocationInfo + # - is_target=True is required or else traceback wont be printed for + # in the event of any errors + # - execute=1 is required or else the ONNX model will be + # skipped in later stages of evaluation model_info.add_unique_invocation( invocation_hash=onnx_hash, is_target=True, input_shapes=input_shapes, + executed=1, ) - # THIS IS BAD - # invocation_info.params = model_info.params - # Begin evaluating the ONNX model tracer_args.script_name = onnx_name tracer_args.models_found[tracer_args.script_name] = model_info From 40e0efdb5f0f1ca4c7fd90b2e4408e50e1e44d55 Mon Sep 17 00:00:00 2001 From: Holanda Noronha Date: Thu, 29 Feb 2024 16:52:10 -0800 Subject: [PATCH 4/5] Improvements --- src/turnkeyml/analyze/status.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/turnkeyml/analyze/status.py b/src/turnkeyml/analyze/status.py index 3b91e5c7..c31cdb35 100644 --- a/src/turnkeyml/analyze/status.py +++ b/src/turnkeyml/analyze/status.py @@ -335,6 +335,9 @@ def add_unique_invocation( parent_hash: Union[str, None] = None, executed: int = 0, ): + model_class = ( + type(self.model) if self.model_type == build.ModelType.PYTORCH else None + ) self.unique_invocations[invocation_hash] = UniqueInvocationInfo( name=self.name, script_name=self.script_name, @@ -344,7 +347,7 @@ def add_unique_invocation( depth=self.depth, build_model=self.build_model, model_type=self.model_type, - model_class=type(self.model), + model_class=model_class, invocation_hash=invocation_hash, hash=self.hash, is_target=is_target, From b4aa3d7196157922562d665e5a5bfee5cc87629f Mon Sep 17 00:00:00 2001 From: Holanda Noronha Date: Thu, 29 Feb 2024 17:05:24 -0800 Subject: [PATCH 5/5] lint --- src/turnkeyml/files_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/turnkeyml/files_api.py b/src/turnkeyml/files_api.py index 9222fb26..211de65a 100644 --- a/src/turnkeyml/files_api.py +++ b/src/turnkeyml/files_api.py @@ -20,7 +20,7 @@ explore_invocation, get_model_hash, ) -from turnkeyml.analyze.status import ModelInfo, UniqueInvocationInfo, Verbosity +from turnkeyml.analyze.status import ModelInfo, Verbosity import turnkeyml.common.build as build import turnkeyml.build.onnx_helpers as onnx_helpers