From 2ad37635fe80346326592bc999c3b68ecc34ef83 Mon Sep 17 00:00:00 2001 From: Tom Gustafsson Date: Fri, 26 Nov 2021 22:55:08 +0200 Subject: [PATCH] Fix svg plotting of quadratic 2D meshes (#809) --- README.md | 2 ++ skfem/mesh/mesh.py | 9 +++++++++ skfem/mesh/mesh_2d_2.py | 4 ++++ tests/test_visuals.py | 23 ++++++++++++++++++++++- 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 6c12d5f0b..81bcb6801 100644 --- a/README.md +++ b/README.md @@ -225,6 +225,8 @@ with respect to documented and/or tested features. ### Unreleased - Fixed: `ElementDG` was not included in the wildcard import +- Fixed: Automatic visualization of `MeshTri2` and `MeshQuad2` in Jupyter + notebooks raised exception ### [5.0.0] - 2021-11-21 diff --git a/skfem/mesh/mesh.py b/skfem/mesh/mesh.py index 93861f86a..0888a9763 100644 --- a/skfem/mesh/mesh.py +++ b/skfem/mesh/mesh.py @@ -649,6 +649,15 @@ def init_refdom(cls): """Initialize a mesh corresponding to the reference domain.""" return cls(cls.elem.refdom.p, cls.elem.refdom.t, validate=False) + def morphed(self, *args): + """Morph the mesh using functions.""" + p = self.p.copy() + for i, arg in enumerate(args): + if arg is None: + continue + p[i] = arg(self.p) + return replace(self, doflocs=p) + def refined(self, times_or_ix: Union[int, ndarray] = 1): """Return a refined mesh. diff --git a/skfem/mesh/mesh_2d_2.py b/skfem/mesh/mesh_2d_2.py index c9c68b1b0..63ba261c8 100644 --- a/skfem/mesh/mesh_2d_2.py +++ b/skfem/mesh/mesh_2d_2.py @@ -7,3 +7,7 @@ def _repr_svg_(self) -> str: def element_finder(self, *args, **kwargs): raise NotImplementedError + + @classmethod + def init_refdom(cls): + return cls.__bases__[-1].init_refdom() diff --git a/tests/test_visuals.py b/tests/test_visuals.py index fbf0dfdb4..b0fabf128 100644 --- a/tests/test_visuals.py +++ b/tests/test_visuals.py @@ -1,7 +1,12 @@ import unittest +import pytest -from skfem.mesh import MeshTri, MeshQuad, MeshTet, MeshLine1 +from skfem.assembly import CellBasis +from skfem.mesh import (MeshTri, MeshQuad, MeshTet, MeshLine1, MeshTri2, + MeshQuad2) from skfem.visuals.matplotlib import draw, plot, plot3 +from skfem.visuals.svg import draw as drawsvg +from skfem.visuals.svg import plot as plotsvg class CallDraw(unittest.TestCase): @@ -42,3 +47,19 @@ class CallPlot3(unittest.TestCase): def runTest(self): m = self.mesh_type() plot3(m, m.p[0]) + + +@pytest.mark.parametrize( + "mtype", + [ + MeshTri, + MeshQuad, + MeshTri2, + MeshQuad2, + ] +) +def test_call_svg_plot(mtype): + m = mtype() + svg = drawsvg(m, nrefs=2) + basis = CellBasis(m, mtype.elem()) + svg_plot = plotsvg(basis, m.p[0], nrefs=2)