Skip to content

Commit

Permalink
Add tensor history erasure via detach param
Browse files Browse the repository at this point in the history
  • Loading branch information
Iainmon committed Aug 6, 2024
1 parent 45c1f9b commit 8469065
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 6 deletions.
5 changes: 4 additions & 1 deletion ModuleSpec.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ use Network;

import Time;

config const detach = true;

Tensor.detachMode(detach);

// Construct the model from specification.
var model: owned Module(real) = modelFromSpecFile("scripts/models/cnn/specification.json");
Expand Down Expand Up @@ -48,4 +51,4 @@ if printResults {
}
}

writeln("The average inference time for batch of size ", numImages, " was ", time, " seconds.");
writeln("The average inference time for batch of size ", numImages, " was ", time, " seconds.");
9 changes: 9 additions & 0 deletions lib/Autograd.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ class TensorResource : BaseTensorResource(?), serializable {
this.operationData = operationData;
}

proc init(tr: shared BaseTensorResource(?eltType,?rank),param forget: bool) where forget == true {
super.init(eltType,rank,tr.dataResource,new remote(ndarray(rank,eltType)));
// super.rank = rank;
// super.dataResource = tr.dataResource;
// super.gradResource = new remote(ndarray(rank,eltType));
this.operation = baseValue;
this.operationData = new baseValue();
}

// proc init(param rank: int, type eltType, operationData: ?operation, device: locale) {
// var res = new remote(ndarray(rank,eltType),device);
// this.init(res,operationData,device);
Expand Down
29 changes: 25 additions & 4 deletions lib/DynamicTensor.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ config param maxRank = 6;

import LoadNumpy;

param defaultDetachedMode = false;

record Tensor : serializable {
type eltType = real;
Expand All @@ -39,11 +40,17 @@ record Tensor : serializable {
this.runtimeRank = meta.runtimeRank;
}

proc init(t: tensor(?rank,?eltType)) {
proc init(t: tensor(?rank,?eltType), detached: bool = Tensor.detachMode()) {
this.eltType = eltType;
if detached {
var u = t.detach();
this.meta = u.meta;
this.runtimeRank = u.meta.runtimeRank;
} else {
this.meta = t.meta;
this.runtimeRank = t.meta.runtimeRank;
}
}

proc init(a: ndarray(?rank,?eltType)) do
this.init(new tensor(a));
Expand Down Expand Up @@ -74,7 +81,7 @@ record Tensor : serializable {
}

proc forceRank(param rank: int): tensor(rank,eltType) do
return new tensor(meta : shared BaseTensorResource(eltType,rank),strict=false);
return new tensor(meta : shared BaseTensorResource(eltType,rank));

proc forceRankMeta(param rank: int): shared BaseTensorResource(eltType,rank) do
return meta : shared BaseTensorResource(eltType,rank);
Expand Down Expand Up @@ -131,13 +138,27 @@ record Tensor : serializable {
proc toArray(param rank: int) : [] eltType do
return toNDArray(rank).data;

proc detach(): Tensor(eltType) {
for param rank in 1..maxRank do
if checkRank(rank) then
return tensorize(rank).detach().eraseRank();
halt("Could not identify rank for this: ", this);
}
}

proc type Tensor.detachMode() param : bool {
return defaultDetachedMode;
}

proc type Tensor.detachMode(detachMode: bool) {
// defaultDetachedMode = detachMode;
}

inline proc ndarray.toTensor(): Tensor(eltType) do
return new Tensor(this);

proc tensor.eraseRank(): Tensor(eltType) do
return new Tensor(this);
proc tensor.eraseRank(detach: bool = Tensor.detachMode()): Tensor(eltType) do
return new Tensor(this,detach);

operator :(t: tensor(?rank,?eltType), type T: Tensor(eltType)): Tensor(eltType) do
return t.eraseRank();
Expand Down
10 changes: 9 additions & 1 deletion lib/StaticTensor.chpl
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,19 @@ record tensor : serializable {
dat = devVal;
}
}

proc detach(): tensor(rank,eltType) {
if var tr = this.meta : borrowed TensorResource(eltType,rank,baseValue)? then
return this;
else
return new tensor(new shared TensorResource(this.resource,forget = true));
}
}

proc tensorFromCtx(param rank: int, type eltType, ctx): tensor(rank,eltType) {
var newMeta = new shared TensorResource(rank,eltType,ctx);
return new tensor(newMeta, strict = true);
newMeta.forward();
return new tensor(newMeta);
}


Expand Down

0 comments on commit 8469065

Please sign in to comment.