Skip to content

Commit

Permalink
fix fillna for variable and expressions
Browse files Browse the repository at this point in the history
harmonize fill_value for constants in expression
  • Loading branch information
FabianHofmann committed Oct 26, 2023
1 parent cf62967 commit 76a1092
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 21 deletions.
4 changes: 3 additions & 1 deletion doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ Upcoming Release
----------------

* It is now possible to set the sense of the objective function to `minimize` or `maximize`. Therefore, a new class `Objective` was introduced which is used in `Model.objective`. It supports the same arithmetic operations as `LinearExpression` and `QuadraticExpression` and contains a `sense` attribute which can be set to `minimize` or `maximize`.

* The `fill_value` of default of constants in the LinearExpression and QuadraticExpression classes was changed to ``NaN``.
* The `fillna` function for variables was made more secure by raising a warning if the fill value is not of variable-like type.
* The `where` and `fillna` functions for expressions were made more flexible: When passing a scalar value or a DataArray, the values are added as constants to the expression, where there were missing values before. If another expression is passed, the values are added to the expression, where there were missing values before.

Version 0.2.5
-------------
Expand Down
2 changes: 1 addition & 1 deletion linopy/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def print_line(expr, const):

res.append(f"{coeff_string}{var_string}")

if const != 0:
if not np.isnan(const):
const_string = f"{const:+.4g}"
if len(res):
res.append(f"{const_string[0]} {const_string[1:]}")
Expand Down
59 changes: 52 additions & 7 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def _exprwrap(expr, *args, **kwargs):


def _expr_unwrap(maybe_expr):
if isinstance(maybe_expr, LinearExpression):
if isinstance(maybe_expr, (LinearExpression, QuadraticExpression)):
return maybe_expr.data

return maybe_expr
Expand Down Expand Up @@ -271,7 +271,7 @@ class LinearExpression:
__array_ufunc__ = None
__array_priority__ = 10000

_fill_value = {"vars": -1, "coeffs": np.nan, "const": 0}
_fill_value = {"vars": -1, "coeffs": np.nan, "const": np.nan}

def __init__(self, data, model):
from linopy.model import Model
Expand Down Expand Up @@ -814,7 +814,18 @@ def reset_const(self):
"""
return self.__class__(self.data[["coeffs", "vars"]], self.model)

def where(self, cond, other=xr.core.dtypes.NA, **kwargs):
def isnull(self):
"""
Get a boolean mask with true values where there is only missing values in an expression.
Returns
-------
xr.DataArray
"""
helper_dims = set(self.vars.dims).intersection(HELPER_DIMS)
return (self.vars == -1).all(helper_dims) & self.const.isnull()

def where(self, cond, other=None, **kwargs):
"""
Filter variables based on a condition.
Expand All @@ -826,6 +837,11 @@ def where(self, cond, other=xr.core.dtypes.NA, **kwargs):
cond : DataArray or callable
Locations at which to preserve this object's values. dtype must be `bool`.
If a callable, it must expect this object as its only parameter.
other : expression-like, DataArray or scalar, optional
Data to use in place of values where cond is False.
If a DataArray or a scalar is provided, it is only used to fill
the missing values of constant values (`const`).
If a DataArray, its coordinates must match this object's.
**kwargs :
Keyword arguments passed to ``xarray.Dataset.where``
Expand All @@ -834,14 +850,45 @@ def where(self, cond, other=xr.core.dtypes.NA, **kwargs):
linopy.LinearExpression
"""
# Cannot set `other` if drop=True
if other is xr.core.dtypes.NA:
if other is None or other is np.nan:
if not kwargs.get("drop", False):
other = self._fill_value
elif isinstance(other, (DataArray, np.floating, np.integer, int, float)):
other = {**self._fill_value, "const": other}
else:
other = _expr_unwrap(other)
cond = _expr_unwrap(cond)
if isinstance(cond, DataArray):
if helper_dims := set(HELPER_DIMS).intersection(cond.dims):
raise ValueError(
f"Filtering by a DataArray with a helper dimension(s) ({helper_dims!r}) is not supported."
)
return self.__class__(self.data.where(cond, other=other, **kwargs), self.model)

def fillna(self, value):
"""
Fill missing values with a given value.
This method fills missing values in the data with a given value. It calls the `fillna` method of the underlying
`xarray.Dataset` object, but sets the default fill value to -1 for variables and ensures that the output is of
type `linopy.LinearExpression`.
Parameters
----------
value : scalar or array_like
Value(s) to use to fill missing values. If a scalar is provided, it will be used to fill all missing values as a constant.
If an array-like object is provided, it should have the same shape as the data and will be used to fill missing values element-wise as a constant.
Returns
-------
linopy.LinearExpression
A new `linopy.LinearExpression` object with missing values filled with the given value.
"""
value = _expr_unwrap(value)
if isinstance(value, (DataArray, np.floating, np.integer, int, float)):
value = {"const": value}
return self.__class__(self.data.fillna(value), self.model)

def diff(self, dim, n=1):
"""
Calculate the n-th order discrete difference along given axis.
Expand Down Expand Up @@ -1065,8 +1112,6 @@ def mask_func(data):

ffill = exprwrap(Dataset.ffill)

fillna = exprwrap(Dataset.fillna, value=_fill_value)

sel = exprwrap(Dataset.sel)

isel = exprwrap(Dataset.isel)
Expand Down Expand Up @@ -1097,7 +1142,7 @@ class QuadraticExpression(LinearExpression):
__array_ufunc__ = None
__array_priority__ = 10000

_fill_value = {"vars": -1, "coeffs": np.nan, "const": 0}
_fill_value = {"vars": -1, "coeffs": np.nan, "const": np.nan}

def __init__(self, data, model):
super().__init__(data, model)
Expand Down
36 changes: 31 additions & 5 deletions linopy/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,13 @@ def diff(self, dim, n=1):
"""
return self.to_linexpr().diff(dim, n)

def where(self, cond, other=-1, **kwargs):
def isnull(self):
"""
Get a boolean mask with true values where there is missing values.
"""
return self.labels == -1

def where(self, cond, other=None, **kwargs):
"""
Filter variables based on a condition.
Expand All @@ -703,14 +709,36 @@ def where(self, cond, other=-1, **kwargs):
-------
linopy.Variable
"""
if isinstance(other, Variable):
if other is None:
other = self._fill_value
elif isinstance(other, Variable):
other = other.data
elif isinstance(other, ScalarVariable):
other = {"labels": other.label, "lower": other.lower, "upper": other.upper}
elif not isinstance(other, (dict, Dataset)):
warn(
"other argument of Variable.where should be a Variable, ScalarVariable or dict. "
"Other types will not be supported in the future.",
FutureWarning,
)
return self.__class__(
self.data.where(cond, other, **kwargs), self.model, self.name
)

def fillna(self, fill_value):
"""
Fill missing values with a variable.
This operation call ``xarray.DataArray.fillna`` but ensures preserving
the linopy.Variable type.
Parameters
----------
fill_value : Variable/ScalarVariable
Variable to use for filling.
"""
return self.where(~self.isnull(), fill_value)

def ffill(self, dim, limit=None):
"""
Forward fill the variable along a dimension.
Expand Down Expand Up @@ -757,7 +785,7 @@ def bfill(self, dim, limit=None):
linopy.Variable
"""
data = (
self.data.where(self.labels != -1)
self.data.where(~self.isnull())
# .bfill(dim, limit=limit)
# breaks with Dataset.bfill, use map instead
.map(DataArray.bfill, dim=dim, limit=limit).fillna(self._fill_value)
Expand Down Expand Up @@ -796,8 +824,6 @@ def equals(self, other):

drop_isel = varwrap(Dataset.drop_isel)

fillna = varwrap(Dataset.fillna)

sel = varwrap(Dataset.sel)

isel = varwrap(Dataset.isel)
Expand Down
73 changes: 68 additions & 5 deletions test/test_linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,40 +374,103 @@ def test_linear_expression_loc(x, y):
assert expr.loc[0].size < expr.loc[:5].size


def test_linear_expression_isnull(v):
expr = np.arange(20) * v
filter = (expr.coeffs >= 10).any(TERM_DIM)
expr = expr.where(filter)
assert expr.isnull().sum() == 10


def test_linear_expression_where(v):
expr = np.arange(20) * v
expr = expr.where(expr.coeffs >= 10)
filter = (expr.coeffs >= 10).any(TERM_DIM)
expr = expr.where(filter)
assert isinstance(expr, LinearExpression)
assert expr.nterm == 1

expr = np.arange(20) * v
expr = expr.where(expr.coeffs >= 10, drop=True).sum()
expr = expr.where(filter, drop=True).sum()
assert isinstance(expr, LinearExpression)
assert expr.nterm == 10


def test_linear_expression_where_with_const(v):
expr = np.arange(20) * v + 10
expr = expr.where(expr.coeffs >= 10)
filter = (expr.coeffs >= 10).any(TERM_DIM)
expr = expr.where(filter)
assert isinstance(expr, LinearExpression)
assert expr.nterm == 1
assert (expr.const[:10] == 0).all()
assert expr.const[:10].isnull().all()
assert (expr.const[10:] == 10).all()

expr = np.arange(20) * v + 10
expr = expr.where(expr.coeffs >= 10, drop=True).sum()
expr = expr.where(filter, drop=True).sum()
assert isinstance(expr, LinearExpression)
assert expr.nterm == 10
assert expr.const == 100


def test_linear_expression_where_scalar_fill_value(v):
expr = np.arange(20) * v + 10
filter = (expr.coeffs >= 10).any(TERM_DIM)
expr = expr.where(filter, 200)
assert isinstance(expr, LinearExpression)
assert expr.nterm == 1
assert (expr.const[:10] == 200).all()
assert (expr.const[10:] == 10).all()


def test_linear_expression_where_array_fill_value(v):
expr = np.arange(20) * v + 10
filter = (expr.coeffs >= 10).any(TERM_DIM)
other = expr.coeffs
expr = expr.where(filter, other)
assert isinstance(expr, LinearExpression)
assert expr.nterm == 1
assert (expr.const[:10] == other[:10]).all()
assert (expr.const[10:] == 10).all()


def test_linear_expression_where_expr_fill_value(v):
expr = np.arange(20) * v + 10
expr2 = np.arange(20) * v + 5
filter = (expr.coeffs >= 10).any(TERM_DIM)
res = expr.where(filter, expr2)
assert isinstance(res, LinearExpression)
assert res.nterm == 1
assert (res.const[:10] == expr2.const[:10]).all()
assert (res.const[10:] == 10).all()


def test_where_with_helper_dim_false(v):
expr = np.arange(20) * v
with pytest.raises(ValueError):
filter = expr.coeffs >= 10
expr.where(filter)


def test_linear_expression_shift(v):
shifted = v.to_linexpr().shift(dim_2=2)
assert shifted.nterm == 1
assert shifted.coeffs.loc[:1].isnull().all()
assert (shifted.vars.loc[:1] == -1).all()


def test_linear_expression_fillna(v):
expr = np.arange(20) * v + 10
assert expr.const.sum() == 200

filter = (expr.coeffs >= 10).any(TERM_DIM)
filtered = expr.where(filter)
assert isinstance(filtered, LinearExpression)
assert filtered.const.sum() == 100

filled = filtered.fillna(10)
assert isinstance(filled, LinearExpression)
assert filled.const.sum() == 200
assert filled.coeffs.isnull().sum() == 10


def test_linear_expression_diff(v):
diff = v.to_linexpr().diff("dim_2")
assert diff.nterm == 2
Expand Down
10 changes: 10 additions & 0 deletions test/test_quadratic_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
import pytest
from scipy.sparse import csc_matrix
from xarray import DataArray

from linopy import Model, merge
from linopy.constants import FACTOR_DIM, TERM_DIM
Expand Down Expand Up @@ -124,6 +125,15 @@ def test_quadratic_expression_loc(x):
assert expr.loc[0].size < expr.loc[:5].size


def test_quadratic_expression_isnull(x):
expr = np.arange(2) * x * x
filter = (expr.coeffs > 0).any(TERM_DIM)
expr = expr.where(filter)
isnull = expr.isnull()
assert isinstance(isnull, DataArray)
assert isnull.sum() == 1


def test_quadratic_expression_flat(x, y):
expr = x * y + x + 5
df = expr.flat
Expand Down
20 changes: 18 additions & 2 deletions test/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,21 @@ def test_variable_shift(x):
assert x.labels[0] == -1


def test_isnull(x):
x = x.where([True] * 4 + [False] * 6)
assert isinstance(x.isnull(), xr.DataArray)
assert (x.isnull() == [False] * 4 + [True] * 6).all()


def test_variable_fillna(x):
x = x.where([True] * 4 + [False] * 6)

with pytest.warns(FutureWarning):
x.fillna(0)

isinstance(x.fillna(x[0]), linopy.variables.Variable)


def test_variable_bfill(x):
x = x.where([False] * 4 + [True] * 6)
x = x.bfill("first")
Expand All @@ -185,13 +200,14 @@ def test_variable_ffill(x):


def test_variable_fillna(x):
result = x.fillna(-1)
result = x.fillna(x[0])
assert isinstance(result, linopy.variables.Variable)


def test_variable_sanitize(x):
# convert intentionally to float with nans
x = x.where([True] * 4 + [False] * 6, np.nan)
fill_value = {"labels": np.nan, "lower": np.nan, "upper": np.nan}
x = x.where([True] * 4 + [False] * 6, fill_value)
x = x.sanitize()
assert isinstance(x, linopy.variables.Variable)
assert x.labels[9] == -1

0 comments on commit 76a1092

Please sign in to comment.