Skip to content

Commit

Permalink
Improve performance of random color sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
paulromano committed Mar 27, 2024
1 parent c8ec2b7 commit 5a2fb1f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
6 changes: 0 additions & 6 deletions openmc_plotter/plot_colors.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@

import numpy as np


# for consistent, but random, colors
def reset_seed():
np.random.seed(10)


def random_rgb():
return tuple(np.random.choice(range(256), size=3))

Expand Down
20 changes: 11 additions & 9 deletions openmc_plotter/plotmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from . import __version__
from .statepointmodel import StatePointModel
from .plot_colors import random_rgb, reset_seed
from .plot_colors import random_rgb

ID, NAME, COLOR, COLORLABEL, MASK, HIGHLIGHT = range(6)

Expand Down Expand Up @@ -170,9 +170,6 @@ def __init__(self, use_settings_pkl, model_path):
self.appliedScores = ()
self.appliedNuclides = ()

# reset random number seed for consistent
# coloring when reloading a model
reset_seed()
self.previousViews = []
self.subsequentViews = []

Expand Down Expand Up @@ -999,8 +996,9 @@ def __init__(self, origin=(0, 0, 0), width=10, height=10, restore_view=None,
self.materials = restore_view.materials
self.selectedTally = restore_view.selectedTally
else:
self.cells = self.getDomains('cell')
self.materials = self.getDomains('material')
rng = np.random.RandomState(10)
self.cells = self.getDomains('cell', rng)
self.materials = self.getDomains('material', rng)
self.selectedTally = None

def __getattr__(self, name):
Expand All @@ -1025,7 +1023,7 @@ def __hash__(self):
return hash(self.__dict__.__str__() + self.__str__())

@staticmethod
def getDomains(domain_type):
def getDomains(domain_type, rng):
""" Return dictionary of domain settings.
Retrieve cell or material ID numbers and names from .xml files
Expand Down Expand Up @@ -1054,9 +1052,13 @@ def getDomains(domain_type):
lib_domain = openmc.lib.materials
domains = DEFAULT_MATERIAL_DOMAIN_VIEW

for domain, domain_obj in lib_domain.items():
# Sample default colors for each domain
num_domain = len(lib_domain)
colors = rng.randint(256, size=(num_domain, 3))

for (domain, domain_obj), color in zip(lib_domain.items(), colors):
name = domain_obj.name
domains[domain] = DomainView(domain, name, random_rgb())
domains[domain] = DomainView(domain, name, color)

# always add void to a material domain at the end
if domain_type == 'material':
Expand Down

0 comments on commit 5a2fb1f

Please sign in to comment.