diff --git a/docs/examples/ex17.py b/docs/examples/ex17.py index ecae4a22a..654656642 100644 --- a/docs/examples/ex17.py +++ b/docs/examples/ex17.py @@ -53,11 +53,13 @@ def conduction(u, v, w): element = ElementQuad1() basis = Basis(mesh, element) -conductivity = basis.zero_w() -for subdomain, elements in mesh.subdomains.items(): - conductivity[elements] = thermal_conductivity[subdomain] +basis0 = basis.with_element(ElementQuad0()) -L = asm(conduction, basis, conductivity=conductivity) +conductivity = basis0.zeros() +conductivity[basis0.get_dofs(elements='core')] = 101. +conductivity[basis0.get_dofs(elements='annulus')] = 11. + +L = asm(conduction, basis, conductivity=basis0.interpolate(conductivity)) facet_basis = FacetBasis(mesh, element, facets=mesh.boundaries['perimeter']) H = heat_transfer_coefficient * asm(convection, facet_basis) diff --git a/docs/examples/ex28.py b/docs/examples/ex28.py index a3e708849..247484efa 100644 --- a/docs/examples/ex28.py +++ b/docs/examples/ex28.py @@ -92,12 +92,15 @@ def advection(u, v, w): return v * velocity_x * grad(u)[0] -conductivity = basis['heat'].zero_w() + 1 +basis0 = basis['heat'].with_element(ElementTriP0()) +conductivity = basis0.zeros() + 1 conductivity[mesh.subdomains['solid']] = kratio longitudinal_gradient = 3 / 4 / peclet -A = (asm(conduction, basis['heat'], conductivity=conductivity) +A = (asm(conduction, + basis['heat'], + conductivity=basis0.interpolate(conductivity)) + peclet * asm(advection, basis['fluid'])) b = (asm(unit_load, basis['heated']) + longitudinal_gradient diff --git a/skfem/assembly/basis/abstract_basis.py b/skfem/assembly/basis/abstract_basis.py index eba682fcb..1a4b69ca7 100644 --- a/skfem/assembly/basis/abstract_basis.py +++ b/skfem/assembly/basis/abstract_basis.py @@ -1,4 +1,5 @@ import logging +from warnings import warn from typing import Any, Dict, List, Optional, Tuple, Type, Union import numpy as np @@ -117,7 +118,7 @@ def complement_dofs(self, *D): def find_dofs(self, facets: Dict[str, ndarray] = None, skip: List[str] = None) -> Dict[str, DofsView]: - """Deprecated in favor of :meth:`~skfem.AbstractBasis.get_dofs`.""" + warn("find_dofs deprecated in favor of get_dofs.", DeprecationWarning) if facets is None: if self.mesh.boundaries is None: facets = {'all': self.mesh.boundary_facets()} @@ -222,11 +223,13 @@ def get_dofs(self, """ if isinstance(facets, dict): - # deprecate + warn("Passing dict to get_dofs is deprecated.", DeprecationWarning) + def to_indices(f): if callable(f): return self.mesh.facets_satisfying(f) return f + return {k: self.dofs.get_facet_dofs(to_indices(facets[k]), skip_dofnames=skip) for k in facets} diff --git a/skfem/assembly/dofs.py b/skfem/assembly/dofs.py index 3185f5de9..441f8fda2 100644 --- a/skfem/assembly/dofs.py +++ b/skfem/assembly/dofs.py @@ -1,4 +1,5 @@ -from typing import Union, NamedTuple, Any, List, Optional +from dataclasses import dataclass, replace +from typing import Union, Any, List, Optional from warnings import warn import numpy as np @@ -8,7 +9,8 @@ from skfem.mesh import Mesh -class DofsView(NamedTuple): +@dataclass(repr=False) +class DofsView: """A subset of :class:`skfem.assembly.Dofs`.""" obj: Any = None @@ -108,19 +110,21 @@ def keep(self, dofnames: List[str]): An array of DOF names, e.g. `["u", "u_n"]`. """ - return DofsView( - self.obj, - self.nodal_ix, - self.facet_ix, - self.edge_ix, - self.interior_ix, - *self._intersect_tuples( - (self.nodal_rows, - self.facet_rows, - self.edge_rows, - self.interior_rows), - self._dofnames_to_rows(dofnames) - ) + nrows = self._intersect_tuples( + ( + self.nodal_rows, + self.facet_rows, + self.edge_rows, + self.interior_rows, + ), + self._dofnames_to_rows(dofnames) + ) + return replace( + self, + nodal_rows=nrows[0], + facet_rows=nrows[1], + edge_rows=nrows[2], + interior_rows=nrows[3], ) def drop(self, dofnames): @@ -132,19 +136,21 @@ def drop(self, dofnames): An array of DOF names, e.g. `["u", "u_n"]`. """ - return DofsView( - self.obj, - self.nodal_ix, - self.facet_ix, - self.edge_ix, - self.interior_ix, - *self._intersect_tuples( - (self.nodal_rows, - self.facet_rows, - self.edge_rows, - self.interior_rows), - self._dofnames_to_rows(dofnames, skip=True) - ) + nrows = self._intersect_tuples( + ( + self.nodal_rows, + self.facet_rows, + self.edge_rows, + self.interior_rows, + ), + self._dofnames_to_rows(dofnames, skip=True) + ) + return replace( + self, + nodal_rows=nrows[0], + facet_rows=nrows[1], + edge_rows=nrows[2], + interior_rows=nrows[3], ) def all(self, key=None): @@ -190,12 +196,12 @@ def __getattr__(self, attr): def __or__(self, other): warn("Use numpy.hstack to combine sets of DOFs", DeprecationWarning) - return DofsView( - self.obj, - np.union1d(self.nodal_ix, other.nodal_ix), - np.union1d(self.facet_ix, other.facet_ix), - np.union1d(self.edge_ix, other.edge_ix), - np.union1d(self.interior_ix, other.interior_ix) + return replace( + self, + nodal_ix=np.union1d(self.nodal_ix, other.nodal_ix), + facet_ix=np.union1d(self.facet_ix, other.facet_ix), + edge_ix=np.union1d(self.edge_ix, other.edge_ix), + interior_ix=np.union1d(self.interior_ix, other.interior_ix), ) def __add__(self, other): diff --git a/skfem/assembly/form/form.py b/skfem/assembly/form/form.py index 3a7b62ba5..e5bb324c6 100644 --- a/skfem/assembly/form/form.py +++ b/skfem/assembly/form/form.py @@ -73,7 +73,10 @@ def dictify(w, basis): """Support additional input formats for 'w'.""" for k in w: if isinstance(w[k], DiscreteField): - continue + if w[k][0].shape[-1] != basis.X.shape[1]: + raise ValueError("Quadrature mismatch: '{}' should have " + "same number of integration points as " + "the basis object.".format(k)) elif isinstance(w[k], tuple): # asm() product index is of type tuple continue