From 9611862ae7fa539638628a47a10d07b99f0b6ce2 Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Wed, 17 Jul 2024 08:52:13 +0200 Subject: [PATCH] Fixed NewArrayLiteralFromAny to also create scalars if dimensions is empty. --- docs/CHANGELOG.md | 1 + xlabuilder/literal.go | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index b9d11e4..68dbd34 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -17,6 +17,7 @@ * Added `xlabuilder.Op.Builder()` * Added comments support to op_types.txt and added comments to several of the operations. * Renamed `xlabuilder.BatchNorm{Inference,Training}` to `xlabuilder.BatchNormFor{Inference,Training}` +* Fixed `NewArrayLiteralFromAny` to also accept scalar values, if dimensions is empty. # v0.1.2 SuppressAbseilLoggingHack diff --git a/xlabuilder/literal.go b/xlabuilder/literal.go index 7743d5f..8d52743 100644 --- a/xlabuilder/literal.go +++ b/xlabuilder/literal.go @@ -96,7 +96,6 @@ func NewScalarLiteralFromAny(value any) *Literal { } // NewArrayLiteralFromAny creates a scalar Literal with the given dynamically typed flat values and its underlying dimensions. -// If dimensions is omitted, it's assumed to be a 1D tensor with dimension matching the length of flat. // It uses reflection to inspect the type. func NewArrayLiteralFromAny(flatAny any, dimensions ...int) *Literal { flatV := reflect.ValueOf(flatAny) @@ -107,9 +106,6 @@ func NewArrayLiteralFromAny(flatAny any, dimensions ...int) *Literal { if dtype == dtypes.InvalidDType { exceptions.Panicf("NewArrayLiteralFromAny expects a slice of valid DTypes, got %T instead", flatAny) } - if len(dimensions) == 0 { - dimensions = []int{flatV.Len()} - } shape := MakeShape(dtype, dimensions...) if shape.Size() != flatV.Len() { exceptions.Panicf("NewArrayLiteralFromAny got a slice of length %d, but the shape %s given has %d elements",