From 07748d574b4cf3a6d04e295f54007a15c8ba6c1f Mon Sep 17 00:00:00 2001 From: JosePizarro3 Date: Wed, 18 Sep 2024 10:48:28 +0200 Subject: [PATCH] Add testing for BeyondDFT workflow --- .../workflow/base_workflows.py | 7 ++ tests/workflow/test_base_workflows.py | 85 ++++++++++++++++++- 2 files changed, 89 insertions(+), 3 deletions(-) diff --git a/src/nomad_simulations/schema_packages/workflow/base_workflows.py b/src/nomad_simulations/schema_packages/workflow/base_workflows.py index 31e438ee..b52f2efa 100644 --- a/src/nomad_simulations/schema_packages/workflow/base_workflows.py +++ b/src/nomad_simulations/schema_packages/workflow/base_workflows.py @@ -138,8 +138,15 @@ def resolve_all_outputs(self) -> list[Outputs]: Returns: list[Outputs]: A list of all the `Outputs` sections from the `tasks`. """ + # Initial check + if not self.tasks: + return [] + + # Populate the list of outputs from the last element in `tasks` all_outputs = [] for task in self.tasks: + if not task.outputs: + continue all_outputs.append(task.outputs[-1]) return all_outputs diff --git a/tests/workflow/test_base_workflows.py b/tests/workflow/test_base_workflows.py index ee50f0f6..905e6e3a 100644 --- a/tests/workflow/test_base_workflows.py +++ b/tests/workflow/test_base_workflows.py @@ -26,8 +26,8 @@ from nomad_simulations.schema_packages.model_system import ModelSystem from nomad_simulations.schema_packages.outputs import Outputs from nomad_simulations.schema_packages.workflow import ( + BeyondDFT, BeyondDFTMethod, - BeyondDFTWorkflow, SimulationWorkflow, ) @@ -211,5 +211,84 @@ def test_resolve_beyonddft_method_ref( class TestBeyondDFT: - def test_resolve_all_outputs(self): - assert True + @pytest.mark.parametrize( + 'tasks, result', + [ + # no task + (None, []), + # empty task + ([Task()], []), + # task only contains inputs + ( + [Task(inputs=[Link(name='Input Model System', section=ModelSystem())])], + [], + ), + # one task with one output + ( + [Task(outputs=[Link(name='Output Data 1', section=Outputs())])], + [Link(name='Output Data 1', section=Outputs())], + ), + # one task with multiple outputs (only last is resolved) + ( + [ + Task( + outputs=[ + Link(name='Output Data 1', section=Outputs()), + Link(name='Output Data 2', section=Outputs()), + ] + ) + ], + [Link(name='Output Data 2', section=Outputs())], + ), + # multiple task with one output each + ( + [ + Task( + outputs=[Link(name='Task 1:Output Data 1', section=Outputs())] + ), + Task( + outputs=[Link(name='Task 2:Output Data 1', section=Outputs())] + ), + ], + [ + Link(name='Task 1:Output Data 1', section=Outputs()), + Link(name='Task 2:Output Data 1', section=Outputs()), + ], + ), + # multiple task with two outputs each (only last is resolved) + ( + [ + Task( + outputs=[ + Link(name='Task 1:Output Data 1', section=Outputs()), + Link(name='Task 1:Output Data 2', section=Outputs()), + ] + ), + Task( + outputs=[ + Link(name='Task 2:Output Data 1', section=Outputs()), + Link(name='Task 2:Output Data 2', section=Outputs()), + ] + ), + ], + [ + Link(name='Task 1:Output Data 2', section=Outputs()), + Link(name='Task 2:Output Data 2', section=Outputs()), + ], + ), + ], + ) + def test_resolve_all_outputs( + self, tasks: Optional[list[Task]], result: list[Outputs] + ): + """ + Test the `resolve_all_outputs` method of the `BeyondDFT` section. + """ + workflow = BeyondDFT() + if tasks is not None: + workflow.tasks = tasks + if result is not None: + for i, output in enumerate(workflow.resolve_all_outputs()): + assert output.name == result[i].name + else: + assert workflow.resolve_all_outputs() == result