diff --git a/.vscode/settings.json b/.vscode/settings.json index 59bf351..e2a9723 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +1,5 @@ { + "python.analysis.typeCheckingMode": "strict", "python.testing.pytestEnabled": true, "[python]": { "editor.defaultFormatter": "charliermarsh.ruff", diff --git a/src/scverse_plotting_api/__init__.py b/src/scverse_plotting_api/__init__.py index fa1b5ab..50bcc06 100644 --- a/src/scverse_plotting_api/__init__.py +++ b/src/scverse_plotting_api/__init__.py @@ -1,37 +1,25 @@ from __future__ import annotations -from typing import TYPE_CHECKING, overload +from typing import TYPE_CHECKING import scanpy as sc -from scverse_plotting_api.helpers import plot_context +from .helpers import plot_decorator if TYPE_CHECKING: from anndata import AnnData - from matplotlib.axes import Axes from matplotlib.figure import Figure - from matplotlib.gridspec import SubplotSpec - - -@overload -def plot_umap(adata: AnnData, color: str, *, ax: None = None) -> Figure: - ... - - -@overload -def plot_umap(adata: AnnData, color: str, *, ax: Axes | SubplotSpec) -> None: - ... + from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec +@plot_decorator(nrows=1, ncols=2, width_ratios=(0.7, 0.1)) def plot_umap( - adata: AnnData, color: str, *, ax: Axes | SubplotSpec | None = None -) -> Figure | None: - with plot_context(ax, ncols=2, width_ratios=(0.7, 0.1)) as (fig, gs): - main_ax = fig.add_subplot(gs[0, 0]) - cbar_ax = fig.add_subplot(gs[0, 1]) - x, y = zip(*adata.obsm["X_umap"]) - scatter = main_ax.scatter( - x, y, s=3, c=sc.get.obs_df(adata, color).values, cmap="viridis" - ) - fig.colorbar(scatter, cax=cbar_ax) - return fig if ax is None else None + adata: AnnData, color: str, *, fig: Figure, gs: GridSpec | GridSpecFromSubplotSpec +) -> None: + main_ax = fig.add_subplot(gs[0, 0]) + cbar_ax = fig.add_subplot(gs[0, 1]) + x, y = zip(*adata.obsm["X_umap"]) + scatter = main_ax.scatter( + x, y, s=3, c=sc.get.obs_df(adata, color).values, cmap="viridis" + ) + fig.colorbar(scatter, cax=cbar_ax) diff --git a/src/scverse_plotting_api/helpers.py b/src/scverse_plotting_api/helpers.py index e720961..df40315 100644 --- a/src/scverse_plotting_api/helpers.py +++ b/src/scverse_plotting_api/helpers.py @@ -1,7 +1,10 @@ from __future__ import annotations +import inspect from contextlib import contextmanager -from typing import TYPE_CHECKING, TypedDict +from functools import wraps +from itertools import islice +from typing import TYPE_CHECKING, TypedDict, overload from matplotlib import get_backend from matplotlib.axes import Axes @@ -9,11 +12,34 @@ from matplotlib.gridspec import GridSpec, SubplotSpec if TYPE_CHECKING: - from collections.abc import Generator - from typing import Unpack + from collections.abc import Callable, Generator + from typing import Generic, ParamSpec, Protocol, Unpack from matplotlib.gridspec import GridSpecFromSubplotSpec + P = ParamSpec("P") + + class PlottingImpl(Protocol, Generic[P]): + def __call__( + self, + *args: P.args, + fig: Figure, + gs: GridSpec | GridSpecFromSubplotSpec, + **kwds: P.kwargs, + ) -> None: + ... + + class PlottingAPI(Protocol, Generic[P]): + @overload + def __call__(self, *args: P.args, ax: None = None, **kwds: P.kwargs) -> Figure: + ... + + @overload + def __call__( + self, *args: P.args, ax: Axes | SubplotSpec, **kwds: P.kwargs + ) -> None: + ... + class GSParams(TypedDict, total=False): nrows: int @@ -21,6 +47,59 @@ class GSParams(TypedDict, total=False): width_ratios: tuple[float, float] +def plot_decorator( + **gridspec_params: Unpack[GSParams], +) -> Callable[[PlottingImpl[P]], PlottingAPI[P]]: + def decorator(f: PlottingImpl[P]) -> PlottingAPI[P]: + @overload + def wrapper( + *args: P.args, + ax: None = None, + **kwargs: P.kwargs, + ) -> Figure: + ... + + @overload + def wrapper( + *args: P.args, + ax: Axes | SubplotSpec, + **kwargs: P.kwargs, + ) -> None: + ... + + @wraps(f) + def wrapper( + *args: P.args, + ax: Axes | SubplotSpec | None = None, + **kwargs: P.kwargs, + ) -> Figure | None: + with plot_context(ax, **gridspec_params) as (fig, gs): + rv = f(*args, fig=fig, gs=gs, **kwargs) # type: ignore[func-returns-value] + if rv is not None: + name = getattr(f, "__name__", "function") + msg = f"{name}’ definition should not return anything" + raise TypeError(msg) + return fig if ax is None else None + + ax_param = inspect.Parameter( + "ax", inspect.Parameter.POSITIONAL_ONLY, annotation=Axes | SubplotSpec + ) + del wrapper.__annotations__["fig"] + del wrapper.__annotations__["gs"] + wrapper.__annotations__["ax"] = ax_param.annotation + + wrapper.__signature__ = (sig := inspect.signature(f)).replace( # type: ignore[attr-defined] + parameters=[ + ax_param, + *islice(sig.parameters.values(), 2, None), + ] + ) + + return wrapper + + return decorator + + @contextmanager def plot_context( ax: Axes | SubplotSpec | None = None,