Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax ml operator fix #4041

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open

Jax ml operator fix #4041

wants to merge 21 commits into from

Conversation

Ig-dolci
Copy link
Contributor

@Ig-dolci Ig-dolci commented Feb 14, 2025

Description

Depends on FEniCS/ufl#348

  • Fixed an issue in firedrake.ml.jax.ml_operator where argument_slots was incorrectly specified. I wrote it as optional:

    argument_slots: Optional[tuple[Union[ufl.coefficient.BaseCoefficient, ufl.argument.BaseArgument]]] = ()
  • Clarified that if argument_slots is not provided, ML_Operator will automatically write it.

  • Test the results involving the Neural operators are in right function space.

  • Test jax and pytorch operators in Firedrake CI.

Copy link

github-actions bot commented Feb 14, 2025

TestsPassed ✅Skipped ⏭️Failed ❌
Firedrake complex0 ran0 passed0 skipped0 failed

Copy link

github-actions bot commented Feb 14, 2025

TestsPassed ✅Skipped ⏭️Failed ❌
Firedrake real0 ran0 passed0 skipped0 failed

@Ig-dolci Ig-dolci changed the title Does this should be optional? Jax ml operator fix Feb 14, 2025
Comment on lines 93 to 96
if isinstance(self._argument_slots[0], Cofunction):
space = self.ufl_operands[0].function_space().dual()
else:
space = self.ufl_function_space()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If c = Cofunction(V.dual()) then V == c.arguments()[0].ufl_function_space(). Could this be used to not have Cofunction as a special case here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this in another way. See the latest commit. I will be happy with your review.

@Ig-dolci Ig-dolci marked this pull request as ready for review February 17, 2025 14:11
@@ -98,7 +102,12 @@ def _pre_forward_callback(self, *operands, unsqueeze=False):

def _post_forward_callback(self, y_P):
"""Callback function to convert the PyTorch output of the ML model to a Firedrake function."""
space = self.ufl_function_space()
# At this point, ``len(self._argument_slots)`` must be greater than 0.
if isinstance(self._argument_slots[0], ufl.coefficient.BaseCoefficient):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function and Cofunction are both subclasses of BaseCoefficient. Don't we want different logic for these two cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function inherits from ufl.Coefficient that inherits from BaseCoefficient. Cofunction inherits directly from BaseCoefficient. So, isinstance(self._argument_slots[0], ufl.coefficient.BaseCoefficient) returns True for both Function and Cofunction.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So self._argument_slots[0] could be either a Function, Cofunction, Argument or Coargument? What do we want for each case: do we just want V, V.dual(), V, and V.dual(), respectively?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or are we also expecting BaseForms?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self._argument_slots contains either only ufl.argument.BaseArgument (Argument or Coargument) objects or a combination of ufl.coefficient.BaseCoefficient (at index 0) and ufl.argument.BaseArgument (at index 1). Our goal is to set the function space based on it. I think determining the function space based on self._argument_slots[0] is reasonable.

as a result of taking the action on a given function.
Tuple containing the arguments of the linear form associated with the ML operator,
i.e. the arguments with respect to which the ML operator is linear. Those arguments can
retain only ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

contain?

Copy link
Contributor Author

@Ig-dolci Ig-dolci Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
retain only ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both
contain ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both

Comment on lines 106 to 110
if isinstance(self._argument_slots[0], ufl.coefficient.BaseCoefficient):
space = self._argument_slots[0].function_space()
else:
# When ``self._argument_slots[0]`` is an ``ufl.argument.BaseArgument``.
space = self._argument_slots[0].arguments()[0].function_space()
Copy link
Contributor

@pbrubeck pbrubeck Feb 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if isinstance(self._argument_slots[0], ufl.coefficient.BaseCoefficient):
space = self._argument_slots[0].function_space()
else:
# When ``self._argument_slots[0]`` is an ``ufl.argument.BaseArgument``.
space = self._argument_slots[0].arguments()[0].function_space()
argument = self._argument_slots[0]
if isinstance(argument, ufl.Coargument):
argument = argument.arguments()[0]
space = argument.function_space()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that space should be a property or method of the parent class MLOperator, so this logic does not have to be repeated in every subclass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic would go in UFL: FEniCS/ufl#348

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants