From 98d1d1beed5206f6659ebdb612d7de4f82ab1a0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 03:17:37 +0200 Subject: [PATCH 1/3] Split structure extraction and matrix ignoring from weight extraction - The idea is to make it explicit that some weightings are based on the structure of the matrix (sum, mean, random) and that some are independent of their input (constant), by separating the part where we extract the structure (or the none for constant) and the part where we use this structure (or the none for constant) to make a vector of weigths. --- src/torchjd/_linalg/__init__.py | 3 ++ src/torchjd/_linalg/_structure.py | 15 ++++++++ src/torchjd/aggregation/_constant.py | 17 ++++----- src/torchjd/aggregation/_mean.py | 18 ++++++---- src/torchjd/aggregation/_random.py | 14 +++++--- src/torchjd/aggregation/_sum.py | 21 ++++++------ src/torchjd/aggregation/_weighting_bases.py | 38 +++++++++++++++++---- tests/unit/aggregation/test_constant.py | 23 ------------- 8 files changed, 88 insertions(+), 61 deletions(-) create mode 100644 src/torchjd/_linalg/_structure.py diff --git a/src/torchjd/_linalg/__init__.py b/src/torchjd/_linalg/__init__.py index 29b8cd0b3..9db3b8a77 100644 --- a/src/torchjd/_linalg/__init__.py +++ b/src/torchjd/_linalg/__init__.py @@ -1,8 +1,11 @@ from ._generalized_gramian import flatten, movedim, reshape from ._gramian import compute_gramian, normalize, regularize from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor +from ._structure import Structure, extract_structure __all__ = [ + "extract_structure", + "Structure", "compute_gramian", "normalize", "regularize", diff --git a/src/torchjd/_linalg/_structure.py b/src/torchjd/_linalg/_structure.py new file mode 100644 index 000000000..59be4f674 --- /dev/null +++ b/src/torchjd/_linalg/_structure.py @@ -0,0 +1,15 @@ +import torch +from attr import dataclass + +from torchjd._linalg import Matrix + + +@dataclass +class Structure: + m: int + device: torch.device + dtype: torch.dtype + + +def extract_structure(matrix: Matrix) -> Structure: + return Structure(m=matrix.shape[0], device=matrix.device, dtype=matrix.dtype) diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 0485e7261..a05ab8fe0 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -1,13 +1,13 @@ from torch import Tensor -from torchjd._linalg import Matrix +from torchjd.aggregation._weighting_bases import FromNothingWeighting from ._aggregator_bases import WeightedAggregator from ._utils.str import vector_to_str from ._weighting_bases import Weighting -class ConstantWeighting(Weighting[Matrix]): +class _ConstantWeighting(Weighting[None]): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined weights. @@ -25,16 +25,13 @@ def __init__(self, weights: Tensor) -> None: super().__init__() self.weights = weights - def forward(self, matrix: Tensor, /) -> Tensor: - self._check_matrix_shape(matrix) + def forward(self, _: None, /) -> Tensor: return self.weights - def _check_matrix_shape(self, matrix: Tensor) -> None: - if matrix.shape[0] != len(self.weights): - raise ValueError( - f"Parameter `matrix` should have {len(self.weights)} rows (the number of specified " - f"weights). Found `matrix` with {matrix.shape[0]} rows.", - ) + +class ConstantWeighting(FromNothingWeighting): + def __init__(self, weights: Tensor) -> None: + super().__init__(_ConstantWeighting(weights)) class Constant(WeightedAggregator): diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index 2ebe208de..4513c6208 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -1,27 +1,33 @@ import torch from torch import Tensor -from torchjd._linalg import Matrix +from torchjd._linalg import Structure +from torchjd.aggregation._weighting_bases import FromStructureWeighting from ._aggregator_bases import WeightedAggregator from ._weighting_bases import Weighting -class MeanWeighting(Weighting[Matrix]): +class _MeanWeighting(Weighting[Structure]): r""" :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. """ - def forward(self, matrix: Tensor, /) -> Tensor: - device = matrix.device - dtype = matrix.dtype - m = matrix.shape[0] + def forward(self, structure: Structure, /) -> Tensor: + device = structure.device + dtype = structure.dtype + m = structure.m weights = torch.full(size=[m], fill_value=1 / m, device=device, dtype=dtype) return weights +class MeanWeighting(FromStructureWeighting): + def __init__(self) -> None: + super().__init__(_MeanWeighting()) + + class Mean(WeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 8345a15cb..612f15326 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -2,24 +2,30 @@ from torch import Tensor from torch.nn import functional as F -from torchjd._linalg import Matrix +from torchjd._linalg import Structure +from torchjd.aggregation._weighting_bases import FromStructureWeighting from ._aggregator_bases import WeightedAggregator from ._weighting_bases import Weighting -class RandomWeighting(Weighting[Matrix]): +class _RandomWeighting(Weighting[Structure]): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights at each call. """ - def forward(self, matrix: Tensor, /) -> Tensor: - random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype) + def forward(self, structure: Structure, /) -> Tensor: + random_vector = torch.randn(structure.m, device=structure.device, dtype=structure.dtype) weights = F.softmax(random_vector, dim=-1) return weights +class RandomWeighting(FromStructureWeighting): + def __init__(self) -> None: + super().__init__(_RandomWeighting()) + + class Random(WeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 0754f4668..65ef32a48 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -1,25 +1,24 @@ import torch from torch import Tensor -from torchjd._linalg import Matrix +from torchjd._linalg import Structure +from torchjd.aggregation._weighting_bases import FromStructureWeighting from ._aggregator_bases import WeightedAggregator from ._weighting_bases import Weighting -class SumWeighting(Weighting[Matrix]): - r""" - :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights - :math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`. - """ - - def forward(self, matrix: Tensor, /) -> Tensor: - device = matrix.device - dtype = matrix.dtype - weights = torch.ones(matrix.shape[0], device=device, dtype=dtype) +class _SumWeighting(Weighting[Structure]): + def forward(self, structure: Structure, /) -> Tensor: + weights = torch.ones(structure.m, device=structure.device, dtype=structure.dtype) return weights +class SumWeighting(FromStructureWeighting): + def __init__(self) -> None: + super().__init__(_SumWeighting()) + + class Sum(WeightedAggregator): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that sums of the rows of the input diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index e321169c3..ee91f2347 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -6,11 +6,11 @@ from torch import Tensor, nn -from torchjd._linalg import PSDTensor, is_psd_tensor +from torchjd._linalg import Matrix, PSDTensor, Structure, extract_structure, is_psd_tensor -_T = TypeVar("_T", contravariant=True, bound=Tensor) -_FnInputT = TypeVar("_FnInputT", bound=Tensor) -_FnOutputT = TypeVar("_FnOutputT", bound=Tensor) +_T = TypeVar("_T", contravariant=True) +_FnInputT = TypeVar("_FnInputT") +_FnOutputT = TypeVar("_FnOutputT") class Weighting(nn.Module, ABC, Generic[_T]): @@ -27,11 +27,9 @@ def __init__(self) -> None: def forward(self, stat: _T, /) -> Tensor: """Computes the vector of weights from the input stat.""" - def __call__(self, stat: Tensor, /) -> Tensor: + def __call__(self, stat: object, /) -> Tensor: """Computes the vector of weights from the input stat and applies all registered hooks.""" - # The value of _T (e.g. PSDMatrix) is not public, so we need the user-facing type hint of - # stat to be Tensor. return super().__call__(stat) def _compose(self, fn: Callable[[_FnInputT], _T]) -> Weighting[_FnInputT]: @@ -55,6 +53,32 @@ def forward(self, stat: _T, /) -> Tensor: return self.weighting(self.fn(stat)) +class FromStructureWeighting(_Composition[Matrix]): + """ + Weighting that extracts the structure of the input matrix before applying a Weighting to it. + + :param structure_weighting: The object responsible for extracting the vector of weights from the + structure. + """ + + def __init__(self, structure_weighting: Weighting[Structure]) -> None: + super().__init__(structure_weighting, extract_structure) + self.structure_weighting = structure_weighting + + +class FromNothingWeighting(_Composition[Matrix]): + """ + Weighting that extracts nothing from the input matrix before applying a Weighting to it (i.e. to + None). + + :param none_weighting: The object responsible for extracting the vector of weights from nothing. + """ + + def __init__(self, none_weighting: Weighting[None]) -> None: + super().__init__(none_weighting, lambda _: None) + self.none_weighting = none_weighting + + class GeneralizedWeighting(nn.Module, ABC): r""" Abstract base class for all weightings that operate on generalized Gramians. It has the role of diff --git a/tests/unit/aggregation/test_constant.py b/tests/unit/aggregation/test_constant.py index aa1332fcb..4fa4488ef 100644 --- a/tests/unit/aggregation/test_constant.py +++ b/tests/unit/aggregation/test_constant.py @@ -63,29 +63,6 @@ def test_weights_shape_check(weights_shape: list[int], expectation: ExceptionCon _ = Constant(weights=weights) -@mark.parametrize( - ["weights_shape", "n_rows", "expectation"], - [ - ([0], 0, does_not_raise()), - ([1], 1, does_not_raise()), - ([5], 5, does_not_raise()), - ([0], 1, raises(ValueError)), - ([1], 0, raises(ValueError)), - ([4], 5, raises(ValueError)), - ([5], 4, raises(ValueError)), - ], -) -def test_matrix_shape_check( - weights_shape: list[int], n_rows: int, expectation: ExceptionContext -) -> None: - matrix = ones_([n_rows, 5]) - weights = ones_(weights_shape) - aggregator = Constant(weights) - - with expectation: - _ = aggregator(matrix) - - def test_representations() -> None: A = Constant(weights=torch.tensor([1.0, 2.0], device="cpu")) assert repr(A) == "Constant(weights=tensor([1., 2.]))" From c139fc1a82526580c60057fffca2d3a158db548e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 03:20:21 +0200 Subject: [PATCH 2/3] Fix import of dataclass --- src/torchjd/_linalg/_structure.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchjd/_linalg/_structure.py b/src/torchjd/_linalg/_structure.py index 59be4f674..303a6c112 100644 --- a/src/torchjd/_linalg/_structure.py +++ b/src/torchjd/_linalg/_structure.py @@ -1,5 +1,6 @@ +from dataclasses import dataclass + import torch -from attr import dataclass from torchjd._linalg import Matrix From ba21fa040bb726ae6c7870d0b6c79735e4d3c6f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 03:28:20 +0200 Subject: [PATCH 3/3] Fix documentation --- src/torchjd/aggregation/_constant.py | 14 +++++++------- src/torchjd/aggregation/_mean.py | 12 ++++++------ src/torchjd/aggregation/_random.py | 10 +++++----- src/torchjd/aggregation/_sum.py | 5 +++++ 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index a05ab8fe0..bf63cf8c8 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -8,13 +8,6 @@ class _ConstantWeighting(Weighting[None]): - """ - :class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined - weights. - - :param weights: The weights to return at each call. - """ - def __init__(self, weights: Tensor) -> None: if weights.dim() != 1: raise ValueError( @@ -30,6 +23,13 @@ def forward(self, _: None, /) -> Tensor: class ConstantWeighting(FromNothingWeighting): + """ + :class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined + weights. + + :param weights: The weights to return at each call. + """ + def __init__(self, weights: Tensor) -> None: super().__init__(_ConstantWeighting(weights)) diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index 4513c6208..a4bde82f0 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -9,12 +9,6 @@ class _MeanWeighting(Weighting[Structure]): - r""" - :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights - :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in - \mathbb{R}^m`. - """ - def forward(self, structure: Structure, /) -> Tensor: device = structure.device dtype = structure.dtype @@ -24,6 +18,12 @@ def forward(self, structure: Structure, /) -> Tensor: class MeanWeighting(FromStructureWeighting): + r""" + :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights + :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in + \mathbb{R}^m`. + """ + def __init__(self) -> None: super().__init__(_MeanWeighting()) diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 612f15326..b61cc0fb5 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -10,11 +10,6 @@ class _RandomWeighting(Weighting[Structure]): - """ - :class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights - at each call. - """ - def forward(self, structure: Structure, /) -> Tensor: random_vector = torch.randn(structure.m, device=structure.device, dtype=structure.dtype) weights = F.softmax(random_vector, dim=-1) @@ -22,6 +17,11 @@ def forward(self, structure: Structure, /) -> Tensor: class RandomWeighting(FromStructureWeighting): + """ + :class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights + at each call. + """ + def __init__(self) -> None: super().__init__(_RandomWeighting()) diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 65ef32a48..b46888326 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -15,6 +15,11 @@ def forward(self, structure: Structure, /) -> Tensor: class SumWeighting(FromStructureWeighting): + r""" + :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights + :math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`. + """ + def __init__(self) -> None: super().__init__(_SumWeighting())