From 293e903ba2e3ee41d2667a76d875aeb7f49e8934 Mon Sep 17 00:00:00 2001 From: Daniel Holanda Date: Fri, 1 Mar 2024 10:15:40 -0800 Subject: [PATCH] Avoid parameter overflow when counting parameters of very large models (#126) --- src/turnkeyml/analyze/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/turnkeyml/analyze/model.py b/src/turnkeyml/analyze/model.py index 671efc82..5590f9ca 100644 --- a/src/turnkeyml/analyze/model.py +++ b/src/turnkeyml/analyze/model.py @@ -25,7 +25,7 @@ def count_parameters(model: torch.nn.Module, model_type: build.ModelType) -> int onnx_model = onnx.load(model) return int( sum( - np.prod(tensor.dims) + np.prod(tensor.dims, dtype=np.int64) for tensor in onnx_model.graph.initializer if tensor.name not in onnx_model.graph.input )