diff --git a/src/turnkeyml/analyze/model.py b/src/turnkeyml/analyze/model.py index 671efc82..5590f9ca 100644 --- a/src/turnkeyml/analyze/model.py +++ b/src/turnkeyml/analyze/model.py @@ -25,7 +25,7 @@ def count_parameters(model: torch.nn.Module, model_type: build.ModelType) -> int onnx_model = onnx.load(model) return int( sum( - np.prod(tensor.dims) + np.prod(tensor.dims, dtype=np.int64) for tensor in onnx_model.graph.initializer if tensor.name not in onnx_model.graph.input )