Skip to content

Commit db127a4

Browse files
authored
Support Bernstein on tensor product cells (#165)
* Support Bernstein on tensor product cells * Add a test * Rename lmbda -> make_finat_element
1 parent 34160d2 commit db127a4

File tree

4 files changed

+43
-26
lines changed

4 files changed

+43
-26
lines changed

finat/element_factory.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,18 @@ def convert_finiteelement(element, **kwargs):
157157

158158
codim = 1 if element.family() == "Boundary Quadrature" else 0
159159
return finat.make_quadrature_element(cell, degree, scheme, codim), set()
160-
lmbda = supported_elements[element.family()]
161-
if element.family() == "Real" and element.cell.cellname() in {"quadrilateral", "hexahedron"}:
162-
lmbda = None
163-
element = finat.ufl.FiniteElement("DQ", element.cell, 0)
164-
if lmbda is None:
160+
161+
make_finat_element = supported_elements[element.family()]
162+
163+
if element.cell.cellname() in {"quadrilateral", "hexahedron"}:
164+
# Reconstruct Real and Bernstein on tensor product cells
165+
if element.family() == "Real":
166+
make_finat_element = None
167+
element = finat.ufl.FiniteElement("DQ", element.cell, 0)
168+
elif element.family() == "Bernstein":
169+
make_finat_element = None
170+
171+
if make_finat_element is None:
165172
if element.cell.cellname() == "quadrilateral":
166173
# Handle quadrilateral short names like RTCF and RTCE.
167174
element = element.reconstruct(cell=quadrilateral_tpc)
@@ -174,49 +181,51 @@ def convert_finiteelement(element, **kwargs):
174181
finat_elem, deps = _create_element(element, **kwargs)
175182
return finat.FlattenedDimensions(finat_elem), deps
176183

184+
deps = set()
177185
finat_kwargs = {}
178186
kind = element.variant()
179187
if kind is None:
180188
kind = 'spectral' # default variant
181189

182190
if element.family() == "Lagrange":
183191
if kind in ['spectral', 'mimetic']:
184-
lmbda = finat.GaussLobattoLegendre
192+
make_finat_element = finat.GaussLobattoLegendre
185193
elif element.cell.cellname() == "interval" and kind in cg_interval_variants:
186-
lmbda = cg_interval_variants[kind]
194+
make_finat_element = cg_interval_variants[kind]
187195
elif any(map(kind.startswith, ['integral', 'demkowicz', 'fdm'])):
188-
lmbda = finat.IntegratedLegendre
196+
make_finat_element = finat.IntegratedLegendre
189197
finat_kwargs["variant"] = kind
190198
elif kind in ['mgd', 'feec', 'qb', 'mse']:
191-
degree = element.degree()
192-
shift_axes = kwargs["shift_axes"]
193-
restriction = kwargs["restriction"]
199+
make_finat_element = finat.RuntimeTabulated
200+
finat_kwargs["variant"] = kind
201+
finat_kwargs["shift_axes"] = kwargs["shift_axes"]
202+
finat_kwargs["restriction"] = kwargs["restriction"]
194203
deps = {"shift_axes", "restriction"}
195-
return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction), deps
196204
else:
197205
# Let FIAT handle the general case
198-
lmbda = finat.Lagrange
206+
make_finat_element = finat.Lagrange
199207
finat_kwargs["variant"] = kind
200208

201209
elif element.family() in ["Discontinuous Lagrange", "Discontinuous Lagrange L2"]:
202210
if kind == 'spectral':
203-
lmbda = finat.GaussLegendre
211+
make_finat_element = finat.GaussLegendre
204212
elif kind == 'mimetic':
205-
lmbda = finat.Histopolation
213+
make_finat_element = finat.Histopolation
206214
elif element.cell.cellname() == "interval" and kind in dg_interval_variants:
207-
lmbda = dg_interval_variants[kind]
215+
make_finat_element = dg_interval_variants[kind]
208216
elif any(map(kind.startswith, ['integral', 'demkowicz', 'fdm'])):
209-
lmbda = finat.Legendre
217+
make_finat_element = finat.Legendre
210218
finat_kwargs["variant"] = kind
211219
elif kind in ['mgd', 'feec', 'qb', 'mse']:
212-
degree = element.degree()
213-
shift_axes = kwargs["shift_axes"]
214-
restriction = kwargs["restriction"]
220+
make_finat_element = finat.RuntimeTabulated
221+
finat_kwargs["variant"] = kind
222+
finat_kwargs["shift_axes"] = kwargs["shift_axes"]
223+
finat_kwargs["restriction"] = kwargs["restriction"]
224+
finat_kwargs["continuous"] = False
215225
deps = {"shift_axes", "restriction"}
216-
return finat.RuntimeTabulated(cell, degree, variant=kind, shift_axes=shift_axes, restriction=restriction, continuous=False), deps
217226
else:
218227
# Let FIAT handle the general case
219-
lmbda = finat.DiscontinuousLagrange
228+
make_finat_element = finat.DiscontinuousLagrange
220229
finat_kwargs["variant"] = kind
221230

222231
elif element.family() in {"HDiv Trace", "Bubble", "FacetBubble"}:
@@ -225,7 +234,7 @@ def convert_finiteelement(element, **kwargs):
225234
elif element.variant() is not None:
226235
finat_kwargs["variant"] = element.variant()
227236

228-
return lmbda(cell, element.degree(), **finat_kwargs), set()
237+
return make_finat_element(cell, element.degree(), **finat_kwargs), deps
229238

230239

231240
# Element modifiers and compound element types

finat/ufl/elementlist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def show_elements():
152152
register_alias("Lob",
153153
lambda family, dim, order, degree: ("Gauss-Lobatto-Legendre", order))
154154

155-
register_element("Bernstein", None, 0, H1, "identity", (1, None), simplices)
155+
register_element("Bernstein", None, 0, H1, "identity", (1, None), any_cell)
156156

157157

158158
# Let Nedelec H(div) elements be aliases to BDMs/RTs

finat/ufl/finiteelement.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,10 @@ def __new__(cls,
9494
return EnrichedElement(HCurl(TensorProductElement(Qc_elt, Id_elt, cell=cell)),
9595
HCurl(TensorProductElement(Qd_elt, Ic_elt, cell=cell)))
9696

97-
elif family == "Q":
98-
return TensorProductElement(*[FiniteElement("CG", c, degree, variant=variant)
97+
elif family in {"Q", "Bernstein"}:
98+
if family == "Q":
99+
family = "CG"
100+
return TensorProductElement(*[FiniteElement(family, c, degree, variant=variant)
99101
for c in cell.sub_cells()],
100102
cell=cell)
101103

test/finat/test_create_finat_element.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ def test_quadrilateral_variant_spectral_q():
106106
assert isinstance(element.product.factors[1], finat.GaussLobattoLegendre)
107107

108108

109+
def test_quadrilateral_bernstein():
110+
element = create_element(finat.ufl.FiniteElement('Bernstein', ufl.quadrilateral, 3))
111+
assert isinstance(element.product.factors[0], finat.Bernstein)
112+
assert isinstance(element.product.factors[1], finat.Bernstein)
113+
114+
109115
def test_quadrilateral_variant_spectral_dq():
110116
element = create_element(finat.ufl.FiniteElement('DQ', ufl.quadrilateral, 1, variant='spectral'))
111117
assert isinstance(element.product.factors[0], finat.GaussLegendre)

0 commit comments

Comments
 (0)