|
10 | 10 | import operator |
11 | 11 | import os |
12 | 12 | from collections.abc import Callable, Generator, Hashable, Iterable, Sequence |
13 | | -from functools import reduce, wraps |
| 13 | +from functools import partial, reduce, wraps |
14 | 14 | from pathlib import Path |
15 | 15 | from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload |
16 | 16 | from warnings import warn |
@@ -1205,7 +1205,7 @@ def check_common_keys_values(list_of_dicts: list[dict[str, Any]]) -> bool: |
1205 | 1205 |
|
1206 | 1206 | def align( |
1207 | 1207 | *objects: LinearExpression | QuadraticExpression | Variable | T_Alignable, |
1208 | | - join: JoinOptions = "exact", |
| 1208 | + join: JoinOptions | None = None, |
1209 | 1209 | copy: bool = True, |
1210 | 1210 | indexes: Any = None, |
1211 | 1211 | exclude: str | Iterable[Hashable] = frozenset(), |
@@ -1265,41 +1265,56 @@ def align( |
1265 | 1265 |
|
1266 | 1266 |
|
1267 | 1267 | """ |
| 1268 | + from linopy.config import options |
1268 | 1269 | from linopy.expressions import LinearExpression, QuadraticExpression |
1269 | 1270 | from linopy.variables import Variable |
1270 | 1271 |
|
1271 | | - # Extract underlying Datasets for index computation. |
| 1272 | + if join is None: |
| 1273 | + join = options["arithmetic_convention"] |
| 1274 | + |
| 1275 | + if join == "legacy": |
| 1276 | + from linopy.config import LEGACY_DEPRECATION_MESSAGE, LinopyDeprecationWarning |
| 1277 | + |
| 1278 | + warn( |
| 1279 | + LEGACY_DEPRECATION_MESSAGE, |
| 1280 | + LinopyDeprecationWarning, |
| 1281 | + stacklevel=2, |
| 1282 | + ) |
| 1283 | + join = "inner" |
| 1284 | + |
| 1285 | + elif join == "v1": |
| 1286 | + join = "exact" |
| 1287 | + |
| 1288 | + finisher: list[partial[Any] | Callable[[Any], Any]] = [] |
1272 | 1289 | das: list[Any] = [] |
1273 | 1290 | for obj in objects: |
1274 | | - if isinstance(obj, LinearExpression | QuadraticExpression | Variable): |
| 1291 | + if isinstance(obj, LinearExpression | QuadraticExpression): |
| 1292 | + finisher.append(partial(obj.__class__, model=obj.model)) |
| 1293 | + das.append(obj.data) |
| 1294 | + elif isinstance(obj, Variable): |
| 1295 | + finisher.append( |
| 1296 | + partial( |
| 1297 | + obj.__class__, |
| 1298 | + model=obj.model, |
| 1299 | + name=obj.data.attrs["name"], |
| 1300 | + skip_broadcast=True, |
| 1301 | + ) |
| 1302 | + ) |
1275 | 1303 | das.append(obj.data) |
1276 | 1304 | else: |
| 1305 | + finisher.append(lambda x: x) |
1277 | 1306 | das.append(obj) |
1278 | 1307 |
|
1279 | 1308 | exclude = frozenset(exclude).union(HELPER_DIMS) |
1280 | | - |
1281 | | - # Compute target indexes. |
1282 | | - target_aligned = xr_align( |
1283 | | - *das, join=join, copy=False, indexes=indexes, exclude=exclude |
| 1309 | + aligned = xr_align( |
| 1310 | + *das, |
| 1311 | + join=join, |
| 1312 | + copy=copy, |
| 1313 | + indexes=indexes, |
| 1314 | + exclude=exclude, |
| 1315 | + fill_value=fill_value, |
1284 | 1316 | ) |
1285 | | - |
1286 | | - # Reindex each object to target indexes. |
1287 | | - reindex_kwargs: dict[str, Any] = {} |
1288 | | - if fill_value is not dtypes.NA: |
1289 | | - reindex_kwargs["fill_value"] = fill_value |
1290 | | - results: list[Any] = [] |
1291 | | - for obj, target in zip(objects, target_aligned): |
1292 | | - indexers = { |
1293 | | - dim: target.indexes[dim] |
1294 | | - for dim in target.dims |
1295 | | - if dim not in exclude and dim in target.indexes |
1296 | | - } |
1297 | | - # Variable.reindex has no fill_value — it always uses sentinels |
1298 | | - if isinstance(obj, Variable): |
1299 | | - results.append(obj.reindex(indexers)) |
1300 | | - else: |
1301 | | - results.append(obj.reindex(indexers, **reindex_kwargs)) # type: ignore[union-attr] |
1302 | | - return tuple(results) |
| 1317 | + return tuple([f(da) for f, da in zip(finisher, aligned)]) |
1303 | 1318 |
|
1304 | 1319 |
|
1305 | 1320 | LocT = TypeVar( |
|
0 commit comments