From 494bbb0ca1274241067117668c4e534ee4b4e5a9 Mon Sep 17 00:00:00 2001 From: Paolo Di Tommaso Date: Sun, 5 Jan 2025 20:09:18 +0800 Subject: [PATCH] Task provenance - poc #2 Signed-off-by: Paolo Di Tommaso --- .../src/main/groovy/nextflow/Session.groovy | 9 +- .../groovy/nextflow/extension/BranchOp.groovy | 2 +- .../nextflow/extension/CollectFileOp.groovy | 12 +- .../nextflow/extension/CollectOp.groovy | 7 +- .../groovy/nextflow/extension/ConcatOp.groovy | 4 +- .../nextflow/extension/DataflowHelper.groovy | 132 ++++++++--- .../nextflow/extension/GroupTupleOp.groovy | 2 +- .../groovy/nextflow/extension/MapOp.groovy | 9 +- .../main/groovy/nextflow/extension/Op.groovy | 166 ++++++++++++++ .../nextflow/extension/OperatorImpl.groovy | 15 +- .../groovy/nextflow/processor/TaskId.groovy | 6 + .../nextflow/processor/TaskProcessor.groovy | 36 +-- .../groovy/nextflow/processor/TaskRun.groovy | 7 +- .../groovy/nextflow/prov/OperatorRun.groovy | 40 ++++ .../src/main/groovy/nextflow/prov/Prov.groovy | 46 ++++ .../main/groovy/nextflow/prov/Tracker.groovy | 162 ++++++++++++++ .../main/groovy/nextflow/prov/TrailRun.groovy | 25 +++ .../nextflow/provenance/ProvTracker.groovy | 92 -------- .../nextflow/extension/ConcatOpTest.groovy | 2 +- .../groovy/nextflow/extension/OpTest.groovy | 47 ++++ .../nextflow/extension/UntilManyOpTest.groovy | 4 +- .../test/groovy/nextflow/prov/ProvTest.groovy | 206 ++++++++++++++++++ .../groovy/nextflow/prov/TrackerTest.groovy | 123 +++++++++++ .../provenance/ProvTrackerTest.groovy | 94 -------- .../testFixtures/groovy/test/Dsl2Spec.groovy | 5 +- .../groovy/test/MockHelpers.groovy | 2 +- .../groovy/test/TestHelper.groovy | 36 +++ 27 files changed, 1024 insertions(+), 267 deletions(-) create mode 100644 modules/nextflow/src/main/groovy/nextflow/extension/Op.groovy create mode 100644 modules/nextflow/src/main/groovy/nextflow/prov/OperatorRun.groovy create mode 100644 modules/nextflow/src/main/groovy/nextflow/prov/Prov.groovy create mode 100644 modules/nextflow/src/main/groovy/nextflow/prov/Tracker.groovy create mode 100644 modules/nextflow/src/main/groovy/nextflow/prov/TrailRun.groovy delete mode 100644 modules/nextflow/src/main/groovy/nextflow/provenance/ProvTracker.groovy create mode 100644 modules/nextflow/src/test/groovy/nextflow/extension/OpTest.groovy create mode 100644 modules/nextflow/src/test/groovy/nextflow/prov/ProvTest.groovy create mode 100644 modules/nextflow/src/test/groovy/nextflow/prov/TrackerTest.groovy delete mode 100644 modules/nextflow/src/test/groovy/nextflow/provenance/ProvTrackerTest.groovy diff --git a/modules/nextflow/src/main/groovy/nextflow/Session.groovy b/modules/nextflow/src/main/groovy/nextflow/Session.groovy index 5d91cfe6a6..5e4c3dac0a 100644 --- a/modules/nextflow/src/main/groovy/nextflow/Session.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/Session.groovy @@ -53,7 +53,7 @@ import nextflow.processor.ErrorStrategy import nextflow.processor.TaskFault import nextflow.processor.TaskHandler import nextflow.processor.TaskProcessor -import nextflow.provenance.ProvTracker +import nextflow.prov.Tracker import nextflow.script.BaseScript import nextflow.script.ProcessConfig import nextflow.script.ProcessFactory @@ -224,8 +224,6 @@ class Session implements ISession { private DAG dag - private ProvTracker provenance - private CacheDB cache private Barrier processesBarrier = new Barrier() @@ -384,9 +382,6 @@ class Session implements ISession { // -- DAG object this.dag = new DAG() - // -- create the provenance tracker - this.provenance = new ProvTracker() - // -- init output dir this.outputDir = FileHelper.toCanonicalPath(config.outputDir ?: 'results') @@ -862,7 +857,7 @@ class Session implements ISession { DAG getDag() { this.dag } - ProvTracker getProvenance() { provenance } + Tracker getProvenance() { provenance } ExecutorService getExecService() { execService } diff --git a/modules/nextflow/src/main/groovy/nextflow/extension/BranchOp.groovy b/modules/nextflow/src/main/groovy/nextflow/extension/BranchOp.groovy index 7fc5067620..868db7131f 100644 --- a/modules/nextflow/src/main/groovy/nextflow/extension/BranchOp.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/extension/BranchOp.groovy @@ -53,7 +53,7 @@ class BranchOp { protected void doNext(it) { TokenBranchChoice ret = switchDef.closure.call(it) if( ret ) { - targets[ret.choice].bind(ret.value) + Op.bind(targets[ret.choice], ret.value) } } diff --git a/modules/nextflow/src/main/groovy/nextflow/extension/CollectFileOp.groovy b/modules/nextflow/src/main/groovy/nextflow/extension/CollectFileOp.groovy index 6f3dcd2b75..4b8c181d50 100644 --- a/modules/nextflow/src/main/groovy/nextflow/extension/CollectFileOp.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/extension/CollectFileOp.groovy @@ -16,6 +16,9 @@ package nextflow.extension +import static nextflow.util.CacheHelper.* +import static nextflow.util.CheckHelper.* + import java.nio.file.Path import groovy.util.logging.Slf4j @@ -28,8 +31,6 @@ import nextflow.file.FileHelper import nextflow.file.SimpleFileCollector import nextflow.file.SortFileCollector import nextflow.util.CacheHelper -import static nextflow.util.CacheHelper.HashMode -import static nextflow.util.CheckHelper.checkParams /** * Implements the body of {@link OperatorImpl#collectFile(groovyx.gpars.dataflow.DataflowReadChannel)} operator * @@ -185,10 +186,10 @@ class CollectFileOp { protected emitItems( obj ) { // emit collected files to 'result' channel collector.saveTo(storeDir).each { - result.bind(it) + Op.bind(result,it) } // close the channel - result.bind(Channel.STOP) + Op.bind(result,Channel.STOP) // close the collector collector.safeClose() } @@ -261,9 +262,8 @@ class CollectFileOp { return collector } - DataflowWriteChannel apply() { - DataflowHelper.subscribeImpl( channel, [onNext: this.&processItem, onComplete: this.&emitItems] ) + DataflowHelper.subscribeImpl( channel, true, [onNext: this.&processItem, onComplete: this.&emitItems] ) return result } } diff --git a/modules/nextflow/src/main/groovy/nextflow/extension/CollectOp.groovy b/modules/nextflow/src/main/groovy/nextflow/extension/CollectOp.groovy index 9d7b02558f..71cdf5aedb 100644 --- a/modules/nextflow/src/main/groovy/nextflow/extension/CollectOp.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/extension/CollectOp.groovy @@ -16,7 +16,7 @@ package nextflow.extension -import static nextflow.util.CheckHelper.checkParams +import static nextflow.util.CheckHelper.* import groovy.transform.CompileStatic import groovyx.gpars.dataflow.DataflowReadChannel @@ -55,7 +55,10 @@ class CollectOp { Map events = [:] events.onNext = { append(result, it) } - events.onComplete = { target << ( result ? new ArrayBag(normalise(result)) : Channel.STOP ) } + events.onComplete = { + final msg = result ? new ArrayBag(normalise(result)) : Channel.STOP + Op.bind(target, msg) + } DataflowHelper.subscribeImpl(source, events) return target diff --git a/modules/nextflow/src/main/groovy/nextflow/extension/ConcatOp.groovy b/modules/nextflow/src/main/groovy/nextflow/extension/ConcatOp.groovy index f21a34d76b..34e9ceca38 100644 --- a/modules/nextflow/src/main/groovy/nextflow/extension/ConcatOp.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/extension/ConcatOp.groovy @@ -56,10 +56,10 @@ class ConcatOp { def next = index < channels.size() ? channels[index] : null def events = new HashMap(2) - events.onNext = { result.bind(it) } + events.onNext = { Op.bind(result, it) } events.onComplete = { if(next) append(result, channels, index) - else result.bind(Channel.STOP) + else Op.bind(result, Channel.STOP) } DataflowHelper.subscribeImpl(current, events) diff --git a/modules/nextflow/src/main/groovy/nextflow/extension/DataflowHelper.groovy b/modules/nextflow/src/main/groovy/nextflow/extension/DataflowHelper.groovy index 14c775b04e..4543a89710 100644 --- a/modules/nextflow/src/main/groovy/nextflow/extension/DataflowHelper.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/extension/DataflowHelper.groovy @@ -36,7 +36,6 @@ import nextflow.Channel import nextflow.Global import nextflow.Session import nextflow.dag.NodeMarker -import static java.util.Arrays.asList /** * This class provides helper methods to implement nextflow operators * @@ -45,6 +44,70 @@ import static java.util.Arrays.asList @Slf4j class DataflowHelper { + static class OpParams { + List inputs + List outputs + List listeners + boolean accumulator + + OpParams() { } + + OpParams(Map params) { + this.inputs = params.inputs as List ?: List.of() + this.outputs = params.outputs as List ?: List.of() + this.listeners = params.listeners as List ?: List.of() + } + + OpParams withInput(DataflowReadChannel channel) { + assert channel != null + this.inputs = List.of(channel) + return this + } + + OpParams withInputs(List channels) { + assert channels != null + this.inputs = channels + return this + } + + OpParams withOutput(DataflowWriteChannel channel) { + assert channel != null + this.outputs = List.of(channel) + return this + } + + OpParams withOutputs(List channels) { + assert channels != null + this.outputs = channels + return this + } + + OpParams withListener(DataflowEventListener listener) { + assert listener != null + this.listeners = List.of(listener) + return this + } + + OpParams withListeners(List listeners) { + assert listeners != null + this.listeners = listeners + return this + } + + OpParams withAccumulator(boolean acc) { + this.accumulator = acc + return this + } + + Map toMap() { + final ret = new HashMap() + ret.inputs = inputs ?: List.of() + ret.outputs = outputs ?: List.of() + ret.listeners = listeners ?: List.of() + return ret + } + } + private static Session getSession() { Global.getSession() as Session } /** @@ -141,6 +204,7 @@ class DataflowHelper { * @param params The map holding inputs, outputs channels and other parameters * @param code The closure to be executed by the operator */ + @Deprecated static DataflowProcessor newOperator( Map params, Closure code ) { // -- add a default error listener @@ -149,13 +213,13 @@ class DataflowHelper { params.listeners = [ DEF_ERROR_LISTENER ] } - final op = Dataflow.operator(params, code) - NodeMarker.appendOperator(op) - if( session && session.allOperators != null ) { - session.allOperators.add(op) - } + return newOperator0(new OpParams(params), code) + } - return op + static DataflowProcessor newOperator( OpParams params, Closure code ) { + if( !params.listeners ) + params.withListener(DEF_ERROR_LISTENER) + return newOperator0(params, code) } /** @@ -195,16 +259,25 @@ class DataflowHelper { * @param code The closure to be executed by the operator */ static DataflowProcessor newOperator( DataflowReadChannel input, DataflowWriteChannel output, DataflowEventListener listener, Closure code ) { - if( !listener ) listener = DEF_ERROR_LISTENER - def params = [:] + final params = [:] params.inputs = [input] params.outputs = [output] params.listeners = [listener] - final op = Dataflow.operator(params, code) + return newOperator0(new OpParams(params), code) + } + + static private DataflowProcessor newOperator0(OpParams params, Closure code) { + assert params + assert params.inputs + assert params.listeners + + // create the underlying dataflow operator + final op = Dataflow.operator(params.toMap(), Op.instrument(code, params.accumulator)) + // track the operator as dag node NodeMarker.appendOperator(op) if( session && session.allOperators != null ) { session.allOperators << op @@ -236,14 +309,11 @@ class DataflowHelper { } - /** - * Subscribe *onNext*, *onError* and *onComplete* - * - * @param source - * @param closure - * @return - */ static final DataflowProcessor subscribeImpl(final DataflowReadChannel source, final Map events ) { + subscribeImpl(source, false, events) + } + + static final DataflowProcessor subscribeImpl(final DataflowReadChannel source, final boolean accumulator, final Map events ) { checkSubscribeHandlers(events) def error = false @@ -276,13 +346,12 @@ class DataflowHelper { } } + final params = new OpParams() + .withInput(source) + .withListener(listener) + .withAccumulator(accumulator) - final Map parameters = new HashMap(); - parameters.put("inputs", [source]) - parameters.put("outputs", []) - parameters.put('listeners', [listener]) - - newOperator (parameters) { + newOperator (params) { if( events.onNext ) { events.onNext.call(it) } @@ -292,7 +361,7 @@ class DataflowHelper { } } - + @Deprecated static DataflowProcessor chainImpl(final DataflowReadChannel source, final DataflowWriteChannel target, final Map params, final Closure closure) { final Map parameters = new HashMap(params) @@ -302,6 +371,10 @@ class DataflowHelper { newOperator(parameters, new ChainWithClosure(closure)) } + static DataflowProcessor chainImpl(OpParams params, final Closure closure) { + newOperator(params, new ChainWithClosure(closure)) + } + /** * Implements the {@code #reduce} operator * @@ -321,7 +394,7 @@ class DataflowHelper { * call the passed closure each time */ void afterRun(final DataflowProcessor processor, final List messages) { - final item = messages.get(0) + final item = Op.unwrap(messages).get(0) final value = accum == null ? item : closure.call(accum, item) if( value == Channel.VOID ) { @@ -339,7 +412,7 @@ class DataflowHelper { * when terminates bind the result value */ void afterStop(final DataflowProcessor processor) { - result.bind(accum) + Op.bind(result, accum) } boolean onException(final DataflowProcessor processor, final Throwable e) { @@ -349,7 +422,12 @@ class DataflowHelper { } } - chainImpl(channel, CH.create(), [listeners: [listener]], {true}) + final params = new OpParams() + .withInput(channel) + .withOutput(CH.create()) + .withListener(listener) + .withAccumulator(true) + chainImpl(params, {true}) } @PackageScope diff --git a/modules/nextflow/src/main/groovy/nextflow/extension/GroupTupleOp.groovy b/modules/nextflow/src/main/groovy/nextflow/extension/GroupTupleOp.groovy index 660ae6cf63..ddde2b825b 100644 --- a/modules/nextflow/src/main/groovy/nextflow/extension/GroupTupleOp.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/extension/GroupTupleOp.groovy @@ -224,7 +224,7 @@ class GroupTupleOp { target = CH.create() /* - * apply the logic the the source channel + * apply the logic to the source channel */ DataflowHelper.subscribeImpl(channel, [onNext: this.&collect, onComplete: this.&finalise]) diff --git a/modules/nextflow/src/main/groovy/nextflow/extension/MapOp.groovy b/modules/nextflow/src/main/groovy/nextflow/extension/MapOp.groovy index 78e05a0777..ecea4fbc2f 100644 --- a/modules/nextflow/src/main/groovy/nextflow/extension/MapOp.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/extension/MapOp.groovy @@ -16,7 +16,6 @@ package nextflow.extension - import groovyx.gpars.dataflow.DataflowReadChannel import groovyx.gpars.dataflow.DataflowWriteChannel import groovyx.gpars.dataflow.expression.DataflowExpression @@ -53,16 +52,16 @@ class MapOp { final stopOnFirst = source instanceof DataflowExpression DataflowHelper.newOperator(source, target) { it -> - def result = mapper.call(it) - def proc = (DataflowProcessor) getDelegate() + final result = mapper.call(it) + final proc = (DataflowProcessor) getDelegate() // bind the result value if (result != Channel.VOID) - proc.bindOutput(result) + Op.bind(target, result) // when the `map` operator is applied to a dataflow flow variable // terminate the processor after the first emission -- Issue #44 - if( result == Channel.STOP || stopOnFirst ) + if( stopOnFirst ) proc.terminate() } diff --git a/modules/nextflow/src/main/groovy/nextflow/extension/Op.groovy b/modules/nextflow/src/main/groovy/nextflow/extension/Op.groovy new file mode 100644 index 0000000000..7e63ea3f0b --- /dev/null +++ b/modules/nextflow/src/main/groovy/nextflow/extension/Op.groovy @@ -0,0 +1,166 @@ +/* + * Copyright 2013-2024, Seqera Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package nextflow.extension + +import groovy.transform.CompileDynamic +import groovy.transform.CompileStatic +import groovy.transform.PackageScope +import groovy.util.logging.Slf4j +import groovyx.gpars.dataflow.DataflowWriteChannel +import groovyx.gpars.dataflow.operator.PoisonPill +import nextflow.Global +import nextflow.Session +import nextflow.prov.OperatorRun +import nextflow.prov.Prov +import nextflow.prov.Tracker + +/** + * Operator helpers methods + * + * @author Paolo Di Tommaso + */ +@Slf4j +@CompileStatic +class Op { + + static final @PackageScope ThreadLocal currentOperator = new ThreadLocal<>() + + static List unwrap(List messages) { + return messages.collect(it -> it instanceof Tracker.Msg ? it.value : it) + } + + static Object unwrap(Object it) { + return it instanceof Tracker.Msg ? it.value : it + } + + static Tracker.Msg wrap(Object obj) { + obj instanceof Tracker.Msg ? obj : Tracker.Msg.of(obj) + } + + static void bind(DataflowWriteChannel channel, Object msg) { + try { + if( msg instanceof PoisonPill ) + channel.bind(msg) + else + Prov.getTracker().bindOutput(currentOperator.get(), channel, msg) + } + catch (Throwable t) { + log.error("Unexpected resolving execution provenance: ${t.message}", t) + (Global.session as Session).abort(t) + } + } + + static Closure instrument(Closure op, boolean accumulator=false) { + return new InvokeOperatorAdapter(op, accumulator) + } + + static class InvokeOperatorAdapter extends Closure { + + private final Closure target + + private final boolean accumulator + + private OperatorRun previousRun + + private InvokeOperatorAdapter(Closure code, boolean accumulator) { + super(code.owner, code.thisObject) + this.target = code + this.target.delegate = code.delegate + this.target.setResolveStrategy(code.resolveStrategy) + this.accumulator = accumulator + } + + @Override + Class[] getParameterTypes() { + return target.getParameterTypes() + } + + @Override + int getMaximumNumberOfParameters() { + return target.getMaximumNumberOfParameters() + } + + @Override + Object getDelegate() { + return target.getDelegate() + } + + @Override + Object getProperty(String propertyName) { + return target.getProperty(propertyName) + } + + @Override + int getDirective() { + return target.getDirective() + } + + @Override + void setDelegate(Object delegate) { + target.setDelegate(delegate) + } + + @Override + void setDirective(int directive) { + target.setDirective(directive) + } + + @Override + void setResolveStrategy(int resolveStrategy) { + target.setResolveStrategy(resolveStrategy) + } + + @Override + void setProperty(String propertyName, Object newValue) { + target.setProperty(propertyName, newValue) + } + + @Override + @CompileDynamic + Object call(final Object... args) { + // when the accumulator flag true, re-use the previous run object + final run = !accumulator || previousRun==null + ? new OperatorRun() + : previousRun + // set as the current run in the thread local + currentOperator.set(run) + // map the inputs + final inputs = Prov.getTracker().receiveInputs(run, args.toList()) + final arr = inputs.toArray() + // todo: the spread operator should be replaced with proper array + final ret = target.call(*arr) + // track the previous run + if( accumulator ) + previousRun = run + // return the operation result + return ret + } + + Object call(Object args) { + // todo: this should invoke the above one + target.call(args) + } + + @Override + Object call() { + // todo: this should invoke the above one + target.call() + } + } + +} diff --git a/modules/nextflow/src/main/groovy/nextflow/extension/OperatorImpl.groovy b/modules/nextflow/src/main/groovy/nextflow/extension/OperatorImpl.groovy index 3614de19db..28de8e1521 100644 --- a/modules/nextflow/src/main/groovy/nextflow/extension/OperatorImpl.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/extension/OperatorImpl.groovy @@ -139,32 +139,31 @@ class OperatorImpl { newOperator(source, target, listener) { item -> - def result = closure != null ? closure.call(item) : item - def proc = ((DataflowProcessor) getDelegate()) + final result = closure != null ? closure.call(item) : item switch( result ) { case Collection: - result.each { it -> proc.bindOutput(it) } + result.each { it -> Op.bind(target,it) } break case (Object[]): - result.each { it -> proc.bindOutput(it) } + result.each { it -> Op.bind(target,it) } break case Map: - result.each { it -> proc.bindOutput(it) } + result.each { it -> Op.bind(target,it) } break case Map.Entry: - proc.bindOutput( (result as Map.Entry).key ) - proc.bindOutput( (result as Map.Entry).value ) + Op.bind(target, (result as Map.Entry).key ) + Op.bind(target, (result as Map.Entry).value ) break case Channel.VOID: break default: - proc.bindOutput(result) + Op.bind(target,result) } } diff --git a/modules/nextflow/src/main/groovy/nextflow/processor/TaskId.groovy b/modules/nextflow/src/main/groovy/nextflow/processor/TaskId.groovy index 3576b9fde7..bbb467ca5f 100644 --- a/modules/nextflow/src/main/groovy/nextflow/processor/TaskId.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/processor/TaskId.groovy @@ -19,6 +19,8 @@ package nextflow.processor import java.util.concurrent.atomic.AtomicInteger import groovy.transform.CompileStatic +import nextflow.util.TestOnly + /** * TaskRun unique identifier * @@ -32,12 +34,16 @@ class TaskId extends Number implements Comparable, Serializable, Cloneable { */ static final private AtomicInteger allCount = new AtomicInteger() + @TestOnly static void clear() { allCount.set(0) } + static TaskId next() { new TaskId(allCount.incrementAndGet()) } private final int value + int getValue() { value } + static TaskId of( value ) { if( value instanceof Integer ) return new TaskId(value) diff --git a/modules/nextflow/src/main/groovy/nextflow/processor/TaskProcessor.groovy b/modules/nextflow/src/main/groovy/nextflow/processor/TaskProcessor.groovy index e0698be111..19ce31fd8d 100644 --- a/modules/nextflow/src/main/groovy/nextflow/processor/TaskProcessor.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/processor/TaskProcessor.groovy @@ -15,9 +15,6 @@ */ package nextflow.processor -import nextflow.provenance.ProvTracker -import nextflow.trace.TraceRecord - import static nextflow.processor.ErrorStrategy.* import java.lang.reflect.InvocationTargetException @@ -36,6 +33,7 @@ import java.util.regex.Pattern import ch.artecat.grengine.Grengine import com.google.common.hash.HashCode import groovy.json.JsonOutput +import groovy.transform.Canonical import groovy.transform.CompileStatic import groovy.transform.Memoized import groovy.transform.PackageScope @@ -83,6 +81,7 @@ import nextflow.file.FilePatternSplitter import nextflow.file.FilePorter import nextflow.plugin.Plugins import nextflow.processor.tip.TaskTipProvider +import nextflow.prov.Prov import nextflow.script.BaseScript import nextflow.script.BodyDef import nextflow.script.ProcessConfig @@ -107,6 +106,7 @@ import nextflow.script.params.TupleInParam import nextflow.script.params.TupleOutParam import nextflow.script.params.ValueInParam import nextflow.script.params.ValueOutParam +import nextflow.trace.TraceRecord import nextflow.util.ArrayBag import nextflow.util.BlankSeparatedList import nextflow.util.CacheHelper @@ -134,6 +134,11 @@ class TaskProcessor { RunType(String str) { message=str }; } + @Canonical + static class FairEntry { + TaskRun task + Map emissions + } static final public String TASK_CONTEXT_PROPERTY_NAME = 'task' final private static Pattern ENV_VAR_NAME = ~/[a-zA-Z_]+[a-zA-Z0-9_]*/ @@ -144,6 +149,8 @@ class TaskProcessor { @TestOnly static TaskProcessor currentProcessor() { currentProcessor0 } + @TestOnly static Map allTasks = new HashMap<>() + /** * Keeps track of the task instance executed by the current thread */ @@ -252,16 +259,14 @@ class TaskProcessor { private static LockManager lockManager = new LockManager() - private List> fairBuffers = new ArrayList<>() + private List fairBuffers = new ArrayList<>() - private int currentEmission + private volatile int currentEmission private Boolean isFair0 private TaskArrayCollector arrayCollector - private ProvTracker provenance - private CompilerConfiguration compilerConfig() { final config = new CompilerConfiguration() config.addCompilationCustomizers( new ASTTransformationCustomizer(TaskTemplateVarsXform) ) @@ -321,7 +326,6 @@ class TaskProcessor { final arraySize = config.getArray() this.arrayCollector = arraySize > 0 ? new TaskArrayCollector(this, executor, arraySize) : null - this.provenance = session.getProvenance() } /** @@ -635,8 +639,9 @@ class TaskProcessor { final task = createTaskRun(params) // -- set the task instance as the current in this thread currentTask.set(task) + allTasks.put(task.id, task) // track the task provenance for the given inputs - final values = provenance.beforeRun(task, inputs) + final values = Prov.tracker.receiveInputs(task, inputs) // -- validate input lengths validateInputTuples(values) @@ -1479,19 +1484,20 @@ class TaskProcessor { synchronized (isFair0) { // decrement -1 because tasks are 1-based final index = task.index-1 + FairEntry entry = new FairEntry(task,emissions) // store the task emission values in a buffer - fairBuffers[index-currentEmission] = emissions + fairBuffers[index-currentEmission] = entry // check if the current task index matches the expected next emission index if( currentEmission == index ) { - while( emissions!=null ) { + while( entry!=null ) { // bind the emission values - bindOutputs0(emissions) + bindOutputs0(entry.emissions, entry.task) // remove the head and try with the following fairBuffers.remove(0) // increase the index of the next emission currentEmission++ // take the next emissions - emissions = fairBuffers[0] + entry = fairBuffers[0] } } } @@ -1524,7 +1530,7 @@ class TaskProcessor { // and result in a potential error. See https://github.com/nextflow-io/nextflow/issues/3768 final copy = x instanceof List && x instanceof Cloneable ? x.clone() : x // emit the final value - provenance.bindOutput(task, ch, copy) + Prov.tracker.bindOutput(task, ch, copy) } } @@ -2374,7 +2380,7 @@ class TaskProcessor { * @param task The {@code TaskRun} instance to finalize */ @PackageScope - final finalizeTask( TaskHandler handler) { + final finalizeTask(TaskHandler handler) { def task = handler.task log.trace "finalizing process > ${safeTaskName(task)} -- $task" diff --git a/modules/nextflow/src/main/groovy/nextflow/processor/TaskRun.groovy b/modules/nextflow/src/main/groovy/nextflow/processor/TaskRun.groovy index 49851e035c..922d67a13a 100644 --- a/modules/nextflow/src/main/groovy/nextflow/processor/TaskRun.groovy +++ b/modules/nextflow/src/main/groovy/nextflow/processor/TaskRun.groovy @@ -38,6 +38,7 @@ import nextflow.exception.ProcessTemplateException import nextflow.exception.ProcessUnrecoverableException import nextflow.file.FileHelper import nextflow.file.FileHolder +import nextflow.prov.TrailRun import nextflow.script.BodyDef import nextflow.script.ScriptType import nextflow.script.TaskClosure @@ -59,7 +60,7 @@ import nextflow.spack.SpackCache */ @Slf4j -class TaskRun implements Cloneable { +class TaskRun implements Cloneable, TrailRun { final private ConcurrentHashMap cache0 = new ConcurrentHashMap() @@ -578,8 +579,8 @@ class TaskRun implements Cloneable { static final public String CMD_ENV = '.command.env' - String toString( ) { - "id: $id; name: $name; type: $type; exit: ${exitStatus==Integer.MAX_VALUE ? '-' : exitStatus}; error: $error; workDir: $workDir" + String toString() { + "TaskRun[id: $id; name: $name; type: $type; upstreams: ${upstreamTasks} exit: ${exitStatus==Integer.MAX_VALUE ? '-' : exitStatus}; error: $error; workDir: $workDir]" } diff --git a/modules/nextflow/src/main/groovy/nextflow/prov/OperatorRun.groovy b/modules/nextflow/src/main/groovy/nextflow/prov/OperatorRun.groovy new file mode 100644 index 0000000000..fbbb52101e --- /dev/null +++ b/modules/nextflow/src/main/groovy/nextflow/prov/OperatorRun.groovy @@ -0,0 +1,40 @@ +/* + * Copyright 2013-2024, Seqera Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package nextflow.prov + +import groovy.transform.Canonical +import groovy.transform.CompileStatic + +/** + * Model an operator run + * + * @author Paolo Di Tommaso + */ +@Canonical +@CompileStatic +class OperatorRun implements TrailRun { + /** + * The list of (object) ids that was received as input by a operator run + */ + List inputIds = new ArrayList<>(10) + + @Override + String toString() { + "OperatorRun[id=${System.identityHashCode(this)}; inputs=${inputIds}]" + } +} diff --git a/modules/nextflow/src/main/groovy/nextflow/prov/Prov.groovy b/modules/nextflow/src/main/groovy/nextflow/prov/Prov.groovy new file mode 100644 index 0000000000..ddf0a02f88 --- /dev/null +++ b/modules/nextflow/src/main/groovy/nextflow/prov/Prov.groovy @@ -0,0 +1,46 @@ +/* + * Copyright 2013-2024, Seqera Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package nextflow.prov + + +import groovy.transform.CompileStatic +import groovy.util.logging.Slf4j +import nextflow.util.TestOnly +/** + * Provenance tracker facade class + * + * @author Paolo Di Tommaso + */ +@Slf4j +@CompileStatic +class Prov { + + static private volatile Tracker tracker0 + + static Tracker getTracker() { + if( tracker0==null ) + tracker0 = new Tracker() + return tracker0 + } + + @TestOnly + static void clear() { + tracker0 = null + } + +} diff --git a/modules/nextflow/src/main/groovy/nextflow/prov/Tracker.groovy b/modules/nextflow/src/main/groovy/nextflow/prov/Tracker.groovy new file mode 100644 index 0000000000..38d0d7d3db --- /dev/null +++ b/modules/nextflow/src/main/groovy/nextflow/prov/Tracker.groovy @@ -0,0 +1,162 @@ +/* + * Copyright 2013-2024, Seqera Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package nextflow.prov + +import java.util.concurrent.ConcurrentHashMap + +import groovy.transform.Canonical +import groovy.transform.CompileStatic +import groovy.util.logging.Slf4j +import groovyx.gpars.dataflow.DataflowWriteChannel +import nextflow.extension.Op +import nextflow.processor.TaskId +import nextflow.processor.TaskRun +/** + * + * @author Paolo Di Tommaso + */ +@Slf4j +@CompileStatic +class Tracker { + + static @Canonical class Msg { + final Object value + + String toString() { + "Msg[id=${System.identityHashCode(this)}; value=${value}]" + } + + static Msg of(Object o) { + new Msg(o) + } + } + + /** + * Associate an output value with the corresponding task run that emitted it + */ + private Map messages = new ConcurrentHashMap<>() + + List receiveInputs(TaskRun task, List inputs) { + // find the upstream tasks id + findUpstreamTasks(task, inputs) + // log for debugging purposes + logInputs(task, inputs) + // the second entry of messages list represent the run inputs list + // apply the de-normalization before returning it + return Op.unwrap(inputs) + } + + private logInputs(TaskRun task, List inputs) { + if( log.isDebugEnabled() ) { + def msg = "Task input" + msg += "\n - id : ${task.id} " + msg += "\n - name : '${task.name}'" + msg += "\n - upstream: ${task.upstreamTasks*.value.join(',')}" + for( Object it : inputs ) { + msg += "\n<= ${it}" + } + log.debug(msg) + } + } + + private logInputs(OperatorRun run, List inputs) { + if( log.isDebugEnabled() ) { + def msg = "Operator input" + msg += "\n - id: ${System.identityHashCode(run)} " + for( Object it : inputs ) { + msg += "\n<= ${it}" + } + log.debug(msg) + } + } + + List receiveInputs(OperatorRun run, List inputs) { + // find the upstream tasks id + run.inputIds.addAll(inputs.collect(msg-> System.identityHashCode(msg))) + // log for debugging purposes + logInputs(run, inputs) + // the second entry of messages list represent the task inputs list + // apply the de-normalization before returning it + return Op.unwrap(inputs) + } + + protected void findUpstreamTasks(TaskRun task, List messages) { + // find upstream tasks and restore nulls + final result = new HashSet() + for( Object msg : messages ) { + if( msg==null ) + throw new IllegalArgumentException("Message cannot be a null object") + if( msg !instanceof Msg ) + continue + final msgId = System.identityHashCode(msg) + result.addAll(findUpstreamTasks0(msgId,result)) + } + // finally bind the result to the task record + task.upstreamTasks = result + } + + protected Set findUpstreamTasks0(final int msgId, Set upstream) { + final run = messages.get(msgId) + if( run instanceof TaskRun ) { + upstream.add(run.id) + return upstream + } + if( run instanceof OperatorRun ) { + for( Integer it : run.inputIds ) { + if( it!=msgId ) { + findUpstreamTasks0(it, upstream) + } + else { + log.debug "Skip duplicate provenance message id=${msgId}" + } + } + } + return upstream + } + + Msg bindOutput(TrailRun run, DataflowWriteChannel channel, Object out) { + assert run!=null, "Argument 'run' cannot be null" + assert channel!=null, "Argument 'channel' cannot be null" + + final msg = Op.wrap(out) + logOutput(run, msg) + // map the message with the run where it has been output + messages.put(System.identityHashCode(msg), run) + // now emit the value + channel.bind(msg) + return msg + } + + private void logOutput(TrailRun run, Msg msg) { + String str + if( run instanceof OperatorRun ) { + str = "Operator output" + str += "\n - id : ${System.identityHashCode(run)}" + } + else if( run instanceof TaskRun ) { + str = "Task output" + str += "\n - id : ${run.id}" + str += "\n - name: '${run.name}'" + } + else + throw new IllegalArgumentException("Unknown run type: ${run}") + str += "\n=> ${msg}" + log.debug(str) + } + +} diff --git a/modules/nextflow/src/main/groovy/nextflow/prov/TrailRun.groovy b/modules/nextflow/src/main/groovy/nextflow/prov/TrailRun.groovy new file mode 100644 index 0000000000..0f04c9e6d9 --- /dev/null +++ b/modules/nextflow/src/main/groovy/nextflow/prov/TrailRun.groovy @@ -0,0 +1,25 @@ +/* + * Copyright 2013-2024, Seqera Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package nextflow.prov + +/** + * + * @author Paolo Di Tommaso + */ +interface TrailRun { +} diff --git a/modules/nextflow/src/main/groovy/nextflow/provenance/ProvTracker.groovy b/modules/nextflow/src/main/groovy/nextflow/provenance/ProvTracker.groovy deleted file mode 100644 index 5eebc2edb1..0000000000 --- a/modules/nextflow/src/main/groovy/nextflow/provenance/ProvTracker.groovy +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright 2013-2024, Seqera Labs - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package nextflow.provenance - -import java.util.concurrent.ConcurrentHashMap - -import groovy.transform.CompileStatic -import groovyx.gpars.dataflow.DataflowWriteChannel -import nextflow.processor.TaskId -import nextflow.processor.TaskRun -/** - * - * @author Paolo Di Tommaso - */ -@CompileStatic -class ProvTracker { - - static class NullMessage { } - - private Map messages = new ConcurrentHashMap<>() - - List beforeRun(TaskRun task, List messages) { - // find the upstream tasks id - findUpstreamTasks(task, messages) - // the second entry of messages list represent the task inputs list - // apply the de-normalization before returning it - return denormalizeMessages(messages) - } - - protected void findUpstreamTasks(TaskRun task, List messages) { - // find upstream tasks and restore nulls - final result = new HashSet() - for( Object msg : messages ) { - if( msg==null ) - throw new IllegalArgumentException("Message cannot be a null object") - final msgId = System.identityHashCode(msg) - result.addAll(findUpstreamTasks0(msgId,result)) - } - // finally bind the result to the task record - task.upstreamTasks = result - } - - protected Set findUpstreamTasks0(final int msgId, Set upstream) { - final task = messages.get(msgId) - if( task==null ) { - return upstream - } - if( task ) { - upstream.add(task.id) - return upstream - } - return upstream - } - - protected Object denormalizeMessage(Object msg) { - return msg !instanceof NullMessage ? msg : null - } - - protected List denormalizeMessages(List messages) { - return messages.collect(it-> denormalizeMessage(it)) - } - - protected Object normalizeMessage(Object message) { - // map a "null" value into an instance of "NullMessage" - // because it's needed the object identity to track the message flow - return message!=null ? message : new NullMessage() - } - - void bindOutput(TaskRun task, DataflowWriteChannel ch, Object msg) { - final value = normalizeMessage(msg) - // map the message with the run where it has been output - messages.put(System.identityHashCode(value), task) - // now emit the value - ch.bind(value) - } - -} diff --git a/modules/nextflow/src/test/groovy/nextflow/extension/ConcatOpTest.groovy b/modules/nextflow/src/test/groovy/nextflow/extension/ConcatOpTest.groovy index 5a40b18927..1600300d97 100644 --- a/modules/nextflow/src/test/groovy/nextflow/extension/ConcatOpTest.groovy +++ b/modules/nextflow/src/test/groovy/nextflow/extension/ConcatOpTest.groovy @@ -25,7 +25,7 @@ import test.Dsl2Spec * @author Paolo Di Tommaso */ @Timeout(5) -class ConcatOp2Test extends Dsl2Spec { +class ConcatOpTest extends Dsl2Spec { def 'should concat two channel'() { diff --git a/modules/nextflow/src/test/groovy/nextflow/extension/OpTest.groovy b/modules/nextflow/src/test/groovy/nextflow/extension/OpTest.groovy new file mode 100644 index 0000000000..e55decaa7d --- /dev/null +++ b/modules/nextflow/src/test/groovy/nextflow/extension/OpTest.groovy @@ -0,0 +1,47 @@ +/* + * Copyright 2013-2024, Seqera Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package nextflow.extension + + +import nextflow.prov.Prov +import spock.lang.Specification + +/** + * + * @author Paolo Di Tommaso + */ +class OpTest extends Specification { + + def 'should instrument a closure'() { + given: + def code = { int x, int y -> x+y } + def v1 = 1 + def v2 = 2 + + when: + def c = Op.instrument(code) + def z = c.call([v1, v2] as Object[]) + then: + z == 3 + and: + Op.currentOperator.get().inputIds == [ System.identityHashCode(v1), System.identityHashCode(v2) ] + + cleanup: + Prov.clear() + } +} diff --git a/modules/nextflow/src/test/groovy/nextflow/extension/UntilManyOpTest.groovy b/modules/nextflow/src/test/groovy/nextflow/extension/UntilManyOpTest.groovy index c64bee82ce..706771b1b9 100644 --- a/modules/nextflow/src/test/groovy/nextflow/extension/UntilManyOpTest.groovy +++ b/modules/nextflow/src/test/groovy/nextflow/extension/UntilManyOpTest.groovy @@ -19,10 +19,13 @@ package nextflow.extension import nextflow.Channel import spock.lang.Specification +import spock.lang.Timeout + /** * * @author Paolo Di Tommaso */ +@Timeout(10) class UntilManyOpTest extends Specification { def 'should emit channel items until the condition is verified' () { @@ -94,5 +97,4 @@ class UntilManyOpTest extends Specification { Z.val == Channel.STOP } - } diff --git a/modules/nextflow/src/test/groovy/nextflow/prov/ProvTest.groovy b/modules/nextflow/src/test/groovy/nextflow/prov/ProvTest.groovy new file mode 100644 index 0000000000..e4ee22054f --- /dev/null +++ b/modules/nextflow/src/test/groovy/nextflow/prov/ProvTest.groovy @@ -0,0 +1,206 @@ +package nextflow.prov + +import static test.TestHelper.* + +import nextflow.config.ConfigParser +import nextflow.processor.TaskId +import nextflow.processor.TaskProcessor +import test.Dsl2Spec +/** + * + * @author Paolo Di Tommaso + */ +class ProvTest extends Dsl2Spec { + + def setup() { + Prov.clear() + TaskId.clear() + TaskProcessor.allTasks.clear() + } + + ConfigObject globalConfig() { + new ConfigParser().parse(''' + process.fair = true + ''') + } + + def 'should chain two process'() { + + when: + dsl_eval(globalConfig(), ''' + workflow { + p1 | map { x-> x } | map { x-> x+1 } | p2 + } + + process p1 { + output: val(x) + exec: + x =1 + } + + process p2 { + input: val(x) + exec: + println x + } + ''') + + then: + def upstream = upstreamTasksOf('p2') + upstream.size() == 1 + upstream.first.name == 'p1' + } + + def 'should branch two process'() { + + when: + dsl_eval(globalConfig(), ''' + workflow { + channel.of(1,10,20) \ + | p1 \ + | branch { left: it <=10; right: it >10 } \ + | set { result } + + result.left | p2 + result.right | p3 + } + + process p1 { + input: val(x) + output: val(y) + exec: + y = x+1 + } + + process p2 { + input: val(x) + exec: + println x + } + + process p3 { + input: val(x) + exec: + println x + } + ''') + then: + def t1 = upstreamTasksOf('p2 (1)') + t1.first.name == 'p1 (1)' + t1.size() == 1 + + and: + def t2 = upstreamTasksOf('p3 (1)') + t2.first.name == 'p1 (2)' + t2.size() == 1 + + and: + def t3 = upstreamTasksOf('p3 (2)') + t3.first.name == 'p1 (3)' + t3.size() == 1 + } + + def 'should track provenance with flatMap operator' () { + when: + dsl_eval(globalConfig(), ''' + workflow { + channel.of(1,2) \ + | p1 \ + | flatMap \ + | p2 + } + + process p1 { + input: val(x) + output: val(y) + exec: + y = [x, x*x] + } + + process p2 { + input: val(x) + exec: + println x + } + ''') + then: + def t1 = upstreamTasksOf('p2 (1)') + t1.first.name == 'p1 (1)' + t1.size() == 1 + + and: + def t2 = upstreamTasksOf('p2 (2)') + t2.first.name == 'p1 (1)' + t2.size() == 1 + + and: + def t3 = upstreamTasksOf('p2 (3)') + t3.first.name == 'p1 (2)' + t3.size() == 1 + + and: + def t4 = upstreamTasksOf('p2 (4)') + t4.first.name == 'p1 (2)' + t4.size() == 1 + } + + def 'should track the provenance of two processes and reduce operator'() { + + when: + dsl_eval(globalConfig(), ''' + workflow { + channel.of(1,2,3) \ + | p1 \ + | reduce {a,b -> return a+b} \ + | p2 + } + + process p1 { + input: val(x) + output: val(y) + exec: + y = x + } + + process p2 { + input: val(x) + exec: + println x + } + ''') + + then: + def t1 = upstreamTasksOf('p2') + t1.name == ['p1 (1)', 'p1 (2)', 'p1 (3)'] + } + + def 'should track the provenance of two tasks and collectFile operator' () { + when: + dsl_eval(globalConfig(), ''' + workflow { + channel.of('a','b','c') \ + | p1 \ + | collectFile(name: 'sample.txt') \ + | p2 + } + + process p1 { + input: val(x) + output: val(y) + exec: + y = x + } + + process p2 { + input: file(x) + exec: + println x + } + ''') + + then: + def t1 = upstreamTasksOf('p2 (1)') + t1.name == ['p1 (1)', 'p1 (2)', 'p1 (3)'] + + } +} diff --git a/modules/nextflow/src/test/groovy/nextflow/prov/TrackerTest.groovy b/modules/nextflow/src/test/groovy/nextflow/prov/TrackerTest.groovy new file mode 100644 index 0000000000..8a2adc1518 --- /dev/null +++ b/modules/nextflow/src/test/groovy/nextflow/prov/TrackerTest.groovy @@ -0,0 +1,123 @@ +/* + * Copyright 2013-2024, Seqera Labs + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package nextflow.prov + +import groovyx.gpars.dataflow.DataflowQueue +import groovyx.gpars.dataflow.DataflowWriteChannel +import nextflow.processor.TaskConfig +import nextflow.processor.TaskId +import nextflow.processor.TaskProcessor +import nextflow.processor.TaskRun +import spock.lang.Specification + +/** + * + * @author Paolo Di Tommaso + */ +class TrackerTest extends Specification { + + def 'should normalize null values' () { + given: + def prov = new Tracker() + and: + def t1 = new TaskRun(id: new TaskId(1), processor: Mock(TaskProcessor), config: Mock(TaskConfig)) + def t2 = new TaskRun(id: new TaskId(2), processor: Mock(TaskProcessor), config: Mock(TaskConfig)) + and: + def msg1 = [Tracker.Msg.of('foo')] + def msg2 = [Tracker.Msg.of('foo'), Tracker.Msg.of(null)] + + when: + def result1 = prov.receiveInputs(t1, msg1) + then: + result1 == msg1.value + and: + t1.upstreamTasks == [] as Set + + when: + def result2 = prov.receiveInputs(t2, msg2) + then: + result2 == ['foo', null] + and: + t2.upstreamTasks == [] as Set + } + + def 'should bind value to task run' () { + given: + def prov = new Tracker() + and: + def t1 = new TaskRun(id: new TaskId(1), processor: Mock(TaskProcessor), config: Mock(TaskConfig)) + def c1 = new DataflowQueue() + def v1 = 'foo' + + when: + def m1 = prov.bindOutput(t1, c1, v1) + then: + c1.val.is(m1) + and: + prov.@messages.get(System.identityHashCode(m1)) == t1 + } + + def 'should determine upstream tasks' () { + given: + def prov = new Tracker() + and: + def t1 = new TaskRun(id: new TaskId(1), processor: Mock(TaskProcessor), config: Mock(TaskConfig)) + def t2 = new TaskRun(id: new TaskId(2), processor: Mock(TaskProcessor), config: Mock(TaskConfig)) + def t3 = new TaskRun(id: new TaskId(3), processor: Mock(TaskProcessor), config: Mock(TaskConfig)) + and: + def v1 = new Object() + def v2 = new Object() + + when: + def m1 = prov.bindOutput(t1, Mock(DataflowWriteChannel), v1) + def m2 = prov.bindOutput(t2, Mock(DataflowWriteChannel), v2) + and: + prov.receiveInputs(t3, [m1, m2]) + then: + t3.upstreamTasks == [t1.id, t2.id] as Set + } + + def 'should determine upstream task with operator' () { + given: + def prov = new Tracker() + and: + def v1 = Integer.valueOf(1) + def v2 = Integer.valueOf(2) + def v3 = Integer.valueOf(3) + and: + def t1 = new TaskRun(id: new TaskId(1), processor: Mock(TaskProcessor), config: Mock(TaskConfig)) + def p2 = new OperatorRun() + def t3 = new TaskRun(id: new TaskId(3), processor: Mock(TaskProcessor), config: Mock(TaskConfig)) + + when: + prov.receiveInputs(t1, []) + def m1 = prov.bindOutput(t1, Mock(DataflowWriteChannel), v1) + and: + prov.receiveInputs(p2, [m1]) + and: + def m2 = prov.bindOutput(p2, Mock(DataflowWriteChannel), v2) + and: + prov.receiveInputs(t3, [m2]) + and: + def m3 = prov.bindOutput(t3, Mock(DataflowWriteChannel), v3) + + then: + t3.upstreamTasks == [t1.id] as Set + } + +} diff --git a/modules/nextflow/src/test/groovy/nextflow/provenance/ProvTrackerTest.groovy b/modules/nextflow/src/test/groovy/nextflow/provenance/ProvTrackerTest.groovy deleted file mode 100644 index 571e127b8c..0000000000 --- a/modules/nextflow/src/test/groovy/nextflow/provenance/ProvTrackerTest.groovy +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright 2013-2024, Seqera Labs - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -package nextflow.provenance - -import groovyx.gpars.dataflow.DataflowQueue -import groovyx.gpars.dataflow.DataflowWriteChannel -import nextflow.processor.TaskId -import nextflow.processor.TaskRun -import spock.lang.Specification - -/** - * - * @author Paolo Di Tommaso - */ -class ProvTrackerTest extends Specification { - - def 'should normalize null values' () { - given: - def prov = new ProvTracker() - and: - def t1 = new TaskRun(id: new TaskId(1)) - def t2 = new TaskRun(id: new TaskId(2)) - and: - def msg1 = ['foo'] - def msg2 = ['foo', new ProvTracker.NullMessage()] - - when: - def result1 = prov.beforeRun(t1, msg1) - then: - result1 == msg1 - and: - t1.upstreamTasks == [] as Set - - when: - def result2 = prov.beforeRun(t2, msg2) - then: - result2 == ['foo', null] - and: - t2.upstreamTasks == [] as Set - } - - def 'should bind value to task run' () { - given: - def prov = new ProvTracker() - and: - def t1 = new TaskRun(id: new TaskId(1)) - def c1 = new DataflowQueue() - def v1 = 'foo' - - when: - prov.bindOutput(t1, c1, v1) - then: - c1.val == 'foo' - and: - prov.@messages.get(System.identityHashCode(v1)) == t1 - } - - def 'should determine upstream tasks' () { - given: - def prov = new ProvTracker() - and: - def t1 = new TaskRun(id: new TaskId(1)) - def t2 = new TaskRun(id: new TaskId(2)) - def t3 = new TaskRun(id: new TaskId(3)) - and: - def v1 = new Object() - def v2 = new Object() - def v3 = new Object() - - when: - prov.bindOutput(t1, Mock(DataflowWriteChannel), v1) - prov.bindOutput(t2, Mock(DataflowWriteChannel), v2) - and: - prov.beforeRun(t3, [v1, v2]) - then: - t3.upstreamTasks == [t1.id, t2.id] as Set - } - -} diff --git a/modules/nextflow/src/testFixtures/groovy/test/Dsl2Spec.groovy b/modules/nextflow/src/testFixtures/groovy/test/Dsl2Spec.groovy index a30c57f916..63c736b365 100644 --- a/modules/nextflow/src/testFixtures/groovy/test/Dsl2Spec.groovy +++ b/modules/nextflow/src/testFixtures/groovy/test/Dsl2Spec.groovy @@ -46,11 +46,14 @@ class Dsl2Spec extends BaseSpec { new MockScriptRunner().setScript(str).execute() } + def dsl_eval(Map config, String str) { + new MockScriptRunner(config).setScript(str).execute() + } + def dsl_eval(Path path) { new MockScriptRunner().setScript(path).execute() } - def dsl_eval(String entry, String str) { new MockScriptRunner() .setScript(str).execute(null, entry) diff --git a/modules/nextflow/src/testFixtures/groovy/test/MockHelpers.groovy b/modules/nextflow/src/testFixtures/groovy/test/MockHelpers.groovy index d4f5065c0d..be772ff133 100644 --- a/modules/nextflow/src/testFixtures/groovy/test/MockHelpers.groovy +++ b/modules/nextflow/src/testFixtures/groovy/test/MockHelpers.groovy @@ -127,7 +127,7 @@ class MockExecutor extends Executor { @Override TaskHandler createTaskHandler(TaskRun task) { - return new MockTaskHandler(task) + return new MockTaskHandler(task) } } diff --git a/modules/nextflow/src/testFixtures/groovy/test/TestHelper.groovy b/modules/nextflow/src/testFixtures/groovy/test/TestHelper.groovy index dce6834090..8ca3fcb841 100644 --- a/modules/nextflow/src/testFixtures/groovy/test/TestHelper.groovy +++ b/modules/nextflow/src/testFixtures/groovy/test/TestHelper.groovy @@ -15,6 +15,7 @@ */ package test + import java.nio.file.Files import java.nio.file.Path import java.util.zip.GZIPInputStream @@ -22,6 +23,9 @@ import java.util.zip.GZIPInputStream import com.google.common.jimfs.Configuration import com.google.common.jimfs.Jimfs import groovy.transform.Memoized +import nextflow.processor.TaskId +import nextflow.processor.TaskProcessor +import nextflow.processor.TaskRun /** * * @author Paolo Di Tommaso @@ -91,4 +95,36 @@ class TestHelper { // Convert the decoded bytes into a string return new String(decodedBytes); } + + static List upstreamTasksOf(v) { + if( v instanceof TaskRun ) + return upstreamTasksOf(v as TaskRun) + + if( v instanceof CharSequence ) { + TaskRun t = getTaskByName(v.toString()) + if( t ) + return upstreamTasksOf(t) + else + throw new IllegalArgumentException("Cannot find any task with name: $v") + } + + TaskRun t = getTaskById(v) + if( !t ) + throw new IllegalArgumentException("Cannot find any task with id: $v") + return upstreamTasksOf(t) + } + + static List upstreamTasksOf(TaskRun t) { + final ids = t.upstreamTasks ?: Set.of() + return ids.collect(it -> getTaskById(it)) + } + + static TaskRun getTaskByName(String name) { + TaskProcessor.allTasks.values().find( it -> it.name==name ) + } + + static TaskRun getTaskById(id) { + TaskProcessor.allTasks.get(TaskId.of(id)) + } + }