diff --git a/lib/NDArray.chpl b/lib/NDArray.chpl index 319f7b16e..720d19cac 100644 --- a/lib/NDArray.chpl +++ b/lib/NDArray.chpl @@ -27,8 +27,7 @@ record ndarray : serializable { forwarding _domain only shape; inline - proc init(type eltType, dom: domainType(?)) - where dom.isRectangular() { + proc init(type eltType, dom: domainType(?)) { this.rank = dom.rank; this.eltType = eltType; this._domain = dom; @@ -36,8 +35,7 @@ record ndarray : serializable { } inline - proc init(type eltType, dom: domainType(?), const in fill: eltType) - where dom.isRectangular() { + proc init(type eltType, dom: domainType(?), const in fill: eltType) { this.rank = dom.rank; this.eltType = eltType; this._domain = dom; @@ -58,13 +56,15 @@ record ndarray : serializable { this.data = arr; } - proc init(type eltType, const shape: ?rank * int) { + proc init(type eltType, shape: ?rank * int) { this.init(eltType,util.domainFromShape((...shape))); } proc init(param rank: int, type eltType = real(32)) { - const shape: rank * int; - this.init(eltType,shape); + this.rank = rank; + this.eltType = eltType; + this._domain = util.emptyDomain(rank); + this.data = noinit; } proc init(type eltType = real(32), const shape: int ...?rank) do @@ -74,8 +74,7 @@ record ndarray : serializable { this.init(eltType,dom); // This could be optimized by refactoring whole init system. proc init(const dom: ?t,type eltType = real(32)) - where isDomainType(t) - && dom.isRectangular() { + where isDomainType(t) { this.init(eltType,dom); } @@ -127,8 +126,9 @@ record ndarray : serializable { _domain = dom; } - proc reshape(dom: domain(rank,int)): ndarray(rank,eltType) { - var arr = new ndarray(eltType,dom,fill=0:eltType); + proc reshape(dom: domainType(?)): ndarray(rank,eltType) + where dom.rank == rank { + var arr = new ndarray(eltType,dom); const arrDom = arr.domain; const selfDom = this.domain; @@ -138,20 +138,20 @@ record ndarray : serializable { } proc reshape(dom: domainType(?)): ndarray(dom.rank,eltType) - where dom.isRectangular() - && dom.rank != rank { + where dom.rank != rank { var arr: ndarray(dom.rank,eltType) = new ndarray(eltType,dom); const selfDom = this.domain; const newDom = arr.domain; + const ref selfData = this.data; ref arrData = arr.data; const zero: eltType = 0; forall (i,meIdx) in newDom.everyZip() { const selfIdx = selfDom.indexAt(i); - const a = if selfDom.contains(selfIdx) then data[selfIdx] else zero; + const a = if selfDom.contains(selfIdx) then selfData[selfIdx] else zero; arrData[meIdx] = a; } return arr; @@ -694,7 +694,7 @@ proc type ndarray.convolve(features: ndarray(3,?eltType),kernel: ndarray(4,eltTy const outHeight: int = ((inHeight - kernelHeight) / stride) + 1; const outWidth: int = ((inWidth - kernelWidth) / stride) + 1; const outShape = (filters,outHeight,outWidth); - const outDom: rect(3) = outShape; + const outDom = util.domainFromShape((...outShape)); var outFeatures = new ndarray(outDom,eltType); const chanR = 0..