-
-
Notifications
You must be signed in to change notification settings - Fork 69
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ksagiyam/add DAGTraverser #365
base: main
Are you sure you want to change the base?
Conversation
bffe57a
to
b99a902
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is pretty great. I have written a lot of tree traversals and the DAGTraverser.postorder
(and friends) solves an issue I've been thinking about.
It's almost a shame that this only lives in UFL.
ufl/corealg/dag_traverser.py
Outdated
|
||
|
||
class DAGTraverser(ABC): | ||
"""Base class for dag traversers.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should document the __init__
parameters here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
ufl/corealg/dag_traverser.py
Outdated
|
||
@staticmethod | ||
def postorder(method): | ||
"""Suppress processed operands in arguments. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this be clearer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found it hard to keep it short. Please let me know if you have a suggestion. We at least have an expanded explanation right below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Decorator indicating that child nodes are visited first in the traversal."?
I'm afraid I have no idea what the current docstring means.
It would also be fine to have something trivial like "Postorder decorator." provided that there is an explanation in the rest of the docstring.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed to "Postorder decorator".
Processed object. | ||
|
||
""" | ||
raise AssertionError(f"UFL expression expected: got {o}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this now boilerplate that needs reproducing for each subclass? Might it be better on the ABC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean putting this base case in the base class? I think the problem is that singledispatchmethod
object needs to be made in each child class. What I had to do here is probably relevant https://github.com/firedrakeproject/ufl/blob/c16f038378e17ad8bd730bf9a020028ad72a00c1/ufl/algorithms/apply_derivatives.py#L716.
|
||
def component_tensor(self, o, Ap, ii): | ||
@process.register(ComponentTensor) | ||
@DAGTraverser.postorder |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is pretty sweet
c74c7be
to
c16f038
Compare
- replace MultiFunction with DAGTraverser in apply_derivatives.py Co-authored-by: Connor Ward <[email protected]>
c16f038
to
449b44e
Compare
Attempt to implement what @wence- described in this issue.
We basically would like to introduce classes (
DAGTraverser
s) that define node processing (usingsingledispatchmethod
) and hold caches.We considered two DAG traversal approaches (for post-order and mostly post-order traversals).
Approach 1. Two-step approach.
Collect nodes first post-order and then process them in order.
This is the current UFL approach (
cutoff_unique_post_traversal
+MultiFunction
wrapped inmap_expr_dags
).This "bottom-up" approach is optimal, but when we process a node, we no longer know the local DAG structure around that node (e.g., we do not know its parent). So if we want to pass down some state from the parent, we need to (indirectly) specify that that parent is a "cutoff" node (so that we do not collect the child nodes in the first step), and perform a sub-DAG traversal under that parent; see RestrictionPropagator.restricted(self, o). In the current implementation, if the number of arguments to the method is two (
self
ando
in the above case), the corresponding node type is regarded as a cutoff node type.MultiFunction
constructor currently identifies cutoff node types byinspect
ing the signature of each method. If we wanted to do a similar thing withsingledispatchmethod
, we could subclasssingledispatchmethod
as David suggested, but we found that we would end up overwritingsingledispatchmethod.register()
method relying on the current implementation ofsingledispatchmethod
, which would be a drawback.Approach 2. Monolithic approach.
Use recursion with caching (each node is processed once and only once) described in this issue.
I observed about a few % - 20% overhead after rewriting
apply_derivatives.py
(please see below) presumably due to recursion, but no special consideration of cutoff is required. This approach is claimed to be more robust.In this PR we take Approach2, and replace
MultiFunction
s withDAGTraverser
s inapply_derivatives.py
. We should be able to incrementally/systematically remove allMultiFunction
s in the future.Performance checks:
holzapfel_ogden.py
(holzapfel_ogden):main
:average time required: 0.27037227153778076
this PR
:average time required: 0.27785348892211914 (+2.8%)
With #339 (<- performance regression) reverted:
main
:average time required: 0.04610729217529297
this PR
:average time required: 0.055258870124816895 (+20%)
wence_test
(#69):main
:average time required: 0.23360061645507812
this PR
:average time required: 0.2485201358795166 (+6.4%)
With #339 reverted:
main
:average time required: 0.060405850410461426
this PR
:average time required: 0.06431174278259277 (+6.5%)
Firedrake CI
firedrakeproject/firedrake#4145 (no notable performance regression).