diff --git a/firedrake/mg/mesh.py b/firedrake/mg/mesh.py index e4eb1f67b8..0966eda114 100644 --- a/firedrake/mg/mesh.py +++ b/firedrake/mg/mesh.py @@ -9,6 +9,11 @@ from firedrake.cython import mgimpl as impl from .utils import set_level +try: + import netgen +except ImportError: + netgen = None + ngsPETSc = None __all__ = ("HierarchyBase", "MeshHierarchy", "ExtrudedMeshHierarchy", "NonNestedHierarchy", "SemiCoarsenedExtrudedHierarchy") @@ -80,6 +85,7 @@ def __getitem__(self, idx): def MeshHierarchy(mesh, refinement_levels, refinements_per_level=1, + order=1, reorder=None, distribution_parameters=None, callbacks=None, mesh_builder=firedrake.Mesh): @@ -102,6 +108,13 @@ def MeshHierarchy(mesh, refinement_levels, callback receives the refined DM (and the current level). :arg mesh_builder: Function to turn a DM into a ``Mesh``. Used by pyadjoint. """ + if netgen and isinstance(mesh, netgen.libngpy._meshing.Mesh): + try: + from ngsPETSc import NetgenHierarchy + except ImportError: + raise ImportError("Unable to import ngsPETSc. Please ensure that ngsolve is installed and available to Firedrake.") + return NetgenHierarchy(mesh, refinement_levels, order=order, digits=8, adaptive=False) + cdm = mesh.topology_dm cdm.setRefinementUniform(True) dms = [] diff --git a/tests/multigrid/test_netgen.py b/tests/multigrid/test_netgen.py new file mode 100644 index 0000000000..b231b010cd --- /dev/null +++ b/tests/multigrid/test_netgen.py @@ -0,0 +1,38 @@ +from firedrake import * +import pytest + + +def test_netgen_mg_circle(): + try: + from netgen.geom2d import Circle, CSG2d + except ImportError: + pytest.skip(reason="Netgen unavailable, skipping Netgen test.") + geo = CSG2d() + + circle = Circle(center=(0, 0), radius=1.0, mat="mat1", bc="circle") + geo.Add(circle) + + ngmesh = geo.GenerateMesh(maxh=0.75) + + nh = MeshHierarchy(ngmesh, 2, order=3) + mesh = nh[-1] + + V = FunctionSpace(mesh, "CG", 3) + + u = TrialFunction(V) + v = TestFunction(V) + + a = dot(grad(u), grad(v))*dx + labels = [i+1 for i, name in enumerate(ngmesh.GetRegionNames(codim=1)) if name in ["circle"]] + bcs = DirichletBC(V, zero(), labels) + x, y = SpatialCoordinate(mesh) + + f = 4+0*x + L = f*v*dx + exact = (1-x**2-y**2) + + u = Function(V) + solve(a == L, u, bcs=bcs, solver_parameters={"ksp_type": "cg", + "pc_type": "mg"}) + expect = Function(V).interpolate(exact) + assert (norm(assemble(u - expect)) <= 1e-6)