Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ksagiyam/add DAGTraverser #365

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ksagiyam
Copy link
Contributor

@ksagiyam ksagiyam commented Mar 26, 2025

Attempt to implement what @wence- described in this issue.

We basically would like to introduce classes (DAGTraversers) that define node processing (using singledispatchmethod) and hold caches.

We considered two DAG traversal approaches (for post-order and mostly post-order traversals).

Approach 1. Two-step approach.

Collect nodes first post-order and then process them in order.

This is the current UFL approach (cutoff_unique_post_traversal + MultiFunction wrapped in map_expr_dags).

This "bottom-up" approach is optimal, but when we process a node, we no longer know the local DAG structure around that node (e.g., we do not know its parent). So if we want to pass down some state from the parent, we need to (indirectly) specify that that parent is a "cutoff" node (so that we do not collect the child nodes in the first step), and perform a sub-DAG traversal under that parent; see RestrictionPropagator.restricted(self, o). In the current implementation, if the number of arguments to the method is two (self and o in the above case), the corresponding node type is regarded as a cutoff node type. MultiFunction constructor currently identifies cutoff node types by inspecting the signature of each method. If we wanted to do a similar thing with singledispatchmethod, we could subclass singledispatchmethod as David suggested, but we found that we would end up overwriting singledispatchmethod.register() method relying on the current implementation of singledispatchmethod, which would be a drawback.

Approach 2. Monolithic approach.

Use recursion with caching (each node is processed once and only once) described in this issue.

I observed about a few % - 20% overhead after rewriting apply_derivatives.py (please see below) presumably due to recursion, but no special consideration of cutoff is required. This approach is claimed to be more robust.

In this PR we take Approach2, and replace MultiFunction s with DAGTraversers in apply_derivatives.py. We should be able to incrementally/systematically remove all MultiFunctions in the future.

Performance checks:

holzapfel_ogden.py (holzapfel_ogden):

import time

from ufl import (
    Coefficient,
    Constant,
    FunctionSpace,
    Identity,
    Mesh,
    TestFunction,
    derivative,
    det,
    diff,
    dot,
    dx,
    exp,
    grad,
    inner,
    ln,
    tetrahedron,
    tr,
    variable,
)
from ufl.algorithms import compute_form_data
from ufl.finiteelement import FiniteElement
from ufl.pullback import identity_pullback
from ufl.sobolevspace import H1


mesh = Mesh(FiniteElement("Lagrange", tetrahedron, 1, (3,), identity_pullback, H1))
lamda = Constant(mesh)
a = Constant(mesh)
b = Constant(mesh)
a_s = Constant(mesh)
b_s = Constant(mesh)
a_f = Constant(mesh)
b_f = Constant(mesh)
a_fs = Constant(mesh)
b_fs = Constant(mesh)
e_s = Constant(mesh, shape=(3,))
e_f = Constant(mesh, shape=(3,))

def isochoric(F):
    C = F.T*F
    I_1 = tr(C)
    I4_f = dot(e_f, C*e_f)
    I4_s = dot(e_s, C*e_s)
    I8_fs = dot(e_f, C*e_s)

    def cutoff(x):
        return 1.0/(1.0 + exp(-(x - 1.0)*30.0))

    def scaled_exp(a0, a1, argument):
        return a0/(2.0*a1)*(exp(b*argument) - 1)

    def scaled_exp(a0, a1, argument):
        return a0/(2.0*a1)*(exp(b*argument) - 1)

    E_1 = scaled_exp(a, b, I_1 - 3.)
    E_f = cutoff(I4_f)*scaled_exp(a_f, b_f, (I4_f - 1.)**2)
    E_s = cutoff(I4_s)*scaled_exp(a_s, b_s, (I4_s - 1.)**2)
    E_3 = scaled_exp(a_fs, b_fs, I8_fs**2)
    E = E_1 + E_f + E_s + E_3
    return E

elem = FiniteElement("Lagrange", tetrahedron, 1, (3,), identity_pullback, H1)
V = FunctionSpace(mesh, elem)
u = Coefficient(V)
v = TestFunction(V)
I = Identity(mesh.ufl_cell().topological_dimension())
F = grad(u) + I
F = variable(F)
J = det(F)
Fbar = J**(-1.0/3.0)*F
E_volumetric = lamda*0.5*ln(J)**2
psi = isochoric(Fbar) + E_volumetric
P = diff(psi, F)
F = inner(P, grad(v))*dx
a = derivative(F, u)

ntest = 2
time_sum = 0
for _ in range(ntest):
    start = time.time()
    fd = compute_form_data(
        a,
        do_apply_function_pullbacks=True,
        do_apply_default_restrictions=True,
        do_apply_geometry_lowering=True,
        do_apply_restrictions=True,
        complex_mode=False,
    )
    end = time.time()
    time_sum += end - start
print("average time required: ", time_sum / ntest)

main:
average time required: 0.27037227153778076

this PR:
average time required: 0.27785348892211914 (+2.8%)

With #339 (<- performance regression) reverted:

main:
average time required: 0.04610729217529297

this PR:
average time required: 0.055258870124816895 (+20%)

wence_test (#69):

import time

from ufl import (
    Coefficient,
    Constant,
    FacetNormal,
    FunctionSpace,
    Identity,
    Mesh,
    avg,
    derivative,
    det,
    diff,
    dS,
    grad,
    hexahedron,
    inner,
    ln,
    outer,
    variable,
)
from ufl.algorithms import compute_form_data
from ufl.finiteelement import FiniteElement
from ufl.pullback import identity_pullback, contravariant_piola
from ufl.sobolevspace import H1, HDiv

cell = hexahedron
mesh = Mesh(FiniteElement("Q", cell, 1, (3,), identity_pullback, H1))
V = FunctionSpace(mesh, FiniteElement("NCF", cell, 2, (3,), contravariant_piola, HDiv))
u = Coefficient(V)
mu = Constant(mesh)
psi = lambda F: (mu/2) * (inner(F, F) - u.ufl_shape[0]) - mu*ln(det(F))
flux = lambda F: diff(psi(F), F)
n = FacetNormal(mesh)
uxn = outer(u('+'), n('+')) + outer(u('-'), n('-'))
eye = Identity(u.ufl_shape[0])
grad_u = grad(u)
F_ = variable(grad_u + eye)
Fm = variable(grad_u('-') + eye)
Fp = variable(grad_u('+') + eye)
avg_flux = avg(flux(F_))
U = inner(avg_flux, uxn)*dS
a = derivative(derivative(U, u), u)

ntest = 2
time_sum = 0
for _ in range(ntest):
    start = time.time()
    fd = compute_form_data(
        a,
        do_apply_function_pullbacks=True,
        do_apply_default_restrictions=True,
        do_apply_geometry_lowering=True,
        do_apply_restrictions=True,
        complex_mode=False,
    )
    end = time.time()
    time_sum += end - start
print("average time required: ", time_sum / ntest)

main:
average time required: 0.23360061645507812

this PR:
average time required: 0.2485201358795166 (+6.4%)

With #339 reverted:

main:
average time required: 0.060405850410461426

this PR:
average time required: 0.06431174278259277 (+6.5%)

Firedrake CI
firedrakeproject/firedrake#4145 (no notable performance regression).

@ksagiyam ksagiyam force-pushed the ksagiyam/add_dag_visitor branch 4 times, most recently from bffe57a to b99a902 Compare March 28, 2025 01:32
Copy link
Contributor

@connorjward connorjward left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is pretty great. I have written a lot of tree traversals and the DAGTraverser.postorder (and friends) solves an issue I've been thinking about.

It's almost a shame that this only lives in UFL.



class DAGTraverser(ABC):
"""Base class for dag traversers."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should document the __init__ parameters here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


@staticmethod
def postorder(method):
"""Suppress processed operands in arguments.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be clearer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it hard to keep it short. Please let me know if you have a suggestion. We at least have an expanded explanation right below.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Decorator indicating that child nodes are visited first in the traversal."?

I'm afraid I have no idea what the current docstring means.

It would also be fine to have something trivial like "Postorder decorator." provided that there is an explanation in the rest of the docstring.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to "Postorder decorator".

Processed object.

"""
raise AssertionError(f"UFL expression expected: got {o}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this now boilerplate that needs reproducing for each subclass? Might it be better on the ABC?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean putting this base case in the base class? I think the problem is that singledispatchmethod object needs to be made in each child class. What I had to do here is probably relevant https://github.com/firedrakeproject/ufl/blob/c16f038378e17ad8bd730bf9a020028ad72a00c1/ufl/algorithms/apply_derivatives.py#L716.


def component_tensor(self, o, Ap, ii):
@process.register(ComponentTensor)
@DAGTraverser.postorder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty sweet

@ksagiyam ksagiyam force-pushed the ksagiyam/add_dag_visitor branch from c74c7be to c16f038 Compare April 3, 2025 13:07
- replace MultiFunction with DAGTraverser in apply_derivatives.py

Co-authored-by: Connor Ward <[email protected]>
@ksagiyam ksagiyam force-pushed the ksagiyam/add_dag_visitor branch from c16f038 to 449b44e Compare April 3, 2025 13:46
@ksagiyam ksagiyam marked this pull request as ready for review April 3, 2025 15:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants