diff --git a/dagmc/dagnav.py b/dagmc/dagnav.py index 7e117a1..385451e 100644 --- a/dagmc/dagnav.py +++ b/dagmc/dagnav.py @@ -1,12 +1,13 @@ +from __future__ import annotations from abc import abstractmethod from functools import cached_property -from pathlib import Path -from typing import Any +from itertools import chain +from typing import Optional, Dict import numpy as np - from pymoab import core, types, rng + class DAGModel: def __init__(self, moab_file): @@ -31,7 +32,7 @@ def volumes(self): return {v.id: v for v in volumes} @property - def groups(self): + def groups(self) -> Dict[str, Group]: group_handles = self._sets_by_category('Group') group_mapping = {} @@ -75,7 +76,7 @@ class DAGSet: """ Generic functionality for a DAGMC EntitySet. """ - def __init__(self, model, handle): + def __init__(self, model: DAGModel, handle): self.model = model self.handle = handle @@ -224,6 +225,41 @@ def _get_triangle_sets(self): class Volume(DAGSet): + @property + def groups(self) -> list[Group]: + """Get list of groups containing this volume.""" + return [group for group in self.model.groups.values() if self in group] + + @property + def material(self) -> Optional[str]: + """Name of the material assigned to this volume.""" + for group in self.groups: + if self in group and group.name.startswith("mat:"): + return group.name[4:] + return None + + @material.setter + def material(self, name: str): + existing_group = False + for group in self.model.groups.values(): + if f"mat:{name}" == group.name: + # Add volume to group matching specified name, unless the volume + # is already in it + if self in group: + return + group.add_set(self) + existing_group = True + + elif self in group and group.name.startswith("mat:"): + # Remove volume from existing group + group.remove_set(self) + + if not existing_group: + # Create new group and add entity + group_id = max((g.id for g in self.model.groups.values()), default=0) + 1 + new_group = Group.create(self.model, name=f"mat:{name}", group_id=group_id) + new_group.add_set(self) + def get_surfaces(self): """Returns surface objects for all surfaces making up this vollume""" surfs = [Surface(self.model, h) for h in self.model.mb.get_child_meshsets(self.handle)] @@ -238,13 +274,17 @@ def _get_triangle_sets(self): class Group(DAGSet): + def __contains__(self, ent_set: DAGSet): + return any(vol.handle == ent_set.handle for vol in chain( + self.get_volumes().values(), self.get_surfaces().values())) + @property - def name(self): + def name(self) -> str: """Returns the name of this group.""" return self.model.mb.tag_get_data(self.model.name_tag, self.handle, flat=True)[0] @name.setter - def name(self, val): + def name(self, val: str): self.model.mb.tag_set_data(self.model.name_tag, self.handle, val) def _get_geom_ent_by_id(self, entity_type, id): @@ -333,7 +373,7 @@ def merge(self, other_group): other_group.handle = self.handle @classmethod - def create(cls, model, name=None, group_id=None): + def create(cls, model, name=None, group_id=None) -> Group: """Create a new group instance with the given name""" mb = model.mb # add necessary tags for this meshset to be identified as a group diff --git a/pyproject.toml b/pyproject.toml index 60969c5..4034efb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ classifiers = [ ] [project.urls] -"Homepage" = "https://github.com/pshriwise/pydagmc" +"Homepage" = "https://github.com/svalinn/pydagmc" [project.optional-dependencies] -test = ["pytest"] \ No newline at end of file +test = ["pytest"] diff --git a/test/test_basic.py b/test/test_basic.py index 8bf1c74..5ed63af 100644 --- a/test/test_basic.py +++ b/test/test_basic.py @@ -111,6 +111,21 @@ def test_group_merge(request): assert 3 in fuel_group.get_volumes() +def test_volume(request): + test_file = str(request.path.parent / 'fuel_pin.h5m') + model = dagmc.DAGModel(test_file) + + v1 = model.volumes[1] + assert v1.material == 'fuel' + assert v1 in model.groups['mat:fuel'] + + v1.material = 'olive oil' + assert v1.material == 'olive oil' + assert 'mat:olive oil' in model.groups + assert v1 in model.groups['mat:olive oil'] + assert v1 not in model.groups['mat:fuel'] + + def test_compressed_coords(request, capfd): test_file = str(request.path.parent / 'fuel_pin.h5m') groups = dagmc.DAGModel(test_file).groups