diff --git a/ultraplot/_subplots.py b/ultraplot/_subplots.py new file mode 100644 index 000000000..41e2ce4bc --- /dev/null +++ b/ultraplot/_subplots.py @@ -0,0 +1,391 @@ +""" +Subplot creation and management for ultraplot figures. +""" + +import inspect +from numbers import Integral + +try: + from typing import Optional, Tuple, Union +except ImportError: + from typing_extensions import Optional, Tuple, Union + +import matplotlib.axes as maxes +import matplotlib.figure as mfigure +import matplotlib.gridspec as mgridspec +import matplotlib.projections as mproj +import numpy as np + +from . import axes as paxes +from . import constructor +from . import gridspec as pgridspec +from .internals import _not_none, _pop_params, warnings + + +class SubplotManager: + """ + Manages subplot creation, gridspec ownership, and projection parsing + for a Figure instance. + + Parameters + ---------- + figure : `~ultraplot.figure.Figure` + The parent figure. + """ + + def __init__(self, figure: "Figure"): + self.figure = figure + self.subplot_dict: dict = {} + self.counter: int = 0 + self._gridspec = None + + @property + def gridspec(self): + """The single GridSpec used for all subplots in the figure.""" + return self._gridspec + + @gridspec.setter + def gridspec(self, gs): + if not isinstance(gs, pgridspec.GridSpec): + raise ValueError("Gridspec must be a ultraplot.GridSpec instance.") + self._gridspec = gs + gs.figure = self.figure # gridspec.figure should reference the real Figure + + @staticmethod + def parse_backend(backend=None, basemap=None): + """ + Handle deprecation of basemap and cartopy package. + """ + if backend == "basemap": + warnings._warn_ultraplot( + f"{backend=} will be deprecated in next major release (v2.0). " + "See https://github.com/Ultraplot/ultraplot/pull/243" + ) + return backend + + def parse_proj( + self, + proj=None, + projection=None, + proj_kw=None, + projection_kw=None, + backend=None, + basemap=None, + **kwargs, + ): + """ + Translate user-input projection into a registered matplotlib axes class. + """ + # Parse arguments + proj = _not_none(proj=proj, projection=projection, default="cartesian") + proj_kw = _not_none(proj_kw=proj_kw, projection_kw=projection_kw, default={}) + backend = self.parse_backend(backend, basemap) + if isinstance(proj, str): + proj = proj.lower() + if isinstance(self.figure, paxes.Axes): + proj = self.figure._name + elif isinstance(self.figure, maxes.Axes): + raise ValueError("Matplotlib axes cannot be added to ultraplot figures.") + + # Search axes projections + name = None + + # Handle cartopy/basemap Projection objects directly + # These should be converted to Ultraplot GeoAxes + if not isinstance(proj, str): + if constructor.Projection is not object and isinstance( + proj, constructor.Projection + ): + name = "ultraplot_cartopy" + kwargs["map_projection"] = proj + elif constructor.Basemap is not object and isinstance( + proj, constructor.Basemap + ): + name = "ultraplot_basemap" + kwargs["map_projection"] = proj + + if name is None and isinstance(proj, str): + try: + mproj.get_projection_class("ultraplot_" + proj) + except (KeyError, ValueError): + pass + else: + name = "ultraplot_" + proj + if name is None and isinstance(proj, str): + # Try geographic projections first if cartopy/basemap available + if ( + constructor.Projection is not object + or constructor.Basemap is not object + ): + try: + proj_obj = constructor.Proj( + proj, backend=backend, include_axes=True, **proj_kw + ) + name = "ultraplot_" + proj_obj._proj_backend + kwargs["map_projection"] = proj_obj + except ValueError: + pass # not a geographic projection, try matplotlib registry below + + # If not geographic, check if registered globally in matplotlib + # (e.g., 'ternary', 'polar', '3d') + if name is None and proj in mproj.get_projection_names(): + name = proj + + if name is None and isinstance(proj, str): + raise ValueError( + f"Invalid projection name {proj!r}. If you are trying to generate a " + "GeoAxes with a cartopy.crs.Projection or mpl_toolkits.basemap.Basemap " + "then cartopy or basemap must be installed. Otherwise the known axes " + f"subclasses are:\n{paxes._cls_table}" + ) + + if name is not None: + kwargs["projection"] = name + return kwargs + + def add_subplot(self, *args, **kwargs): + """ + The driver function for adding single subplots. + """ + fig = self.figure + fig._layout_dirty = True + kwargs = self.parse_proj(**kwargs) + + args = args or (1, 1, 1) + gs = self.gridspec + + # Integer arg + if len(args) == 1 and isinstance(args[0], Integral): + if not 111 <= args[0] <= 999: + raise ValueError(f"Input {args[0]} must fall between 111 and 999.") + args = tuple(map(int, str(args[0]))) + + # Subplot spec + if len(args) == 1 and isinstance( + args[0], (maxes.SubplotBase, mgridspec.SubplotSpec) + ): + ss = args[0] + if isinstance(ss, maxes.SubplotBase): + ss = ss.get_subplotspec() + if gs is None: + gs = ss.get_topmost_subplotspec().get_gridspec() + if not isinstance(gs, pgridspec.GridSpec): + raise ValueError( + "Input subplotspec must be derived from a ultraplot.GridSpec." + ) + if ss.get_topmost_subplotspec().get_gridspec() is not gs: + raise ValueError( + "Input subplotspec must be derived from the active figure gridspec." + ) + + # Row and column spec + elif ( + len(args) == 3 + and all(isinstance(arg, Integral) for arg in args[:2]) + and all(isinstance(arg, Integral) for arg in np.atleast_1d(args[2])) + ): + nrows, ncols, num = args + i, j = np.resize(num, 2) + if gs is None: + gs = pgridspec.GridSpec(nrows, ncols) + orows, ocols = gs.get_geometry() + if orows % nrows: + raise ValueError( + f"The input number of rows {nrows} does not divide the " + f"figure gridspec number of rows {orows}." + ) + if ocols % ncols: + raise ValueError( + f"The input number of columns {ncols} does not divide the " + f"figure gridspec number of columns {ocols}." + ) + if any(_ < 1 or _ > nrows * ncols for _ in (i, j)): + raise ValueError( + "The input subplot indices must fall between " + f"1 and {nrows * ncols}. Instead got {i} and {j}." + ) + rowfact, colfact = orows // nrows, ocols // ncols + irow, icol = divmod(i - 1, ncols) # convert to zero-based + jrow, jcol = divmod(j - 1, ncols) + irow, icol = irow * rowfact, icol * colfact + jrow, jcol = (jrow + 1) * rowfact - 1, (jcol + 1) * colfact - 1 + ss = gs[irow : jrow + 1, icol : jcol + 1] + + else: + raise ValueError(f"Invalid add_subplot positional arguments {args!r}.") + + # Add the subplot + # NOTE: Must assign unique label to each subplot or else subsequent calls + # to add_subplot() in mpl < 3.4 may return an already-drawn subplot in the + # wrong location due to gridspec override. + self.gridspec = gs # trigger layout adjustment + self.counter += 1 + kwargs.setdefault("label", f"subplot_{self.counter}") + kwargs.setdefault("number", 1 + max(self.subplot_dict, default=0)) + kwargs.pop("refwidth", None) # TODO: remove this + + # Use container approach for external projections to make them + # ultraplot-compatible. Skip projections that start with "ultraplot_" + # as these are already Ultraplot axes classes. + projection_name = kwargs.get("projection") + external_axes_class = None + external_axes_kwargs = {} + + if projection_name and isinstance(projection_name, str): + if not projection_name.startswith("ultraplot_"): + try: + proj_class = mproj.get_projection_class(projection_name) + if not issubclass(proj_class, paxes.Axes): + external_axes_class = proj_class + external_axes_kwargs["projection"] = projection_name + + from .axes.container import create_external_axes_container + + container_name = f"_ultraplot_container_{projection_name}" + if container_name not in mproj.get_projection_names(): + container_class = create_external_axes_container( + proj_class, projection_name=container_name + ) + mproj.register_projection(container_class) + + kwargs["projection"] = container_name + kwargs["external_axes_class"] = external_axes_class + kwargs["external_axes_kwargs"] = external_axes_kwargs + except (KeyError, ValueError): + pass + + kwargs.pop("_subplot_spec", None) + + # NOTE: We call mfigure.Figure.add_subplot directly (unbound) rather + # than fig.add_subplot because SubplotManager is not a Figure subclass + # and cannot use super(). This bypasses any Figure.add_subplot override, + # which is acceptable because Figure._add_subplot is the real entry point. + ax = mfigure.Figure.add_subplot(fig, ss, **kwargs) + if ax.number: + self.subplot_dict[ax.number] = ax + return ax + + def add_subplots( + self, + array=None, + nrows=1, + ncols=1, + order="C", + proj=None, + projection=None, + proj_kw=None, + projection_kw=None, + backend=None, + basemap=None, + **kwargs, + ): + """ + The driver function for adding multiple subplots. + """ + fig = self.figure + + # Helper to normalize per-axes arguments into {num: value} dicts. + # Accepts 'string', {1: 'string1', (2, 3): 'string2'}, or lists. + def _axes_dict(naxs, input, kw=False, default=None): + if not kw: # 'string' or {1: 'string1', (2, 3): 'string2'} + if np.iterable(input) and not isinstance(input, (str, dict)): + input = {num + 1: item for num, item in enumerate(input)} + elif not isinstance(input, dict): + input = {range(1, naxs + 1): input} + else: # {key: value} or {1: {key: value1}, (2, 3): {key: value2}} + nested = [isinstance(_, dict) for _ in input.values()] + if not any(nested): # any([]) == False + input = {range(1, naxs + 1): input.copy()} + elif not all(nested): + raise ValueError(f"Invalid input {input!r}.") + # Unfurl keys that contain multiple axes numbers + output = {} + for nums, item in input.items(): + nums = np.atleast_1d(nums) + for num in nums.flat: + output[num] = item.copy() if kw else item + # Fill with default values + for num in range(1, naxs + 1): + if num not in output: + output[num] = {} if kw else default + if output.keys() != set(range(1, naxs + 1)): + raise ValueError( + f"Have {naxs} axes, but {input!r} includes props for the axes: " + + ", ".join(map(repr, sorted(output))) + + "." + ) + return output + + # Build the subplot array + # NOTE: Currently this may ignore user-input nrows/ncols without warning + if order not in ("C", "F"): # better error message + raise ValueError(f"Invalid order={order!r}. Options are 'C' or 'F'.") + gs = None + if array is None or isinstance(array, mgridspec.GridSpec): + if array is not None: + gs, nrows, ncols = array, array.nrows, array.ncols + array = np.arange(1, nrows * ncols + 1)[..., None] + array = array.reshape((nrows, ncols), order=order) + else: + array = np.atleast_1d(array) + array[array == None] = 0 # None or 0 both valid placeholders # noqa: E711 + array = array.astype(int) + if array.ndim == 1: # interpret as single row or column + array = array[None, :] if order == "C" else array[:, None] + elif array.ndim != 2: + raise ValueError(f"Expected 1D or 2D array of integers. Got {array}.") + + # Parse input format, gridspec, and projection arguments + # NOTE: Permit figure format keywords for e.g. 'collabels' (more intuitive) + nums = np.unique(array[array != 0]) + naxs = len(nums) + if any(num < 0 or not isinstance(num, Integral) for num in nums.flat): + raise ValueError(f"Expected array of positive integers. Got {array}.") + proj = _not_none(projection=projection, proj=proj) + proj = _axes_dict(naxs, proj, kw=False, default="cartesian") + proj_kw = _not_none(projection_kw=projection_kw, proj_kw=proj_kw) or {} + proj_kw = _axes_dict(naxs, proj_kw, kw=True) + backend = self.parse_backend(backend, basemap) + backend = _axes_dict(naxs, backend, kw=False) + axes_kw = { + num: {"proj": proj[num], "proj_kw": proj_kw[num], "backend": backend[num]} + for num in proj + } + for key in ("gridspec_kw", "subplot_kw"): + kw = kwargs.pop(key, None) + if not kw: + continue + warnings._warn_ultraplot( + f"{key!r} is not necessary in ultraplot. Pass the " + "parameters as keyword arguments instead." + ) + kwargs.update(kw or {}) + figure_kw = _pop_params(kwargs, fig._format_signature) + gridspec_kw = _pop_params(kwargs, pgridspec.GridSpec._update_params) + + # Create or update the gridspec and add subplots with subplotspecs + # NOTE: The gridspec is added to the figure when we pass the subplotspec + if gs is None: + if "layout_array" not in gridspec_kw: + gridspec_kw = {**gridspec_kw, "layout_array": array} + gs = pgridspec.GridSpec(*array.shape, **gridspec_kw) + else: + gs.update(**gridspec_kw) + axs = naxs * [None] # list of axes + axids = [np.where(array == i) for i in np.sort(np.unique(array)) if i > 0] + axcols = np.array([[x.min(), x.max()] for _, x in axids]) + axrows = np.array([[y.min(), y.max()] for y, _ in axids]) + for idx in range(naxs): + num = idx + 1 + x0, x1 = axcols[idx, 0], axcols[idx, 1] + y0, y1 = axrows[idx, 0], axrows[idx, 1] + ss = gs[y0 : y1 + 1, x0 : x1 + 1] + kw = {**kwargs, **axes_kw[num], "number": num} + axs[idx] = fig.add_subplot(ss, **kw) + fig.format(skip_axes=True, **figure_kw) + return pgridspec.SubplotGrid(axs) + + @property + def subplotgrid(self): + """A SubplotGrid of numbered subplots sorted by number.""" + return pgridspec.SubplotGrid([s for _, s in sorted(self.subplot_dict.items())]) diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index 29a962e14..0f41cbd81 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -1365,7 +1365,7 @@ def shared(paxs): child._sharey_setup(parent) # Global sharing, use the reference subplot where compatible - ref = self.figure._subplot_dict.get(self.figure._refnum, None) + ref = self.figure._get_subplot(self.figure._refnum) if self is not ref and ref is not None: if self.figure._sharex > 3: ok, reason = self.figure._share_axes_compatible(ref, self, "x") diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 334c61a44..1be53df7d 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -43,6 +43,7 @@ labels, warnings, ) +from ._subplots import SubplotManager from .utils import _Crawler, units __all__ = [ @@ -881,10 +882,8 @@ def _init_figure_state(self, figwidth, figheight, kwargs): Initialize internal state, call matplotlib's Figure.__init__, set up super labels, and apply initial formatting. """ - self._gridspec = None + self._subplots = SubplotManager(self) self._panel_dict = {"left": [], "right": [], "bottom": [], "top": []} - self._subplot_dict = {} - self._subplot_counter = 0 self._is_adjusting = False self._is_authorized = False self._layout_initialized = False @@ -1533,107 +1532,12 @@ def _context_authorized(self): @staticmethod def _parse_backend(backend=None, basemap=None): - """ - Handle deprecation of basemap and cartopy package. - """ - # Basemap is currently being developed again so are removing the deprecation warning - if backend == "basemap": - warnings._warn_ultraplot( - f"{backend=} will be deprecated in next major release (v2.0). See https://github.com/Ultraplot/ultraplot/pull/243" - ) - return backend - - def _parse_proj( - self, - proj=None, - projection=None, - proj_kw=None, - projection_kw=None, - backend=None, - basemap=None, - **kwargs, - ): - """ - Translate the user-input projection into a registered matplotlib - axes class. Input projection can be a string, `matplotlib.axes.Axes`, - `cartopy.crs.Projection`, or `mpl_toolkits.basemap.Basemap`. - """ - # Parse arguments - proj = _not_none(proj=proj, projection=projection, default="cartesian") - proj_kw = _not_none(proj_kw=proj_kw, projection_kw=projection_kw, default={}) - backend = self._parse_backend(backend, basemap) - if isinstance(proj, str): - proj = proj.lower() - if isinstance(self, paxes.Axes): - proj = self._name - elif isinstance(self, maxes.Axes): - raise ValueError("Matplotlib axes cannot be added to ultraplot figures.") - - # Search axes projections - name = None - - # Handle cartopy/basemap Projection objects directly - # These should be converted to Ultraplot GeoAxes - if not isinstance(proj, str): - # Check if it's a cartopy or basemap projection object - if constructor.Projection is not object and isinstance( - proj, constructor.Projection - ): - # It's a cartopy projection - use cartopy backend - name = "ultraplot_cartopy" - kwargs["map_projection"] = proj - elif constructor.Basemap is not object and isinstance( - proj, constructor.Basemap - ): - # It's a basemap projection - name = "ultraplot_basemap" - kwargs["map_projection"] = proj - # If not recognized, leave name as None and it will pass through - - if name is None and isinstance(proj, str): - try: - mproj.get_projection_class("ultraplot_" + proj) - except (KeyError, ValueError): - pass - else: - name = "ultraplot_" + proj - if name is None and isinstance(proj, str): - # Try geographic projections first if cartopy/basemap available - if ( - constructor.Projection is not object - or constructor.Basemap is not object - ): - try: - proj_obj = constructor.Proj( - proj, backend=backend, include_axes=True, **proj_kw - ) - name = "ultraplot_" + proj_obj._proj_backend - kwargs["map_projection"] = proj_obj - except ValueError: - # Not a geographic projection, will try matplotlib registry below - pass - - # If not geographic, check if registered globally in Matplotlib (e.g., 'ternary', 'polar', '3d') - if name is None and proj in mproj.get_projection_names(): - name = proj - - # Helpful error message if still not found - if name is None and isinstance(proj, str): - raise ValueError( - f"Invalid projection name {proj!r}. If you are trying to generate a " - "GeoAxes with a cartopy.crs.Projection or mpl_toolkits.basemap.Basemap " - "then cartopy or basemap must be installed. Otherwise the known axes " - f"subclasses are:\n{paxes._cls_table}" - ) + """Delegate to SubplotManager.""" + return SubplotManager.parse_backend(backend, basemap) - # Only set projection if we found a named projection - # Otherwise preserve the original projection (e.g., cartopy Projection objects) - if name is not None: - kwargs["projection"] = name - # If name is None and proj is not a string, it means we have a non-string - # projection (e.g., cartopy.crs.Projection object) that should be passed through - # The original projection kwarg is already in kwargs, so no action needed - return kwargs + def _parse_proj(self, *args, **kwargs): + """Delegate to SubplotManager.""" + return self._subplots.parse_proj(*args, **kwargs) def _get_align_axes(self, side): """ @@ -1642,7 +1546,7 @@ def _get_align_axes(self, side): For 'left'/'right': select one extreme axis per row (leftmost/rightmost). For 'top'/'bottom': select one extreme axis per column (topmost/bottommost). """ - axs = tuple(self._subplot_dict.values()) + axs = tuple(self._iter_subplots()) if not axs: return [] if side not in ("left", "right", "top", "bottom"): @@ -1650,7 +1554,7 @@ def _get_align_axes(self, side): from .utils import _get_subplot_layout grid = _get_subplot_layout( - self._gridspec, list(self._iter_axes(panels=False, hidden=False)) + self.gridspec, list(self._iter_axes(panels=False, hidden=False)) )[0] # From the @side we find the first non-zero # entry in each row or column and collect the axes @@ -2057,142 +1961,8 @@ def _add_figure_panel( @_clear_border_cache def _add_subplot(self, *args, **kwargs): - """ - The driver function for adding single subplots. - """ - self._layout_dirty = True - # Parse arguments - kwargs = self._parse_proj(**kwargs) - - args = args or (1, 1, 1) - gs = self.gridspec - - # Integer arg - if len(args) == 1 and isinstance(args[0], Integral): - if not 111 <= args[0] <= 999: - raise ValueError(f"Input {args[0]} must fall between 111 and 999.") - args = tuple(map(int, str(args[0]))) - - # Subplot spec - if len(args) == 1 and isinstance( - args[0], (maxes.SubplotBase, mgridspec.SubplotSpec) - ): - ss = args[0] - if isinstance(ss, maxes.SubplotBase): - ss = ss.get_subplotspec() - if gs is None: - gs = ss.get_topmost_subplotspec().get_gridspec() - if not isinstance(gs, pgridspec.GridSpec): - raise ValueError( - "Input subplotspec must be derived from a ultraplot.GridSpec." - ) - if ss.get_topmost_subplotspec().get_gridspec() is not gs: - raise ValueError( - "Input subplotspec must be derived from the active figure gridspec." - ) - - # Row and column spec - # TODO: How to pass spacing parameters to gridspec? Consider overriding - # subplots adjust? Or require using gridspec manually? - elif ( - len(args) == 3 - and all(isinstance(arg, Integral) for arg in args[:2]) - and all(isinstance(arg, Integral) for arg in np.atleast_1d(args[2])) - ): - nrows, ncols, num = args - i, j = np.resize(num, 2) - if gs is None: - gs = pgridspec.GridSpec(nrows, ncols) - orows, ocols = gs.get_geometry() - if orows % nrows: - raise ValueError( - f"The input number of rows {nrows} does not divide the " - f"figure gridspec number of rows {orows}." - ) - if ocols % ncols: - raise ValueError( - f"The input number of columns {ncols} does not divide the " - f"figure gridspec number of columns {ocols}." - ) - if any(_ < 1 or _ > nrows * ncols for _ in (i, j)): - raise ValueError( - "The input subplot indices must fall between " - f"1 and {nrows * ncols}. Instead got {i} and {j}." - ) - rowfact, colfact = orows // nrows, ocols // ncols - irow, icol = divmod(i - 1, ncols) # convert to zero-based - jrow, jcol = divmod(j - 1, ncols) - irow, icol = irow * rowfact, icol * colfact - jrow, jcol = (jrow + 1) * rowfact - 1, (jcol + 1) * colfact - 1 - ss = gs[irow : jrow + 1, icol : jcol + 1] - - # Otherwise - else: - raise ValueError(f"Invalid add_subplot positional arguments {args!r}.") - - # Add the subplot - # NOTE: Pass subplotspec as keyword arg for mpl >= 3.4 workaround - # NOTE: Must assign unique label to each subplot or else subsequent calls - # to add_subplot() in mpl < 3.4 may return an already-drawn subplot in the - # wrong location due to gridspec override. Is against OO package design. - self.gridspec = gs # trigger layout adjustment - self._subplot_counter += 1 # unique label for each subplot - kwargs.setdefault("label", f"subplot_{self._subplot_counter}") - kwargs.setdefault("number", 1 + max(self._subplot_dict, default=0)) - kwargs.pop("refwidth", None) # TODO: remove this - - # Use container approach for external projections to make them ultraplot-compatible - projection_name = kwargs.get("projection") - external_axes_class = None - external_axes_kwargs = {} - - if projection_name and isinstance(projection_name, str): - # Check if this is an external (non-ultraplot) projection - # Skip external wrapping for projections that start with "ultraplot_" prefix - # as these are already Ultraplot axes classes - if not projection_name.startswith("ultraplot_"): - try: - # Get the projection class - proj_class = mproj.get_projection_class(projection_name) - - # Check if it's not a built-in ultraplot axes - # Only wrap if it's NOT a subclass of Ultraplot's Axes - if not issubclass(proj_class, paxes.Axes): - # Store the external axes class and original projection name - external_axes_class = proj_class - external_axes_kwargs["projection"] = projection_name - - # Create or get the container class for this external axes type - from .axes.container import create_external_axes_container - - container_name = f"_ultraplot_container_{projection_name}" - - # Check if container is already registered - if container_name not in mproj.get_projection_names(): - container_class = create_external_axes_container( - proj_class, projection_name=container_name - ) - mproj.register_projection(container_class) - - # Use the container projection and pass external axes info - kwargs["projection"] = container_name - kwargs["external_axes_class"] = external_axes_class - kwargs["external_axes_kwargs"] = external_axes_kwargs - except (KeyError, ValueError): - # Projection not found, let matplotlib handle the error - pass - - # Remove _subplot_spec from kwargs if present to prevent it from being passed - # to .set() or other methods that don't accept it. - kwargs.pop("_subplot_spec", None) - - # Pass only the SubplotSpec as a positional argument - # Don't pass _subplot_spec as a keyword argument to avoid it being - # propagated to Axes.set() or other methods that don't accept it - ax = super().add_subplot(ss, **kwargs) - if ax.number: - self._subplot_dict[ax.number] = ax - return ax + """Delegate to SubplotManager.""" + return self._subplots.add_subplot(*args, **kwargs) def _unshare_axes(self): @@ -2287,111 +2057,20 @@ def _add_subplots( basemap=None, **kwargs, ): - """ - The driver function for adding multiple subplots. - """ - - # Clunky helper function - # TODO: Consider deprecating and asking users to use add_subplot() - def _axes_dict(naxs, input, kw=False, default=None): - # First build up dictionary - if not kw: # 'string' or {1: 'string1', (2, 3): 'string2'} - if np.iterable(input) and not isinstance(input, (str, dict)): - input = {num + 1: item for num, item in enumerate(input)} - elif not isinstance(input, dict): - input = {range(1, naxs + 1): input} - else: # {key: value} or {1: {key: value1}, (2, 3): {key: value2}} - nested = [isinstance(_, dict) for _ in input.values()] - if not any(nested): # any([]) == False - input = {range(1, naxs + 1): input.copy()} - elif not all(nested): - raise ValueError(f"Invalid input {input!r}.") - # Unfurl keys that contain multiple axes numbers - output = {} - for nums, item in input.items(): - nums = np.atleast_1d(nums) - for num in nums.flat: - output[num] = item.copy() if kw else item - # Fill with default values - for num in range(1, naxs + 1): - if num not in output: - output[num] = {} if kw else default - if output.keys() != set(range(1, naxs + 1)): - raise ValueError( - f"Have {naxs} axes, but {input!r} includes props for the axes: " - + ", ".join(map(repr, sorted(output))) - + "." - ) - return output - - # Build the subplot array - # NOTE: Currently this may ignore user-input nrows/ncols without warning - if order not in ("C", "F"): # better error message - raise ValueError(f"Invalid order={order!r}. Options are 'C' or 'F'.") - gs = None - if array is None or isinstance(array, mgridspec.GridSpec): - if array is not None: - gs, nrows, ncols = array, array.nrows, array.ncols - array = np.arange(1, nrows * ncols + 1)[..., None] - array = array.reshape((nrows, ncols), order=order) - else: - array = np.atleast_1d(array) - array[array == None] = 0 # None or 0 both valid placeholders # noqa: E711 - array = array.astype(int) - if array.ndim == 1: # interpret as single row or column - array = array[None, :] if order == "C" else array[:, None] - elif array.ndim != 2: - raise ValueError(f"Expected 1D or 2D array of integers. Got {array}.") - - # Parse input format, gridspec, and projection arguments - # NOTE: Permit figure format keywords for e.g. 'collabels' (more intuitive) - nums = np.unique(array[array != 0]) - naxs = len(nums) - if any(num < 0 or not isinstance(num, Integral) for num in nums.flat): - raise ValueError(f"Expected array of positive integers. Got {array}.") - proj = _not_none(projection=projection, proj=proj) - proj = _axes_dict(naxs, proj, kw=False, default="cartesian") - proj_kw = _not_none(projection_kw=projection_kw, proj_kw=proj_kw) or {} - proj_kw = _axes_dict(naxs, proj_kw, kw=True) - backend = self._parse_backend(backend, basemap) - backend = _axes_dict(naxs, backend, kw=False) - axes_kw = { - num: {"proj": proj[num], "proj_kw": proj_kw[num], "backend": backend[num]} - for num in proj - } - for key in ("gridspec_kw", "subplot_kw"): - kw = kwargs.pop(key, None) - if not kw: - continue - warnings._warn_ultraplot( - f"{key!r} is not necessary in ultraplot. Pass the " - "parameters as keyword arguments instead." - ) - kwargs.update(kw or {}) - figure_kw = _pop_params(kwargs, self._format_signature) - gridspec_kw = _pop_params(kwargs, pgridspec.GridSpec._update_params) - - # Create or update the gridspec and add subplots with subplotspecs - # NOTE: The gridspec is added to the figure when we pass the subplotspec - if gs is None: - if "layout_array" not in gridspec_kw: - gridspec_kw = {**gridspec_kw, "layout_array": array} - gs = pgridspec.GridSpec(*array.shape, **gridspec_kw) - else: - gs.update(**gridspec_kw) - axs = naxs * [None] # list of axes - axids = [np.where(array == i) for i in np.sort(np.unique(array)) if i > 0] - axcols = np.array([[x.min(), x.max()] for _, x in axids]) - axrows = np.array([[y.min(), y.max()] for y, _ in axids]) - for idx in range(naxs): - num = idx + 1 - x0, x1 = axcols[idx, 0], axcols[idx, 1] - y0, y1 = axrows[idx, 0], axrows[idx, 1] - ss = gs[y0 : y1 + 1, x0 : x1 + 1] - kw = {**kwargs, **axes_kw[num], "number": num} - axs[idx] = self.add_subplot(ss, **kw) - self.format(skip_axes=True, **figure_kw) - return pgridspec.SubplotGrid(axs) + """Delegate to SubplotManager.""" + return self._subplots.add_subplots( + array=array, + nrows=nrows, + ncols=ncols, + order=order, + proj=proj, + projection=projection, + proj_kw=proj_kw, + projection_kw=projection_kw, + backend=backend, + basemap=basemap, + **kwargs, + ) def _align_axis_label(self, x): """ @@ -2403,7 +2082,7 @@ def _align_axis_label(self, x): seen = set() span = getattr(self, "_span" + x) align = getattr(self, "_align" + x) - for ax in self._subplot_dict.values(): + for ax in self._iter_subplots(): if isinstance(ax, paxes.CartesianAxes): ax._apply_axis_sharing() # always! else: @@ -2652,7 +2331,7 @@ def _align_super_labels(self, side, renderer): Adjust the position of super labels. """ # NOTE: Ensure title is offset only here. - for ax in self._subplot_dict.values(): + for ax in self._iter_subplots(): ax._apply_title_above() if side not in ("left", "right", "bottom", "top"): raise ValueError(f"Invalid side {side!r}.") @@ -3223,7 +2902,7 @@ def format( """ self._layout_dirty = True # Initiate context block - axs = axs or self._subplot_dict.values() + axs = axs or self._iter_subplots() skip_axes = kwargs.pop("skip_axes", False) # internal keyword arg rc_kw, rc_mode = _pop_rc(kwargs) with rc.context(rc_kw, mode=rc_mode): @@ -3912,7 +3591,7 @@ def _iter_axes(self, hidden=False, children=False, panels=True): raise ValueError(f"Invalid sides {panels!r}.") # Iterate axs = ( - *self._subplot_dict.values(), + *self._iter_subplots(), *(ax for side in panels for ax in self._panel_dict[side]), ) for ax in axs: @@ -3932,14 +3611,19 @@ def gridspec(self): ultraplot.gridspec.GridSpec.figure ultraplot.gridspec.SubplotGrid.gridspec """ - return self._gridspec + return self._subplots.gridspec @gridspec.setter def gridspec(self, gs): - if not isinstance(gs, pgridspec.GridSpec): - raise ValueError("Gridspec must be a ultraplot.GridSpec instance.") - self._gridspec = gs - gs.figure = self # trigger copying settings from the figure + self._subplots.gridspec = gs + + def _get_subplot(self, number: int): + """Return the subplot with the given *number*, or ``None``.""" + return self._subplots.subplot_dict.get(number, None) + + def _iter_subplots(self): + """Iterate over all numbered subplots.""" + return self._subplots.subplot_dict.values() @property def subplotgrid(self): @@ -3952,7 +3636,7 @@ def subplotgrid(self): ultraplot.figure.Figure.gridspec ultraplot.gridspec.SubplotGrid.figure """ - return pgridspec.SubplotGrid([s for _, s in sorted(self._subplot_dict.items())]) + return self._subplots.subplotgrid @property def tight(self): diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 2f4761f23..ca989c554 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -1197,7 +1197,7 @@ def _auto_layout_aspect(self): fig = self.figure if not fig: return - ax = fig._subplot_dict.get(fig._refnum, None) + ax = fig._get_subplot(fig._refnum) if ax is None: return @@ -1302,7 +1302,7 @@ def _update_figsize(self): fig = self.figure if fig is None: # drawing before subplots are added? return - ax = fig._subplot_dict.get(fig._refnum, None) + ax = fig._get_subplot(fig._refnum) if ax is None: # drawing before subplots are added? return ss = ax.get_subplotspec().get_topmost_subplotspec() @@ -2086,7 +2086,7 @@ def format(self, **kwargs): ylabel = kwargs.get("ylabel", None) title = kwargs.get("title", None) axes = [ax for ax in self if ax is not None] - all_axes = set(self.figure._subplot_dict.values()) + all_axes = set(self.figure._iter_subplots()) is_subset = bool(axes) and all_axes and set(axes) != all_axes shared_subset_title = len(self) > 1 and is_subset and isinstance(title, str) shared_title_kw = ( diff --git a/ultraplot/tests/test_subplot_manager.py b/ultraplot/tests/test_subplot_manager.py new file mode 100644 index 000000000..0405c1244 --- /dev/null +++ b/ultraplot/tests/test_subplot_manager.py @@ -0,0 +1,314 @@ +""" +Tests for SubplotManager (ultraplot._subplots). +""" + +import numpy as np +import pytest + +import ultraplot as uplt +from ultraplot import gridspec as pgridspec +from ultraplot._subplots import SubplotManager + + +def test_gridspec_setter_rejects_non_ultraplot(): + """Setting gridspec to a non-ultraplot GridSpec raises ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="ultraplot.GridSpec"): + fig.gridspec = "not a gridspec" + + +def test_gridspec_setter_accepts_ultraplot(): + """Setting gridspec to an ultraplot GridSpec works.""" + fig = uplt.figure() + gs = pgridspec.GridSpec(2, 2) + fig.gridspec = gs + assert fig.gridspec is gs + uplt.close(fig) + + +def test_parse_backend_basemap_warns(): + """parse_backend emits a deprecation warning for basemap.""" + with pytest.warns(match="basemap"): + SubplotManager.parse_backend(backend="basemap") + + +def test_parse_backend_passthrough(): + """parse_backend returns backend unchanged for non-basemap values.""" + assert SubplotManager.parse_backend(backend="cartopy") == "cartopy" + assert SubplotManager.parse_backend(backend=None) is None + + +def test_add_subplot_integer_arg(): + """add_subplot(111) creates a single subplot.""" + fig = uplt.figure() + ax = fig.add_subplot(111) + assert ax is not None + assert ax.number == 1 + uplt.close(fig) + + +def test_add_subplot_integer_arg_invalid(): + """add_subplot with out-of-range integer raises ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="must fall between 111 and 999"): + fig.add_subplot(10) + uplt.close(fig) + + +def test_add_subplot_row_mismatch(): + """Rows that don't divide the gridspec raise ValueError.""" + fig = uplt.figure() + fig.gridspec = pgridspec.GridSpec(3, 3) + with pytest.raises(ValueError, match="does not divide"): + fig.add_subplot(2, 2, 1) + uplt.close(fig) + + +def test_add_subplot_col_mismatch(): + """Columns that don't divide the gridspec raise ValueError.""" + fig = uplt.figure() + fig.gridspec = pgridspec.GridSpec(4, 3) + with pytest.raises(ValueError, match="does not divide"): + fig.add_subplot(2, 2, 1) + uplt.close(fig) + + +def test_add_subplot_index_out_of_range(): + """Subplot index beyond nrows*ncols raises ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="must fall between"): + fig.add_subplot(2, 2, 5) + uplt.close(fig) + + +def test_add_subplot_invalid_args(): + """Unrecognized positional args raise ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="Invalid add_subplot"): + fig.add_subplot("bad", "args") + uplt.close(fig) + + +def test_add_subplot_non_ultraplot_gridspec_raises(): + """SubplotSpec from a plain matplotlib GridSpec is rejected.""" + import matplotlib.gridspec as mgridspec + + fig = uplt.figure() + mpl_gs = mgridspec.GridSpec(2, 2) + ss = mpl_gs[0, 0] + with pytest.raises(ValueError, match="ultraplot.GridSpec"): + fig.add_subplot(ss) + uplt.close(fig) + + +def test_add_subplot_wrong_gridspec_raises(): + """SubplotSpec from a different GridSpec than the figure's raises.""" + fig = uplt.figure() + gs1 = pgridspec.GridSpec(2, 2) + gs2 = pgridspec.GridSpec(3, 3) + fig.gridspec = gs1 + ss = gs2[0, 0] + with pytest.raises(ValueError, match="active figure gridspec"): + fig.add_subplot(ss) + uplt.close(fig) + + +def test_add_subplot_with_subplotspec(): + """add_subplot accepts an ultraplot SubplotSpec directly.""" + fig = uplt.figure() + gs = pgridspec.GridSpec(2, 2) + ss = gs[0, 0] + ax = fig.add_subplot(ss) + assert ax.number == 1 + uplt.close(fig) + + +def test_add_subplot_3d_projection(): + """External projection like '3d' gets wrapped in a container.""" + fig = uplt.figure() + ax = fig.add_subplot(111, proj="3d") + assert ax is not None + assert hasattr(ax, "number") + uplt.close(fig) + + +def test_parse_proj_invalid_name(): + """Invalid projection name raises a helpful ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="Invalid projection"): + fig.add_subplot(111, proj="totally_nonexistent_proj_xyz") + uplt.close(fig) + + +def test_add_subplots_invalid_order(): + """Invalid order raises ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="Invalid order"): + fig.add_subplots(nrows=2, ncols=2, order="Z") + uplt.close(fig) + + +def test_add_subplots_fortran_order(): + """Fortran order ('F') arranges subplots column-major.""" + fig = uplt.figure() + axs = fig.add_subplots(nrows=2, ncols=2, order="F") + assert len(axs) == 4 + uplt.close(fig) + + +def test_add_subplots_1d_array_fortran(): + """1D array with order='F' creates a column layout.""" + fig = uplt.figure() + axs = fig.add_subplots(array=[1, 2, 3], order="F") + assert len(axs) == 3 + uplt.close(fig) + + +def test_add_subplots_3d_array_raises(): + """3D+ array raises ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="1D or 2D"): + fig.add_subplots(array=np.ones((2, 2, 2), dtype=int)) + uplt.close(fig) + + +def test_add_subplots_negative_indices_raises(): + """Negative indices in the array raise ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="positive integers"): + fig.add_subplots(array=[[-1, 1], [2, 3]]) + uplt.close(fig) + + +def test_add_subplots_gridspec_kw_warns(): + """Passing gridspec_kw emits a deprecation warning.""" + fig = uplt.figure() + with pytest.warns(match="gridspec_kw"): + axs = fig.add_subplots(nrows=1, ncols=2, gridspec_kw={"wspace": 1}) + uplt.close(fig) + + +def test_add_subplots_subplot_kw_warns(): + """Passing subplot_kw emits a deprecation warning.""" + fig = uplt.figure() + with pytest.warns(match="subplot_kw"): + axs = fig.add_subplots(nrows=1, ncols=2, subplot_kw={"facecolor": "red"}) + uplt.close(fig) + + +def test_add_subplots_per_axes_proj(): + """Per-axes projection as a dict selects different projections.""" + fig = uplt.figure() + axs = fig.add_subplots( + nrows=1, + ncols=2, + proj={1: "cartesian", 2: "polar"}, + ) + assert len(axs) == 2 + uplt.close(fig) + + +def test_add_subplots_proj_as_list(): + """Per-axes projection as a list works.""" + fig = uplt.figure() + axs = fig.add_subplots( + nrows=1, + ncols=2, + proj=["cartesian", "polar"], + ) + assert len(axs) == 2 + uplt.close(fig) + + +def test_add_subplots_none_placeholder(): + """None in the array is treated as an empty slot.""" + fig = uplt.figure() + axs = fig.add_subplots(array=[[1, None], [2, 3]]) + assert len(axs) == 3 + uplt.close(fig) + + +def test_subplotgrid_sorted_by_number(): + """subplotgrid returns subplots sorted by number.""" + fig, axs = uplt.subplots(nrows=2, ncols=2) + grid = fig.subplotgrid + numbers = [ax.number for ax in grid] + assert numbers == sorted(numbers) + uplt.close(fig) + + +def test_get_subplot_returns_correct_axes(): + """_get_subplot returns the axes with the given number.""" + fig, axs = uplt.subplots(nrows=1, ncols=3) + for i in range(1, 4): + ax = fig._get_subplot(i) + assert ax is not None + assert ax.number == i + uplt.close(fig) + + +def test_get_subplot_missing_returns_none(): + """_get_subplot returns None for a nonexistent number.""" + fig, axs = uplt.subplots(nrows=1, ncols=2) + assert fig._get_subplot(999) is None + uplt.close(fig) + + +def test_iter_subplots_yields_all(): + """_iter_subplots yields all numbered subplots.""" + fig, axs = uplt.subplots(nrows=2, ncols=2) + count = sum(1 for _ in fig._iter_subplots()) + assert count == 4 + uplt.close(fig) + + +def test_parse_proj_polar(): + """'polar' is found via matplotlib's projection registry.""" + fig = uplt.figure() + ax = fig.add_subplot(111, proj="polar") + assert ax is not None + uplt.close(fig) + + +def test_add_subplots_proj_kw_mixed_nested_raises(): + """Mixed nested/flat proj_kw dicts raise ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError): + fig.add_subplots( + nrows=1, + ncols=2, + proj_kw={1: {"key": "val"}, 2: "not_a_dict"}, + ) + uplt.close(fig) + + +def test_add_subplots_proj_dict_wrong_keys_raises(): + """proj dict with wrong axes numbers raises ValueError.""" + fig = uplt.figure() + with pytest.raises(ValueError, match="Have 2 axes"): + fig.add_subplots( + nrows=1, + ncols=2, + proj={1: "cartesian", 2: "cartesian", 3: "polar"}, + ) + uplt.close(fig) + + +def test_add_subplot_external_projection_wrapped(): + """Non-ultraplot projections (e.g. 'hammer') get wrapped in a container.""" + fig = uplt.figure() + ax = fig.add_subplot(111, proj="hammer") + assert ax is not None + assert hasattr(ax, "number") + uplt.close(fig) + + +def test_add_subplot_external_projection_reuses_container(): + """Calling add_subplot twice with same external projection reuses container.""" + fig = uplt.figure() + gs = pgridspec.GridSpec(1, 2) + ax1 = fig.add_subplot(gs[0, 0], proj="mollweide") + ax2 = fig.add_subplot(gs[0, 1], proj="mollweide") + assert ax1 is not None + assert ax2 is not None + uplt.close(fig)