Skip to content

Commit

Permalink
Improved schema with new workflow base sections
Browse files Browse the repository at this point in the history
Deleted unused methods

Improved docstrings
  • Loading branch information
JosePizarro3 committed Oct 10, 2024
1 parent 20e0277 commit 7f46488
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 99 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ maintainers = [
]
license = { file = "LICENSE" }
dependencies = [
"nomad-lab>=1.3.0",
"nomad-lab@git+https://gitlab.mpcdf.mpg.de/nomad-lab/nomad-FAIR.git@6b7149a71b2999abbb2225fcb67a5acafc811806",
"matid>=2.0.0.dev2",
]

Expand Down
29 changes: 13 additions & 16 deletions src/nomad_simulations/schema_packages/workflow/base_workflows.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
from functools import wraps
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from nomad.datamodel.datamodel import EntryArchive
from structlog.stdlib import BoundLogger

from nomad.datamodel.data import ArchiveSection
from nomad.datamodel.metainfo.workflow import TaskReference, Workflow
from nomad.datamodel.metainfo.workflow_new import BaseTask
from nomad.datamodel.metainfo.workflow_new import Workflow2 as Workflow
from nomad.metainfo import SubSection

from nomad_simulations.schema_packages.model_method import BaseModelMethod
from nomad_simulations.schema_packages.outputs import Outputs


def check_n_tasks(n_tasks: Optional[int] = None):
def check_n_tasks(n_tasks: int = 1):
"""
Check if the `tasks` of a workflow exist. If the `n_tasks` input specified, it checks whether `tasks`
is of the same length as `n_tasks`.
Check if the `tasks` of a workflow exist. It checks whether `tasks` is of the same length as `n_tasks`.
Args:
n_tasks (Optional[int], optional): The length of the `tasks` needs to be checked if set to an integer. Defaults to None.
n_tasks (int): The length of the `tasks` needs to be checked if set to an integer. Defaults to 1.
"""

def decorator(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if not self.tasks:
return None
if n_tasks is not None and len(self.tasks) != n_tasks:
if not self.tasks or len(self.tasks) != n_tasks:
return None

return func(self, *args, **kwargs)
Expand All @@ -39,14 +37,14 @@ def wrapper(self, *args, **kwargs):

class SimulationWorkflow(Workflow):
"""
A base section used to define the workflows of a simulation with references to specific `tasks`, `inputs`, and `outputs`. The
A base section used to define the workflows of a simulation with specific `tasks`, `inputs`, and `outputs`. The
normalize function checks the definition of these sections and sets the name of the workflow.
A `SimulationWorkflow` will be composed of:
- a `method` section containing methodological parameters used specifically during the workflow,
- a list of `inputs` with references to the `ModelSystem` and, optionally, `ModelMethod` input sections,
- a list of `outputs` with references to the `Outputs` section,
- a list of `tasks` containing references to the activity `Simulation` used in the workflow,
- a list of `tasks` containing references or the section information of the `task` used in the workflow,
"""

method = SubSection(
Expand All @@ -66,7 +64,7 @@ class BeyondDFTMethod(ArchiveSection):
"""
An abstract section used to store references to the `ModelMethod` sections of each of the
archives defining the `tasks` and used to build the standard `BeyondDFT` workflow. This section needs to be
inherit and the method references need to be defined for each specific case (see, e.g., dft_plus_tb.py module).
inherit and the method references need to be defined for each specific case (see, e.g., `dft_plus_tb.py` module).
"""

pass
Expand Down Expand Up @@ -104,16 +102,15 @@ def resolve_all_outputs(self) -> list[Outputs]:
all_outputs.append(task.outputs[-1])
return all_outputs

@check_n_tasks()
def resolve_method_refs(
self, tasks: list[TaskReference], tasks_names: list[str]
self, tasks: list[BaseTask], tasks_names: list[str]
) -> list[BaseModelMethod]:
"""
Resolve the references to the `BaseModelMethod` sections in the list of `tasks`. This is useful
when defining the `method` section of the `BeyondDFT` workflow.
Args:
tasks (list[TaskReference]): The list of tasks from which resolve the `BaseModelMethod` sections.
tasks (list[BaseTask]): The list of tasks from which resolve the `BaseModelMethod` sections.
tasks_names (list[str]): The list of names for each of the tasks forming the BeyondDFT workflow.
Returns:
Expand All @@ -132,7 +129,7 @@ def resolve_method_refs(
if not task.m_xpath('task.inputs'):
continue

# Resolve the method of each task.inputs
# Resolve the method of each `tasks[*].task.inputs`
for input in task.task.inputs:
if isinstance(input.section, BaseModelMethod):
method_refs.append(input.section)
Expand Down
105 changes: 36 additions & 69 deletions src/nomad_simulations/schema_packages/workflow/dft_plus_tb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from nomad.datamodel.datamodel import EntryArchive
from structlog.stdlib import BoundLogger

from nomad.datamodel.metainfo.workflow import Link, TaskReference
from nomad.metainfo import Quantity, Reference
from nomad.datamodel.metainfo.workflow_new import LinkReference
from nomad.metainfo import Quantity
from nomad.utils import extract_section

from nomad_simulations.schema_packages.model_method import DFT, TB
from nomad_simulations.schema_packages.workflow import BeyondDFT, BeyondDFTMethod
Expand All @@ -21,13 +22,13 @@ class DFTPlusTBMethod(BeyondDFTMethod):
"""

dft_method_ref = Quantity(
type=Reference(DFT),
type=DFT,
description="""
Reference to the DFT `ModelMethod` section in the DFT task.
""",
)
tb_method_ref = Quantity(
type=Reference(TB),
type=TB,
description="""
Reference to the TB `ModelMethod` section in the TB task.
""",
Expand All @@ -40,12 +41,10 @@ class DFTPlusTB(BeyondDFT):
two tasks: the initial DFT calculation + the final TB projection.
The section only needs to be populated with (everything else is handled by the `normalize` function):
i. The `tasks` as `TaskReference` sections, adding `task` to the specific archive.workflow2 sections.
ii. The `inputs` and `outputs` as `Link` sections pointing to the specific archives.
i. The `tasks` as `TaskReference` sections, adding `task` to the specific `archive.workflow2` sections.
Note 1: the `inputs[0]` of the `DFTPlusTB` coincides with the `inputs[0]` of the DFT task (`ModelSystem` section).
Note 2: the `outputs[-1]` of the `DFTPlusTB` coincides with the `outputs[-1]` of the TB task (`Outputs` section).
Note 3: the `outputs[-1]` of the DFT task is used as `inputs[0]` of the TB task.
The archive.workflow2 section is:
- name = 'DFT+TB'
Expand All @@ -54,68 +53,39 @@ class DFTPlusTB(BeyondDFT):
tb_method_ref=tb_archive.data.model_method[-1],
)
- inputs = [
Link(name='Input Model System', section=dft_archive.data.model_system[0]),
LinkReference(name='Input Model System', section=dft_archive.data.model_system[0]),
]
- outputs = [
Link(name='Output TB Data', section=tb_archive.data.outputs[-1]),
LinkReference(name='Output TB Data', section=tb_archive.data.outputs[-1]),
]
- tasks = [
TaskReference(
name='DFT SinglePoint Task',
task=dft_archive.workflow2
inputs=[
Link(name='Input Model System', section=dft_archive.data.model_system[0]),
],
outputs=[
Link(name='Output DFT Data', section=dft_archive.data.outputs[-1]),
]
),
TaskReference(
name='TB SinglePoint Task',
task=tb_archive.workflow2,
inputs=[
Link(name='Output DFT Data', section=dft_archive.data.outputs[-1]),
],
outputs=[
Link(name='Output tb Data', section=tb_archive.data.outputs[-1]),
]
),
TaskReference(task=dft_archive.workflow2),
TaskReference(task=tb_archive.workflow2),
]
"""

@check_n_tasks(n_tasks=2)
def link_task_inputs_outputs(
self, tasks: list[TaskReference], logger: 'BoundLogger'
) -> None:
if not self.inputs or not self.outputs:
logger.warning(
'The `DFTPlusTB` workflow needs to have `inputs` and `outputs` defined in order to link with the `tasks`.'
)
def resolve_inputs_outputs(self) -> None:
"""
Resolve the `inputs` and `outputs` of the `DFTPlusTB` workflow.
"""
input = extract_section(self.tasks[0], ['task', 'inputs[0]', 'section'])
print(
self.tasks[0],
extract_section(self.tasks[0], ['task']),
extract_section(self.tasks[0], ['task', 'inputs[0]']),
extract_section(self.tasks[0], ['task', 'inputs[0]', 'section']),
)
if not input:
return None
print(input)
self.inputs = [LinkReference(name='Input Model System', section=input)]

dft_task = tasks[0]
tb_task = tasks[1]

# Initial check
if not dft_task.m_xpath('task.outputs'):
output = extract_section(self.tasks[1], ['task', 'outputs[-1]', 'section'])
if not output:
return None

# Input of DFT Task is the ModelSystem
dft_task.inputs = [
Link(name='Input Model System', section=self.inputs[0]),
]
# Output of DFT Task is the output section of the DFT entry
dft_task.outputs = [
Link(name='Output DFT Data', section=dft_task.task.outputs[-1]),
]
# Input of TB Task is the output of the DFT task
tb_task.inputs = [
Link(name='Output DFT Data', section=dft_task.task.outputs[-1]),
]
# Output of TB Task is the output section of the TB entry
tb_task.outputs = [
Link(name='Output TB Data', section=self.outputs[-1]),
]
print(output)
self.outputs = [LinkReference(name='Output TB Data', section=output)]

# TODO check if implementing overwritting the FermiLevel.value in the TB entry from the DFT entry

Expand Down Expand Up @@ -144,14 +114,11 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
tasks=self.tasks,
tasks_names=['DFT SinglePoint Task', 'TB SinglePoint Task'],
)
if method_refs is not None:
method_workflow = DFTPlusTBMethod()
for method in method_refs:
if isinstance(method, DFT):
method_workflow.dft_method_ref = method
elif isinstance(method, TB):
method_workflow.tb_method_ref = method
self.method = method_workflow

# Resolve `tasks[*].inputs` and `tasks[*].outputs`
self.link_task_inputs_outputs(tasks=self.tasks, logger=logger)
if method_refs is not None and len(method_refs) == 2:
print(method_refs)
self.method = DFTPlusTBMethod(
dft_method_ref=method_refs[0], tb_method_ref=method_refs[1]
)

# Resolve `inputs` and `outputs` from the `tasks`
self.resolve_inputs_outputs()
30 changes: 17 additions & 13 deletions src/nomad_simulations/schema_packages/workflow/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
from nomad.datamodel.datamodel import EntryArchive
from structlog.stdlib import BoundLogger

from nomad.datamodel.metainfo.workflow import Link
from nomad.datamodel.metainfo.workflow_new import LinkReference
from nomad.metainfo import Quantity
from nomad.utils import extract_section

from nomad_simulations.schema_packages.outputs import SCFOutputs
from nomad_simulations.schema_packages.utils import extract_all_simulation_subsections
from nomad_simulations.schema_packages.workflow import SimulationWorkflow


Expand All @@ -26,11 +26,11 @@ class SinglePoint(SimulationWorkflow):
The archive.workflow2 section is:
- name = 'SinglePoint'
- inputs = [
Link(name='Input Model System', section=archive.data.model_system[0]),
Link(name='Input Model Method', section=archive.data.model_method[-1]),
LinkReference(name='Input Model System', section=archive.data.model_system[0]),
LinkReference(name='Input Model Method', section=archive.data.model_method[-1]),
]
- outputs = [
Link(name='Output Data', section=archive.data.outputs[-1]),
LinkReference(name='Output Data', section=archive.data.outputs[-1]),
]
- tasks = []
"""
Expand All @@ -53,19 +53,23 @@ def normalize(self, archive: 'EntryArchive', logger: 'BoundLogger') -> None:
self.name = 'SinglePoint'

# Define `inputs` and `outputs`
input_model_system, input_model_method, output = (
extract_all_simulation_subsections(archive=archive)
)
if not input_model_system or not input_model_method or not output:
input_model_system = extract_section(archive, ['data', 'model_system'])
output = extract_section(archive, ['data', 'outputs'])
if not input_model_system or not output:
logger.warning(
'Could not find the ModelSystem, ModelMethod, or Outputs section in the archive.data section of the SinglePoint entry.'
'Could not find the `ModelSystem` or `Outputs` section in the archive.data section of the SinglePoint entry.'
)
return
self.inputs = [
Link(name='Input Model System', section=input_model_system),
Link(name='Input Model Method', section=input_model_method),
LinkReference(name='Input Model System', section=input_model_system),
]
self.outputs = [Link(name='Output Data', section=output)]
self.outputs = [LinkReference(name='Output Data', section=output)]
# `ModelMethod` is optional when defining workflows like the `SinglePoint`
input_model_method = extract_section(archive, ['data', 'model_method'])
if input_model_method is not None:
self.inputs.append(
LinkReference(name='Input Model Method', section=input_model_method)
)

# Resolve the `n_scf_steps` if the output is of `SCFOutputs` type
if isinstance(output, SCFOutputs):
Expand Down

1 comment on commit 7f46488

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
src/nomad_simulations
   __init__.py4250%3–4
   _version.py11282%5–6
src/nomad_simulations/schema_packages
   __init__.py15287%39–41
   atoms_state.py19012932%13–15, 159–179, 197–232, 245–276, 279–297, 344–355, 358–372, 491–506, 518–526, 529–550, 608–615, 627–634, 637–643
   basis_set.py24015336%8–9, 78–79, 115–119, 122–133, 162–165, 168–169, 172–185, 208, 264, 270–271, 275–279, 290–294, 298–302, 306–327, 335, 370–381, 385–397, 417–418, 453–466, 470, 475–476, 513–521, 530, 562–589, 595–600, 603–618, 647–698
   general.py894253%4–7, 51–54, 121, 185, 214–216, 228–286, 289–311
   model_method.py26915443%10–12, 73, 92, 125, 171–174, 177–184, 276–277, 297, 318–339, 355–381, 384–401, 493, 516–559, 562–587, 637–645, 704–706, 732–753, 756–762, 780, 791, 833–840, 878, 897, 977, 1034, 1109, 1223
   model_system.py31722031%25–27, 182–190, 194–200, 283–291, 302–319, 322, 373–392, 404–413, 427–458, 469–486, 489–492, 620–647, 665–747, 750–766, 832–836, 839–860, 1065–1103, 1106–1145
   numerical_settings.py25919325%12–14, 36, 147, 175–191, 213–284, 374–394, 410–425, 444–466, 469–496, 561–582, 594–610, 632–662, 680–742, 745–766, 792–794, 809–827, 830–835, 889
   outputs.py1206843%9–10, 153–159, 169–173, 183–187, 190–198, 241–259, 276–309, 312–329, 362, 381
   physical_property.py1024952%20–22, 46–62, 200–202, 223, 235–238, 251, 258, 265–291, 301–303, 306–309, 331–333
   variables.py862176%8–10, 63–70, 73–76, 98, 121, 145, 167, 189, 211, 233, 256, 272–273, 276
src/nomad_simulations/schema_packages/properties
   band_gap.py513237%8–10, 79–92, 104–126, 129–144
   band_structure.py1237638%9–11, 53–55, 58, 131–132, 144–159, 174–199, 209–218, 232–265, 274–286, 289–308, 321–322, 325, 372–373, 378
   energies.py421857%7–9, 36, 57, 77–79, 82, 99–100, 103, 115–116, 119, 130–131, 134
   fermi_surface.py17759%7–9, 34–37, 40
   forces.py22864%7–9, 36, 56, 75–76, 79
   greens_function.py995247%7–9, 135–137, 149–179, 182–189, 210–211, 214, 235–236, 239, 260–261, 264, 349–355, 365–367, 376–384, 387–403
   hopping_matrix.py291162%7–9, 50–53, 58, 88–91, 94
   permittivity.py483233%7–9, 53–56, 59–62, 71–94, 97–105
   spectral_profile.py26021517%9–11, 39–40, 49–51, 54–60, 80, 94–106, 109–112, 176–177, 199–300, 313–345, 356–368, 382–400, 415–463, 466–502, 521–523, 526, 555–557, 568–593, 598–604
   thermodynamics.py752764%7–9, 35, 56, 72, 81, 90, 101, 110, 137, 147, 157, 172–174, 177, 193, 213–215, 218, 234, 254–256, 259
src/nomad_simulations/schema_packages/utils
   utils.py745427%8–16, 53–63, 70–79, 83–89, 92–97, 101, 105, 122–127, 144–150, 159–161
src/nomad_simulations/schema_packages/workflow
   base_workflows.py512551%5–6, 28–31, 60, 98–103, 120–137, 140
   dft_plus_tb.py442739%4–5, 72–88, 94–124
   single_point.py281739%6–7, 50–77
TOTAL2679163639% 

Tests Skipped Failures Errors Time
2 0 💤 0 ❌ 2 🔥 2.025s ⏱️

Please sign in to comment.