diff --git a/lib/Autograd.chpl b/lib/Autograd.chpl index 6ede4668a..34d2308b9 100644 --- a/lib/Autograd.chpl +++ b/lib/Autograd.chpl @@ -132,6 +132,7 @@ class TensorResource : BaseTensorResource(?), serializable { override proc forward() { if operationData.type != baseValue { + on dataResource.device { // ref data = dataResource.access().data; // data = operationData.forward(); diff --git a/lib/DynamicTensor.chpl b/lib/DynamicTensor.chpl index 22e79258b..3864548a8 100644 --- a/lib/DynamicTensor.chpl +++ b/lib/DynamicTensor.chpl @@ -47,9 +47,9 @@ record Tensor : serializable { this.meta = u.meta; this.runtimeRank = u.meta.runtimeRank; } else { - this.meta = t.meta; - this.runtimeRank = t.meta.runtimeRank; - } + this.meta = t.meta; + this.runtimeRank = t.meta.runtimeRank; + } } proc init(a: ndarray(?rank,?eltType)) do diff --git a/lib/StaticTensor.chpl b/lib/StaticTensor.chpl index 7d9c30e67..f4186616d 100644 --- a/lib/StaticTensor.chpl +++ b/lib/StaticTensor.chpl @@ -23,11 +23,10 @@ record tensor : serializable { this.resource = new shared TensorResource(rank,eltType,baseValue); } - proc init(resource: shared BaseTensorResource(?eltType,?rank), param strict: bool = false) { + proc init(resource: shared BaseTensorResource(?eltType,?rank)) { this.rank = rank; this.eltType = eltType; this.resource = resource; - if strict then resource.forward(); } proc init(nda: ndarray(?rank,?eltType)) { @@ -94,14 +93,12 @@ proc tensorFromCtx(param rank: int, type eltType, ctx): tensor(rank,eltType) { operator +(a: tensor(?rank,?eltType), b: tensor(rank,eltType)) { var ctx = new addOp(rank,eltType,a.meta,b.meta); - var newMeta = new shared TensorResource(rank,eltType,ctx); - return new tensor(newMeta, strict = true); + return tensorFromCtx(rank,eltType,ctx); } operator -(a: tensor(?rank,?eltType), b: tensor(rank,eltType)) { var ctx = new subOp(a.meta,b.meta); - var newMeta = new shared TensorResource(rank,eltType,ctx); - return new tensor(newMeta, strict = true); + return tensorFromCtx(rank,eltType,ctx); } operator *(a: tensor(?rank,?eltType), b: tensor(rank,eltType)) {