Skip to content

Commit

Permalink
Remove strict param from StaticTensor module.
Browse files Browse the repository at this point in the history
  • Loading branch information
Iainmon committed Aug 6, 2024
1 parent 8469065 commit 21c6d54
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
1 change: 1 addition & 0 deletions lib/Autograd.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ class TensorResource : BaseTensorResource(?), serializable {

override proc forward() {
if operationData.type != baseValue {

on dataResource.device {
// ref data = dataResource.access().data;
// data = operationData.forward();
Expand Down
6 changes: 3 additions & 3 deletions lib/DynamicTensor.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ record Tensor : serializable {
this.meta = u.meta;
this.runtimeRank = u.meta.runtimeRank;
} else {
this.meta = t.meta;
this.runtimeRank = t.meta.runtimeRank;
}
this.meta = t.meta;
this.runtimeRank = t.meta.runtimeRank;
}
}

proc init(a: ndarray(?rank,?eltType)) do
Expand Down
9 changes: 3 additions & 6 deletions lib/StaticTensor.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ record tensor : serializable {
this.resource = new shared TensorResource(rank,eltType,baseValue);
}

proc init(resource: shared BaseTensorResource(?eltType,?rank), param strict: bool = false) {
proc init(resource: shared BaseTensorResource(?eltType,?rank)) {
this.rank = rank;
this.eltType = eltType;
this.resource = resource;
if strict then resource.forward();
}

proc init(nda: ndarray(?rank,?eltType)) {
Expand Down Expand Up @@ -94,14 +93,12 @@ proc tensorFromCtx(param rank: int, type eltType, ctx): tensor(rank,eltType) {

operator +(a: tensor(?rank,?eltType), b: tensor(rank,eltType)) {
var ctx = new addOp(rank,eltType,a.meta,b.meta);
var newMeta = new shared TensorResource(rank,eltType,ctx);
return new tensor(newMeta, strict = true);
return tensorFromCtx(rank,eltType,ctx);
}

operator -(a: tensor(?rank,?eltType), b: tensor(rank,eltType)) {
var ctx = new subOp(a.meta,b.meta);
var newMeta = new shared TensorResource(rank,eltType,ctx);
return new tensor(newMeta, strict = true);
return tensorFromCtx(rank,eltType,ctx);
}

operator *(a: tensor(?rank,?eltType), b: tensor(rank,eltType)) {
Expand Down

0 comments on commit 21c6d54

Please sign in to comment.