From 754c0744596d90882f541533e9fccebe81458640 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Wed, 3 Jun 2026 13:16:40 +0200 Subject: [PATCH 1/4] perf: scatter groupby-sum terms directly instead of unstacking The fast path of LinearExpression.groupby(...).sum() used ds.unstack(group_dim, fill_value=...) followed by a stack, which materializes 2-3 intermediate copies of the padded result (n_groups x max_group_size x nterm) and goes through pandas MultiIndex machinery sized by the number of elements. Instead, factorize the groups and scatter coeffs/vars directly into the preallocated padded result arrays; constants are group-summed with np.add.at. Peak memory drops to input + result (the minimum for the padded layout) and the grouping itself gets considerably faster. The result is unchanged: same dims, coords, term ordering and padding. The unstack-based implementation is kept as _sum_by_unstack and still used for chunked (dask-backed) data, which cannot be scattered into numpy arrays. NaN group labels now raise an informative ValueError instead of failing inside unstack. Co-Authored-By: Claude Opus 4.8 (1M context) --- linopy/expressions.py | 140 +++++++++++++++++++++++++++++---- test/test_linear_expression.py | 124 +++++++++++++++++++++++++++++ 2 files changed, 250 insertions(+), 14 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index ea8588d2..0e42af19 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -340,20 +340,13 @@ def sum( # At this point, group is always a pandas Series assert isinstance(group, pd.Series) - group_dim = group.index.name - - arrays = [group, group.groupby(group).cumcount()] - idx = pd.MultiIndex.from_arrays(arrays, names=[GROUP_DIM, GROUPED_TERM_DIM]) - new_coords = Coordinates.from_pandas_multiindex(idx, group_dim) - # collapsing group_dim invalidates every coordinate aligned to it - names_to_drop = [ - name - for name, coord in self.data.coords.items() - if group_dim in coord.dims - ] - ds = self.data.drop_vars(names_to_drop).assign_coords(new_coords) - ds = ds.unstack(group_dim, fill_value=LinearExpression._fill_value) - ds = LinearExpression._sum(ds, dim=GROUPED_TERM_DIM) + + if self._can_sum_by_scatter(group): + ds = self._sum_by_scatter(group) + else: + # chunked (e.g. dask-backed) data or exotic coordinates on the + # grouped dimension: use xarray's unstack machinery + ds = self._sum_by_unstack(group) if int_map is not None: index = ds.indexes[GROUP_DIM].map({v: k for k, v in int_map.items()}) @@ -374,6 +367,125 @@ def func(ds: Dataset) -> Dataset: return self.map(func, **kwargs, shortcut=True) + def _can_sum_by_scatter(self, group: pd.Series) -> bool: + """ + Whether :meth:`_sum_by_scatter` covers the structure of the data. + + The scatter kernel requires numpy-backed arrays (chunked data cannot be + scattered into preallocated numpy arrays) and no coordinates tied to + the grouped dimension besides its own index. Everything else falls + back to :meth:`_sum_by_unstack`. + """ + data = self.data + group_dim = group.index.name + + numpy_backed = all( + isinstance(data[k].data, np.ndarray) for k in ("coeffs", "vars", "const") + ) + if not numpy_backed: + return False + + index = data.indexes.get(group_dim) + index_names = {group_dim, *(index.names if index is not None else ())} + return all( + coord.dims == (group_dim,) and name in index_names + for name, coord in data.coords.items() + if group_dim in coord.dims + ) + + def _sum_by_scatter(self, group: pd.Series) -> Dataset: + """ + Sum groups by scattering all terms directly into the final padded arrays. + + Every group member keeps its block of ``nterm`` terms, so the resulting + term dimension has size ``max_group_size * nterm`` and smaller groups are + padded with fill values. In contrast to :meth:`_sum_by_unstack` only the + result arrays are allocated, without intermediate copies of that size. + + Only the term and constant values are computed with numpy; the result + structure (dimensions, coordinates and their order) is assembled by + xarray. :meth:`_can_sum_by_scatter` decides whether the data is simple + enough for this kernel. + """ + data = self.data + group_dim = group.index.name + fill_value = LinearExpression._fill_value + + codes, unique_groups = pd.factorize(group, sort=True) + if (codes == -1).any(): + raise ValueError( + "Cannot group by a pandas object containing NaN values. " + "Drop or fill the corresponding entries before grouping." + ) + + n_groups = len(unique_groups) + sizes = np.bincount(codes, minlength=n_groups) + max_size = int(sizes.max()) if n_groups else 0 + + # position of each element within its group (order of appearance) + positions = pd.Series(codes).groupby(codes).cumcount().to_numpy() + + def scatter( + da: DataArray, fill: Any + ) -> tuple[tuple[Hashable, ...], np.ndarray]: + """Scatter one term-array into its padded (group x term) layout.""" + rest_dims = [d for d in da.dims if d not in (group_dim, TERM_DIM)] + values = da.transpose(group_dim, *rest_dims, TERM_DIM).values + rest_shape = values.shape[1:-1] + nterm = values.shape[-1] + + out = np.full( + (n_groups, *rest_shape, nterm, max_size), fill, dtype=values.dtype + ) + locs = (codes, *(slice(None),) * (len(rest_shape) + 1), positions) + out[locs] = values + # collapsing (nterm, max_size) into one axis keeps all terms of one + # group member together, with padding at the end of each block + out = out.reshape((n_groups, *rest_shape, nterm * max_size)) + return (GROUP_DIM, *rest_dims, TERM_DIM), out + + coeffs_dims, coeffs = scatter(data.coeffs, fill_value["coeffs"]) + vars_dims, vars = scatter(data.vars, fill_value["vars"]) + + # constants are summed up within each group, skipping NaN values + const_dims = [d for d in data.const.dims if d != group_dim] + const_values = data.const.transpose(group_dim, *const_dims).values + const = np.zeros((n_groups, *const_values.shape[1:]), dtype=const_values.dtype) + np.add.at(const, codes, np.where(np.isnan(const_values), 0, const_values)) + + # only the values above are computed with numpy, the result structure + # (dimensions, coordinates and their order) is assembled by xarray + # itself and thereby matches a result of unstacking the group dimension + structure = data.drop_vars(["coeffs", "vars", "const"]) + structure = structure.drop_dims(group_dim) + structure = structure.expand_dims({GROUP_DIM: unique_groups}) + + return structure.assign( + coeffs=(coeffs_dims, coeffs), + vars=(vars_dims, vars), + const=((GROUP_DIM, *const_dims), const), + ) + + def _sum_by_unstack(self, group: pd.Series) -> Dataset: + """ + Sum groups by unstacking the group dimension into a padded helper + dimension and summing over it. + + Equivalent to :meth:`_sum_by_scatter` but goes through xarray's + unstack/stack machinery, which also supports chunked (dask) data. + """ + group_dim = group.index.name + arrays = [group, group.groupby(group).cumcount()] + idx = pd.MultiIndex.from_arrays(arrays, names=[GROUP_DIM, GROUPED_TERM_DIM]) + new_coords = Coordinates.from_pandas_multiindex(idx, group_dim) + # collapsing group_dim invalidates every coordinate aligned to it + names_to_drop = [ + name for name, coord in self.data.coords.items() if group_dim in coord.dims + ] + ds = self.data.drop_vars(names_to_drop).assign_coords(new_coords) + ds = ds.unstack(group_dim, fill_value=LinearExpression._fill_value) + return LinearExpression._sum(ds, dim=GROUPED_TERM_DIM) + def roll(self, **kwargs: Any) -> LinearExpression: """ Roll the groupby object. diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 5ffd7de1..19850999 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1908,6 +1908,130 @@ def test_linear_expression_groupby_from_variable(v: Variable) -> None: assert grouped.nterm == 10 +def test_linear_expression_groupby_skewed_unsorted_groups(v: Variable) -> None: + """ + The scatter-based fast path must match the xarray fallback for groups that + are unsorted, non-contiguous and of very different sizes. + """ + expr = 2 * v + 5 + # 'b' appears 14 times, 'c' 5 times, 'a' once, scattered over the dimension + labels = ["b"] * 4 + ["c", "a"] + ["b"] * 5 + ["c"] * 4 + ["b"] * 5 + groups = pd.Series(labels, index=v.indexes["dim_2"], name="letter") + + grouped = expr.groupby(groups).sum() + fallback = expr.groupby(groups.to_xarray()).sum(use_fallback=True) + + assert list(grouped.data.letter) == ["a", "b", "c"] + # padded to the largest group times the number of terms of the input + assert grouped.nterm == 14 * expr.nterm + assert_linequal(grouped, fallback) + + # every group must carry exactly the variables of its members, the rest is fill + for letter in ["a", "b", "c"]: + members = np.where(np.array(labels) == letter)[0] + vars_of_group = grouped.data.vars.sel(letter=letter).values + assert set(vars_of_group[vars_of_group >= 0]) == set(v.labels.values[members]) + assert (vars_of_group >= 0).sum() == len(members) * expr.nterm + assert grouped.const.sel(letter=letter).item() == 5 * len(members) + + +def test_linear_expression_groupby_chunked(v: Variable) -> None: + """Chunked (dask-backed) expressions group via xarray's unstack machinery.""" + pytest.importorskip("dask") + expr = 2 * v + 5 + groups = pd.Series([1] * 12 + [2] * 8, index=v.indexes["dim_2"], name="group") + + chunked = LinearExpression(expr.data.chunk({"dim_2": 5}), expr.model) + grouped_chunked = chunked.groupby(groups).sum() + grouped = expr.groupby(groups).sum() + + assert grouped_chunked.nterm == grouped.nterm + assert_linequal( + LinearExpression(grouped_chunked.data.compute(), expr.model), grouped + ) + + +def test_linear_expression_groupby_with_nan_groups(v: Variable) -> None: + expr = 1 * v + groups = pd.Series([1.0, np.nan] * 10, index=v.indexes["dim_2"], name="with_nans") + with pytest.raises(ValueError, match="NaN"): + expr.groupby(groups).sum() + + +@pytest.mark.parametrize( + "case", + [ + "skewed_int_groups", + "multidim_with_const", + "nan_const", + "masked_vars", + "quadratic", + "single_group", + "identity_groups", + ], +) +def test_linear_expression_groupby_scatter_equals_unstack(case: str) -> None: + """ + Lock the two groupby-sum kernels together. + + The fast path of groupby(...).sum() scatters terms into numpy arrays + (_sum_by_scatter); the xarray unstack implementation (_sum_by_unstack) is + kept for chunked data and exotic coordinates. Both must stay + interchangeable — if an xarray/pandas update changes the unstack output or + an edge case diverges, this fails. + """ + m = Model() + rng = np.random.default_rng(0) + idx = pd.RangeIndex(60, name="elem") + skewed = pd.Series(rng.choice(8, 60, p=[0.5] + [0.5 / 7] * 7), index=idx, name="g") + groups = skewed + + if case == "skewed_int_groups": + x = m.add_variables(coords=[idx], name="x") + expr: LinearExpression | QuadraticExpression = 3 * x - 2 * x + 7 + elif case == "multidim_with_const": + other = pd.Index(list("abc"), name="other") + y = m.add_variables(coords=[other, idx], name="y") + const = xr.DataArray(rng.normal(size=(3, 60)), coords=[other, idx]) + expr = 2 * y + 1 * y + const + elif case == "nan_const": + x = m.add_variables(coords=[idx], name="x") + expr = 1 * x + np.where(np.arange(60) % 3, np.nan, 5.0) + elif case == "masked_vars": + mask = xr.DataArray(np.arange(60) % 4 != 0, coords=[idx]) + x = m.add_variables(coords=[idx], name="x", mask=mask) + expr = 1 * x + elif case == "quadratic": + x = m.add_variables(coords=[idx], name="x") + expr = x * x + 2 * x + elif case == "single_group": + x = m.add_variables(coords=[idx], name="x") + expr = 1 * x + groups = pd.Series(1, index=idx, name="g") + else: # identity_groups + x = m.add_variables(coords=[idx], name="x") + expr = 1 * x + groups = pd.Series(np.arange(60), index=idx, name="g") + + gb = expr.groupby(groups) + assert gb._can_sum_by_scatter(groups) + scatter = LinearExpression(gb._sum_by_scatter(groups).rename(_group="g"), m) + unstack = LinearExpression(gb._sum_by_unstack(groups).rename(_group="g"), m) + + # identical structure: dims, dim order, coordinates + assert scatter.data.coeffs.dims == unstack.data.coeffs.dims + assert scatter.data.const.dims == unstack.data.const.dims + assert list(scatter.data.coords) == list(unstack.data.coords) + for name in scatter.data.coords: + assert_equal(scatter.data[name], unstack.data[name]) + + # identical values: vars and coeffs bit-exact, including padding positions + np.testing.assert_array_equal(scatter.vars.values, unstack.vars.values) + np.testing.assert_array_equal(scatter.coeffs.values, unstack.coeffs.values) + # constants may differ by floating-point summation order + np.testing.assert_allclose(scatter.const.values, unstack.const.values, rtol=1e-12) + + def test_linear_expression_rolling(v: Variable) -> None: expr = 1 * v rolled = expr.rolling(dim_2=2).sum() From 7f3edea01a5dbb2c4c9a78808a44458a23cec952 Mon Sep 17 00:00:00 2001 From: Fabian Date: Tue, 30 Jun 2026 13:18:06 +0200 Subject: [PATCH 2/4] test: cover empty group dim in scatter groupby-sum Add a test for grouping over an empty group dimension, which the scatter fast path handles cleanly but the unstack fallback cannot. Trim comments that duplicated the helper docstrings. --- linopy/expressions.py | 10 +++------- test/test_linear_expression.py | 12 ++++++++++++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 0e42af19..405ae7fc 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -344,8 +344,6 @@ def sum( if self._can_sum_by_scatter(group): ds = self._sum_by_scatter(group) else: - # chunked (e.g. dask-backed) data or exotic coordinates on the - # grouped dimension: use xarray's unstack machinery ds = self._sum_by_unstack(group) if int_map is not None: @@ -404,8 +402,9 @@ def _sum_by_scatter(self, group: pd.Series) -> Dataset: Only the term and constant values are computed with numpy; the result structure (dimensions, coordinates and their order) is assembled by - xarray. :meth:`_can_sum_by_scatter` decides whether the data is simple - enough for this kernel. + xarray itself and thereby matches the result of unstacking the group + dimension. :meth:`_can_sum_by_scatter` decides whether the data is + simple enough for this kernel. """ data = self.data group_dim = group.index.name @@ -453,9 +452,6 @@ def scatter( const = np.zeros((n_groups, *const_values.shape[1:]), dtype=const_values.dtype) np.add.at(const, codes, np.where(np.isnan(const_values), 0, const_values)) - # only the values above are computed with numpy, the result structure - # (dimensions, coordinates and their order) is assembled by xarray - # itself and thereby matches a result of unstacking the group dimension structure = data.drop_vars(["coeffs", "vars", "const"]) structure = structure.drop_dims(group_dim) structure = structure.expand_dims({GROUP_DIM: unique_groups}) diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index 19850999..b710f0b1 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -1958,6 +1958,18 @@ def test_linear_expression_groupby_with_nan_groups(v: Variable) -> None: expr.groupby(groups).sum() +def test_linear_expression_groupby_empty_groups() -> None: + """An empty group dimension scatters into an empty, well-formed result.""" + m = Model() + idx = pd.RangeIndex(0, name="elem") + x = m.add_variables(coords=[idx], name="x") + groups = pd.Series([], index=idx, name="g", dtype=int) + + grouped = (1 * x).groupby(groups).sum() + assert grouped.nterm == 0 + assert dict(grouped.data.sizes) == {"g": 0, "_term": 0} + + @pytest.mark.parametrize( "case", [ From 7598180d8829f87741058936523d7fd2ba9976e5 Mon Sep 17 00:00:00 2001 From: Fabian Date: Tue, 30 Jun 2026 13:19:01 +0200 Subject: [PATCH 3/4] docs: add release note for scatter groupby-sum --- doc/release_notes.rst | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/doc/release_notes.rst b/doc/release_notes.rst index 7e849a42..fb5cd3e2 100644 --- a/doc/release_notes.rst +++ b/doc/release_notes.rst @@ -21,6 +21,10 @@ Upcoming Version * ``add_variables(binary=True, ...)`` now accepts ``lower``/``upper`` bounds, as long as they are 0 or 1. Previously binary bounds could only be set via the ``.lower``/``.upper`` setters after creation. (https://github.com/PyPSA/linopy/issues/776) +**Performance** + +* ``LinearExpression.groupby(...).sum()`` now scatters terms directly into the padded result arrays instead of unstacking through pandas ``MultiIndex`` machinery, cutting peak memory to input + result and speeding up the grouping. + **Deprecations** * Mutation via assignment to ``Variable.lower`` / ``Variable.upper`` / ``Constraint.coeffs`` / ``Constraint.vars`` / ``Constraint.lhs`` / ``Constraint.sign`` / ``Constraint.rhs`` is deprecated and emits a ``DeprecationWarning``. Use ``Variable.update(...)`` / ``Constraint.update(...)`` instead — the canonical mutation API with one validation path and one place that flips the persistent-solver dirty flag. Read access to these properties is unchanged. The setters will be removed in a future release. From 81d7c23a513c484804aa176760e4d6c2d84cbf7a Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 30 Jun 2026 21:07:33 +0200 Subject: [PATCH 4/4] perf(groupby): widen scatter fast path to all numpy-backed data MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Relax the groupby-sum scatter gate to a pure numpy/dask check: auxiliary coordinates on the grouped dimension no longer force the slow unstack path. Summing over groups collapses that dimension, so both kernels drop every coordinate tied to it — the scatter result is identical, just cheaper. The unstack kernel now serves only chunked (dask) data, and a debug log records when that fallback is taken. Inline the now-trivial predicate into the dispatch and consolidate the kernel tests into a TestGroupbySumScatterKernel class: a one-line case table over a shared fixture, with added coverage for combined structures, auxiliary coords, and a MultiIndex grouped dimension. Co-Authored-By: Claude Opus 4.8 (1M context) --- linopy/expressions.py | 45 ++--- test/test_linear_expression.py | 303 +++++++++++++++++++-------------- 2 files changed, 191 insertions(+), 157 deletions(-) diff --git a/linopy/expressions.py b/linopy/expressions.py index 405ae7fc..a3b10ba8 100644 --- a/linopy/expressions.py +++ b/linopy/expressions.py @@ -341,9 +341,17 @@ def sum( # At this point, group is always a pandas Series assert isinstance(group, pd.Series) - if self._can_sum_by_scatter(group): + numpy_backed = all( + isinstance(self.data[k].data, np.ndarray) + for k in ("coeffs", "vars", "const") + ) + if numpy_backed: ds = self._sum_by_scatter(group) else: + logger.debug( + "groupby-sum: non-numpy-backed (e.g. dask) data, " + "falling back to the unstack kernel." + ) ds = self._sum_by_unstack(group) if int_map is not None: @@ -365,32 +373,6 @@ def func(ds: Dataset) -> Dataset: return self.map(func, **kwargs, shortcut=True) - def _can_sum_by_scatter(self, group: pd.Series) -> bool: - """ - Whether :meth:`_sum_by_scatter` covers the structure of the data. - - The scatter kernel requires numpy-backed arrays (chunked data cannot be - scattered into preallocated numpy arrays) and no coordinates tied to - the grouped dimension besides its own index. Everything else falls - back to :meth:`_sum_by_unstack`. - """ - data = self.data - group_dim = group.index.name - - numpy_backed = all( - isinstance(data[k].data, np.ndarray) for k in ("coeffs", "vars", "const") - ) - if not numpy_backed: - return False - - index = data.indexes.get(group_dim) - index_names = {group_dim, *(index.names if index is not None else ())} - return all( - coord.dims == (group_dim,) and name in index_names - for name, coord in data.coords.items() - if group_dim in coord.dims - ) - def _sum_by_scatter(self, group: pd.Series) -> Dataset: """ Sum groups by scattering all terms directly into the final padded arrays. @@ -403,8 +385,8 @@ def _sum_by_scatter(self, group: pd.Series) -> Dataset: Only the term and constant values are computed with numpy; the result structure (dimensions, coordinates and their order) is assembled by xarray itself and thereby matches the result of unstacking the group - dimension. :meth:`_can_sum_by_scatter` decides whether the data is - simple enough for this kernel. + dimension. The caller dispatches here only for numpy-backed data + (chunked data uses :meth:`_sum_by_unstack`). """ data = self.data group_dim = group.index.name @@ -467,8 +449,9 @@ def _sum_by_unstack(self, group: pd.Series) -> Dataset: Sum groups by unstacking the group dimension into a padded helper dimension and summing over it. - Equivalent to :meth:`_sum_by_scatter` but goes through xarray's - unstack/stack machinery, which also supports chunked (dask) data. + Equivalent to :meth:`_sum_by_scatter`, but goes through xarray's + unstack/stack machinery. It is the fallback for chunked (dask) data, + which cannot be scattered into preallocated numpy buffers. """ group_dim = group.index.name arrays = [group, group.groupby(group).cumcount()] diff --git a/test/test_linear_expression.py b/test/test_linear_expression.py index b710f0b1..20032351 100644 --- a/test/test_linear_expression.py +++ b/test/test_linear_expression.py @@ -7,7 +7,10 @@ from __future__ import annotations +import logging import warnings +from collections.abc import Callable +from types import SimpleNamespace from typing import Any import numpy as np @@ -1908,140 +1911,188 @@ def test_linear_expression_groupby_from_variable(v: Variable) -> None: assert grouped.nterm == 10 -def test_linear_expression_groupby_skewed_unsorted_groups(v: Variable) -> None: - """ - The scatter-based fast path must match the xarray fallback for groups that - are unsorted, non-contiguous and of very different sizes. - """ - expr = 2 * v + 5 - # 'b' appears 14 times, 'c' 5 times, 'a' once, scattered over the dimension - labels = ["b"] * 4 + ["c", "a"] + ["b"] * 5 + ["c"] * 4 + ["b"] * 5 - groups = pd.Series(labels, index=v.indexes["dim_2"], name="letter") - - grouped = expr.groupby(groups).sum() - fallback = expr.groupby(groups.to_xarray()).sum(use_fallback=True) - - assert list(grouped.data.letter) == ["a", "b", "c"] - # padded to the largest group times the number of terms of the input - assert grouped.nterm == 14 * expr.nterm - assert_linequal(grouped, fallback) - - # every group must carry exactly the variables of its members, the rest is fill - for letter in ["a", "b", "c"]: - members = np.where(np.array(labels) == letter)[0] - vars_of_group = grouped.data.vars.sel(letter=letter).values - assert set(vars_of_group[vars_of_group >= 0]) == set(v.labels.values[members]) - assert (vars_of_group >= 0).sum() == len(members) * expr.nterm - assert grouped.const.sel(letter=letter).item() == 5 * len(members) - - -def test_linear_expression_groupby_chunked(v: Variable) -> None: - """Chunked (dask-backed) expressions group via xarray's unstack machinery.""" - pytest.importorskip("dask") - expr = 2 * v + 5 - groups = pd.Series([1] * 12 + [2] * 8, index=v.indexes["dim_2"], name="group") - - chunked = LinearExpression(expr.data.chunk({"dim_2": 5}), expr.model) - grouped_chunked = chunked.groupby(groups).sum() - grouped = expr.groupby(groups).sum() - - assert grouped_chunked.nterm == grouped.nterm - assert_linequal( - LinearExpression(grouped_chunked.data.compute(), expr.model), grouped +@pytest.fixture +def scatter_ctx() -> SimpleNamespace: + """Shared 60-element building blocks for the scatter-vs-unstack case table.""" + m = Model() + rng = np.random.default_rng(0) + idx = pd.RangeIndex(60, name="elem") + other = pd.Index(list("abc"), name="other") + p, q = pd.Index(list("pq"), name="p"), pd.Index([10, 20, 30], name="q") + a, b = pd.Index(range(12), name="a"), pd.Index(list("vwxyz"), name="b") + + const = xr.DataArray(rng.normal(size=(3, 60)), coords=[other, idx]) + y = m.add_variables(coords=[other, idx], name="y") + yab = m.add_variables(coords=[a, b], name="yab") + stacked = LinearExpression((2 * yab + 1 * yab).data.stack(elem=["a", "b"]), m) + skewed = pd.Series(rng.choice(8, 60, p=[0.5] + [0.5 / 7] * 7), index=idx, name="g") + + return SimpleNamespace( + m=m, + x=m.add_variables(coords=[idx], name="x"), + y=y, + y3=m.add_variables(coords=[p, q, idx], name="y3"), + mx=m.add_variables( + coords=[idx], name="mx", mask=xr.DataArray(np.arange(60) % 4 != 0, [idx]) + ), + my=m.add_variables( + coords=[other, idx], + name="my", + mask=xr.DataArray(rng.random((3, 60)) > 0.25, [other, idx]), + ), + const=const, + nan_const=const.where(rng.random((3, 60)) > 0.3), + nan_vec=np.where(np.arange(60) % 3, np.nan, 5.0), + y_aux=(2 * y + 1 * y).assign_coords( + carrier=("elem", rng.choice(list("PQ"), 60)), + tag=(("other", "elem"), rng.integers(0, 9, (3, 60))), + ), + stacked=stacked, + skewed=skewed, + one_group=pd.Series(1, index=idx, name="g"), + identity=pd.Series(np.arange(60), index=idx, name="g"), + mi_groups=skewed.set_axis(stacked.data.indexes["elem"]), ) -def test_linear_expression_groupby_with_nan_groups(v: Variable) -> None: - expr = 1 * v - groups = pd.Series([1.0, np.nan] * 10, index=v.indexes["dim_2"], name="with_nans") - with pytest.raises(ValueError, match="NaN"): - expr.groupby(groups).sum() +# Each case maps a structure to (expr, groups) from `scatter_ctx`. The skewed +# group puts ~half the elements in group 0 and spreads 1..7 over the rest. +SCATTER_EQUALS_UNSTACK_CASES = { + "skewed_int_groups": lambda c: (3 * c.x - 2 * c.x + 7, c.skewed), + "multidim_with_const": lambda c: (2 * c.y + 1 * c.y + c.const, c.skewed), + "nan_const": lambda c: (1 * c.x + c.nan_vec, c.skewed), + "masked_vars": lambda c: (1 * c.mx, c.skewed), + "quadratic": lambda c: (c.x * c.x + 2 * c.x, c.skewed), + "single_group": lambda c: (1 * c.x, c.one_group), + "identity_groups": lambda c: (1 * c.x, c.identity), + # combined structures exercising several features at once + "multidim_const_nan": lambda c: ( + 2 * c.y - 3 * c.y + 1 * c.y + c.nan_const, + c.skewed, + ), + "three_dims": lambda c: (4 * c.y3 + 1 * c.y3, c.skewed), + "quadratic_multidim_const": lambda c: (c.y * c.y + 2 * c.y + c.const, c.skewed), + "masked_multidim": lambda c: (5 * c.my - 2 * c.my, c.skewed), + # both collapse the grouped dim, dropping every coordinate tied to it + "aux_coords_on_group_dim": lambda c: (c.y_aux, c.skewed), + "multiindex_dim": lambda c: (c.stacked, c.mi_groups), +} + + +class TestGroupbySumScatterKernel: + """ + ``groupby(...).sum()`` takes a scatter fast path (``_sum_by_scatter``) for + numpy-backed expressions and falls back to the xarray unstack machinery + (``_sum_by_unstack``) for chunked data and exotic coordinates. These tests + pin the two kernels together and cover the structural edge cases. + """ + @staticmethod + def _assert_kernels_identical(gb: Any, groups: pd.Series, m: Model) -> None: + """Force both kernels and assert they produce the same expression.""" + scatter = LinearExpression(gb._sum_by_scatter(groups).rename(_group="g"), m) + unstack = LinearExpression(gb._sum_by_unstack(groups).rename(_group="g"), m) + + assert scatter.data.coeffs.dims == unstack.data.coeffs.dims + assert scatter.data.const.dims == unstack.data.const.dims + assert list(scatter.data.coords) == list(unstack.data.coords) + for name in scatter.data.coords: + assert_equal(scatter.data[name], unstack.data[name]) + + np.testing.assert_array_equal(scatter.vars.values, unstack.vars.values) + np.testing.assert_array_equal(scatter.coeffs.values, unstack.coeffs.values) + # constants may differ only by floating-point summation order + np.testing.assert_allclose( + scatter.const.values, unstack.const.values, rtol=1e-12 + ) -def test_linear_expression_groupby_empty_groups() -> None: - """An empty group dimension scatters into an empty, well-formed result.""" - m = Model() - idx = pd.RangeIndex(0, name="elem") - x = m.add_variables(coords=[idx], name="x") - groups = pd.Series([], index=idx, name="g", dtype=int) - - grouped = (1 * x).groupby(groups).sum() - assert grouped.nterm == 0 - assert dict(grouped.data.sizes) == {"g": 0, "_term": 0} - - -@pytest.mark.parametrize( - "case", - [ - "skewed_int_groups", - "multidim_with_const", - "nan_const", - "masked_vars", - "quadratic", - "single_group", - "identity_groups", - ], -) -def test_linear_expression_groupby_scatter_equals_unstack(case: str) -> None: - """ - Lock the two groupby-sum kernels together. + def test_skewed_unsorted_groups(self, v: Variable) -> None: + """ + The scatter-based fast path must match the xarray fallback for groups + that are unsorted, non-contiguous and of very different sizes. + """ + expr = 2 * v + 5 + # 'b' appears 14 times, 'c' 5 times, 'a' once, scattered over the dimension + labels = ["b"] * 4 + ["c", "a"] + ["b"] * 5 + ["c"] * 4 + ["b"] * 5 + groups = pd.Series(labels, index=v.indexes["dim_2"], name="letter") + + grouped = expr.groupby(groups).sum() + fallback = expr.groupby(groups.to_xarray()).sum(use_fallback=True) + + assert list(grouped.data.letter) == ["a", "b", "c"] + # padded to the largest group times the number of terms of the input + assert grouped.nterm == 14 * expr.nterm + assert_linequal(grouped, fallback) + + # every group carries exactly the variables of its members, rest is fill + for letter in ["a", "b", "c"]: + members = np.where(np.array(labels) == letter)[0] + vars_of_group = grouped.data.vars.sel(letter=letter).values + present = set(vars_of_group[vars_of_group >= 0]) + assert present == set(v.labels.values[members]) + assert (vars_of_group >= 0).sum() == len(members) * expr.nterm + assert grouped.const.sel(letter=letter).item() == 5 * len(members) + + def test_chunked_uses_unstack( + self, v: Variable, caplog: pytest.LogCaptureFixture + ) -> None: + """Chunked (dask-backed) expressions group via xarray's unstack path.""" + pytest.importorskip("dask") + expr = 2 * v + 5 + groups = pd.Series([1] * 12 + [2] * 8, index=v.indexes["dim_2"], name="group") + + chunked = LinearExpression(expr.data.chunk({"dim_2": 5}), expr.model) + with caplog.at_level(logging.DEBUG, logger="linopy.expressions"): + grouped_chunked = chunked.groupby(groups).sum() + assert "falling back to the unstack kernel" in caplog.text + + grouped = expr.groupby(groups).sum() + assert grouped_chunked.nterm == grouped.nterm + assert_linequal( + LinearExpression(grouped_chunked.data.compute(), expr.model), grouped + ) - The fast path of groupby(...).sum() scatters terms into numpy arrays - (_sum_by_scatter); the xarray unstack implementation (_sum_by_unstack) is - kept for chunked data and exotic coordinates. Both must stay - interchangeable — if an xarray/pandas update changes the unstack output or - an edge case diverges, this fails. - """ - m = Model() - rng = np.random.default_rng(0) - idx = pd.RangeIndex(60, name="elem") - skewed = pd.Series(rng.choice(8, 60, p=[0.5] + [0.5 / 7] * 7), index=idx, name="g") - groups = skewed + def test_nan_groups_raise(self, v: Variable) -> None: + expr = 1 * v + groups = pd.Series( + [1.0, np.nan] * 10, index=v.indexes["dim_2"], name="with_nans" + ) + with pytest.raises(ValueError, match="NaN"): + expr.groupby(groups).sum() - if case == "skewed_int_groups": - x = m.add_variables(coords=[idx], name="x") - expr: LinearExpression | QuadraticExpression = 3 * x - 2 * x + 7 - elif case == "multidim_with_const": - other = pd.Index(list("abc"), name="other") - y = m.add_variables(coords=[other, idx], name="y") - const = xr.DataArray(rng.normal(size=(3, 60)), coords=[other, idx]) - expr = 2 * y + 1 * y + const - elif case == "nan_const": - x = m.add_variables(coords=[idx], name="x") - expr = 1 * x + np.where(np.arange(60) % 3, np.nan, 5.0) - elif case == "masked_vars": - mask = xr.DataArray(np.arange(60) % 4 != 0, coords=[idx]) - x = m.add_variables(coords=[idx], name="x", mask=mask) - expr = 1 * x - elif case == "quadratic": - x = m.add_variables(coords=[idx], name="x") - expr = x * x + 2 * x - elif case == "single_group": - x = m.add_variables(coords=[idx], name="x") - expr = 1 * x - groups = pd.Series(1, index=idx, name="g") - else: # identity_groups + def test_empty_groups(self) -> None: + """An empty group dimension scatters into an empty, well-formed result.""" + m = Model() + idx = pd.RangeIndex(0, name="elem") x = m.add_variables(coords=[idx], name="x") - expr = 1 * x - groups = pd.Series(np.arange(60), index=idx, name="g") - - gb = expr.groupby(groups) - assert gb._can_sum_by_scatter(groups) - scatter = LinearExpression(gb._sum_by_scatter(groups).rename(_group="g"), m) - unstack = LinearExpression(gb._sum_by_unstack(groups).rename(_group="g"), m) - - # identical structure: dims, dim order, coordinates - assert scatter.data.coeffs.dims == unstack.data.coeffs.dims - assert scatter.data.const.dims == unstack.data.const.dims - assert list(scatter.data.coords) == list(unstack.data.coords) - for name in scatter.data.coords: - assert_equal(scatter.data[name], unstack.data[name]) - - # identical values: vars and coeffs bit-exact, including padding positions - np.testing.assert_array_equal(scatter.vars.values, unstack.vars.values) - np.testing.assert_array_equal(scatter.coeffs.values, unstack.coeffs.values) - # constants may differ by floating-point summation order - np.testing.assert_allclose(scatter.const.values, unstack.const.values, rtol=1e-12) + groups = pd.Series([], index=idx, name="g", dtype=int) + + grouped = (1 * x).groupby(groups).sum() + assert grouped.nterm == 0 + assert dict(grouped.data.sizes) == {"g": 0, "_term": 0} + + @pytest.mark.parametrize( + "build", + SCATTER_EQUALS_UNSTACK_CASES.values(), + ids=SCATTER_EQUALS_UNSTACK_CASES.keys(), + ) + def test_scatter_equals_unstack( + self, + build: Callable[[SimpleNamespace], tuple[LinearExpression, pd.Series]], + scatter_ctx: SimpleNamespace, + ) -> None: + """ + Lock the two groupby-sum kernels together. + + The fast path scatters terms into numpy arrays (``_sum_by_scatter``); + the unstack implementation (``_sum_by_unstack``) is kept for chunked + data. Both must stay interchangeable — if an xarray/pandas update + changes the unstack output or an edge case diverges, this fails. See + ``SCATTER_EQUALS_UNSTACK_CASES`` for the structures covered. + """ + expr, groups = build(scatter_ctx) + gb = expr.groupby(groups) + self._assert_kernels_identical(gb, groups, scatter_ctx.m) def test_linear_expression_rolling(v: Variable) -> None: