From b291573ba1022495f34b6ce1f6be186b00f25d7a Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Thu, 16 Apr 2026 09:27:43 +0200 Subject: [PATCH 01/11] Make Stateful an iff --- src/torchjd/aggregation/_mixins.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py index 8481feab..7730b973 100644 --- a/src/torchjd/aggregation/_mixins.py +++ b/src/torchjd/aggregation/_mixins.py @@ -2,8 +2,19 @@ class Stateful(ABC): - """Mixin adding a reset method.""" + r""" + Mixin for stateful mappings. + + A maping implements `Stateful` **if and only if** its behavior depends on an internal + state. + + Formally, a stateless mapping is a function :math:`f : x \mapsto y` whereas a stateful + maping is a transition map :math:`A : (x, s) \mapsto (y, s')` where :math:`s` is the + internal state, :math:`s'` the updated state, and :math:`y` the output. + There exists an initial state :math:`s_0`, and the method `reset()` restores the state to + :math:`s_0`. A `Stateful` mapping must be constructed with the intial state :math:`s_0`. + """ @abstractmethod def reset(self) -> None: - """Resets the internal state.""" + """Resets the internal state :math:`s_0`.""" From 56ea02af66649e86a4b0863f52acf38f7dd2525b Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 17 Apr 2026 09:23:49 +0200 Subject: [PATCH 02/11] Add test of stateful/stateless (all must be one or the other). --- tests/unit/aggregation/_asserts.py | 33 ++++++++++++++++++++- tests/unit/aggregation/test_aligned_mtl.py | 7 ++++- tests/unit/aggregation/test_cagrad.py | 12 +++++++- tests/unit/aggregation/test_config.py | 6 ++++ tests/unit/aggregation/test_constant.py | 6 ++++ tests/unit/aggregation/test_dualproj.py | 6 ++++ tests/unit/aggregation/test_graddrop.py | 7 ++++- tests/unit/aggregation/test_gradvac.py | 7 ++++- tests/unit/aggregation/test_imtl_g.py | 6 ++++ tests/unit/aggregation/test_krum.py | 7 ++++- tests/unit/aggregation/test_mean.py | 6 ++++ tests/unit/aggregation/test_mgda.py | 6 ++++ tests/unit/aggregation/test_nash_mtl.py | 7 ++++- tests/unit/aggregation/test_pcgrad.py | 7 ++++- tests/unit/aggregation/test_random.py | 7 ++++- tests/unit/aggregation/test_sum.py | 6 ++++ tests/unit/aggregation/test_trimmed_mean.py | 7 ++++- tests/unit/aggregation/test_upgrad.py | 6 ++++ 18 files changed, 139 insertions(+), 10 deletions(-) diff --git a/tests/unit/aggregation/_asserts.py b/tests/unit/aggregation/_asserts.py index 4b85bf09..09d824c3 100644 --- a/tests/unit/aggregation/_asserts.py +++ b/tests/unit/aggregation/_asserts.py @@ -1,10 +1,11 @@ import torch +from numpy.ma.testutils import assert_allclose from pytest import raises from torch import Tensor from torch.testing import assert_close from utils.tensors import rand_, randperm_ -from torchjd.aggregation import Aggregator +from torchjd.aggregation import Aggregator, Stateful from torchjd.aggregation._utils.non_differentiable import NonDifferentiableError @@ -110,3 +111,33 @@ def assert_non_differentiable(aggregator: Aggregator, matrix: Tensor) -> None: vector = aggregator(matrix) with raises(NonDifferentiableError): vector.backward(torch.ones_like(vector)) + + +def assert_stateful(aggregator: Aggregator, matrix: Tensor) -> None: + """ + Test that a given `Aggregator` is stateful. Specifically: + - For a fixed state, the aggregator is determinist on the matrix + - The reset method and the constructor both set the state to the initial state + """ + + assert isinstance(aggregator, Stateful) + + first_pair = (aggregator(matrix), aggregator(matrix)) + aggregator.reset() + second_pair = (aggregator(matrix), aggregator(matrix)) + + assert_allclose(first_pair[0], second_pair[0], atol=0.0, rtol=0.0) + assert_allclose(first_pair[1], second_pair[1], atol=0.0, rtol=0.0) + + +def assert_stateless(aggregator: Aggregator, matrix: Tensor) -> None: + """ + Test that a given `Aggregator` is stateless. Specifically, it must be deterministic. + """ + + assert not isinstance(aggregator, Stateful) + + first = aggregator(matrix) + second = aggregator(matrix) + + assert_allclose(first, second, atol=0.0, rtol=0.0) diff --git a/tests/unit/aggregation/test_aligned_mtl.py b/tests/unit/aggregation/test_aligned_mtl.py index d48e8855..24ce80a5 100644 --- a/tests/unit/aggregation/test_aligned_mtl.py +++ b/tests/unit/aggregation/test_aligned_mtl.py @@ -5,7 +5,7 @@ from torchjd.aggregation import AlignedMTL -from ._asserts import assert_expected_structure, assert_permutation_invariant +from ._asserts import assert_expected_structure, assert_permutation_invariant, assert_stateless from ._inputs import scaled_matrices, typical_matrices aggregators = [ @@ -28,6 +28,11 @@ def test_permutation_invariant(aggregator: AlignedMTL, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: AlignedMTL, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_representations() -> None: A = AlignedMTL(pref_vector=None) assert repr(A) == "AlignedMTL(pref_vector=None, scale_mode='min')" diff --git a/tests/unit/aggregation/test_cagrad.py b/tests/unit/aggregation/test_cagrad.py index c7d18b1f..70aa3b16 100644 --- a/tests/unit/aggregation/test_cagrad.py +++ b/tests/unit/aggregation/test_cagrad.py @@ -12,7 +12,12 @@ pytest.skip("CAGrad dependencies not installed", allow_module_level=True) -from ._asserts import assert_expected_structure, assert_non_conflicting, assert_non_differentiable +from ._asserts import ( + assert_expected_structure, + assert_non_conflicting, + assert_non_differentiable, + assert_stateless, +) from ._inputs import scaled_matrices, typical_matrices scaled_pairs = [(CAGrad(c=0.5), matrix) for matrix in scaled_matrices] @@ -38,6 +43,11 @@ def test_non_conflicting(aggregator: CAGrad, matrix: Tensor) -> None: assert_non_conflicting(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: CAGrad, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + @mark.parametrize( ["c", "expectation"], [ diff --git a/tests/unit/aggregation/test_config.py b/tests/unit/aggregation/test_config.py index 2db2ea0f..c3ba275e 100644 --- a/tests/unit/aggregation/test_config.py +++ b/tests/unit/aggregation/test_config.py @@ -10,6 +10,7 @@ assert_linear_under_scaling, assert_non_differentiable, assert_permutation_invariant, + assert_stateless, ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices @@ -39,6 +40,11 @@ def test_non_differentiable(aggregator: ConFIG, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: ConFIG, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_representations() -> None: A = ConFIG() assert repr(A) == "ConFIG(pref_vector=None)" diff --git a/tests/unit/aggregation/test_constant.py b/tests/unit/aggregation/test_constant.py index aa1332fc..07bcd110 100644 --- a/tests/unit/aggregation/test_constant.py +++ b/tests/unit/aggregation/test_constant.py @@ -11,6 +11,7 @@ from ._asserts import ( assert_expected_structure, assert_linear_under_scaling, + assert_stateless, assert_strongly_stationary, ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices @@ -42,6 +43,11 @@ def test_strongly_stationary(aggregator: Constant, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: Constant, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + @mark.parametrize( ["weights_shape", "expectation"], [ diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 5bd0e71a..051ec1a0 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -10,6 +10,7 @@ assert_non_conflicting, assert_non_differentiable, assert_permutation_invariant, + assert_stateless, assert_strongly_stationary, ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices @@ -45,6 +46,11 @@ def test_non_differentiable(aggregator: DualProj, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: DualProj, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_representations() -> None: A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") assert ( diff --git a/tests/unit/aggregation/test_graddrop.py b/tests/unit/aggregation/test_graddrop.py index 2868dca0..586d94c8 100644 --- a/tests/unit/aggregation/test_graddrop.py +++ b/tests/unit/aggregation/test_graddrop.py @@ -9,7 +9,7 @@ from torchjd.aggregation import GradDrop -from ._asserts import assert_expected_structure, assert_non_differentiable +from ._asserts import assert_expected_structure, assert_non_differentiable, assert_stateful from ._inputs import scaled_matrices, typical_matrices scaled_pairs = [(GradDrop(), matrix) for matrix in scaled_matrices] @@ -27,6 +27,11 @@ def test_non_differentiable(aggregator: GradDrop, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateful(aggregator: GradDrop, matrix: Tensor) -> None: + assert_stateful(aggregator, matrix) + + @mark.parametrize( ["leak_shape", "expectation"], [ diff --git a/tests/unit/aggregation/test_gradvac.py b/tests/unit/aggregation/test_gradvac.py index bde2e8fd..88fd4466 100644 --- a/tests/unit/aggregation/test_gradvac.py +++ b/tests/unit/aggregation/test_gradvac.py @@ -6,7 +6,7 @@ from torchjd.aggregation import GradVac, GradVacWeighting -from ._asserts import assert_expected_structure, assert_non_differentiable +from ._asserts import assert_expected_structure, assert_non_differentiable, assert_stateful from ._inputs import scaled_matrices, typical_matrices, typical_matrices_2_plus_rows scaled_pairs = [(GradVac(), m) for m in scaled_matrices] @@ -104,6 +104,11 @@ def test_non_differentiable(aggregator: GradVac, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateful(aggregator: GradVac, matrix: Tensor) -> None: + assert_stateful(aggregator, matrix) + + def test_weighting_beta_out_of_range() -> None: with raises(ValueError, match="beta"): GradVacWeighting(beta=-0.1) diff --git a/tests/unit/aggregation/test_imtl_g.py b/tests/unit/aggregation/test_imtl_g.py index 03c41d5e..3fa40ceb 100644 --- a/tests/unit/aggregation/test_imtl_g.py +++ b/tests/unit/aggregation/test_imtl_g.py @@ -9,6 +9,7 @@ assert_expected_structure, assert_non_differentiable, assert_permutation_invariant, + assert_stateless, ) from ._inputs import scaled_matrices, typical_matrices @@ -32,6 +33,11 @@ def test_non_differentiable(aggregator: IMTLG, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: IMTLG, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_imtlg_zero() -> None: """ Tests that IMTLG correctly returns the 0 vector in the special case where input matrix only diff --git a/tests/unit/aggregation/test_krum.py b/tests/unit/aggregation/test_krum.py index 4097f2eb..bab3011f 100644 --- a/tests/unit/aggregation/test_krum.py +++ b/tests/unit/aggregation/test_krum.py @@ -7,7 +7,7 @@ from torchjd.aggregation import Krum -from ._asserts import assert_expected_structure +from ._asserts import assert_expected_structure, assert_stateless from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows scaled_pairs = [(Krum(n_byzantine=1), matrix) for matrix in scaled_matrices_2_plus_rows] @@ -19,6 +19,11 @@ def test_expected_structure(aggregator: Krum, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: Krum, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + @mark.parametrize( ["n_byzantine", "expectation"], [ diff --git a/tests/unit/aggregation/test_mean.py b/tests/unit/aggregation/test_mean.py index 88c28e93..628f1e2a 100644 --- a/tests/unit/aggregation/test_mean.py +++ b/tests/unit/aggregation/test_mean.py @@ -7,6 +7,7 @@ assert_expected_structure, assert_linear_under_scaling, assert_permutation_invariant, + assert_stateless, assert_strongly_stationary, ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices @@ -36,6 +37,11 @@ def test_strongly_stationary(aggregator: Mean, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: Mean, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_representations() -> None: A = Mean() assert repr(A) == "Mean()" diff --git a/tests/unit/aggregation/test_mgda.py b/tests/unit/aggregation/test_mgda.py index 5c925b8f..69c9b9d8 100644 --- a/tests/unit/aggregation/test_mgda.py +++ b/tests/unit/aggregation/test_mgda.py @@ -11,6 +11,7 @@ assert_expected_structure, assert_non_conflicting, assert_permutation_invariant, + assert_stateless, ) from ._inputs import scaled_matrices, typical_matrices @@ -33,6 +34,11 @@ def test_permutation_invariant(aggregator: MGDA, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: MGDA, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + @mark.parametrize( "shape", [ diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index d82fca41..68d78ecc 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -10,7 +10,7 @@ pytest.skip("NashMTL dependencies not installed", allow_module_level=True) -from ._asserts import assert_expected_structure, assert_non_differentiable +from ._asserts import assert_expected_structure, assert_non_differentiable, assert_stateful from ._inputs import nash_mtl_matrices @@ -48,6 +48,11 @@ def test_non_differentiable(aggregator: NashMTL, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], standard_pairs) +def test_stateful(aggregator: NashMTL, matrix: Tensor) -> None: + assert_stateful(aggregator, matrix) + + @mark.filterwarnings("ignore: You are solving a parameterized problem that is not DPP.") def test_nash_mtl_reset() -> None: """ diff --git a/tests/unit/aggregation/test_pcgrad.py b/tests/unit/aggregation/test_pcgrad.py index b776071d..79d67d27 100644 --- a/tests/unit/aggregation/test_pcgrad.py +++ b/tests/unit/aggregation/test_pcgrad.py @@ -8,7 +8,7 @@ from torchjd.aggregation._pcgrad import PCGradWeighting from torchjd.aggregation._upgrad import UPGradWeighting -from ._asserts import assert_expected_structure, assert_non_differentiable +from ._asserts import assert_expected_structure, assert_non_differentiable, assert_stateful from ._inputs import scaled_matrices, typical_matrices scaled_pairs = [(PCGrad(), matrix) for matrix in scaled_matrices] @@ -26,6 +26,11 @@ def test_non_differentiable(aggregator: PCGrad, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateful(aggregator: PCGrad, matrix: Tensor) -> None: + assert_stateful(aggregator, matrix) + + @mark.parametrize( "shape", [ diff --git a/tests/unit/aggregation/test_random.py b/tests/unit/aggregation/test_random.py index 77ab7f42..fcc8b08b 100644 --- a/tests/unit/aggregation/test_random.py +++ b/tests/unit/aggregation/test_random.py @@ -3,7 +3,7 @@ from torchjd.aggregation import Random -from ._asserts import assert_expected_structure, assert_strongly_stationary +from ._asserts import assert_expected_structure, assert_stateful, assert_strongly_stationary from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices scaled_pairs = [(Random(), matrix) for matrix in scaled_matrices] @@ -21,6 +21,11 @@ def test_strongly_stationary(aggregator: Random, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateful(aggregator: Random, matrix: Tensor) -> None: + assert_stateful(aggregator, matrix) + + def test_representations() -> None: A = Random() assert repr(A) == "Random()" diff --git a/tests/unit/aggregation/test_sum.py b/tests/unit/aggregation/test_sum.py index 386c507f..757e7e77 100644 --- a/tests/unit/aggregation/test_sum.py +++ b/tests/unit/aggregation/test_sum.py @@ -7,6 +7,7 @@ assert_expected_structure, assert_linear_under_scaling, assert_permutation_invariant, + assert_stateless, assert_strongly_stationary, ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices @@ -36,6 +37,11 @@ def test_strongly_stationary(aggregator: Sum, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: Sum, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_representations() -> None: A = Sum() assert repr(A) == "Sum()" diff --git a/tests/unit/aggregation/test_trimmed_mean.py b/tests/unit/aggregation/test_trimmed_mean.py index 3a6ccb2b..97c027d8 100644 --- a/tests/unit/aggregation/test_trimmed_mean.py +++ b/tests/unit/aggregation/test_trimmed_mean.py @@ -7,7 +7,7 @@ from torchjd.aggregation import TrimmedMean -from ._asserts import assert_expected_structure, assert_permutation_invariant +from ._asserts import assert_expected_structure, assert_permutation_invariant, assert_stateless from ._inputs import scaled_matrices_2_plus_rows, typical_matrices_2_plus_rows scaled_pairs = [(TrimmedMean(trim_number=1), matrix) for matrix in scaled_matrices_2_plus_rows] @@ -24,6 +24,11 @@ def test_permutation_invariant(aggregator: TrimmedMean, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: TrimmedMean, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + @mark.parametrize( ["trim_number", "expectation"], [ diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 1859b662..758c2d4b 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -11,6 +11,7 @@ assert_non_conflicting, assert_non_differentiable, assert_permutation_invariant, + assert_stateless, assert_strongly_stationary, ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices @@ -51,6 +52,11 @@ def test_non_differentiable(aggregator: UPGrad, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.parametrize(["aggregator", "matrix"], typical_pairs) +def test_stateless(aggregator: UPGrad, matrix: Tensor) -> None: + assert_stateless(aggregator, matrix) + + def test_representations() -> None: A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="quadprog") assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='quadprog')" From a6735c0037283135e90f4c9b4c65eb71b8fb74a3 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 17 Apr 2026 10:09:39 +0200 Subject: [PATCH 03/11] Make GradVac, GradDrop, PCGrad and Random Stochastic (borken). --- src/torchjd/aggregation/_graddrop.py | 14 +++++++--- src/torchjd/aggregation/_gradvac.py | 39 ++++++++++++++------------- src/torchjd/aggregation/_mixins.py | 40 ++++++++++++++++++++++++++++ src/torchjd/aggregation/_pcgrad.py | 23 ++++++++++++---- src/torchjd/aggregation/_random.py | 25 +++++++++++++---- 5 files changed, 109 insertions(+), 32 deletions(-) diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index 61c9354e..4a98cfb9 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -6,6 +6,7 @@ from torchjd._linalg import Matrix from ._aggregator_bases import Aggregator +from ._mixins import Stochastic from ._utils.non_differentiable import raise_non_differentiable_error @@ -13,7 +14,7 @@ def _identity(P: Tensor) -> Tensor: return P -class GradDrop(Aggregator): +class GradDrop(Aggregator, Stochastic): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that applies the gradient combination steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign: @@ -24,16 +25,21 @@ class GradDrop(Aggregator): increasing. Defaults to identity. :param leak: The tensor of leak values, determining how much each row is allowed to leak through. Defaults to None, which means no leak. + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. """ - def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None: + def __init__( + self, f: Callable = _identity, leak: Tensor | None = None, seed: int | None = None + ) -> None: if leak is not None and leak.dim() != 1: raise ValueError( "Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = " f"{leak.shape}`.", ) - super().__init__() + Aggregator.__init__(self) + Stochastic.__init__(self, seed=seed) self.f = f self.leak = leak @@ -50,7 +56,7 @@ def forward(self, matrix: Matrix, /) -> Tensor: P = 0.5 * (torch.ones_like(matrix[0]) + matrix.sum(dim=0) / matrix.abs().sum(dim=0)) fP = self.f(P) - U = torch.rand(P.shape, dtype=matrix.dtype, device=matrix.device) + U = torch.rand(P.shape, dtype=matrix.dtype, device=matrix.device, generator=self.generator) vector = torch.zeros_like(matrix[0]) for i in range(len(matrix)): diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index efb55f44..f3e13abd 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -6,16 +6,16 @@ from torch import Tensor from torchjd._linalg import PSDMatrix -from torchjd.aggregation._mixins import Stateful +from torchjd.aggregation._mixins import Stochastic from ._aggregator_bases import GramianWeightedAggregator from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import Weighting -class GradVac(GramianWeightedAggregator, Stateful): +class GradVac(GramianWeightedAggregator, Stochastic): r""" - :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation._mixins.Stochastic` :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) @@ -35,16 +35,14 @@ class GradVac(GramianWeightedAggregator, Stateful): :param beta: EMA decay for :math:`\hat{\phi}`. :param eps: Small non-negative constant added to denominators. - - .. note:: - For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently - using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if - you need reproducibility. + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. """ - def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: - weighting = GradVacWeighting(beta=beta, eps=eps) - super().__init__(weighting) + def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None: + weighting = GradVacWeighting(beta=beta, eps=eps, seed=seed) + GramianWeightedAggregator.__init__(self, weighting) + Stochastic.__init__(self, generator=weighting.generator) self._gradvac_weighting = weighting self.register_full_backward_pre_hook(raise_non_differentiable_error) @@ -65,17 +63,18 @@ def eps(self, value: float) -> None: self._gradvac_weighting.eps = value def reset(self) -> None: - """Clears EMA state so the next forward starts from zero targets.""" + """Resets the random number generator and clears the EMA state.""" + Stochastic.reset(self) self._gradvac_weighting.reset() def __repr__(self) -> str: return f"GradVac(beta={self.beta!r}, eps={self.eps!r})" -class GradVacWeighting(Weighting[PSDMatrix], Stateful): +class GradVacWeighting(Weighting[PSDMatrix], Stochastic): r""" - :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation._mixins.Stochastic` :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.GradVac`. @@ -97,10 +96,13 @@ class GradVacWeighting(Weighting[PSDMatrix], Stateful): :param beta: EMA decay for :math:`\hat{\phi}`. :param eps: Small non-negative constant added to denominators. + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. """ - def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: - super().__init__() + def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None: + Weighting.__init__(self) + Stochastic.__init__(self, seed=seed) if not (0.0 <= beta <= 1.0): raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.") if eps < 0.0: @@ -132,8 +134,9 @@ def eps(self, value: float) -> None: self._eps = value def reset(self) -> None: - """Clears EMA state so the next forward starts from zero targets.""" + """Resets the random number generator and clears the EMA state.""" + Stochastic.reset(self) self._phi_t = None self._state_key = None @@ -161,7 +164,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: cG = C[i] @ G others = [j for j in range(m) if j != i] - perm = torch.randperm(len(others)) + perm = torch.randperm(len(others), generator=self.generator) shuffled_js = [others[idx] for idx in perm.tolist()] for j in shuffled_js: diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py index 7730b973..f96346f8 100644 --- a/src/torchjd/aggregation/_mixins.py +++ b/src/torchjd/aggregation/_mixins.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +import torch + class Stateful(ABC): r""" @@ -18,3 +20,41 @@ class Stateful(ABC): @abstractmethod def reset(self) -> None: """Resets the internal state :math:`s_0`.""" + + +class Stochastic(Stateful, ABC): + r""" + Stateful mixin that represents mappings that have inherent randomness. + + Internally, a ``Stochastic`` mapping holds a :class:`torch.Generator` that serves as an + independent random number stream. Implementing classes must pass this generator to all torch + random functions via their ``generator`` argument, e.g.: + + .. code-block:: python + + torch.rand(n, generator=self.generator) + torch.randn(n, generator=self.generator) + torch.randperm(n, generator=self.generator) + + :param seed: Seed for the internal :class:`torch.Generator`. If ``None``, a seed is drawn + from the global PyTorch RNG to fork an independent stream. + :param generator: An existing :class:`torch.Generator` to share, typically from a companion + :class:`Stochastic` instance (e.g. a :class:`Weighting` sharing the generator of its + :class:`Aggregator`). Mutually exclusive with ``seed``. + """ + + def __init__(self, seed: int | None = None, generator: torch.Generator | None = None) -> None: + if generator is not None and seed is not None: + raise ValueError("Parameters `seed` and `generator` are mutually exclusive.") + if generator is not None: + self.generator = generator + else: + self.generator = torch.Generator() + if seed is None: + seed = int(torch.randint(0, 2**62, size=(1,), dtype=torch.int64).item()) + self.generator.manual_seed(seed) + self._initial_rng_state = self.generator.get_state() + + def reset(self) -> None: + """Resets the random number generator to its initial state.""" + self.generator.set_state(self._initial_rng_state) diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index 0f1241df..40715f17 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -6,29 +6,42 @@ from torchjd._linalg import PSDMatrix from ._aggregator_bases import GramianWeightedAggregator +from ._mixins import Stochastic from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import Weighting -class PCGrad(GramianWeightedAggregator): +class PCGrad(GramianWeightedAggregator, Stochastic): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of `Gradient Surgery for Multi-Task Learning `_. + + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. """ - def __init__(self) -> None: - super().__init__(PCGradWeighting()) + def __init__(self, seed: int | None = None) -> None: + weighting = PCGradWeighting(seed=seed) + GramianWeightedAggregator.__init__(self, weighting) + Stochastic.__init__(self, generator=weighting.generator) # This prevents running into a RuntimeError due to modifying stored tensors in place. self.register_full_backward_pre_hook(raise_non_differentiable_error) -class PCGradWeighting(Weighting[PSDMatrix]): +class PCGradWeighting(Weighting[PSDMatrix], Stochastic): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.PCGrad`. + + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. """ + def __init__(self, seed: int | None = None) -> None: + Weighting.__init__(self) + Stochastic.__init__(self, seed=seed) + def forward(self, gramian: PSDMatrix, /) -> Tensor: # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration device = gramian.device @@ -40,7 +53,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: weights = torch.zeros(dimension, device=cpu, dtype=dtype) for i in range(dimension): - permutation = torch.randperm(dimension) + permutation = torch.randperm(dimension, generator=self.generator) current_weights = torch.zeros(dimension, device=cpu, dtype=dtype) current_weights[i] = 1.0 diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 734dfc17..0317c174 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -5,28 +5,43 @@ from torchjd._linalg import Matrix from ._aggregator_bases import WeightedAggregator +from ._mixins import Stochastic from ._weighting_bases import Weighting -class Random(WeightedAggregator): +class Random(WeightedAggregator, Stochastic): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning `_. + + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. """ - def __init__(self) -> None: - super().__init__(RandomWeighting()) + def __init__(self, seed: int | None = None) -> None: + weighting = RandomWeighting(seed=seed) + WeightedAggregator.__init__(self, weighting) + Stochastic.__init__(self, generator=weighting.generator) -class RandomWeighting(Weighting[Matrix]): +class RandomWeighting(Weighting[Matrix], Stochastic): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights at each call. + + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. """ + def __init__(self, seed: int | None = None) -> None: + Weighting.__init__(self) + Stochastic.__init__(self, seed=seed) + def forward(self, matrix: Tensor, /) -> Tensor: - random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype) + random_vector = torch.randn( + matrix.shape[0], device=matrix.device, dtype=matrix.dtype, generator=self.generator + ) weights = F.softmax(random_vector, dim=-1) return weights From df450a78c93c74cde8544c4c3d2ecffa92a96af9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 13:12:14 +0200 Subject: [PATCH 04/11] Fix mistakes when resolving merge conflict --- src/torchjd/aggregation/_gradvac.py | 27 ++++++++++----------------- src/torchjd/aggregation/_pcgrad.py | 13 ++++++++----- src/torchjd/aggregation/_random.py | 13 ++++++++----- 3 files changed, 26 insertions(+), 27 deletions(-) diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 19334844..edb68ff6 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -137,9 +137,9 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None: self._state_key = key -class GradVac(GramianWeightedAggregator, Stateful): +class GradVac(GramianWeightedAggregator, Stochastic): r""" - :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation._mixins.Stochastic` :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) @@ -159,22 +159,14 @@ class GradVac(GramianWeightedAggregator, Stateful): :param beta: EMA decay for :math:`\hat{\phi}`. :param eps: Small non-negative constant added to denominators. - - .. note:: - For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently - using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if - you need reproducibility. - - .. note:: - To apply GradVac with the `whole_model`, `enc_dec`, `all_layer` or `all_matrix` grouping - strategy, please refer to the :doc:`Grouping ` examples. + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. """ - gramian_weighting: GradVacWeighting - - def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None: - weighting = GradVacWeighting(beta=beta, eps=eps) - super().__init__(weighting) + def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None: + weighting = GradVacWeighting(beta=beta, eps=eps, seed=seed) + GramianWeightedAggregator.__init__(self, weighting) + Stochastic.__init__(self, generator=weighting.generator) self._gradvac_weighting = weighting self.register_full_backward_pre_hook(raise_non_differentiable_error) @@ -195,8 +187,9 @@ def eps(self, value: float) -> None: self._gradvac_weighting.eps = value def reset(self) -> None: - """Clears EMA state so the next forward starts from zero targets.""" + """Resets the random number generator and clears the EMA state.""" + Stochastic.reset(self) self._gradvac_weighting.reset() def __repr__(self) -> str: diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index 691ca1b8..24e281e4 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -54,16 +54,19 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: return weights.to(device) -class PCGrad(GramianWeightedAggregator): +class PCGrad(GramianWeightedAggregator, Stochastic): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of `Gradient Surgery for Multi-Task Learning `_. - """ - gramian_weighting: PCGradWeighting + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. + """ - def __init__(self) -> None: - super().__init__(PCGradWeighting()) + def __init__(self, seed: int | None = None) -> None: + weighting = PCGradWeighting(seed=seed) + GramianWeightedAggregator.__init__(self, weighting) + Stochastic.__init__(self, generator=weighting.generator) # This prevents running into a RuntimeError due to modifying stored tensors in place. self.register_full_backward_pre_hook(raise_non_differentiable_error) diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index 2f146459..4d602b57 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -30,15 +30,18 @@ def forward(self, matrix: Tensor, /) -> Tensor: return weights -class Random(WeightedAggregator): +class Random(WeightedAggregator, Stochastic): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning `_. - """ - weighting: RandomWeighting + :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from + the global PyTorch RNG to fork an independent stream. + """ - def __init__(self) -> None: - super().__init__(RandomWeighting()) + def __init__(self, seed: int | None = None) -> None: + weighting = RandomWeighting(seed=seed) + WeightedAggregator.__init__(self, weighting) + Stochastic.__init__(self, generator=weighting.generator) From b6c6d8b805692691edf1cd03385c408b5a763d65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 13:15:43 +0200 Subject: [PATCH 05/11] Fix --- src/torchjd/aggregation/_gradvac.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index edb68ff6..9c605a4b 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -161,6 +161,10 @@ class GradVac(GramianWeightedAggregator, Stochastic): :param eps: Small non-negative constant added to denominators. :param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from the global PyTorch RNG to fork an independent stream. + + .. note:: + To apply GradVac with the `whole_model`, `enc_dec`, `all_layer` or `all_matrix` grouping + strategy, please refer to the :doc:`Grouping ` examples. """ def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None: From 0bf7626be7dcae10a0f425cada19f391922c7cb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 13:18:15 +0200 Subject: [PATCH 06/11] Fix nashmtl test --- tests/unit/aggregation/test_nash_mtl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index 68d78ecc..0bc9ee64 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -48,6 +48,9 @@ def test_non_differentiable(aggregator: NashMTL, matrix: Tensor) -> None: assert_non_differentiable(aggregator, matrix) +@mark.filterwarnings( + "ignore:Solution may be inaccurate.", +) @mark.parametrize(["aggregator", "matrix"], standard_pairs) def test_stateful(aggregator: NashMTL, matrix: Tensor) -> None: assert_stateful(aggregator, matrix) From 1bc6be248a364be56437a42d57c4532ee6ba16e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 13:20:09 +0200 Subject: [PATCH 07/11] Fix nahsmtl --- tests/unit/aggregation/test_nash_mtl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/aggregation/test_nash_mtl.py b/tests/unit/aggregation/test_nash_mtl.py index 0bc9ee64..9b7cfd52 100644 --- a/tests/unit/aggregation/test_nash_mtl.py +++ b/tests/unit/aggregation/test_nash_mtl.py @@ -50,6 +50,7 @@ def test_non_differentiable(aggregator: NashMTL, matrix: Tensor) -> None: @mark.filterwarnings( "ignore:Solution may be inaccurate.", + "ignore:You are solving a parameterized problem that is not DPP.", ) @mark.parametrize(["aggregator", "matrix"], standard_pairs) def test_stateful(aggregator: NashMTL, matrix: Tensor) -> None: From 18aef6c44c3cb713314dc618e09e59afa0354d23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 13:22:23 +0200 Subject: [PATCH 08/11] Make Stochastic public --- src/torchjd/aggregation/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 400cfe27..5016458b 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -71,7 +71,7 @@ from ._krum import Krum, KrumWeighting from ._mean import Mean, MeanWeighting from ._mgda import MGDA, MGDAWeighting -from ._mixins import Stateful +from ._mixins import Stateful, Stochastic from ._pcgrad import PCGrad, PCGradWeighting from ._random import Random, RandomWeighting from ._sum import Sum, SumWeighting @@ -109,6 +109,7 @@ "Random", "RandomWeighting", "Stateful", + "Stochastic", "Sum", "SumWeighting", "TrimmedMean", From 62a269fd722ce2ef25d463cf4cbb0cf421dd70cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 13:22:28 +0200 Subject: [PATCH 09/11] Fix doc --- docs/source/docs/aggregation/index.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 73442a93..57ab1a8f 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -16,6 +16,9 @@ Abstract base classes .. autoclass:: torchjd.aggregation.Stateful :members: reset +.. autoclass:: torchjd.aggregation.Stochastic + :members: reset + .. toctree:: :hidden: From 477f19fe746835b3403961a049bc4dadd90c044e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 13:30:28 +0200 Subject: [PATCH 10/11] Fix asserts --- tests/unit/aggregation/_asserts.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/unit/aggregation/_asserts.py b/tests/unit/aggregation/_asserts.py index 09d824c3..b4d8c738 100644 --- a/tests/unit/aggregation/_asserts.py +++ b/tests/unit/aggregation/_asserts.py @@ -1,5 +1,4 @@ import torch -from numpy.ma.testutils import assert_allclose from pytest import raises from torch import Tensor from torch.testing import assert_close @@ -126,8 +125,8 @@ def assert_stateful(aggregator: Aggregator, matrix: Tensor) -> None: aggregator.reset() second_pair = (aggregator(matrix), aggregator(matrix)) - assert_allclose(first_pair[0], second_pair[0], atol=0.0, rtol=0.0) - assert_allclose(first_pair[1], second_pair[1], atol=0.0, rtol=0.0) + assert_close(first_pair[0], second_pair[0], atol=0.0, rtol=0.0) + assert_close(first_pair[1], second_pair[1], atol=0.0, rtol=0.0) def assert_stateless(aggregator: Aggregator, matrix: Tensor) -> None: @@ -140,4 +139,4 @@ def assert_stateless(aggregator: Aggregator, matrix: Tensor) -> None: first = aggregator(matrix) second = aggregator(matrix) - assert_allclose(first, second, atol=0.0, rtol=0.0) + assert_close(first, second, atol=0.0, rtol=0.0) From 1d85ba4869298ebec393fcbfa7f95ad5859c83f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 17 Apr 2026 14:00:08 +0200 Subject: [PATCH 11/11] [WIP] transform Stochastic into StochasticState --- src/torchjd/aggregation/_mixins.py | 18 +++++------------- src/torchjd/aggregation/_pcgrad.py | 17 ++++++++--------- 2 files changed, 13 insertions(+), 22 deletions(-) diff --git a/src/torchjd/aggregation/_mixins.py b/src/torchjd/aggregation/_mixins.py index f96346f8..90f0f9a6 100644 --- a/src/torchjd/aggregation/_mixins.py +++ b/src/torchjd/aggregation/_mixins.py @@ -22,25 +22,17 @@ def reset(self) -> None: """Resets the internal state :math:`s_0`.""" -class Stochastic(Stateful, ABC): +class StochasticState(Stateful): r""" - Stateful mixin that represents mappings that have inherent randomness. + State respresenting stochasticity. - Internally, a ``Stochastic`` mapping holds a :class:`torch.Generator` that serves as an - independent random number stream. Implementing classes must pass this generator to all torch - random functions via their ``generator`` argument, e.g.: - - .. code-block:: python - - torch.rand(n, generator=self.generator) - torch.randn(n, generator=self.generator) - torch.randperm(n, generator=self.generator) + Internally, a ``StochasticState`` mapping holds a :class:`torch.Generator` that serves as an + independent random number stream. :param seed: Seed for the internal :class:`torch.Generator`. If ``None``, a seed is drawn from the global PyTorch RNG to fork an independent stream. :param generator: An existing :class:`torch.Generator` to share, typically from a companion - :class:`Stochastic` instance (e.g. a :class:`Weighting` sharing the generator of its - :class:`Aggregator`). Mutually exclusive with ``seed``. + :class:`StochasticState` instance. Mutually exclusive with ``seed``. """ def __init__(self, seed: int | None = None, generator: torch.Generator | None = None) -> None: diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index 24e281e4..26ee42f4 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -4,14 +4,15 @@ from torch import Tensor from torchjd._linalg import PSDMatrix +from torchjd.aggregation import Stateful +from torchjd.aggregation._mixins import StochasticState from ._aggregator_bases import GramianWeightedAggregator -from ._mixins import Stochastic from ._utils.non_differentiable import raise_non_differentiable_error from ._weighting_bases import Weighting -class PCGradWeighting(Weighting[PSDMatrix], Stochastic): +class PCGradWeighting(Weighting[PSDMatrix], Stateful): """ :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of :class:`~torchjd.aggregation.PCGrad`. @@ -21,8 +22,8 @@ class PCGradWeighting(Weighting[PSDMatrix], Stochastic): """ def __init__(self, seed: int | None = None) -> None: - Weighting.__init__(self) - Stochastic.__init__(self, seed=seed) + super().__init__() + self.state = StochasticState(seed=seed) def forward(self, gramian: PSDMatrix, /) -> Tensor: # Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration @@ -35,7 +36,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: weights = torch.zeros(dimension, device=cpu, dtype=dtype) for i in range(dimension): - permutation = torch.randperm(dimension, generator=self.generator) + permutation = torch.randperm(dimension, generator=self.state.generator) current_weights = torch.zeros(dimension, device=cpu, dtype=dtype) current_weights[i] = 1.0 @@ -54,7 +55,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: return weights.to(device) -class PCGrad(GramianWeightedAggregator, Stochastic): +class PCGrad(GramianWeightedAggregator, Stateful): """ :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of `Gradient Surgery for Multi-Task Learning `_. @@ -64,9 +65,7 @@ class PCGrad(GramianWeightedAggregator, Stochastic): """ def __init__(self, seed: int | None = None) -> None: - weighting = PCGradWeighting(seed=seed) - GramianWeightedAggregator.__init__(self, weighting) - Stochastic.__init__(self, generator=weighting.generator) + super().__init__(PCGradWeighting(seed=seed)) # This prevents running into a RuntimeError due to modifying stored tensors in place. self.register_full_backward_pre_hook(raise_non_differentiable_error)