Skip to content

Commit

Permalink
Fix performance issue caused by simple domain.
Browse files Browse the repository at this point in the history
  • Loading branch information
Iainmon committed Aug 31, 2024
1 parent 6aefbcf commit b6f26d2
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 21 deletions.
32 changes: 16 additions & 16 deletions lib/NDArray.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,15 @@ record ndarray : serializable {
forwarding _domain only shape;

inline
proc init(type eltType, dom: domainType(?))
where dom.isRectangular() {
proc init(type eltType, dom: domainType(?)) {
this.rank = dom.rank;
this.eltType = eltType;
this._domain = dom;
this.data = noinit;
}

inline
proc init(type eltType, dom: domainType(?), const in fill: eltType)
where dom.isRectangular() {
proc init(type eltType, dom: domainType(?), const in fill: eltType) {
this.rank = dom.rank;
this.eltType = eltType;
this._domain = dom;
Expand All @@ -58,13 +56,15 @@ record ndarray : serializable {
this.data = arr;
}

proc init(type eltType, const shape: ?rank * int) {
proc init(type eltType, shape: ?rank * int) {
this.init(eltType,util.domainFromShape((...shape)));
}

proc init(param rank: int, type eltType = real(32)) {
const shape: rank * int;
this.init(eltType,shape);
this.rank = rank;
this.eltType = eltType;
this._domain = util.emptyDomain(rank);
this.data = noinit;
}

proc init(type eltType = real(32), const shape: int ...?rank) do
Expand All @@ -74,8 +74,7 @@ record ndarray : serializable {
this.init(eltType,dom); // This could be optimized by refactoring whole init system.

proc init(const dom: ?t,type eltType = real(32))
where isDomainType(t)
&& dom.isRectangular() {
where isDomainType(t) {
this.init(eltType,dom);
}

Expand Down Expand Up @@ -127,8 +126,9 @@ record ndarray : serializable {
_domain = dom;
}

proc reshape(dom: domain(rank,int)): ndarray(rank,eltType) {
var arr = new ndarray(eltType,dom,fill=0:eltType);
proc reshape(dom: domainType(?)): ndarray(rank,eltType)
where dom.rank == rank {
var arr = new ndarray(eltType,dom);
const arrDom = arr.domain;
const selfDom = this.domain;

Expand All @@ -138,20 +138,20 @@ record ndarray : serializable {
}

proc reshape(dom: domainType(?)): ndarray(dom.rank,eltType)
where dom.isRectangular()
&& dom.rank != rank {
where dom.rank != rank {

var arr: ndarray(dom.rank,eltType) = new ndarray(eltType,dom);

const selfDom = this.domain;
const newDom = arr.domain;
const ref selfData = this.data;
ref arrData = arr.data;

const zero: eltType = 0;

forall (i,meIdx) in newDom.everyZip() {
const selfIdx = selfDom.indexAt(i);
const a = if selfDom.contains(selfIdx) then data[selfIdx] else zero;
const a = if selfDom.contains(selfIdx) then selfData[selfIdx] else zero;
arrData[meIdx] = a;
}
return arr;
Expand Down Expand Up @@ -694,7 +694,7 @@ proc type ndarray.convolve(features: ndarray(3,?eltType),kernel: ndarray(4,eltTy
const outHeight: int = ((inHeight - kernelHeight) / stride) + 1;
const outWidth: int = ((inWidth - kernelWidth) / stride) + 1;
const outShape = (filters,outHeight,outWidth);
const outDom: rect(3) = outShape;
const outDom = util.domainFromShape((...outShape));
var outFeatures = new ndarray(outDom,eltType);

const chanR = 0..<channels; // don't trust daniel's codemotion.
Expand All @@ -706,7 +706,7 @@ proc type ndarray.convolve(features: ndarray(3,?eltType),kernel: ndarray(4,eltTy
ref ker = kernel.data;

// @assertOnGpu
forall (f,h_,w_) in outDom {
forall (f,h_,w_) in outDom.every() {
const hi: int = h_ * stride;
const wi: int = w_ * stride;
var sum: eltType = 0;
Expand Down
8 changes: 4 additions & 4 deletions lib/Utilities.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ module Utilities {
// tup(i) = v;
// }

inline proc _domain.simple() const : rect(rank) do
/*inline proc _domain.simple() const : rect(rank) do
return new rect(this);
inline iter _domain.every() {
Expand All @@ -390,10 +390,10 @@ module Utilities {
where tag == iterKind.standalone {
const simple = this.simple();
foreach idx in simple.eachOrder() do yield idx;
}
}*/



/*
inline iter _domain.each {
const shape = this.shape;
var prod = 1;
Expand Down Expand Up @@ -715,7 +715,7 @@ module Utilities {
}
}
}
}*/
}

inline proc _domain.indexAt(n: int) where rank == 1 {
return n;
Expand Down
28 changes: 27 additions & 1 deletion simpleDomainTest.chpl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use SimpleDomain;
use Utilities.Standard;

import Utilities as util;

proc compare(dom:domain(?),d: rect(?rank)) where dom.rank == rank {
inline proc compareI(i) do
Expand Down Expand Up @@ -86,3 +86,29 @@ fo(1,2,3);
// writeln(__primitive("field by num",t,i+1).type:string);
// }

proc domainType(param rank: int) type do
return DefaultRectangularDom(rank,int,strideKind.one);

{
const d = util.domainFromShape(2,3,4);
writeln(d.type:string);
writeln(d._instance.type:string);
writeln(d.distribution.type:string);
writeln(d._instance.type:string);
writeln(d._value.type:string);
}

{
const d = util.domainFromShape(2,3,4);
writeln(d._instance.type:string);
writeln(domainType(3):string);

writeln(domainType(3) == d.type);
const x: domain(3,int) = util.domainFromShape(2,3,4);
writeln(x.type:string);
writeln(x.type == d.type);

const y: _domain(?) = x;
writeln(x.type == y.type);

}

0 comments on commit b6f26d2

Please sign in to comment.