Skip to content
3 changes: 3 additions & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ Abstract base classes
.. autoclass:: torchjd.aggregation.Stateful
:members: reset

.. autoclass:: torchjd.aggregation.Stochastic
:members: reset


.. toctree::
:hidden:
Expand Down
3 changes: 2 additions & 1 deletion src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
from ._krum import Krum, KrumWeighting
from ._mean import Mean, MeanWeighting
from ._mgda import MGDA, MGDAWeighting
from ._mixins import Stateful
from ._mixins import Stateful, Stochastic
from ._pcgrad import PCGrad, PCGradWeighting
from ._random import Random, RandomWeighting
from ._sum import Sum, SumWeighting
Expand Down Expand Up @@ -109,6 +109,7 @@
"Random",
"RandomWeighting",
"Stateful",
"Stochastic",
"Sum",
"SumWeighting",
"TrimmedMean",
Expand Down
14 changes: 10 additions & 4 deletions src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from torchjd._linalg import Matrix

from ._aggregator_bases import Aggregator
from ._mixins import Stochastic
from ._utils.non_differentiable import raise_non_differentiable_error


def _identity(P: Tensor) -> Tensor:
return P


class GradDrop(Aggregator):
class GradDrop(Aggregator, Stochastic):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that applies the gradient combination
steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign:
Expand All @@ -24,16 +25,21 @@ class GradDrop(Aggregator):
increasing. Defaults to identity.
:param leak: The tensor of leak values, determining how much each row is allowed to leak
through. Defaults to None, which means no leak.
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.
"""

def __init__(self, f: Callable = _identity, leak: Tensor | None = None) -> None:
def __init__(
self, f: Callable = _identity, leak: Tensor | None = None, seed: int | None = None
) -> None:
if leak is not None and leak.dim() != 1:
raise ValueError(
"Parameter `leak` should be a 1-dimensional tensor. Found `leak.shape = "
f"{leak.shape}`.",
)

super().__init__()
Aggregator.__init__(self)
Stochastic.__init__(self, seed=seed)
self.f = f
self.leak = leak

Expand All @@ -50,7 +56,7 @@ def forward(self, matrix: Matrix, /) -> Tensor:

P = 0.5 * (torch.ones_like(matrix[0]) + matrix.sum(dim=0) / matrix.abs().sum(dim=0))
fP = self.f(P)
U = torch.rand(P.shape, dtype=matrix.dtype, device=matrix.device)
U = torch.rand(P.shape, dtype=matrix.dtype, device=matrix.device, generator=self.generator)

vector = torch.zeros_like(matrix[0])
for i in range(len(matrix)):
Expand Down
41 changes: 21 additions & 20 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
from torch import Tensor

from torchjd._linalg import PSDMatrix
from torchjd.aggregation._mixins import Stateful
from torchjd.aggregation._mixins import Stochastic

from ._aggregator_bases import GramianWeightedAggregator
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import Weighting


class GradVacWeighting(Weighting[PSDMatrix], Stateful):
class GradVacWeighting(Weighting[PSDMatrix], Stochastic):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._mixins.Stochastic`
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GradVac`.

Expand All @@ -37,10 +37,13 @@ class GradVacWeighting(Weighting[PSDMatrix], Stateful):

:param beta: EMA decay for :math:`\hat{\phi}`.
:param eps: Small non-negative constant added to denominators.
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.
"""

def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
super().__init__()
def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None:
Weighting.__init__(self)
Stochastic.__init__(self, seed=seed)
if not (0.0 <= beta <= 1.0):
raise ValueError(f"Parameter `beta` must be in [0, 1]. Found beta={beta!r}.")
if eps < 0.0:
Expand Down Expand Up @@ -72,8 +75,9 @@ def eps(self, value: float) -> None:
self._eps = value

def reset(self) -> None:
"""Clears EMA state so the next forward starts from zero targets."""
"""Resets the random number generator and clears the EMA state."""

Stochastic.reset(self)
self._phi_t = None
self._state_key = None

Expand Down Expand Up @@ -101,7 +105,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
cG = C[i] @ G

others = [j for j in range(m) if j != i]
perm = torch.randperm(len(others))
perm = torch.randperm(len(others), generator=self.generator)
shuffled_js = [others[idx] for idx in perm.tolist()]

for j in shuffled_js:
Expand Down Expand Up @@ -133,9 +137,9 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None:
self._state_key = key


class GradVac(GramianWeightedAggregator, Stateful):
class GradVac(GramianWeightedAggregator, Stochastic):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._mixins.Stochastic`
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
Expand All @@ -155,22 +159,18 @@ class GradVac(GramianWeightedAggregator, Stateful):

:param beta: EMA decay for :math:`\hat{\phi}`.
:param eps: Small non-negative constant added to denominators.

.. note::
For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if
you need reproducibility.
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.

.. note::
To apply GradVac with the `whole_model`, `enc_dec`, `all_layer` or `all_matrix` grouping
strategy, please refer to the :doc:`Grouping </examples/grouping>` examples.
"""

gramian_weighting: GradVacWeighting

def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:
weighting = GradVacWeighting(beta=beta, eps=eps)
super().__init__(weighting)
def __init__(self, beta: float = 0.5, eps: float = 1e-8, seed: int | None = None) -> None:
weighting = GradVacWeighting(beta=beta, eps=eps, seed=seed)
GramianWeightedAggregator.__init__(self, weighting)
Stochastic.__init__(self, generator=weighting.generator)
self._gradvac_weighting = weighting
self.register_full_backward_pre_hook(raise_non_differentiable_error)

Expand All @@ -191,8 +191,9 @@ def eps(self, value: float) -> None:
self._gradvac_weighting.eps = value

def reset(self) -> None:
"""Clears EMA state so the next forward starts from zero targets."""
"""Resets the random number generator and clears the EMA state."""

Stochastic.reset(self)
self._gradvac_weighting.reset()

def __repr__(self) -> str:
Expand Down
55 changes: 53 additions & 2 deletions src/torchjd/aggregation/_mixins.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,60 @@
from abc import ABC, abstractmethod

import torch


class Stateful(ABC):
"""Mixin adding a reset method."""
r"""
Mixin for stateful mappings.

A maping implements `Stateful` **if and only if** its behavior depends on an internal
state.

Formally, a stateless mapping is a function :math:`f : x \mapsto y` whereas a stateful
maping is a transition map :math:`A : (x, s) \mapsto (y, s')` where :math:`s` is the
internal state, :math:`s'` the updated state, and :math:`y` the output.
There exists an initial state :math:`s_0`, and the method `reset()` restores the state to
:math:`s_0`. A `Stateful` mapping must be constructed with the intial state :math:`s_0`.
"""

@abstractmethod
def reset(self) -> None:
"""Resets the internal state."""
"""Resets the internal state :math:`s_0`."""


class Stochastic(Stateful, ABC):
r"""
Stateful mixin that represents mappings that have inherent randomness.

Internally, a ``Stochastic`` mapping holds a :class:`torch.Generator` that serves as an
independent random number stream. Implementing classes must pass this generator to all torch
random functions via their ``generator`` argument, e.g.:

.. code-block:: python

torch.rand(n, generator=self.generator)
torch.randn(n, generator=self.generator)
torch.randperm(n, generator=self.generator)

:param seed: Seed for the internal :class:`torch.Generator`. If ``None``, a seed is drawn
from the global PyTorch RNG to fork an independent stream.
:param generator: An existing :class:`torch.Generator` to share, typically from a companion
:class:`Stochastic` instance (e.g. a :class:`Weighting` sharing the generator of its
:class:`Aggregator`). Mutually exclusive with ``seed``.
"""

def __init__(self, seed: int | None = None, generator: torch.Generator | None = None) -> None:
if generator is not None and seed is not None:
raise ValueError("Parameters `seed` and `generator` are mutually exclusive.")
if generator is not None:
self.generator = generator
else:
self.generator = torch.Generator()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generator requires a device, so this wont work for cuda I think. And we don't know the device at that point, so I don't think it's easy to fix.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This leads to RuntimeError: Expected a 'cuda' device type for generator but found 'cpu'

Need a different implementation for Stochastic I think.

if seed is None:
seed = int(torch.randint(0, 2**62, size=(1,), dtype=torch.int64).item())
self.generator.manual_seed(seed)
self._initial_rng_state = self.generator.get_state()

def reset(self) -> None:
"""Resets the random number generator to its initial state."""
self.generator.set_state(self._initial_rng_state)
25 changes: 18 additions & 7 deletions src/torchjd/aggregation/_pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,24 @@
from torchjd._linalg import PSDMatrix

from ._aggregator_bases import GramianWeightedAggregator
from ._mixins import Stochastic
from ._utils.non_differentiable import raise_non_differentiable_error
from ._weighting_bases import Weighting


class PCGradWeighting(Weighting[PSDMatrix]):
class PCGradWeighting(Weighting[PSDMatrix], Stochastic):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.PCGrad`.

:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.
"""

def __init__(self, seed: int | None = None) -> None:
Weighting.__init__(self)
Stochastic.__init__(self, seed=seed)

def forward(self, gramian: PSDMatrix, /) -> Tensor:
# Move all computations on cpu to avoid moving memory between cpu and gpu at each iteration
device = gramian.device
Expand All @@ -27,7 +35,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
weights = torch.zeros(dimension, device=cpu, dtype=dtype)

for i in range(dimension):
permutation = torch.randperm(dimension)
permutation = torch.randperm(dimension, generator=self.generator)
current_weights = torch.zeros(dimension, device=cpu, dtype=dtype)
current_weights[i] = 1.0

Expand All @@ -46,16 +54,19 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
return weights.to(device)


class PCGrad(GramianWeightedAggregator):
class PCGrad(GramianWeightedAggregator, Stochastic):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of
`Gradient Surgery for Multi-Task Learning <https://arxiv.org/pdf/2001.06782.pdf>`_.
"""

gramian_weighting: PCGradWeighting
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.
"""

def __init__(self) -> None:
super().__init__(PCGradWeighting())
def __init__(self, seed: int | None = None) -> None:
weighting = PCGradWeighting(seed=seed)
GramianWeightedAggregator.__init__(self, weighting)
Stochastic.__init__(self, generator=weighting.generator)

# This prevents running into a RuntimeError due to modifying stored tensors in place.
self.register_full_backward_pre_hook(raise_non_differentiable_error)
27 changes: 20 additions & 7 deletions src/torchjd/aggregation/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,43 @@
from torchjd._linalg import Matrix

from ._aggregator_bases import WeightedAggregator
from ._mixins import Stochastic
from ._weighting_bases import Weighting


class RandomWeighting(Weighting[Matrix]):
class RandomWeighting(Weighting[Matrix], Stochastic):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights
at each call.

:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.
"""

def __init__(self, seed: int | None = None) -> None:
Weighting.__init__(self)
Stochastic.__init__(self, seed=seed)

def forward(self, matrix: Tensor, /) -> Tensor:
random_vector = torch.randn(matrix.shape[0], device=matrix.device, dtype=matrix.dtype)
random_vector = torch.randn(
matrix.shape[0], device=matrix.device, dtype=matrix.dtype, generator=self.generator
)
weights = F.softmax(random_vector, dim=-1)
return weights


class Random(WeightedAggregator):
class Random(WeightedAggregator, Stochastic):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of
the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of
Random Weighting: A Litmus Test for Multi-Task Learning
<https://arxiv.org/pdf/2111.10603.pdf>`_.
"""

weighting: RandomWeighting
:param seed: Seed for the internal random number generator. If ``None``, a seed is drawn from
the global PyTorch RNG to fork an independent stream.
"""

def __init__(self) -> None:
super().__init__(RandomWeighting())
def __init__(self, seed: int | None = None) -> None:
weighting = RandomWeighting(seed=seed)
WeightedAggregator.__init__(self, weighting)
Stochastic.__init__(self, generator=weighting.generator)
Loading
Loading