Skip to content

Commit

Permalink
ufl_domains() now returns a set
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Nov 13, 2024
1 parent c6609d1 commit f91615c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
6 changes: 3 additions & 3 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,7 +1610,7 @@ def _integral_type(self):

@cached_property
def _mesh(self):
return self._form.ufl_domains()[self._kinfo.domain_number]
return tuple(self._form.ufl_domains())[self._kinfo.domain_number]

@cached_property
def _needs_subset(self):
Expand Down Expand Up @@ -1731,7 +1731,7 @@ def _as_global_kernel_arg_coefficient(_, self):

ufl_element = V.ufl_element()
if ufl_element.family() == "Real":
return op2.GlobalKernelArg((ufl_element.value_size,))
return op2.GlobalKernelArg((V.value_size,))
else:
return self._make_dat_global_kernel_arg(V, index=index)

Expand Down Expand Up @@ -1945,7 +1945,7 @@ def _indexed_tensor(self):

@cached_property
def _mesh(self):
return self._form.ufl_domains()[self._kinfo.domain_number]
return tuple(self._form.ufl_domains())[self._kinfo.domain_number]

@cached_property
def _iterset(self):
Expand Down
3 changes: 2 additions & 1 deletion firedrake/slate/slac/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ def _compile_expression_hashkey(slate_expr, compiler_parameters=None):

def _compile_expression_comm(*args, **kwargs):
# args[0] is a slate_expr
return args[0].ufl_domains()[0].comm
domain, = args[0].ufl_domains()
return domain.comm


@memory_and_disk_cache(
Expand Down
10 changes: 5 additions & 5 deletions firedrake/slate/slate.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,11 @@ def ufl_domain(self):
The function will fail if multiple domains are found.
"""
domains = self.ufl_domains()
assert all(domain == domains[0] for domain in domains), (
"All integrals must share the same domain of integration."
)
return domains[0]
try:
domain, = self.ufl_domains()
except ValueError:
raise ValueError("All integrals must share the same domain of integration.")
return domain

@abstractmethod
def ufl_domains(self):
Expand Down

0 comments on commit f91615c

Please sign in to comment.