Skip to content

Commit

Permalink
Make DofsView a dataclass (#787)
Browse files Browse the repository at this point in the history
  • Loading branch information
kinnala authored Nov 14, 2021
1 parent 08a695d commit 62f72d3
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 43 deletions.
10 changes: 6 additions & 4 deletions docs/examples/ex17.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,13 @@ def conduction(u, v, w):
element = ElementQuad1()
basis = Basis(mesh, element)

conductivity = basis.zero_w()
for subdomain, elements in mesh.subdomains.items():
conductivity[elements] = thermal_conductivity[subdomain]
basis0 = basis.with_element(ElementQuad0())

L = asm(conduction, basis, conductivity=conductivity)
conductivity = basis0.zeros()
conductivity[basis0.get_dofs(elements='core')] = 101.
conductivity[basis0.get_dofs(elements='annulus')] = 11.

L = asm(conduction, basis, conductivity=basis0.interpolate(conductivity))

facet_basis = FacetBasis(mesh, element, facets=mesh.boundaries['perimeter'])
H = heat_transfer_coefficient * asm(convection, facet_basis)
Expand Down
7 changes: 5 additions & 2 deletions docs/examples/ex28.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,15 @@ def advection(u, v, w):
return v * velocity_x * grad(u)[0]


conductivity = basis['heat'].zero_w() + 1
basis0 = basis['heat'].with_element(ElementTriP0())
conductivity = basis0.zeros() + 1
conductivity[mesh.subdomains['solid']] = kratio

longitudinal_gradient = 3 / 4 / peclet

A = (asm(conduction, basis['heat'], conductivity=conductivity)
A = (asm(conduction,
basis['heat'],
conductivity=basis0.interpolate(conductivity))
+ peclet * asm(advection, basis['fluid']))
b = (asm(unit_load, basis['heated'])
+ longitudinal_gradient
Expand Down
7 changes: 5 additions & 2 deletions skfem/assembly/basis/abstract_basis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from warnings import warn
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import numpy as np
Expand Down Expand Up @@ -117,7 +118,7 @@ def complement_dofs(self, *D):
def find_dofs(self,
facets: Dict[str, ndarray] = None,
skip: List[str] = None) -> Dict[str, DofsView]:
"""Deprecated in favor of :meth:`~skfem.AbstractBasis.get_dofs`."""
warn("find_dofs deprecated in favor of get_dofs.", DeprecationWarning)
if facets is None:
if self.mesh.boundaries is None:
facets = {'all': self.mesh.boundary_facets()}
Expand Down Expand Up @@ -222,11 +223,13 @@ def get_dofs(self,
"""
if isinstance(facets, dict):
# deprecate
warn("Passing dict to get_dofs is deprecated.", DeprecationWarning)

def to_indices(f):
if callable(f):
return self.mesh.facets_satisfying(f)
return f

return {k: self.dofs.get_facet_dofs(to_indices(facets[k]),
skip_dofnames=skip)
for k in facets}
Expand Down
74 changes: 40 additions & 34 deletions skfem/assembly/dofs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union, NamedTuple, Any, List, Optional
from dataclasses import dataclass, replace
from typing import Union, Any, List, Optional
from warnings import warn

import numpy as np
Expand All @@ -8,7 +9,8 @@
from skfem.mesh import Mesh


class DofsView(NamedTuple):
@dataclass(repr=False)
class DofsView:
"""A subset of :class:`skfem.assembly.Dofs`."""

obj: Any = None
Expand Down Expand Up @@ -108,19 +110,21 @@ def keep(self, dofnames: List[str]):
An array of DOF names, e.g. `["u", "u_n"]`.
"""
return DofsView(
self.obj,
self.nodal_ix,
self.facet_ix,
self.edge_ix,
self.interior_ix,
*self._intersect_tuples(
(self.nodal_rows,
self.facet_rows,
self.edge_rows,
self.interior_rows),
self._dofnames_to_rows(dofnames)
)
nrows = self._intersect_tuples(
(
self.nodal_rows,
self.facet_rows,
self.edge_rows,
self.interior_rows,
),
self._dofnames_to_rows(dofnames)
)
return replace(
self,
nodal_rows=nrows[0],
facet_rows=nrows[1],
edge_rows=nrows[2],
interior_rows=nrows[3],
)

def drop(self, dofnames):
Expand All @@ -132,19 +136,21 @@ def drop(self, dofnames):
An array of DOF names, e.g. `["u", "u_n"]`.
"""
return DofsView(
self.obj,
self.nodal_ix,
self.facet_ix,
self.edge_ix,
self.interior_ix,
*self._intersect_tuples(
(self.nodal_rows,
self.facet_rows,
self.edge_rows,
self.interior_rows),
self._dofnames_to_rows(dofnames, skip=True)
)
nrows = self._intersect_tuples(
(
self.nodal_rows,
self.facet_rows,
self.edge_rows,
self.interior_rows,
),
self._dofnames_to_rows(dofnames, skip=True)
)
return replace(
self,
nodal_rows=nrows[0],
facet_rows=nrows[1],
edge_rows=nrows[2],
interior_rows=nrows[3],
)

def all(self, key=None):
Expand Down Expand Up @@ -190,12 +196,12 @@ def __getattr__(self, attr):

def __or__(self, other):
warn("Use numpy.hstack to combine sets of DOFs", DeprecationWarning)
return DofsView(
self.obj,
np.union1d(self.nodal_ix, other.nodal_ix),
np.union1d(self.facet_ix, other.facet_ix),
np.union1d(self.edge_ix, other.edge_ix),
np.union1d(self.interior_ix, other.interior_ix)
return replace(
self,
nodal_ix=np.union1d(self.nodal_ix, other.nodal_ix),
facet_ix=np.union1d(self.facet_ix, other.facet_ix),
edge_ix=np.union1d(self.edge_ix, other.edge_ix),
interior_ix=np.union1d(self.interior_ix, other.interior_ix),
)

def __add__(self, other):
Expand Down
5 changes: 4 additions & 1 deletion skfem/assembly/form/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def dictify(w, basis):
"""Support additional input formats for 'w'."""
for k in w:
if isinstance(w[k], DiscreteField):
continue
if w[k][0].shape[-1] != basis.X.shape[1]:
raise ValueError("Quadrature mismatch: '{}' should have "
"same number of integration points as "
"the basis object.".format(k))
elif isinstance(w[k], tuple):
# asm() product index is of type tuple
continue
Expand Down

0 comments on commit 62f72d3

Please sign in to comment.