From 8469065597c5dd8a08fd3327e5a4d3ddd0ba7472 Mon Sep 17 00:00:00 2001 From: Iain Moncrief Date: Tue, 6 Aug 2024 11:53:54 -1000 Subject: [PATCH] Add tensor history erasure via detach param --- ModuleSpec.chpl | 5 ++++- lib/Autograd.chpl | 9 +++++++++ lib/DynamicTensor.chpl | 29 +++++++++++++++++++++++++---- lib/StaticTensor.chpl | 10 +++++++++- 4 files changed, 47 insertions(+), 6 deletions(-) diff --git a/ModuleSpec.chpl b/ModuleSpec.chpl index d73e4a90b..11fa112f4 100644 --- a/ModuleSpec.chpl +++ b/ModuleSpec.chpl @@ -4,6 +4,9 @@ use Network; import Time; +config const detach = true; + +Tensor.detachMode(detach); // Construct the model from specification. var model: owned Module(real) = modelFromSpecFile("scripts/models/cnn/specification.json"); @@ -48,4 +51,4 @@ if printResults { } } -writeln("The average inference time for batch of size ", numImages, " was ", time, " seconds."); \ No newline at end of file +writeln("The average inference time for batch of size ", numImages, " was ", time, " seconds."); diff --git a/lib/Autograd.chpl b/lib/Autograd.chpl index 4bf54afc0..6ede4668a 100644 --- a/lib/Autograd.chpl +++ b/lib/Autograd.chpl @@ -108,6 +108,15 @@ class TensorResource : BaseTensorResource(?), serializable { this.operationData = operationData; } + proc init(tr: shared BaseTensorResource(?eltType,?rank),param forget: bool) where forget == true { + super.init(eltType,rank,tr.dataResource,new remote(ndarray(rank,eltType))); + // super.rank = rank; + // super.dataResource = tr.dataResource; + // super.gradResource = new remote(ndarray(rank,eltType)); + this.operation = baseValue; + this.operationData = new baseValue(); + } + // proc init(param rank: int, type eltType, operationData: ?operation, device: locale) { // var res = new remote(ndarray(rank,eltType),device); // this.init(res,operationData,device); diff --git a/lib/DynamicTensor.chpl b/lib/DynamicTensor.chpl index b722a0e0e..22e79258b 100644 --- a/lib/DynamicTensor.chpl +++ b/lib/DynamicTensor.chpl @@ -13,6 +13,7 @@ config param maxRank = 6; import LoadNumpy; +param defaultDetachedMode = false; record Tensor : serializable { type eltType = real; @@ -39,11 +40,17 @@ record Tensor : serializable { this.runtimeRank = meta.runtimeRank; } - proc init(t: tensor(?rank,?eltType)) { + proc init(t: tensor(?rank,?eltType), detached: bool = Tensor.detachMode()) { this.eltType = eltType; + if detached { + var u = t.detach(); + this.meta = u.meta; + this.runtimeRank = u.meta.runtimeRank; + } else { this.meta = t.meta; this.runtimeRank = t.meta.runtimeRank; } + } proc init(a: ndarray(?rank,?eltType)) do this.init(new tensor(a)); @@ -74,7 +81,7 @@ record Tensor : serializable { } proc forceRank(param rank: int): tensor(rank,eltType) do - return new tensor(meta : shared BaseTensorResource(eltType,rank),strict=false); + return new tensor(meta : shared BaseTensorResource(eltType,rank)); proc forceRankMeta(param rank: int): shared BaseTensorResource(eltType,rank) do return meta : shared BaseTensorResource(eltType,rank); @@ -131,13 +138,27 @@ record Tensor : serializable { proc toArray(param rank: int) : [] eltType do return toNDArray(rank).data; + proc detach(): Tensor(eltType) { + for param rank in 1..maxRank do + if checkRank(rank) then + return tensorize(rank).detach().eraseRank(); + halt("Could not identify rank for this: ", this); + } +} + +proc type Tensor.detachMode() param : bool { + return defaultDetachedMode; +} + +proc type Tensor.detachMode(detachMode: bool) { + // defaultDetachedMode = detachMode; } inline proc ndarray.toTensor(): Tensor(eltType) do return new Tensor(this); -proc tensor.eraseRank(): Tensor(eltType) do - return new Tensor(this); +proc tensor.eraseRank(detach: bool = Tensor.detachMode()): Tensor(eltType) do + return new Tensor(this,detach); operator :(t: tensor(?rank,?eltType), type T: Tensor(eltType)): Tensor(eltType) do return t.eraseRank(); diff --git a/lib/StaticTensor.chpl b/lib/StaticTensor.chpl index e43486762..7d9c30e67 100644 --- a/lib/StaticTensor.chpl +++ b/lib/StaticTensor.chpl @@ -76,11 +76,19 @@ record tensor : serializable { dat = devVal; } } + + proc detach(): tensor(rank,eltType) { + if var tr = this.meta : borrowed TensorResource(eltType,rank,baseValue)? then + return this; + else + return new tensor(new shared TensorResource(this.resource,forget = true)); + } } proc tensorFromCtx(param rank: int, type eltType, ctx): tensor(rank,eltType) { var newMeta = new shared TensorResource(rank,eltType,ctx); - return new tensor(newMeta, strict = true); + newMeta.forward(); + return new tensor(newMeta); }