diff --git a/skfem/assembly/basis/abstract_basis.py b/skfem/assembly/basis/abstract_basis.py index 579ade03..a29e9f7c 100644 --- a/skfem/assembly/basis/abstract_basis.py +++ b/skfem/assembly/basis/abstract_basis.py @@ -451,3 +451,23 @@ def draw(self, visuals='matplotlib', **kwargs): logger.warning("First argument, 'visuals', must be a string.") mod = importlib.import_module('skfem.visuals.{}'.format(visuals)) return mod.draw(self, **kwargs) + + def __mul__(self, other): + from copy import deepcopy + assert len(self.basis) == len(other.basis) + basis = [] + element_dofs = [] + for itr in range(len(self.basis)): + basis.append((self.basis[itr][0], other.basis[0][0].zeros())) + element_dofs.append(self.element_dofs[itr]) + for itr in range(len(other.basis)): + basis.append((self.basis[0][0].zeros(), other.basis[itr][0])) + element_dofs.append(other.element_dofs[itr] + self.N) + out = deepcopy(self) + out.basis = basis + out.dofs.N = self.N + other.N + out.Nbfun = self.Nbfun + other.Nbfun + out._element_dofs = element_dofs + out.interpolate = lambda w: (self.interpolate(w[:self.N]), + other.interpolate(w[self.N:])) + return out