-
Notifications
You must be signed in to change notification settings - Fork 161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax ml operator fix #4041
base: master
Are you sure you want to change the base?
Jax ml operator fix #4041
Conversation
|
|
firedrake/ml/jax/ml_operator.py
Outdated
if isinstance(self._argument_slots[0], Cofunction): | ||
space = self.ufl_operands[0].function_space().dual() | ||
else: | ||
space = self.ufl_function_space() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If c = Cofunction(V.dual())
then V == c.arguments()[0].ufl_function_space()
. Could this be used to not have Cofunction
as a special case here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote this in another way. See the latest commit. I will be happy with your review.
firedrake/ml/pytorch/ml_operator.py
Outdated
@@ -98,7 +102,12 @@ def _pre_forward_callback(self, *operands, unsqueeze=False): | |||
|
|||
def _post_forward_callback(self, y_P): | |||
"""Callback function to convert the PyTorch output of the ML model to a Firedrake function.""" | |||
space = self.ufl_function_space() | |||
# At this point, ``len(self._argument_slots)`` must be greater than 0. | |||
if isinstance(self._argument_slots[0], ufl.coefficient.BaseCoefficient): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
and Cofunction
are both subclasses of BaseCoefficient
. Don't we want different logic for these two cases?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
inherits from ufl.Coefficient
that inherits from BaseCoefficient
. Cofunction
inherits directly from BaseCoefficient
. So, isinstance(self._argument_slots[0], ufl.coefficient.BaseCoefficient)
returns True
for both Function
and Cofunction
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So self._argument_slots[0]
could be either a Function
, Cofunction
, Argument
or Coargument
? What do we want for each case: do we just want V
, V.dual()
, V,
and V.dual()
, respectively?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or are we also expecting BaseForms
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self._argument_slots
contains either only ufl.argument.BaseArgument
(Argument
or Coargument
) objects or a combination of ufl.coefficient.BaseCoefficient
(at index 0) and ufl.argument.BaseArgument
(at index 1). Our goal is to set the function space based on it. I think determining the function space based on self._argument_slots[0]
is reasonable.
as a result of taking the action on a given function. | ||
Tuple containing the arguments of the linear form associated with the ML operator, | ||
i.e. the arguments with respect to which the ML operator is linear. Those arguments can | ||
retain only ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
contain?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
retain only ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both | |
contain ``ufl.argument.BaseArgument`` objects, as a result of differentiation, or both |
firedrake/ml/pytorch/ml_operator.py
Outdated
if isinstance(self._argument_slots[0], ufl.coefficient.BaseCoefficient): | ||
space = self._argument_slots[0].function_space() | ||
else: | ||
# When ``self._argument_slots[0]`` is an ``ufl.argument.BaseArgument``. | ||
space = self._argument_slots[0].arguments()[0].function_space() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if isinstance(self._argument_slots[0], ufl.coefficient.BaseCoefficient): | |
space = self._argument_slots[0].function_space() | |
else: | |
# When ``self._argument_slots[0]`` is an ``ufl.argument.BaseArgument``. | |
space = self._argument_slots[0].arguments()[0].function_space() | |
argument = self._argument_slots[0] | |
if isinstance(argument, ufl.Coargument): | |
argument = argument.arguments()[0] | |
space = argument.function_space() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that space
should be a property or method of the parent class MLOperator
, so this logic does not have to be repeated in every subclass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic would go in UFL: FEniCS/ufl#348
Description
Depends on FEniCS/ufl#348
Fixed an issue in
firedrake.ml.jax.ml_operator
whereargument_slots
was incorrectly specified. I wrote it as optional:Clarified that if
argument_slots
is not provided,ML_Operator
will automatically write it.Test the results involving the Neural operators are in right function space.
Test jax and pytorch operators in Firedrake CI.