Skip to content

Commit

Permalink
Fixed NewArrayLiteralFromAny to also create scalars if dimensions is …
Browse files Browse the repository at this point in the history
…empty.
  • Loading branch information
janpfeifer committed Jul 17, 2024
1 parent bd08d5f commit 9611862
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 4 deletions.
1 change: 1 addition & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* Added `xlabuilder.Op.Builder()`
* Added comments support to op_types.txt and added comments to several of the operations.
* Renamed `xlabuilder.BatchNorm{Inference,Training}` to `xlabuilder.BatchNormFor{Inference,Training}`
* Fixed `NewArrayLiteralFromAny` to also accept scalar values, if dimensions is empty.

# v0.1.2 SuppressAbseilLoggingHack

Expand Down
4 changes: 0 additions & 4 deletions xlabuilder/literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ func NewScalarLiteralFromAny(value any) *Literal {
}

// NewArrayLiteralFromAny creates a scalar Literal with the given dynamically typed flat values and its underlying dimensions.
// If dimensions is omitted, it's assumed to be a 1D tensor with dimension matching the length of flat.
// It uses reflection to inspect the type.
func NewArrayLiteralFromAny(flatAny any, dimensions ...int) *Literal {
flatV := reflect.ValueOf(flatAny)
Expand All @@ -107,9 +106,6 @@ func NewArrayLiteralFromAny(flatAny any, dimensions ...int) *Literal {
if dtype == dtypes.InvalidDType {
exceptions.Panicf("NewArrayLiteralFromAny expects a slice of valid DTypes, got %T instead", flatAny)
}
if len(dimensions) == 0 {
dimensions = []int{flatV.Len()}
}
shape := MakeShape(dtype, dimensions...)
if shape.Size() != flatV.Len() {
exceptions.Panicf("NewArrayLiteralFromAny got a slice of length %d, but the shape %s given has %d elements",
Expand Down

0 comments on commit 9611862

Please sign in to comment.