From 5ebfe351e339960b090b7ba817eb3e0039c408d6 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Mon, 20 Apr 2026 13:08:14 +1000 Subject: [PATCH 1/8] refactor step 1 --- ultraplot/axes/base.py | 2 +- ultraplot/figure.py | 388 ++++------------------------------------- ultraplot/gridspec.py | 6 +- 3 files changed, 36 insertions(+), 360 deletions(-) diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 29a962e14..613a9cf1a 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -1365,7 +1365,7 @@ def shared(paxs): child._sharey_setup(parent) # Global sharing, use the reference subplot where compatible - ref = self.figure._subplot_dict.get(self.figure._refnum, None) + ref = self.figure._subplots.subplot_dict.get(self.figure._refnum, None) if self is not ref and ref is not None: if self.figure._sharex > 3: ok, reason = self.figure._share_axes_compatible(ref, self, "x") diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 334c61a44..8c63b4d9e 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -43,6 +43,7 @@ labels, warnings, ) +from ._subplots import SubplotManager from .utils import _Crawler, units __all__ = [ @@ -881,10 +882,8 @@ def _init_figure_state(self, figwidth, figheight, kwargs): Initialize internal state, call matplotlib's Figure.__init__, set up super labels, and apply initial formatting. """ - self._gridspec = None + self._subplots = SubplotManager(self) self._panel_dict = {"left": [], "right": [], "bottom": [], "top": []} - self._subplot_dict = {} - self._subplot_counter = 0 self._is_adjusting = False self._is_authorized = False self._layout_initialized = False @@ -1533,107 +1532,12 @@ def _context_authorized(self): @staticmethod def _parse_backend(backend=None, basemap=None): - """ - Handle deprecation of basemap and cartopy package. - """ - # Basemap is currently being developed again so are removing the deprecation warning - if backend == "basemap": - warnings._warn_ultraplot( - f"{backend=} will be deprecated in next major release (v2.0). See https://github.com/Ultraplot/ultraplot/pull/243" - ) - return backend + """Delegate to SubplotManager.""" + return SubplotManager.parse_backend(backend, basemap) - def _parse_proj( - self, - proj=None, - projection=None, - proj_kw=None, - projection_kw=None, - backend=None, - basemap=None, - **kwargs, - ): - """ - Translate the user-input projection into a registered matplotlib - axes class. Input projection can be a string, `matplotlib.axes.Axes`, - `cartopy.crs.Projection`, or `mpl_toolkits.basemap.Basemap`. - """ - # Parse arguments - proj = _not_none(proj=proj, projection=projection, default="cartesian") - proj_kw = _not_none(proj_kw=proj_kw, projection_kw=projection_kw, default={}) - backend = self._parse_backend(backend, basemap) - if isinstance(proj, str): - proj = proj.lower() - if isinstance(self, paxes.Axes): - proj = self._name - elif isinstance(self, maxes.Axes): - raise ValueError("Matplotlib axes cannot be added to ultraplot figures.") - - # Search axes projections - name = None - - # Handle cartopy/basemap Projection objects directly - # These should be converted to Ultraplot GeoAxes - if not isinstance(proj, str): - # Check if it's a cartopy or basemap projection object - if constructor.Projection is not object and isinstance( - proj, constructor.Projection - ): - # It's a cartopy projection - use cartopy backend - name = "ultraplot_cartopy" - kwargs["map_projection"] = proj - elif constructor.Basemap is not object and isinstance( - proj, constructor.Basemap - ): - # It's a basemap projection - name = "ultraplot_basemap" - kwargs["map_projection"] = proj - # If not recognized, leave name as None and it will pass through - - if name is None and isinstance(proj, str): - try: - mproj.get_projection_class("ultraplot_" + proj) - except (KeyError, ValueError): - pass - else: - name = "ultraplot_" + proj - if name is None and isinstance(proj, str): - # Try geographic projections first if cartopy/basemap available - if ( - constructor.Projection is not object - or constructor.Basemap is not object - ): - try: - proj_obj = constructor.Proj( - proj, backend=backend, include_axes=True, **proj_kw - ) - name = "ultraplot_" + proj_obj._proj_backend - kwargs["map_projection"] = proj_obj - except ValueError: - # Not a geographic projection, will try matplotlib registry below - pass - - # If not geographic, check if registered globally in Matplotlib (e.g., 'ternary', 'polar', '3d') - if name is None and proj in mproj.get_projection_names(): - name = proj - - # Helpful error message if still not found - if name is None and isinstance(proj, str): - raise ValueError( - f"Invalid projection name {proj!r}. If you are trying to generate a " - "GeoAxes with a cartopy.crs.Projection or mpl_toolkits.basemap.Basemap " - "then cartopy or basemap must be installed. Otherwise the known axes " - f"subclasses are:\n{paxes._cls_table}" - ) - - # Only set projection if we found a named projection - # Otherwise preserve the original projection (e.g., cartopy Projection objects) - if name is not None: - kwargs["projection"] = name - # If name is None and proj is not a string, it means we have a non-string - # projection (e.g., cartopy.crs.Projection object) that should be passed through - # The original projection kwarg is already in kwargs, so no action needed - return kwargs + def _parse_proj(self, *args, **kwargs): + """Delegate to SubplotManager.""" + return self._subplots.parse_proj(*args, **kwargs) def _get_align_axes(self, side): """ @@ -1642,7 +1546,7 @@ def _get_align_axes(self, side): For 'left'/'right': select one extreme axis per row (leftmost/rightmost). For 'top'/'bottom': select one extreme axis per column (topmost/bottommost). """ - axs = tuple(self._subplot_dict.values()) + axs = tuple(self._subplots.subplot_dict.values()) if not axs: return [] if side not in ("left", "right", "top", "bottom"): @@ -1650,7 +1554,7 @@ def _get_align_axes(self, side): from .utils import _get_subplot_layout grid = _get_subplot_layout( - self._gridspec, list(self._iter_axes(panels=False, hidden=False)) + self.gridspec, list(self._iter_axes(panels=False, hidden=False)) )[0] # From the @side we find the first non-zero # entry in each row or column and collect the axes @@ -2057,142 +1961,8 @@ def _add_figure_panel( @_clear_border_cache def _add_subplot(self, *args, **kwargs): - """ - The driver function for adding single subplots. - """ - self._layout_dirty = True - # Parse arguments - kwargs = self._parse_proj(**kwargs) - - args = args or (1, 1, 1) - gs = self.gridspec - - # Integer arg - if len(args) == 1 and isinstance(args[0], Integral): - if not 111 <= args[0] <= 999: - raise ValueError(f"Input {args[0]} must fall between 111 and 999.") - args = tuple(map(int, str(args[0]))) - - # Subplot spec - if len(args) == 1 and isinstance( - args[0], (maxes.SubplotBase, mgridspec.SubplotSpec) - ): - ss = args[0] - if isinstance(ss, maxes.SubplotBase): - ss = ss.get_subplotspec() - if gs is None: - gs = ss.get_topmost_subplotspec().get_gridspec() - if not isinstance(gs, pgridspec.GridSpec): - raise ValueError( - "Input subplotspec must be derived from a ultraplot.GridSpec." - ) - if ss.get_topmost_subplotspec().get_gridspec() is not gs: - raise ValueError( - "Input subplotspec must be derived from the active figure gridspec." - ) - - # Row and column spec - # TODO: How to pass spacing parameters to gridspec? Consider overriding - # subplots adjust? Or require using gridspec manually? - elif ( - len(args) == 3 - and all(isinstance(arg, Integral) for arg in args[:2]) - and all(isinstance(arg, Integral) for arg in np.atleast_1d(args[2])) - ): - nrows, ncols, num = args - i, j = np.resize(num, 2) - if gs is None: - gs = pgridspec.GridSpec(nrows, ncols) - orows, ocols = gs.get_geometry() - if orows % nrows: - raise ValueError( - f"The input number of rows {nrows} does not divide the " - f"figure gridspec number of rows {orows}." - ) - if ocols % ncols: - raise ValueError( - f"The input number of columns {ncols} does not divide the " - f"figure gridspec number of columns {ocols}." - ) - if any(_ < 1 or _ > nrows * ncols for _ in (i, j)): - raise ValueError( - "The input subplot indices must fall between " - f"1 and {nrows * ncols}. Instead got {i} and {j}." - ) - rowfact, colfact = orows // nrows, ocols // ncols - irow, icol = divmod(i - 1, ncols) # convert to zero-based - jrow, jcol = divmod(j - 1, ncols) - irow, icol = irow * rowfact, icol * colfact - jrow, jcol = (jrow + 1) * rowfact - 1, (jcol + 1) * colfact - 1 - ss = gs[irow : jrow + 1, icol : jcol + 1] - - # Otherwise - else: - raise ValueError(f"Invalid add_subplot positional arguments {args!r}.") - - # Add the subplot - # NOTE: Pass subplotspec as keyword arg for mpl >= 3.4 workaround - # NOTE: Must assign unique label to each subplot or else subsequent calls - # to add_subplot() in mpl < 3.4 may return an already-drawn subplot in the - # wrong location due to gridspec override. Is against OO package design. - self.gridspec = gs # trigger layout adjustment - self._subplot_counter += 1 # unique label for each subplot - kwargs.setdefault("label", f"subplot_{self._subplot_counter}") - kwargs.setdefault("number", 1 + max(self._subplot_dict, default=0)) - kwargs.pop("refwidth", None) # TODO: remove this - - # Use container approach for external projections to make them ultraplot-compatible - projection_name = kwargs.get("projection") - external_axes_class = None - external_axes_kwargs = {} - - if projection_name and isinstance(projection_name, str): - # Check if this is an external (non-ultraplot) projection - # Skip external wrapping for projections that start with "ultraplot_" prefix - # as these are already Ultraplot axes classes - if not projection_name.startswith("ultraplot_"): - try: - # Get the projection class - proj_class = mproj.get_projection_class(projection_name) - - # Check if it's not a built-in ultraplot axes - # Only wrap if it's NOT a subclass of Ultraplot's Axes - if not issubclass(proj_class, paxes.Axes): - # Store the external axes class and original projection name - external_axes_class = proj_class - external_axes_kwargs["projection"] = projection_name - - # Create or get the container class for this external axes type - from .axes.container import create_external_axes_container - - container_name = f"_ultraplot_container_{projection_name}" - - # Check if container is already registered - if container_name not in mproj.get_projection_names(): - container_class = create_external_axes_container( - proj_class, projection_name=container_name - ) - mproj.register_projection(container_class) - - # Use the container projection and pass external axes info - kwargs["projection"] = container_name - kwargs["external_axes_class"] = external_axes_class - kwargs["external_axes_kwargs"] = external_axes_kwargs - except (KeyError, ValueError): - # Projection not found, let matplotlib handle the error - pass - - # Remove _subplot_spec from kwargs if present to prevent it from being passed - # to .set() or other methods that don't accept it. - kwargs.pop("_subplot_spec", None) - - # Pass only the SubplotSpec as a positional argument - # Don't pass _subplot_spec as a keyword argument to avoid it being - # propagated to Axes.set() or other methods that don't accept it - ax = super().add_subplot(ss, **kwargs) - if ax.number: - self._subplot_dict[ax.number] = ax - return ax + """Delegate to SubplotManager.""" + return self._subplots.add_subplot(*args, **kwargs) def _unshare_axes(self): @@ -2287,111 +2057,20 @@ def _add_subplots( basemap=None, **kwargs, ): - """ - The driver function for adding multiple subplots. - """ - - # Clunky helper function - # TODO: Consider deprecating and asking users to use add_subplot() - def _axes_dict(naxs, input, kw=False, default=None): - # First build up dictionary - if not kw: # 'string' or {1: 'string1', (2, 3): 'string2'} - if np.iterable(input) and not isinstance(input, (str, dict)): - input = {num + 1: item for num, item in enumerate(input)} - elif not isinstance(input, dict): - input = {range(1, naxs + 1): input} - else: # {key: value} or {1: {key: value1}, (2, 3): {key: value2}} - nested = [isinstance(_, dict) for _ in input.values()] - if not any(nested): # any([]) == False - input = {range(1, naxs + 1): input.copy()} - elif not all(nested): - raise ValueError(f"Invalid input {input!r}.") - # Unfurl keys that contain multiple axes numbers - output = {} - for nums, item in input.items(): - nums = np.atleast_1d(nums) - for num in nums.flat: - output[num] = item.copy() if kw else item - # Fill with default values - for num in range(1, naxs + 1): - if num not in output: - output[num] = {} if kw else default - if output.keys() != set(range(1, naxs + 1)): - raise ValueError( - f"Have {naxs} axes, but {input!r} includes props for the axes: " - + ", ".join(map(repr, sorted(output))) - + "." - ) - return output - - # Build the subplot array - # NOTE: Currently this may ignore user-input nrows/ncols without warning - if order not in ("C", "F"): # better error message - raise ValueError(f"Invalid order={order!r}. Options are 'C' or 'F'.") - gs = None - if array is None or isinstance(array, mgridspec.GridSpec): - if array is not None: - gs, nrows, ncols = array, array.nrows, array.ncols - array = np.arange(1, nrows * ncols + 1)[..., None] - array = array.reshape((nrows, ncols), order=order) - else: - array = np.atleast_1d(array) - array[array == None] = 0 # None or 0 both valid placeholders # noqa: E711 - array = array.astype(int) - if array.ndim == 1: # interpret as single row or column - array = array[None, :] if order == "C" else array[:, None] - elif array.ndim != 2: - raise ValueError(f"Expected 1D or 2D array of integers. Got {array}.") - - # Parse input format, gridspec, and projection arguments - # NOTE: Permit figure format keywords for e.g. 'collabels' (more intuitive) - nums = np.unique(array[array != 0]) - naxs = len(nums) - if any(num < 0 or not isinstance(num, Integral) for num in nums.flat): - raise ValueError(f"Expected array of positive integers. Got {array}.") - proj = _not_none(projection=projection, proj=proj) - proj = _axes_dict(naxs, proj, kw=False, default="cartesian") - proj_kw = _not_none(projection_kw=projection_kw, proj_kw=proj_kw) or {} - proj_kw = _axes_dict(naxs, proj_kw, kw=True) - backend = self._parse_backend(backend, basemap) - backend = _axes_dict(naxs, backend, kw=False) - axes_kw = { - num: {"proj": proj[num], "proj_kw": proj_kw[num], "backend": backend[num]} - for num in proj - } - for key in ("gridspec_kw", "subplot_kw"): - kw = kwargs.pop(key, None) - if not kw: - continue - warnings._warn_ultraplot( - f"{key!r} is not necessary in ultraplot. Pass the " - "parameters as keyword arguments instead." - ) - kwargs.update(kw or {}) - figure_kw = _pop_params(kwargs, self._format_signature) - gridspec_kw = _pop_params(kwargs, pgridspec.GridSpec._update_params) - - # Create or update the gridspec and add subplots with subplotspecs - # NOTE: The gridspec is added to the figure when we pass the subplotspec - if gs is None: - if "layout_array" not in gridspec_kw: - gridspec_kw = {**gridspec_kw, "layout_array": array} - gs = pgridspec.GridSpec(*array.shape, **gridspec_kw) - else: - gs.update(**gridspec_kw) - axs = naxs * [None] # list of axes - axids = [np.where(array == i) for i in np.sort(np.unique(array)) if i > 0] - axcols = np.array([[x.min(), x.max()] for _, x in axids]) - axrows = np.array([[y.min(), y.max()] for y, _ in axids]) - for idx in range(naxs): - num = idx + 1 - x0, x1 = axcols[idx, 0], axcols[idx, 1] - y0, y1 = axrows[idx, 0], axrows[idx, 1] - ss = gs[y0 : y1 + 1, x0 : x1 + 1] - kw = {**kwargs, **axes_kw[num], "number": num} - axs[idx] = self.add_subplot(ss, **kw) - self.format(skip_axes=True, **figure_kw) - return pgridspec.SubplotGrid(axs) + """Delegate to SubplotManager.""" + return self._subplots.add_subplots( + array=array, + nrows=nrows, + ncols=ncols, + order=order, + proj=proj, + projection=projection, + proj_kw=proj_kw, + projection_kw=projection_kw, + backend=backend, + basemap=basemap, + **kwargs, + ) def _align_axis_label(self, x): """ @@ -2403,7 +2082,7 @@ def _align_axis_label(self, x): seen = set() span = getattr(self, "_span" + x) align = getattr(self, "_align" + x) - for ax in self._subplot_dict.values(): + for ax in self._subplots.subplot_dict.values(): if isinstance(ax, paxes.CartesianAxes): ax._apply_axis_sharing() # always! else: @@ -2652,7 +2331,7 @@ def _align_super_labels(self, side, renderer): Adjust the position of super labels. """ # NOTE: Ensure title is offset only here. - for ax in self._subplot_dict.values(): + for ax in self._subplots.subplot_dict.values(): ax._apply_title_above() if side not in ("left", "right", "bottom", "top"): raise ValueError(f"Invalid side {side!r}.") @@ -3223,7 +2902,7 @@ def format( """ self._layout_dirty = True # Initiate context block - axs = axs or self._subplot_dict.values() + axs = axs or self._subplots.subplot_dict.values() skip_axes = kwargs.pop("skip_axes", False) # internal keyword arg rc_kw, rc_mode = _pop_rc(kwargs) with rc.context(rc_kw, mode=rc_mode): @@ -3912,7 +3591,7 @@ def _iter_axes(self, hidden=False, children=False, panels=True): raise ValueError(f"Invalid sides {panels!r}.") # Iterate axs = ( - *self._subplot_dict.values(), + *self._subplots.subplot_dict.values(), *(ax for side in panels for ax in self._panel_dict[side]), ) for ax in axs: @@ -3932,14 +3611,11 @@ def gridspec(self): ultraplot.gridspec.GridSpec.figure ultraplot.gridspec.SubplotGrid.gridspec """ - return self._gridspec + return self._subplots.gridspec @gridspec.setter def gridspec(self, gs): - if not isinstance(gs, pgridspec.GridSpec): - raise ValueError("Gridspec must be a ultraplot.GridSpec instance.") - self._gridspec = gs - gs.figure = self # trigger copying settings from the figure + self._subplots.gridspec = gs @property def subplotgrid(self): @@ -3952,7 +3628,7 @@ def subplotgrid(self): ultraplot.figure.Figure.gridspec ultraplot.gridspec.SubplotGrid.figure """ - return pgridspec.SubplotGrid([s for _, s in sorted(self._subplot_dict.items())]) + return self._subplots.subplotgrid @property def tight(self): diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 2f4761f23..3ec8943e9 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -1197,7 +1197,7 @@ def _auto_layout_aspect(self): fig = self.figure if not fig: return - ax = fig._subplot_dict.get(fig._refnum, None) + ax = fig._subplots.subplot_dict.get(fig._refnum, None) if ax is None: return @@ -1302,7 +1302,7 @@ def _update_figsize(self): fig = self.figure if fig is None: # drawing before subplots are added? return - ax = fig._subplot_dict.get(fig._refnum, None) + ax = fig._subplots.subplot_dict.get(fig._refnum, None) if ax is None: # drawing before subplots are added? return ss = ax.get_subplotspec().get_topmost_subplotspec() @@ -2086,7 +2086,7 @@ def format(self, **kwargs): ylabel = kwargs.get("ylabel", None) title = kwargs.get("title", None) axes = [ax for ax in self if ax is not None] - all_axes = set(self.figure._subplot_dict.values()) + all_axes = set(self.figure._subplots.subplot_dict.values()) is_subset = bool(axes) and all_axes and set(axes) != all_axes shared_subset_title = len(self) > 1 and is_subset and isinstance(title, str) shared_title_kw = ( From 4d4afb80c19f19084d0b45b68ec92ec69e334a84 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Mon, 20 Apr 2026 13:11:03 +1000 Subject: [PATCH 2/8] forgot to add this --- ultraplot/_subplots.py | 384 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 384 insertions(+) create mode 100644 ultraplot/_subplots.py diff --git a/ultraplot/_subplots.py b/ultraplot/_subplots.py new file mode 100644 index 000000000..4924057f5 --- /dev/null +++ b/ultraplot/_subplots.py @@ -0,0 +1,384 @@ +""" +Subplot creation and management for ultraplot figures. +""" + +import inspect +from numbers import Integral + +try: + from typing import Optional, Tuple, Union +except ImportError: + from typing_extensions import Optional, Tuple, Union + +import matplotlib.axes as maxes +import matplotlib.figure as mfigure +import matplotlib.gridspec as mgridspec +import matplotlib.projections as mproj +import numpy as np + +from . import axes as paxes +from . import constructor +from . import gridspec as pgridspec +from .internals import _not_none, _pop_params, warnings + + +class SubplotManager: + """ + Manages subplot creation, gridspec ownership, and projection parsing + for a Figure instance. + + Parameters + ---------- + figure : `~ultraplot.figure.Figure` + The parent figure. + """ + + def __init__(self, figure: "Figure"): + self.figure = figure + self.subplot_dict: dict = {} + self.counter: int = 0 + self._gridspec = None + + # --------------------------------------------------------------- + # Gridspec property + # --------------------------------------------------------------- + @property + def gridspec(self): + """The single GridSpec used for all subplots in the figure.""" + return self._gridspec + + @gridspec.setter + def gridspec(self, gs): + if not isinstance(gs, pgridspec.GridSpec): + raise ValueError("Gridspec must be a ultraplot.GridSpec instance.") + self._gridspec = gs + gs.figure = self.figure # gridspec.figure should reference the real Figure + + # --------------------------------------------------------------- + # Projection parsing + # --------------------------------------------------------------- + @staticmethod + def parse_backend(backend=None, basemap=None): + """ + Handle deprecation of basemap and cartopy package. + """ + if backend == "basemap": + warnings._warn_ultraplot( + f"{backend=} will be deprecated in next major release (v2.0). " + "See https://github.com/Ultraplot/ultraplot/pull/243" + ) + return backend + + def parse_proj( + self, + proj=None, + projection=None, + proj_kw=None, + projection_kw=None, + backend=None, + basemap=None, + **kwargs, + ): + """ + Translate user-input projection into a registered matplotlib axes class. + """ + proj = _not_none(proj=proj, projection=projection, default="cartesian") + proj_kw = _not_none(proj_kw=proj_kw, projection_kw=projection_kw, default={}) + backend = self.parse_backend(backend, basemap) + if isinstance(proj, str): + proj = proj.lower() + if isinstance(self.figure, paxes.Axes): + proj = self.figure._name + elif isinstance(self.figure, maxes.Axes): + raise ValueError("Matplotlib axes cannot be added to ultraplot figures.") + + name = None + + # Handle cartopy/basemap Projection objects directly + if not isinstance(proj, str): + if constructor.Projection is not object and isinstance( + proj, constructor.Projection + ): + name = "ultraplot_cartopy" + kwargs["map_projection"] = proj + elif constructor.Basemap is not object and isinstance( + proj, constructor.Basemap + ): + name = "ultraplot_basemap" + kwargs["map_projection"] = proj + + if name is None and isinstance(proj, str): + try: + mproj.get_projection_class("ultraplot_" + proj) + except (KeyError, ValueError): + pass + else: + name = "ultraplot_" + proj + if name is None and isinstance(proj, str): + if ( + constructor.Projection is not object + or constructor.Basemap is not object + ): + try: + proj_obj = constructor.Proj( + proj, backend=backend, include_axes=True, **proj_kw + ) + name = "ultraplot_" + proj_obj._proj_backend + kwargs["map_projection"] = proj_obj + except ValueError: + pass + + if name is None and proj in mproj.get_projection_names(): + name = proj + + if name is None and isinstance(proj, str): + raise ValueError( + f"Invalid projection name {proj!r}. If you are trying to generate a " + "GeoAxes with a cartopy.crs.Projection or mpl_toolkits.basemap.Basemap " + "then cartopy or basemap must be installed. Otherwise the known axes " + f"subclasses are:\n{paxes._cls_table}" + ) + + if name is not None: + kwargs["projection"] = name + return kwargs + + # --------------------------------------------------------------- + # Subplot creation + # --------------------------------------------------------------- + def add_subplot(self, *args, **kwargs): + """ + The driver function for adding single subplots. + """ + fig = self.figure + fig._layout_dirty = True + kwargs = self.parse_proj(**kwargs) + + args = args or (1, 1, 1) + gs = self.gridspec + + # Integer arg + if len(args) == 1 and isinstance(args[0], Integral): + if not 111 <= args[0] <= 999: + raise ValueError(f"Input {args[0]} must fall between 111 and 999.") + args = tuple(map(int, str(args[0]))) + + # Subplot spec + if len(args) == 1 and isinstance( + args[0], (maxes.SubplotBase, mgridspec.SubplotSpec) + ): + ss = args[0] + if isinstance(ss, maxes.SubplotBase): + ss = ss.get_subplotspec() + if gs is None: + gs = ss.get_topmost_subplotspec().get_gridspec() + if not isinstance(gs, pgridspec.GridSpec): + raise ValueError( + "Input subplotspec must be derived from a ultraplot.GridSpec." + ) + if ss.get_topmost_subplotspec().get_gridspec() is not gs: + raise ValueError( + "Input subplotspec must be derived from the active figure gridspec." + ) + + # Row and column spec + elif ( + len(args) == 3 + and all(isinstance(arg, Integral) for arg in args[:2]) + and all(isinstance(arg, Integral) for arg in np.atleast_1d(args[2])) + ): + nrows, ncols, num = args + i, j = np.resize(num, 2) + if gs is None: + gs = pgridspec.GridSpec(nrows, ncols) + orows, ocols = gs.get_geometry() + if orows % nrows: + raise ValueError( + f"The input number of rows {nrows} does not divide the " + f"figure gridspec number of rows {orows}." + ) + if ocols % ncols: + raise ValueError( + f"The input number of columns {ncols} does not divide the " + f"figure gridspec number of columns {ocols}." + ) + if any(_ < 1 or _ > nrows * ncols for _ in (i, j)): + raise ValueError( + "The input subplot indices must fall between " + f"1 and {nrows * ncols}. Instead got {i} and {j}." + ) + rowfact, colfact = orows // nrows, ocols // ncols + irow, icol = divmod(i - 1, ncols) + jrow, jcol = divmod(j - 1, ncols) + irow, icol = irow * rowfact, icol * colfact + jrow, jcol = (jrow + 1) * rowfact - 1, (jcol + 1) * colfact - 1 + ss = gs[irow : jrow + 1, icol : jcol + 1] + + else: + raise ValueError(f"Invalid add_subplot positional arguments {args!r}.") + + # Add the subplot + self.gridspec = gs # trigger layout adjustment + self.counter += 1 + kwargs.setdefault("label", f"subplot_{self.counter}") + kwargs.setdefault("number", 1 + max(self.subplot_dict, default=0)) + kwargs.pop("refwidth", None) # TODO: remove this + + # Use container approach for external projections + projection_name = kwargs.get("projection") + external_axes_class = None + external_axes_kwargs = {} + + if projection_name and isinstance(projection_name, str): + if not projection_name.startswith("ultraplot_"): + try: + proj_class = mproj.get_projection_class(projection_name) + if not issubclass(proj_class, paxes.Axes): + external_axes_class = proj_class + external_axes_kwargs["projection"] = projection_name + + from .axes.container import create_external_axes_container + + container_name = f"_ultraplot_container_{projection_name}" + if container_name not in mproj.get_projection_names(): + container_class = create_external_axes_container( + proj_class, projection_name=container_name + ) + mproj.register_projection(container_class) + + kwargs["projection"] = container_name + kwargs["external_axes_class"] = external_axes_class + kwargs["external_axes_kwargs"] = external_axes_kwargs + except (KeyError, ValueError): + pass + + kwargs.pop("_subplot_spec", None) + + # Call matplotlib's Figure.add_subplot directly + ax = mfigure.Figure.add_subplot(fig, ss, **kwargs) + if ax.number: + self.subplot_dict[ax.number] = ax + return ax + + def add_subplots( + self, + array=None, + nrows=1, + ncols=1, + order="C", + proj=None, + projection=None, + proj_kw=None, + projection_kw=None, + backend=None, + basemap=None, + **kwargs, + ): + """ + The driver function for adding multiple subplots. + """ + fig = self.figure + + def _axes_dict(naxs, input, kw=False, default=None): + if not kw: + if np.iterable(input) and not isinstance(input, (str, dict)): + input = {num + 1: item for num, item in enumerate(input)} + elif not isinstance(input, dict): + input = {range(1, naxs + 1): input} + else: + nested = [isinstance(_, dict) for _ in input.values()] + if not any(nested): + input = {range(1, naxs + 1): input.copy()} + elif not all(nested): + raise ValueError(f"Invalid input {input!r}.") + output = {} + for nums, item in input.items(): + nums = np.atleast_1d(nums) + for num in nums.flat: + output[num] = item.copy() if kw else item + for num in range(1, naxs + 1): + if num not in output: + output[num] = {} if kw else default + if output.keys() != set(range(1, naxs + 1)): + raise ValueError( + f"Have {naxs} axes, but {input!r} includes props for the axes: " + + ", ".join(map(repr, sorted(output))) + + "." + ) + return output + + # Build the subplot array + if order not in ("C", "F"): + raise ValueError(f"Invalid order={order!r}. Options are 'C' or 'F'.") + gs = None + if array is None or isinstance(array, mgridspec.GridSpec): + if array is not None: + gs, nrows, ncols = array, array.nrows, array.ncols + array = np.arange(1, nrows * ncols + 1)[..., None] + array = array.reshape((nrows, ncols), order=order) + else: + array = np.atleast_1d(array) + array[array == None] = 0 # noqa: E711 + array = array.astype(int) + if array.ndim == 1: + array = array[None, :] if order == "C" else array[:, None] + elif array.ndim != 2: + raise ValueError(f"Expected 1D or 2D array of integers. Got {array}.") + + # Parse input format, gridspec, and projection arguments + nums = np.unique(array[array != 0]) + naxs = len(nums) + if any(num < 0 or not isinstance(num, Integral) for num in nums.flat): + raise ValueError(f"Expected array of positive integers. Got {array}.") + proj = _not_none(projection=projection, proj=proj) + proj = _axes_dict(naxs, proj, kw=False, default="cartesian") + proj_kw = _not_none(projection_kw=projection_kw, proj_kw=proj_kw) or {} + proj_kw = _axes_dict(naxs, proj_kw, kw=True) + backend = self.parse_backend(backend, basemap) + backend = _axes_dict(naxs, backend, kw=False) + axes_kw = { + num: {"proj": proj[num], "proj_kw": proj_kw[num], "backend": backend[num]} + for num in proj + } + for key in ("gridspec_kw", "subplot_kw"): + kw = kwargs.pop(key, None) + if not kw: + continue + warnings._warn_ultraplot( + f"{key!r} is not necessary in ultraplot. Pass the " + "parameters as keyword arguments instead." + ) + kwargs.update(kw or {}) + figure_kw = _pop_params(kwargs, fig._format_signature) + gridspec_kw = _pop_params(kwargs, pgridspec.GridSpec._update_params) + + # Create or update the gridspec and add subplots with subplotspecs + if gs is None: + if "layout_array" not in gridspec_kw: + gridspec_kw = {**gridspec_kw, "layout_array": array} + gs = pgridspec.GridSpec(*array.shape, **gridspec_kw) + else: + gs.update(**gridspec_kw) + axs = naxs * [None] + axids = [np.where(array == i) for i in np.sort(np.unique(array)) if i > 0] + axcols = np.array([[x.min(), x.max()] for _, x in axids]) + axrows = np.array([[y.min(), y.max()] for y, _ in axids]) + for idx in range(naxs): + num = idx + 1 + x0, x1 = axcols[idx, 0], axcols[idx, 1] + y0, y1 = axrows[idx, 0], axrows[idx, 1] + ss = gs[y0 : y1 + 1, x0 : x1 + 1] + kw = {**kwargs, **axes_kw[num], "number": num} + axs[idx] = fig.add_subplot(ss, **kw) + fig.format(skip_axes=True, **figure_kw) + return pgridspec.SubplotGrid(axs) + + # --------------------------------------------------------------- + # Convenience accessors + # --------------------------------------------------------------- + @property + def subplotgrid(self): + """A SubplotGrid of numbered subplots sorted by number.""" + return pgridspec.SubplotGrid( + [s for _, s in sorted(self.subplot_dict.items())] + ) From 7fd420c67b658871bb621aef800768a50aacb1d2 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 22 Apr 2026 10:18:55 +1000 Subject: [PATCH 3/8] rm comment fancy headers --- ultraplot/_subplots.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/ultraplot/_subplots.py b/ultraplot/_subplots.py index 4924057f5..54f323537 100644 --- a/ultraplot/_subplots.py +++ b/ultraplot/_subplots.py @@ -54,9 +54,6 @@ def gridspec(self, gs): self._gridspec = gs gs.figure = self.figure # gridspec.figure should reference the real Figure - # --------------------------------------------------------------- - # Projection parsing - # --------------------------------------------------------------- @staticmethod def parse_backend(backend=None, basemap=None): """ @@ -143,9 +140,6 @@ def parse_proj( kwargs["projection"] = name return kwargs - # --------------------------------------------------------------- - # Subplot creation - # --------------------------------------------------------------- def add_subplot(self, *args, **kwargs): """ The driver function for adding single subplots. @@ -379,6 +373,4 @@ def _axes_dict(naxs, input, kw=False, default=None): @property def subplotgrid(self): """A SubplotGrid of numbered subplots sorted by number.""" - return pgridspec.SubplotGrid( - [s for _, s in sorted(self.subplot_dict.items())] - ) + return pgridspec.SubplotGrid([s for _, s in sorted(self.subplot_dict.items())]) From 8e05bdecfce610e19f84d827798252cb2fb933ce Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 22 Apr 2026 10:20:40 +1000 Subject: [PATCH 4/8] rm comment fancy headers part 2 --- ultraplot/_subplots.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/ultraplot/_subplots.py b/ultraplot/_subplots.py index 54f323537..fa9d39159 100644 --- a/ultraplot/_subplots.py +++ b/ultraplot/_subplots.py @@ -39,9 +39,6 @@ def __init__(self, figure: "Figure"): self.counter: int = 0 self._gridspec = None - # --------------------------------------------------------------- - # Gridspec property - # --------------------------------------------------------------- @property def gridspec(self): """The single GridSpec used for all subplots in the figure.""" @@ -367,9 +364,6 @@ def _axes_dict(naxs, input, kw=False, default=None): fig.format(skip_axes=True, **figure_kw) return pgridspec.SubplotGrid(axs) - # --------------------------------------------------------------- - # Convenience accessors - # --------------------------------------------------------------- @property def subplotgrid(self): """A SubplotGrid of numbered subplots sorted by number.""" From 47b2a31d13023a203c1d074356a5c2c6b187451e Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 22 Apr 2026 10:30:28 +1000 Subject: [PATCH 5/8] annotate the code --- ultraplot/_subplots.py | 43 +++++++++++++++++++++++++++++++----------- ultraplot/axes/base.py | 2 +- ultraplot/figure.py | 18 +++++++++++++----- ultraplot/gridspec.py | 6 +++--- 4 files changed, 49 insertions(+), 20 deletions(-) diff --git a/ultraplot/_subplots.py b/ultraplot/_subplots.py index fa9d39159..41e2ce4bc 100644 --- a/ultraplot/_subplots.py +++ b/ultraplot/_subplots.py @@ -76,6 +76,7 @@ def parse_proj( """ Translate user-input projection into a registered matplotlib axes class. """ + # Parse arguments proj = _not_none(proj=proj, projection=projection, default="cartesian") proj_kw = _not_none(proj_kw=proj_kw, projection_kw=projection_kw, default={}) backend = self.parse_backend(backend, basemap) @@ -86,9 +87,11 @@ def parse_proj( elif isinstance(self.figure, maxes.Axes): raise ValueError("Matplotlib axes cannot be added to ultraplot figures.") + # Search axes projections name = None # Handle cartopy/basemap Projection objects directly + # These should be converted to Ultraplot GeoAxes if not isinstance(proj, str): if constructor.Projection is not object and isinstance( proj, constructor.Projection @@ -109,6 +112,7 @@ def parse_proj( else: name = "ultraplot_" + proj if name is None and isinstance(proj, str): + # Try geographic projections first if cartopy/basemap available if ( constructor.Projection is not object or constructor.Basemap is not object @@ -120,8 +124,10 @@ def parse_proj( name = "ultraplot_" + proj_obj._proj_backend kwargs["map_projection"] = proj_obj except ValueError: - pass + pass # not a geographic projection, try matplotlib registry below + # If not geographic, check if registered globally in matplotlib + # (e.g., 'ternary', 'polar', '3d') if name is None and proj in mproj.get_projection_names(): name = proj @@ -199,7 +205,7 @@ def add_subplot(self, *args, **kwargs): f"1 and {nrows * ncols}. Instead got {i} and {j}." ) rowfact, colfact = orows // nrows, ocols // ncols - irow, icol = divmod(i - 1, ncols) + irow, icol = divmod(i - 1, ncols) # convert to zero-based jrow, jcol = divmod(j - 1, ncols) irow, icol = irow * rowfact, icol * colfact jrow, jcol = (jrow + 1) * rowfact - 1, (jcol + 1) * colfact - 1 @@ -209,13 +215,18 @@ def add_subplot(self, *args, **kwargs): raise ValueError(f"Invalid add_subplot positional arguments {args!r}.") # Add the subplot + # NOTE: Must assign unique label to each subplot or else subsequent calls + # to add_subplot() in mpl < 3.4 may return an already-drawn subplot in the + # wrong location due to gridspec override. self.gridspec = gs # trigger layout adjustment self.counter += 1 kwargs.setdefault("label", f"subplot_{self.counter}") kwargs.setdefault("number", 1 + max(self.subplot_dict, default=0)) kwargs.pop("refwidth", None) # TODO: remove this - # Use container approach for external projections + # Use container approach for external projections to make them + # ultraplot-compatible. Skip projections that start with "ultraplot_" + # as these are already Ultraplot axes classes. projection_name = kwargs.get("projection") external_axes_class = None external_axes_kwargs = {} @@ -245,7 +256,10 @@ def add_subplot(self, *args, **kwargs): kwargs.pop("_subplot_spec", None) - # Call matplotlib's Figure.add_subplot directly + # NOTE: We call mfigure.Figure.add_subplot directly (unbound) rather + # than fig.add_subplot because SubplotManager is not a Figure subclass + # and cannot use super(). This bypasses any Figure.add_subplot override, + # which is acceptable because Figure._add_subplot is the real entry point. ax = mfigure.Figure.add_subplot(fig, ss, **kwargs) if ax.number: self.subplot_dict[ax.number] = ax @@ -270,23 +284,27 @@ def add_subplots( """ fig = self.figure + # Helper to normalize per-axes arguments into {num: value} dicts. + # Accepts 'string', {1: 'string1', (2, 3): 'string2'}, or lists. def _axes_dict(naxs, input, kw=False, default=None): - if not kw: + if not kw: # 'string' or {1: 'string1', (2, 3): 'string2'} if np.iterable(input) and not isinstance(input, (str, dict)): input = {num + 1: item for num, item in enumerate(input)} elif not isinstance(input, dict): input = {range(1, naxs + 1): input} - else: + else: # {key: value} or {1: {key: value1}, (2, 3): {key: value2}} nested = [isinstance(_, dict) for _ in input.values()] - if not any(nested): + if not any(nested): # any([]) == False input = {range(1, naxs + 1): input.copy()} elif not all(nested): raise ValueError(f"Invalid input {input!r}.") + # Unfurl keys that contain multiple axes numbers output = {} for nums, item in input.items(): nums = np.atleast_1d(nums) for num in nums.flat: output[num] = item.copy() if kw else item + # Fill with default values for num in range(1, naxs + 1): if num not in output: output[num] = {} if kw else default @@ -299,7 +317,8 @@ def _axes_dict(naxs, input, kw=False, default=None): return output # Build the subplot array - if order not in ("C", "F"): + # NOTE: Currently this may ignore user-input nrows/ncols without warning + if order not in ("C", "F"): # better error message raise ValueError(f"Invalid order={order!r}. Options are 'C' or 'F'.") gs = None if array is None or isinstance(array, mgridspec.GridSpec): @@ -309,14 +328,15 @@ def _axes_dict(naxs, input, kw=False, default=None): array = array.reshape((nrows, ncols), order=order) else: array = np.atleast_1d(array) - array[array == None] = 0 # noqa: E711 + array[array == None] = 0 # None or 0 both valid placeholders # noqa: E711 array = array.astype(int) - if array.ndim == 1: + if array.ndim == 1: # interpret as single row or column array = array[None, :] if order == "C" else array[:, None] elif array.ndim != 2: raise ValueError(f"Expected 1D or 2D array of integers. Got {array}.") # Parse input format, gridspec, and projection arguments + # NOTE: Permit figure format keywords for e.g. 'collabels' (more intuitive) nums = np.unique(array[array != 0]) naxs = len(nums) if any(num < 0 or not isinstance(num, Integral) for num in nums.flat): @@ -344,13 +364,14 @@ def _axes_dict(naxs, input, kw=False, default=None): gridspec_kw = _pop_params(kwargs, pgridspec.GridSpec._update_params) # Create or update the gridspec and add subplots with subplotspecs + # NOTE: The gridspec is added to the figure when we pass the subplotspec if gs is None: if "layout_array" not in gridspec_kw: gridspec_kw = {**gridspec_kw, "layout_array": array} gs = pgridspec.GridSpec(*array.shape, **gridspec_kw) else: gs.update(**gridspec_kw) - axs = naxs * [None] + axs = naxs * [None] # list of axes axids = [np.where(array == i) for i in np.sort(np.unique(array)) if i > 0] axcols = np.array([[x.min(), x.max()] for _, x in axids]) axrows = np.array([[y.min(), y.max()] for y, _ in axids]) diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 613a9cf1a..0f41cbd81 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -1365,7 +1365,7 @@ def shared(paxs): child._sharey_setup(parent) # Global sharing, use the reference subplot where compatible - ref = self.figure._subplots.subplot_dict.get(self.figure._refnum, None) + ref = self.figure._get_subplot(self.figure._refnum) if self is not ref and ref is not None: if self.figure._sharex > 3: ok, reason = self.figure._share_axes_compatible(ref, self, "x") diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 8c63b4d9e..1be53df7d 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -1546,7 +1546,7 @@ def _get_align_axes(self, side): For 'left'/'right': select one extreme axis per row (leftmost/rightmost). For 'top'/'bottom': select one extreme axis per column (topmost/bottommost). """ - axs = tuple(self._subplots.subplot_dict.values()) + axs = tuple(self._iter_subplots()) if not axs: return [] if side not in ("left", "right", "top", "bottom"): @@ -2082,7 +2082,7 @@ def _align_axis_label(self, x): seen = set() span = getattr(self, "_span" + x) align = getattr(self, "_align" + x) - for ax in self._subplots.subplot_dict.values(): + for ax in self._iter_subplots(): if isinstance(ax, paxes.CartesianAxes): ax._apply_axis_sharing() # always! else: @@ -2331,7 +2331,7 @@ def _align_super_labels(self, side, renderer): Adjust the position of super labels. """ # NOTE: Ensure title is offset only here. - for ax in self._subplots.subplot_dict.values(): + for ax in self._iter_subplots(): ax._apply_title_above() if side not in ("left", "right", "bottom", "top"): raise ValueError(f"Invalid side {side!r}.") @@ -2902,7 +2902,7 @@ def format( """ self._layout_dirty = True # Initiate context block - axs = axs or self._subplots.subplot_dict.values() + axs = axs or self._iter_subplots() skip_axes = kwargs.pop("skip_axes", False) # internal keyword arg rc_kw, rc_mode = _pop_rc(kwargs) with rc.context(rc_kw, mode=rc_mode): @@ -3591,7 +3591,7 @@ def _iter_axes(self, hidden=False, children=False, panels=True): raise ValueError(f"Invalid sides {panels!r}.") # Iterate axs = ( - *self._subplots.subplot_dict.values(), + *self._iter_subplots(), *(ax for side in panels for ax in self._panel_dict[side]), ) for ax in axs: @@ -3617,6 +3617,14 @@ def gridspec(self): def gridspec(self, gs): self._subplots.gridspec = gs + def _get_subplot(self, number: int): + """Return the subplot with the given *number*, or ``None``.""" + return self._subplots.subplot_dict.get(number, None) + + def _iter_subplots(self): + """Iterate over all numbered subplots.""" + return self._subplots.subplot_dict.values() + @property def subplotgrid(self): """ diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 3ec8943e9..ca989c554 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -1197,7 +1197,7 @@ def _auto_layout_aspect(self): fig = self.figure if not fig: return - ax = fig._subplots.subplot_dict.get(fig._refnum, None) + ax = fig._get_subplot(fig._refnum) if ax is None: return @@ -1302,7 +1302,7 @@ def _update_figsize(self): fig = self.figure if fig is None: # drawing before subplots are added? return - ax = fig._subplots.subplot_dict.get(fig._refnum, None) + ax = fig._get_subplot(fig._refnum) if ax is None: # drawing before subplots are added? return ss = ax.get_subplotspec().get_topmost_subplotspec() @@ -2086,7 +2086,7 @@ def format(self, **kwargs): ylabel = kwargs.get("ylabel", None) title = kwargs.get("title", None) axes = [ax for ax in self if ax is not None] - all_axes = set(self.figure._subplots.subplot_dict.values()) + all_axes = set(self.figure._iter_subplots()) is_subset = bool(axes) and all_axes and set(axes) != all_axes shared_subset_title = len(self) > 1 and is_subset and isinstance(title, str) shared_title_kw = ( From 5b51fb9bbf0e496fa41650322cc05aeb1d7d26b9 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 22 Apr 2026 14:18:19 +1000 Subject: [PATCH 6/8] adding more tests --- ultraplot/tests/test_subplot_manager.py | 394 ++++++++++++++++++++++++ 1 file changed, 394 insertions(+) create mode 100644 ultraplot/tests/test_subplot_manager.py diff --git a/ultraplot/tests/test_subplot_manager.py b/ultraplot/tests/test_subplot_manager.py new file mode 100644 index 000000000..881855dbf --- /dev/null +++ b/ultraplot/tests/test_subplot_manager.py @@ -0,0 +1,394 @@ +""" +Tests for SubplotManager (ultraplot._subplots). +""" + +import numpy as np +import pytest + +import ultraplot as uplt +from ultraplot._subplots import SubplotManager +from ultraplot import gridspec as pgridspec + + +# --------------------------------------------------------------------------- +# gridspec property +# --------------------------------------------------------------------------- + + +def test_gridspec_setter_rejects_non_ultraplot(): + """Setting gridspec to a non-ultraplot GridSpec raises ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="ultraplot.GridSpec"): + fig.gridspec = "not a gridspec" + + +def test_gridspec_setter_accepts_ultraplot(): + """Setting gridspec to an ultraplot GridSpec works.""" + fig = uplt.figure() + gs = pgridspec.GridSpec(2, 2) + fig.gridspec = gs + assert fig.gridspec is gs + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# parse_backend +# --------------------------------------------------------------------------- + + +def test_parse_backend_basemap_warns(): + """parse_backend emits a deprecation warning for basemap.""" + with pytest.warns(match="basemap"): + SubplotManager.parse_backend(backend="basemap") + + +def test_parse_backend_passthrough(): + """parse_backend returns backend unchanged for non-basemap values.""" + assert SubplotManager.parse_backend(backend="cartopy") == "cartopy" + assert SubplotManager.parse_backend(backend=None) is None + + +# --------------------------------------------------------------------------- +# add_subplot — integer arg form +# --------------------------------------------------------------------------- + + +def test_add_subplot_integer_arg(): + """add_subplot(111) creates a single subplot.""" + fig = uplt.figure() + ax = fig.add_subplot(111) + assert ax is not None + assert ax.number == 1 + uplt.close(fig) + + +def test_add_subplot_integer_arg_invalid(): + """add_subplot with out-of-range integer raises ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="must fall between 111 and 999"): + fig.add_subplot(10) + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# add_subplot — row/col validation errors +# --------------------------------------------------------------------------- + + +def test_add_subplot_row_mismatch(): + """Rows that don't divide the gridspec raise ValueError.""" + fig = uplt.figure() + fig.gridspec = pgridspec.GridSpec(3, 3) + with pytest.raises(ValueError, match="does not divide"): + fig.add_subplot(2, 2, 1) + uplt.close(fig) + + +def test_add_subplot_col_mismatch(): + """Columns that don't divide the gridspec raise ValueError.""" + fig = uplt.figure() + fig.gridspec = pgridspec.GridSpec(4, 3) + with pytest.raises(ValueError, match="does not divide"): + fig.add_subplot(2, 2, 1) + uplt.close(fig) + + +def test_add_subplot_index_out_of_range(): + """Subplot index beyond nrows*ncols raises ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="must fall between"): + fig.add_subplot(2, 2, 5) + uplt.close(fig) + + +def test_add_subplot_invalid_args(): + """Unrecognized positional args raise ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="Invalid add_subplot"): + fig.add_subplot("bad", "args") + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# add_subplot — subplotspec validation +# --------------------------------------------------------------------------- + + +def test_add_subplot_non_ultraplot_gridspec_raises(): + """SubplotSpec from a plain matplotlib GridSpec is rejected.""" + import matplotlib.gridspec as mgridspec + + fig = uplt.figure() + mpl_gs = mgridspec.GridSpec(2, 2) + ss = mpl_gs[0, 0] + with pytest.raises(ValueError, match="ultraplot.GridSpec"): + fig.add_subplot(ss) + uplt.close(fig) + + +def test_add_subplot_wrong_gridspec_raises(): + """SubplotSpec from a different GridSpec than the figure's raises.""" + fig = uplt.figure() + gs1 = pgridspec.GridSpec(2, 2) + gs2 = pgridspec.GridSpec(3, 3) + fig.gridspec = gs1 + ss = gs2[0, 0] + with pytest.raises(ValueError, match="active figure gridspec"): + fig.add_subplot(ss) + uplt.close(fig) + + +def test_add_subplot_with_subplotspec(): + """add_subplot accepts an ultraplot SubplotSpec directly.""" + fig = uplt.figure() + gs = pgridspec.GridSpec(2, 2) + ss = gs[0, 0] + ax = fig.add_subplot(ss) + assert ax.number == 1 + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# add_subplot — external projection wrapping (e.g. '3d') +# --------------------------------------------------------------------------- + + +def test_add_subplot_3d_projection(): + """External projection like '3d' gets wrapped in a container.""" + fig = uplt.figure() + ax = fig.add_subplot(111, proj="3d") + assert ax is not None + assert hasattr(ax, "number") + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# add_subplot — projection parsing errors +# --------------------------------------------------------------------------- + + +def test_parse_proj_invalid_name(): + """Invalid projection name raises a helpful ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="Invalid projection"): + fig.add_subplot(111, proj="totally_nonexistent_proj_xyz") + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# add_subplots — order / array validation +# --------------------------------------------------------------------------- + + +def test_add_subplots_invalid_order(): + """Invalid order raises ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="Invalid order"): + fig.add_subplots(nrows=2, ncols=2, order="Z") + uplt.close(fig) + + +def test_add_subplots_fortran_order(): + """Fortran order ('F') arranges subplots column-major.""" + fig = uplt.figure() + axs = fig.add_subplots(nrows=2, ncols=2, order="F") + assert len(axs) == 4 + uplt.close(fig) + + +def test_add_subplots_1d_array_fortran(): + """1D array with order='F' creates a column layout.""" + fig = uplt.figure() + axs = fig.add_subplots(array=[1, 2, 3], order="F") + assert len(axs) == 3 + uplt.close(fig) + + +def test_add_subplots_3d_array_raises(): + """3D+ array raises ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="1D or 2D"): + fig.add_subplots(array=np.ones((2, 2, 2), dtype=int)) + uplt.close(fig) + + +def test_add_subplots_negative_indices_raises(): + """Negative indices in the array raise ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="positive integers"): + fig.add_subplots(array=[[-1, 1], [2, 3]]) + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# add_subplots — deprecated keyword warnings +# --------------------------------------------------------------------------- + + +def test_add_subplots_gridspec_kw_warns(): + """Passing gridspec_kw emits a deprecation warning.""" + fig = uplt.figure() + with pytest.warns(match="gridspec_kw"): + axs = fig.add_subplots(nrows=1, ncols=2, gridspec_kw={"wspace": 1}) + uplt.close(fig) + + +def test_add_subplots_subplot_kw_warns(): + """Passing subplot_kw emits a deprecation warning.""" + fig = uplt.figure() + with pytest.warns(match="subplot_kw"): + axs = fig.add_subplots(nrows=1, ncols=2, subplot_kw={"facecolor": "red"}) + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# add_subplots — per-axes projection dicts +# --------------------------------------------------------------------------- + + +def test_add_subplots_per_axes_proj(): + """Per-axes projection as a dict selects different projections.""" + fig = uplt.figure() + axs = fig.add_subplots( + nrows=1, + ncols=2, + proj={1: "cartesian", 2: "polar"}, + ) + assert len(axs) == 2 + uplt.close(fig) + + +def test_add_subplots_proj_as_list(): + """Per-axes projection as a list works.""" + fig = uplt.figure() + axs = fig.add_subplots( + nrows=1, + ncols=2, + proj=["cartesian", "polar"], + ) + assert len(axs) == 2 + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# add_subplots — array with None placeholders +# --------------------------------------------------------------------------- + + +def test_add_subplots_none_placeholder(): + """None in the array is treated as an empty slot.""" + fig = uplt.figure() + axs = fig.add_subplots(array=[[1, None], [2, 3]]) + assert len(axs) == 3 + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# subplotgrid property +# --------------------------------------------------------------------------- + + +def test_subplotgrid_sorted_by_number(): + """subplotgrid returns subplots sorted by number.""" + fig, axs = uplt.subplots(nrows=2, ncols=2) + grid = fig.subplotgrid + numbers = [ax.number for ax in grid] + assert numbers == sorted(numbers) + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# Figure accessors (_get_subplot, _iter_subplots) +# --------------------------------------------------------------------------- + + +def test_get_subplot_returns_correct_axes(): + """_get_subplot returns the axes with the given number.""" + fig, axs = uplt.subplots(nrows=1, ncols=3) + for i in range(1, 4): + ax = fig._get_subplot(i) + assert ax is not None + assert ax.number == i + uplt.close(fig) + + +def test_get_subplot_missing_returns_none(): + """_get_subplot returns None for a nonexistent number.""" + fig, axs = uplt.subplots(nrows=1, ncols=2) + assert fig._get_subplot(999) is None + uplt.close(fig) + + +def test_iter_subplots_yields_all(): + """_iter_subplots yields all numbered subplots.""" + fig, axs = uplt.subplots(nrows=2, ncols=2) + count = sum(1 for _ in fig._iter_subplots()) + assert count == 4 + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# parse_proj — matplotlib-registered projection (e.g. 'polar') +# --------------------------------------------------------------------------- + + +def test_parse_proj_polar(): + """'polar' is found via matplotlib's projection registry.""" + fig = uplt.figure() + ax = fig.add_subplot(111, proj="polar") + assert ax is not None + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# _axes_dict edge cases +# --------------------------------------------------------------------------- + + +def test_add_subplots_proj_kw_mixed_nested_raises(): + """Mixed nested/flat proj_kw dicts raise ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError): + fig.add_subplots( + nrows=1, + ncols=2, + proj_kw={1: {"key": "val"}, 2: "not_a_dict"}, + ) + uplt.close(fig) + + +def test_add_subplots_proj_dict_wrong_keys_raises(): + """proj dict with wrong axes numbers raises ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="Have 2 axes"): + fig.add_subplots( + nrows=1, + ncols=2, + proj={1: "cartesian", 2: "cartesian", 3: "polar"}, + ) + uplt.close(fig) + + +# --------------------------------------------------------------------------- +# External projection container wrapping +# --------------------------------------------------------------------------- + + +def test_add_subplot_external_projection_wrapped(): + """Non-ultraplot projections (e.g. 'hammer') get wrapped in a container.""" + fig = uplt.figure() + ax = fig.add_subplot(111, proj="hammer") + assert ax is not None + assert hasattr(ax, "number") + uplt.close(fig) + + +def test_add_subplot_external_projection_reuses_container(): + """Calling add_subplot twice with same external projection reuses container.""" + fig = uplt.figure() + gs = pgridspec.GridSpec(1, 2) + ax1 = fig.add_subplot(gs[0, 0], proj="mollweide") + ax2 = fig.add_subplot(gs[0, 1], proj="mollweide") + assert ax1 is not None + assert ax2 is not None + uplt.close(fig) From 92c5275d5a4ba7ecf04cc8432d4d5acda4887f9a Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 22 Apr 2026 14:19:31 +1000 Subject: [PATCH 7/8] formatting --- ultraplot/tests/test_subplot_manager.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ultraplot/tests/test_subplot_manager.py b/ultraplot/tests/test_subplot_manager.py index 881855dbf..eddfcd1e5 100644 --- a/ultraplot/tests/test_subplot_manager.py +++ b/ultraplot/tests/test_subplot_manager.py @@ -9,7 +9,6 @@ from ultraplot._subplots import SubplotManager from ultraplot import gridspec as pgridspec - # --------------------------------------------------------------------------- # gridspec property # --------------------------------------------------------------------------- From e81a7c8cb5fcae37b5e72cc5d7059d2345361897 Mon Sep 17 00:00:00 2001 From: cvanelteren Date: Wed, 22 Apr 2026 14:38:46 +1000 Subject: [PATCH 8/8] remove fancy headers --- ultraplot/tests/test_subplot_manager.py | 81 +------------------------ 1 file changed, 1 insertion(+), 80 deletions(-) diff --git a/ultraplot/tests/test_subplot_manager.py b/ultraplot/tests/test_subplot_manager.py index eddfcd1e5..0405c1244 100644 --- a/ultraplot/tests/test_subplot_manager.py +++ b/ultraplot/tests/test_subplot_manager.py @@ -6,12 +6,8 @@ import pytest import ultraplot as uplt -from ultraplot._subplots import SubplotManager from ultraplot import gridspec as pgridspec - -# --------------------------------------------------------------------------- -# gridspec property -# --------------------------------------------------------------------------- +from ultraplot._subplots import SubplotManager def test_gridspec_setter_rejects_non_ultraplot(): @@ -30,11 +26,6 @@ def test_gridspec_setter_accepts_ultraplot(): uplt.close(fig) -# --------------------------------------------------------------------------- -# parse_backend -# --------------------------------------------------------------------------- - - def test_parse_backend_basemap_warns(): """parse_backend emits a deprecation warning for basemap.""" with pytest.warns(match="basemap"): @@ -47,11 +38,6 @@ def test_parse_backend_passthrough(): assert SubplotManager.parse_backend(backend=None) is None -# --------------------------------------------------------------------------- -# add_subplot — integer arg form -# --------------------------------------------------------------------------- - - def test_add_subplot_integer_arg(): """add_subplot(111) creates a single subplot.""" fig = uplt.figure() @@ -69,11 +55,6 @@ def test_add_subplot_integer_arg_invalid(): uplt.close(fig) -# --------------------------------------------------------------------------- -# add_subplot — row/col validation errors -# --------------------------------------------------------------------------- - - def test_add_subplot_row_mismatch(): """Rows that don't divide the gridspec raise ValueError.""" fig = uplt.figure() @@ -108,11 +89,6 @@ def test_add_subplot_invalid_args(): uplt.close(fig) -# --------------------------------------------------------------------------- -# add_subplot — subplotspec validation -# --------------------------------------------------------------------------- - - def test_add_subplot_non_ultraplot_gridspec_raises(): """SubplotSpec from a plain matplotlib GridSpec is rejected.""" import matplotlib.gridspec as mgridspec @@ -147,11 +123,6 @@ def test_add_subplot_with_subplotspec(): uplt.close(fig) -# --------------------------------------------------------------------------- -# add_subplot — external projection wrapping (e.g. '3d') -# --------------------------------------------------------------------------- - - def test_add_subplot_3d_projection(): """External projection like '3d' gets wrapped in a container.""" fig = uplt.figure() @@ -161,11 +132,6 @@ def test_add_subplot_3d_projection(): uplt.close(fig) -# --------------------------------------------------------------------------- -# add_subplot — projection parsing errors -# --------------------------------------------------------------------------- - - def test_parse_proj_invalid_name(): """Invalid projection name raises a helpful ValueError.""" fig = uplt.figure() @@ -174,11 +140,6 @@ def test_parse_proj_invalid_name(): uplt.close(fig) -# --------------------------------------------------------------------------- -# add_subplots — order / array validation -# --------------------------------------------------------------------------- - - def test_add_subplots_invalid_order(): """Invalid order raises ValueError.""" fig = uplt.figure() @@ -219,11 +180,6 @@ def test_add_subplots_negative_indices_raises(): uplt.close(fig) -# --------------------------------------------------------------------------- -# add_subplots — deprecated keyword warnings -# --------------------------------------------------------------------------- - - def test_add_subplots_gridspec_kw_warns(): """Passing gridspec_kw emits a deprecation warning.""" fig = uplt.figure() @@ -240,11 +196,6 @@ def test_add_subplots_subplot_kw_warns(): uplt.close(fig) -# --------------------------------------------------------------------------- -# add_subplots — per-axes projection dicts -# --------------------------------------------------------------------------- - - def test_add_subplots_per_axes_proj(): """Per-axes projection as a dict selects different projections.""" fig = uplt.figure() @@ -269,11 +220,6 @@ def test_add_subplots_proj_as_list(): uplt.close(fig) -# --------------------------------------------------------------------------- -# add_subplots — array with None placeholders -# --------------------------------------------------------------------------- - - def test_add_subplots_none_placeholder(): """None in the array is treated as an empty slot.""" fig = uplt.figure() @@ -282,11 +228,6 @@ def test_add_subplots_none_placeholder(): uplt.close(fig) -# --------------------------------------------------------------------------- -# subplotgrid property -# --------------------------------------------------------------------------- - - def test_subplotgrid_sorted_by_number(): """subplotgrid returns subplots sorted by number.""" fig, axs = uplt.subplots(nrows=2, ncols=2) @@ -296,11 +237,6 @@ def test_subplotgrid_sorted_by_number(): uplt.close(fig) -# --------------------------------------------------------------------------- -# Figure accessors (_get_subplot, _iter_subplots) -# --------------------------------------------------------------------------- - - def test_get_subplot_returns_correct_axes(): """_get_subplot returns the axes with the given number.""" fig, axs = uplt.subplots(nrows=1, ncols=3) @@ -326,11 +262,6 @@ def test_iter_subplots_yields_all(): uplt.close(fig) -# --------------------------------------------------------------------------- -# parse_proj — matplotlib-registered projection (e.g. 'polar') -# --------------------------------------------------------------------------- - - def test_parse_proj_polar(): """'polar' is found via matplotlib's projection registry.""" fig = uplt.figure() @@ -339,11 +270,6 @@ def test_parse_proj_polar(): uplt.close(fig) -# --------------------------------------------------------------------------- -# _axes_dict edge cases -# --------------------------------------------------------------------------- - - def test_add_subplots_proj_kw_mixed_nested_raises(): """Mixed nested/flat proj_kw dicts raise ValueError.""" fig = uplt.figure() @@ -368,11 +294,6 @@ def test_add_subplots_proj_dict_wrong_keys_raises(): uplt.close(fig) -# --------------------------------------------------------------------------- -# External projection container wrapping -# --------------------------------------------------------------------------- - - def test_add_subplot_external_projection_wrapped(): """Non-ultraplot projections (e.g. 'hammer') get wrapped in a container.""" fig = uplt.figure()