Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for specifying interpolation modes for time and rho_norm when loading InterpolatedVarTimeRho objects. #422

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,24 @@ Time-varying arrays
Time-varying arrays can be defined using either primitives, an
``xarray.DataArray`` or a ``tuple`` of ``Array``.

Specifying interpolation methods
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
By default piecewise linear interpolation is used to interpolate values in time.
To specify a different interpolation method, use the following syntax of a tuple
with two elements. The first element in the tuple is the usual value for the
time-varying-array (as defined below), the second value is a dict with keys
``time_interpolation_mode`` and ``rho_interpolation_mode`` and values the
desired interpolation modes.

.. code-block:: python

(time_varying_array_value, {'time_interpolation_mode': 'STEP', 'rho_interpolation_mode': 'PIECEWISE_LINEAR'})

Currently two interpolation modes are supported:

* ``'STEP'``
* ``'PIECEWISE_LINEAR'``

Using primitives
^^^^^^^^^^^^^^^^

Expand Down
35 changes: 30 additions & 5 deletions torax/config/config_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
# TypeVar for generic dataclass types.
_T = TypeVar('_T')
RHO_NORM = 'rho_norm'
TIME_INTERPOLATION_MODE = 'time_interpolation_mode'
RHO_INTERPOLATION_MODE = 'rho_interpolation_mode'


def input_is_an_interpolated_var_single_axis(
Expand Down Expand Up @@ -283,6 +285,11 @@ def get_interpolated_var_2d(
2. An xr.DataArray is passed in, see _load_from_xr_array for details.
3. A tuple of arrays is passed in, see _load_from_arrays for details.

Additionally the interpolation mode for rhon and time can be specified as
strings by passing a 3-tuple with the first element being the input, the
second element being the time interpolation mode and the third element
being the rhon interpolation mode.

Args:
time_rho_interpolated_input: An input that can be used to construct a
InterpolatedVarTimeRho object.
Expand All @@ -292,11 +299,27 @@ def get_interpolated_var_2d(
An InterpolatedVarTimeRho object which has been preinterpolated onto the
provided rho_norm values.
"""
# if isinstance(time_rho_interpolated_input, float):
# values = _load_from_float(
# time_rho_interpolated_input,
# rho_norm,
# )
# Potentially parse the interpolation modes from the input.
time_interpolation_mode = (
interpolated_param.InterpolationMode.PIECEWISE_LINEAR
)
rho_interpolation_mode = interpolated_param.InterpolationMode.PIECEWISE_LINEAR
if isinstance(time_rho_interpolated_input, tuple):
if (
len(time_rho_interpolated_input) == 2
and isinstance(time_rho_interpolated_input[1], dict)
):
# Second and third elements in tuple are interpreted as interpolation
# modes.
time_interpolation_mode = interpolated_param.InterpolationMode[
time_rho_interpolated_input[1][TIME_INTERPOLATION_MODE].upper()
]
rho_interpolation_mode = interpolated_param.InterpolationMode[
time_rho_interpolated_input[1][RHO_INTERPOLATION_MODE].upper()
]
# First element in tuple assumed to be the input.
time_rho_interpolated_input = time_rho_interpolated_input[0]

if isinstance(time_rho_interpolated_input, xr.DataArray):
values = _load_from_xr_array(
time_rho_interpolated_input,
Expand All @@ -319,6 +342,8 @@ def get_interpolated_var_2d(
time_rho_interpolated = interpolated_param.InterpolatedVarTimeRho(
values=values,
rho_norm=rho_norm,
time_interpolation_mode=time_interpolation_mode,
rho_interpolation_mode=rho_interpolation_mode,
)
return time_rho_interpolated

Expand Down
Loading