diff --git a/deeptrack/aberrations.py b/deeptrack/aberrations.py index 22905c99..89de890d 100644 --- a/deeptrack/aberrations.py +++ b/deeptrack/aberrations.py @@ -63,7 +63,7 @@ >>> wavelength=530e-9, >>> output_region=(0, 0, 64, 48), >>> padding=(64, 64, 64, 64), ->>> aberration=aberrations.GaussianApodization(sigma=0.9), +>>> pupil=dt.GaussianApodization(sigma=0.9), >>> z = -1.0 * dt.units.micrometer, >>> ) >>> aberrated_particle = aberrated_optics(particle) @@ -71,22 +71,26 @@ """ -#TODO ***??*** revise class docstring -#TODO ***??*** revise DTAT325 from __future__ import annotations import math -from typing import Any +from typing import Any, TYPE_CHECKING import numpy as np +from deeptrack.backend import TORCH_AVAILABLE, xp from deeptrack.features import Feature from deeptrack.types import PropertyLike from deeptrack.utils import as_list +if TORCH_AVAILABLE: + import torch + +if TYPE_CHECKING: + import torch + -#TODO ***??*** revise Aberration - torch, docstring, unit test class Aberration(Feature): """Base class for optical aberrations. @@ -108,19 +112,20 @@ class Aberration(Feature): Methods ------- - `_process_and_get(image_list: list[np.ndarray], **kwargs: dict) -> list[np.ndarray]` + `_process_and_get(image_list, **kwargs) -> list[np.ndarray | torch.Tensor]` Processes a list of input images to compute pupil coordinates (rho and theta) and passes them, along with the original images, to the superclass method for further processing. """ + __distributed__: bool = True def _process_and_get( - self: Feature, - image_list: list[np.ndarray], - **kwargs: dict[str, np.ndarray] - ) -> list[np.ndarray]: + self: Aberration, + image_list: list[np.ndarray | torch.Tensor], + **kwargs: Any, + ) -> list[np.ndarray | torch.Tensor]: """Computes pupil coordinates. Computes pupil coordinates (rho and theta) for each input image and @@ -128,27 +133,29 @@ def _process_and_get( Parameters ---------- - image_list: list[np.ndarray] + image_list: list[np.ndarray | torch.Tensor] A list of 2D input images to be processed. - **kwargs: dict[str, np.ndarray] + **kwargs: Any Additional parameters to be passed to the superclass's `_process_and_get` method. Returns ------- - list: list[np.ndarray] + list[np.ndarray | torch.Tensor] A list of processed images with added pupil coordinates. """ new_list = [] for image in image_list: - x = np.arange(image.shape[0]) - image.shape[0] / 2 - y = np.arange(image.shape[1]) - image.shape[1] / 2 - X, Y = np.meshgrid(y, x) - rho = np.sqrt(X ** 2 + Y ** 2) - rho /= np.max(rho[image != 0]) - theta = np.arctan2(Y, X) + x = xp.arange(image.shape[0]) - image.shape[0] / 2 + y = xp.arange(image.shape[1]) - image.shape[1] / 2 + X, Y = xp.meshgrid(y, x) + rho = xp.sqrt(X ** 2 + Y ** 2) + mask = image != 0 + if bool(xp.any(mask)): + rho /= xp.max(rho[mask]) + theta = xp.arctan2(Y, X) new_list += super()._process_and_get( [image], rho=rho, theta=theta, **kwargs @@ -156,7 +163,6 @@ def _process_and_get( return new_list -#TODO ***??*** revise GaussianApodization - torch, docstring, unit test class GaussianApodization(Aberration): """Introduces pupil apodization. @@ -177,7 +183,7 @@ class GaussianApodization(Aberration): Methods ------- - `get(pupil: np.ndarray, offset: tuple[float, float], sigma: float, rho: np.ndarray, **kwargs: dict[str, Any]) -> np.ndarray` + `get(pupil, offset, sigma, rho, **kwargs) -> np.ndarray | torch.Tensor` Applies Gaussian apodization to the input pupil function. Examples @@ -198,8 +204,8 @@ class GaussianApodization(Aberration): def __init__( self: GaussianApodization, sigma: PropertyLike[float] = 1, - offset: PropertyLike[tuple[int, int]] = (0, 0), - **kwargs: dict[str, Any] + offset: PropertyLike[tuple[float, float]] = (0, 0), + **kwargs: Any, ) -> None: """Initializes the GaussianApodization class. @@ -212,9 +218,8 @@ def __init__( The standard deviation of the Gaussian apodization. A smaller value results in more rapid attenuation at the edges. Default is 1. offset: tuple of float, optional - The (x, y) coordinates of the Gaussian center's offset relative - to the geometric center of the pupil. Default is (0, 0). - **kwargs: dict, optional + Offset of the Gaussian center relative to the pupil center. + **kwargs: Any, optional Additional parameters passed to the parent class `Aberration`. """ @@ -223,12 +228,12 @@ def __init__( def get( self: GaussianApodization, - pupil: np.ndarray, + pupil: np.ndarray | torch.Tensor, offset: tuple[float, float], sigma: float, - rho: np.ndarray, - **kwargs: dict[str, Any] - ) -> np.ndarray: + rho: np.ndarray | torch.Tensor, + **kwargs: Any, + ) -> np.ndarray | torch.Tensor: """Applies Gaussian apodization to the input pupil function. This method attenuates the amplitude of the pupil function based @@ -237,17 +242,17 @@ def get( Parameters ---------- - pupil: np.ndarray + pupil: np.ndarray or torch.Tensor A 2D array representing the input pupil function. offset: tuple of float Specifies the (x, y) offset of the Gaussian center relative to the pupil's center. sigma: float The standard deviation of the Gaussian apodization. - rho: np.ndarray + rho: np.ndarray or torch.Tensor A 2D array of radial coordinates normalized to the pupil aperture. - **kwargs: dict, optional + **kwargs: Any, optional Additional parameters for compatibility with other features or inherited methods. These are typically passed by the parent class and may include: @@ -256,7 +261,7 @@ def get( Returns ------- - np.ndarray + np.ndarray or torch.Tensor The modified pupil function after applying Gaussian apodization. Examples @@ -291,18 +296,19 @@ def get( """ if offset != (0, 0): - x = np.arange(pupil.shape[0]) - pupil.shape[0] / 2 - offset[0] - y = np.arange(pupil.shape[1]) - pupil.shape[1] / 2 - offset[1] - X, Y = np.meshgrid(x, y) - rho = np.sqrt(X ** 2 + Y ** 2) - rho /= np.max(rho[pupil != 0]) - rho[rho > 1] = np.inf - - pupil = pupil * np.exp(-((rho / sigma) ** 2)) + x = xp.arange(pupil.shape[0]) - pupil.shape[0] / 2 - offset[0] + y = xp.arange(pupil.shape[1]) - pupil.shape[1] / 2 - offset[1] + X, Y = xp.meshgrid(y, x) + rho = xp.sqrt(X ** 2 + Y ** 2) + mask = pupil != 0 + if bool(xp.any(mask)): + rho /= xp.max(rho[mask]) + rho[rho > 1] = xp.inf + + pupil = pupil * xp.exp(-((rho / sigma) ** 2)) return pupil -#TODO ***??*** revise Zernike - torch, docstring, unit test class Zernike(Aberration): """Introduces a Zernike phase aberration. @@ -336,7 +342,7 @@ class Zernike(Aberration): Methods ------- - `get(pupil: np.ndarray, rho: np.ndarray, theta: np.ndarray, n: int | list[int], m: int | list[int], coefficient: float | list[float], **kwargs: dict[str, Any]) -> np.ndarray` + `get(pupil, rho, theta, n, m, coefficient, **kwargs) -> np.ndarray | torch.Tensor` Applies the Zernike phase aberration to the input pupil function. Notes @@ -354,8 +360,8 @@ class Zernike(Aberration): >>> particle = dt.PointParticle(z = 1 * dt.units.micrometer) >>> aberrated_optics = dt.Fluorescence( >>> pupil=dt.Zernike( - >>> n=[0, 1], - >>> m = [1, 2], + >>> n = [2, 3], + >>> m = [0, 1], >>> coefficient=[1, 1] >>> ) >>> ) @@ -369,7 +375,7 @@ def __init__( n: PropertyLike[int | list[int]], m: PropertyLike[int | list[int]], coefficient: PropertyLike[float | list[float]] = 1, - **kwargs: dict[str, Any] + **kwargs: Any, ) -> None: """ Initializes the Zernike class. @@ -385,7 +391,7 @@ def __init__( coefficient: float or list of floats, optional The coefficients for the Zernike polynomials. These determine the relative contribution of each polynomial. Default is 1. - **kwargs: dict, optional + **kwargs: Any, optional Additional parameters passed to the parent class `Aberration`. Notes @@ -399,14 +405,14 @@ def __init__( def get( self: Zernike, - pupil: np.ndarray, - rho: np.ndarray, - theta: np.ndarray, + pupil: np.ndarray | torch.Tensor, + rho: np.ndarray | torch.Tensor, + theta: np.ndarray | torch.Tensor, n: int | list[int], m: int | list[int], coefficient: float | list[float], **kwargs: Any, - ) -> np.ndarray: + ) -> np.ndarray | torch.Tensor: """Applies the Zernike phase aberration to the input pupil function. The method calculates Zernike polynomials for the specified indices `n` @@ -416,13 +422,13 @@ def get( Parameters ---------- - pupil: np.ndarray + pupil: np.ndarray or torch.Tensor A 2D array representing the input pupil function. The values should represent the amplitude and phase across the aperture. - rho: np.ndarray + rho: np.ndarray or torch.Tensor A 2D array of radial coordinates normalized to the pupil aperture. The values should range from 0 to 1 within the aperture. - theta: np.ndarray + theta: np.ndarray or torch.Tensor A 2D array of angular coordinates in radians. These define the azimuthal positions for the pupil. n: int or list of ints @@ -432,20 +438,21 @@ def get( coefficient: float or list of floats The coefficients for the Zernike polynomials, controlling their relative contributions to the phase. - **kwargs: dict, optional + **kwargs: Any, optional Additional parameters for compatibility with other features or inherited methods. Returns ------- - np.ndarray + np.ndarray or torch.Tensor The modified pupil function with the applied Zernike phase aberration. Raises ------ - AssertionError - If the lengths of `n`, `m`, and `coefficient` lists do not match. + ValueError + If `n`, `m`, and `coefficient` do not have matching lengths when + provided as lists. Notes ----- @@ -464,7 +471,7 @@ def get( >>> pupil = np.ones((128, 128), dtype=complex) >>> x = np.linspace(-1, 1, 128) >>> y = np.linspace(-1, 1, 128) - >>> X, Y = np.meshgrid(x, y) + >>> X, Y = np.meshgrid(y, x) >>> rho = np.sqrt(X**2 + Y**2) >>> theta = np.arctan2(Y, X) >>> pupil[rho > 1] = 0 @@ -486,63 +493,64 @@ def get( n_list = as_list(n) coefficients = as_list(coefficient) - assert len(m_list) == len(n_list), "The number of indices need to match" - assert len(m_list) == len( - coefficients - ), "The number of indices need to match the number of coefficients" + if len(m_list) != len(n_list): + raise ValueError("`n` and `m` must have the same length.") + if len(m_list) != len(coefficients): + raise ValueError("`n`, `m`, and `coefficient` must have the same length.") pupil_bool = pupil != 0 rho = rho[pupil_bool] theta = theta[pupil_bool] - Z = 0 + Z = 0 * rho for n, m, coefficient in zip(n_list, m_list, coefficients): - if (n - m) % 2 or coefficient == 0: + if (n - abs(m)) % 2 or coefficient == 0: continue - R = 0 - for k in range((n - np.abs(m)) // 2 + 1): + R = 0 * rho + for k in range((n - abs(m)) // 2 + 1): R += ( (-1) ** k * math.factorial(n - k) / ( math.factorial(k) - * math.factorial((n - m) // 2 - k) - * math.factorial((n + m) // 2 - k) + * math.factorial((n - abs(m)) // 2 - k) + * math.factorial((n + abs(m)) // 2 - k) ) * rho ** (n - 2 * k) ) if m > 0: - R = R * np.cos(m * theta) * (np.sqrt(2 * n + 2) * coefficient) + R = R * xp.cos(m * theta) * (math.sqrt(2 * n + 2) * coefficient) elif m < 0: - R = R * np.sin(-m * theta) * (np.sqrt(2 * n + 2) * coefficient) + R = R * xp.sin(-m * theta) * (math.sqrt(2 * n + 2) * coefficient) else: - R = R * (np.sqrt(n + 1) * coefficient) + R = R * (math.sqrt(n + 1) * coefficient) Z += R - phase = np.exp(1j * Z) + phase = xp.exp(1j * Z) - pupil[pupil_bool] *= phase + pupil[pupil_bool] = pupil[pupil_bool] * phase return pupil -#TODO ***??*** revise Piston - torch, docstring, unit test class Piston(Zernike): """Zernike polynomial with n=0, m=0. - This class represents the simplest Zernike polynomial, often referred to as the piston term, - which has no radial or azimuthal variations (n=0, m=0). It adds a uniform phase contribution - to the pupil function. + This class represents the simplest Zernike polynomial, often referred to as + the piston term, which has no radial or azimuthal variations (n=0, m=0). It + adds a uniform phase contribution to the pupil function. Parameters ---------- coefficient: PropertyLike[float or list of floats], optional The coefficient of the polynomial. Default is 1. + kwargs: Any, optional + Additional parameters passed to the parent Zernike class. Attributes ---------- @@ -571,8 +579,7 @@ class Piston(Zernike): """ def __init__( - self: "Piston", - *args: tuple[Any, ...], + self: Piston, coefficient: PropertyLike[float | list[float]] = 1, **kwargs: Any, ) -> None: @@ -582,17 +589,14 @@ def __init__( ---------- coefficient: float or list of floats, optional The coefficient for the piston term. Default is 1. - *args: tuple, optional - Additional arguments passed to the parent Zernike class. - **kwargs: dict, optional + **kwargs: Any, optional Additional parameters passed to the parent Zernike class. """ - super().__init__(*args, n=0, m=0, coefficient=coefficient, **kwargs) + super().__init__(n=0, m=0, coefficient=coefficient, **kwargs) -#TODO ***??*** revise VerticalTilt - torch, docstring, unit test class VerticalTilt(Zernike): """Zernike polynomial with n=1, m=-1. @@ -604,6 +608,8 @@ class VerticalTilt(Zernike): ---------- coefficient: PropertyLike[float or list of floats], optional The coefficient of the polynomial. Default is 1. + kwargs: Any, optional + Additional parameters passed to the parent Zernike class. Attributes ---------- @@ -628,11 +634,11 @@ class VerticalTilt(Zernike): >>> ) >>> aberrated_particle = aberrated_optics(particle) >>> aberrated_particle.plot(cmap="gray") + """ def __init__( self: VerticalTilt, - *args: tuple[Any, ...], coefficient: PropertyLike[float | list[float]] = 1, **kwargs: Any, ) -> None: @@ -642,15 +648,14 @@ def __init__( ---------- coefficient: float or list of floats, optional The coefficient for the vertical tilt term. Default is 1. - *args: tuple, optional - Additional arguments passed to the parent Zernike class. - **kwargs: dict, optional + **kwargs: Any, optional Additional parameters passed to the parent Zernike class. + """ - super().__init__(*args, n=1, m=-1, coefficient=coefficient, **kwargs) + + super().__init__(n=1, m=-1, coefficient=coefficient, **kwargs) -#TODO ***??*** revise HorizontalTilt - torch, docstring, unit test class HorizontalTilt(Zernike): """Zernike polynomial with n=1, m=1. @@ -662,6 +667,8 @@ class HorizontalTilt(Zernike): ---------- coefficient: PropertyLike[float or list of floats], optional The coefficient of the polynomial. Default is 1. + kwargs: Any, optional + Additional parameters passed to the parent Zernike class. Attributes ---------- @@ -688,11 +695,11 @@ class HorizontalTilt(Zernike): >>> ) >>> aberrated_particle = aberrated_optics(particle) >>> aberrated_particle.plot(cmap="gray") + """ def __init__( self: HorizontalTilt, - *args: tuple[Any, ...], coefficient: PropertyLike[float | list[float]] = 1, **kwargs: Any, ) -> None: @@ -702,15 +709,14 @@ def __init__( ---------- coefficient: float or list of floats, optional The coefficient for the horizontal tilt term. Default is 1. - *args: tuple, optional - Additional arguments passed to the parent Zernike class. - **kwargs: dict, optional + **kwargs: Any, optional Additional parameters passed to the parent Zernike class. + """ - super().__init__(*args, n=1, m=1, coefficient=coefficient, **kwargs) + + super().__init__(n=1, m=1, coefficient=coefficient, **kwargs) -#TODO ***??*** revise ObliqueAstigmatism - torch, docstring, unit test class ObliqueAstigmatism(Zernike): """Zernike polynomial with n=2, m=-2. @@ -723,6 +729,8 @@ class ObliqueAstigmatism(Zernike): ---------- coefficient: PropertyLike[float or list of floats], optional The coefficient of the polynomial. Default is 1. + kwargs: Any, optional + Additional parameters passed to the parent Zernike class. Attributes ---------- @@ -749,11 +757,11 @@ class ObliqueAstigmatism(Zernike): >>> ) >>> aberrated_particle = aberrated_optics(particle) >>> aberrated_particle.plot(cmap="gray") + """ def __init__( self: ObliqueAstigmatism, - *args: tuple[Any, ...], coefficient: PropertyLike[float | list[float]] = 1, **kwargs: Any, ) -> None: @@ -763,15 +771,14 @@ def __init__( ---------- coefficient: float or list of floats, optional The coefficient for the oblique astigmatism term. Default is 1. - *args: tuple, optional - Additional arguments passed to the parent Zernike class. - **kwargs: dict, optional + **kwargs: Any, optional Additional parameters passed to the parent Zernike class. + """ - super().__init__(*args, n=2, m=-2, coefficient=coefficient, **kwargs) + + super().__init__(n=2, m=-2, coefficient=coefficient, **kwargs) -#TODO ***??*** revise Defocus - torch, docstring, unit test class Defocus(Zernike): """Zernike polynomial with n=2, m=0. @@ -784,6 +791,8 @@ class Defocus(Zernike): ---------- coefficient: PropertyLike[float or list of floats], optional The coefficient of the polynomial. Default is 1. + kwargs: Any, optional + Additional parameters passed to the parent Zernike class. Attributes ---------- @@ -812,7 +821,6 @@ class Defocus(Zernike): def __init__( self: Defocus, - *args: tuple[Any, ...], coefficient: PropertyLike[float | list[float]] = 1, **kwargs: Any, ) -> None: @@ -822,15 +830,14 @@ def __init__( ---------- coefficient: float or list of floats, optional The coefficient for the defocus term. Default is 1. - *args: tuple, optional - Additional arguments passed to the parent Zernike class. - **kwargs: dict, optional + **kwargs: Any, optional Additional parameters passed to the parent Zernike class. + """ - super().__init__(*args, n=2, m=0, coefficient=coefficient, **kwargs) + + super().__init__(n=2, m=0, coefficient=coefficient, **kwargs) -#TODO ***??*** revise Astigmatism - torch, docstring, unit test class Astigmatism(Zernike): """Zernike polynomial with n=2, m=2. @@ -843,6 +850,8 @@ class Astigmatism(Zernike): ---------- coefficient: PropertyLike[float or list of floats], optional The coefficient of the polynomial. Default is 1. + kwargs: Any, optional + Additional parameters passed to the parent Zernike class. Attributes ---------- @@ -867,11 +876,11 @@ class Astigmatism(Zernike): >>> ) >>> aberrated_particle = aberrated_optics(particle) >>> aberrated_particle.plot(cmap="gray") + """ def __init__( self: Astigmatism, - *args: tuple[Any, ...], coefficient: PropertyLike[float | list[float]] = 1, **kwargs: Any, ) -> None: @@ -881,15 +890,14 @@ def __init__( ---------- coefficient: float or list of floats, optional The coefficient for the astigmatism term. Default is 1. - *args: tuple, optional - Additional arguments passed to the parent Zernike class. - **kwargs: dict, optional + **kwargs: Any, optional Additional parameters passed to the parent Zernike class. + """ - super().__init__(*args, n=2, m=2, coefficient=coefficient, **kwargs) + + super().__init__(n=2, m=2, coefficient=coefficient, **kwargs) -#TODO ***??*** revise ObliqueTrefoil - torch, docstring, unit test class ObliqueTrefoil(Zernike): """Zernike polynomial with n=3, m=-3. @@ -901,6 +909,8 @@ class ObliqueTrefoil(Zernike): ---------- coefficient: PropertyLike[float or list of floats], optional The coefficient of the polynomial. Default is 1. + kwargs: Any, optional + Additional parameters passed to the parent Zernike class. Examples -------- @@ -916,18 +926,29 @@ class ObliqueTrefoil(Zernike): >>> ) >>> aberrated_particle = aberrated_optics(particle) >>> aberrated_particle.plot(cmap="gray") + """ def __init__( self: ObliqueTrefoil, - *args: tuple[Any, ...], coefficient: PropertyLike[float | list[float]] = 1, **kwargs: Any, ) -> None: - super().__init__(*args, n=3, m=-3, coefficient=coefficient, **kwargs) + """Initializes the ObliqueTrefoil class. + + Parameters + ---------- + coefficient: float or list of floats, optional + The coefficient for the oblique trefoil term. Default is 1. + **kwargs: Any, optional + Additional parameters passed to the parent Zernike class. + + + """ + + super().__init__(n=3, m=-3, coefficient=coefficient, **kwargs) -#TODO ***??*** revise VerticalComa - torch, docstring, unit test class VerticalComa(Zernike): """Zernike polynomial with n=3, m=-1. @@ -938,18 +959,30 @@ class VerticalComa(Zernike): ---------- coefficient: PropertyLike[float or list of floats], optional The coefficient of the polynomial. Default is 1. + kwargs: Any, optional + Additional parameters passed to the parent Zernike class. + """ def __init__( self: VerticalComa, - *args: tuple[Any, ...], coefficient: PropertyLike[float | list[float]] = 1, **kwargs: Any, ) -> None: - super().__init__(*args, n=3, m=-1, coefficient=coefficient, **kwargs) + """Initializes the VerticalComa class. + + Parameters + ---------- + coefficient: float or list of floats, optional + The coefficient for the vertical coma term. Default is 1. + **kwargs: Any, optional + Additional parameters passed to the parent Zernike class. + + """ + + super().__init__(n=3, m=-1, coefficient=coefficient, **kwargs) -#TODO ***??*** revise HorizontalComa - torch, docstring, unit test class HorizontalComa(Zernike): """Zernike polynomial with n=3, m=1. @@ -960,18 +993,30 @@ class HorizontalComa(Zernike): ---------- coefficient: PropertyLike[float or list of floats], optional The coefficient of the polynomial. Default is 1. + kwargs: Any, optional + Additional parameters passed to the parent Zernike class. + """ def __init__( self: HorizontalComa, - *args: tuple[Any, ...], coefficient: PropertyLike[float | list[float]] = 1, **kwargs: Any, ) -> None: - super().__init__(*args, n=3, m=1, coefficient=coefficient, **kwargs) + """Initializes the HorizontalComa class. + + Parameters + ---------- + coefficient: float or list of floats, optional + The coefficient for the horizontal coma term. Default is 1. + **kwargs: Any, optional + Additional parameters passed to the parent Zernike class. + + """ + + super().__init__(n=3, m=1, coefficient=coefficient, **kwargs) -#TODO ***??*** revise Trefoil - torch, docstring, unit test class Trefoil(Zernike): """Zernike polynomial with n=3, m=3. @@ -982,18 +1027,29 @@ class Trefoil(Zernike): ---------- coefficient: PropertyLike[float or list of floats], optional The coefficient of the polynomial. Default is 1. + kwargs: Any, optional + Additional parameters passed to the parent Zernike class. + """ def __init__( self: Trefoil, - *args: tuple[Any, ...], coefficient: PropertyLike[float | list[float]] = 1, **kwargs: Any, ) -> None: - super().__init__(*args, n=3, m=3, coefficient=coefficient, **kwargs) + """Initializes the Trefoil class. + Parameters + ---------- + coefficient: float or list of floats, optional + The coefficient for the trefoil term. Default is 1. + **kwargs: Any, optional + Additional parameters passed to the parent Zernike class. + + """ + + super().__init__(n=3, m=3, coefficient=coefficient, **kwargs) -#TODO ***??*** revise SphericalAberration - torch, docstring, unit test class SphericalAberration(Zernike): """Zernike polynomial with n=4, m=0. @@ -1004,12 +1060,25 @@ class SphericalAberration(Zernike): ---------- coefficient: PropertyLike[float or list of floats], optional The coefficient of the polynomial. Default is 1. + kwargs: Any, optional + Additional parameters passed to the parent Zernike class. + """ def __init__( self: SphericalAberration, - *args: tuple[Any, ...], coefficient: PropertyLike[float | list[float]] = 1, **kwargs: Any, ) -> None: - super().__init__(*args, n=4, m=0, coefficient=coefficient, **kwargs) + """Initializes the SphericalAberration class. + + Parameters + ---------- + coefficient: float or list of floats, optional + The coefficient for the spherical aberration term. Default is 1. + **kwargs: Any, optional + Additional parameters passed to the parent Zernike class. + + """ + + super().__init__(n=4, m=0, coefficient=coefficient, **kwargs) diff --git a/deeptrack/tests/test_aberrations.py b/deeptrack/tests/test_aberrations.py index 14a771da..8460b38c 100644 --- a/deeptrack/tests/test_aberrations.py +++ b/deeptrack/tests/test_aberrations.py @@ -1,312 +1,220 @@ -import sys - -# sys.path.append(".") # Adds the module to path - import unittest -raise unittest.SkipTest("Temporarily skipped") - import numpy as np from deeptrack import aberrations - -from deeptrack.scatterers import PointParticle from deeptrack.optics import Fluorescence -from deeptrack.image import Image - +from deeptrack.scatterers import PointParticle +from deeptrack.backend import TORCH_AVAILABLE +from deeptrack.tests import BackendTestBase +if TORCH_AVAILABLE: + import torch -class TestAberrations(unittest.TestCase): - particle = PointParticle(position=(32, 32), position_unit="pixel", intensity=1) +class TestAberrations_NumPy(BackendTestBase): + BACKEND = "numpy" - def testGaussianApodization(self): - aberrated_optics = Fluorescence( - NA=0.3, - resolution=1e-6, - magnification=10, - wavelength=530e-9, - output_region=(0, 0, 64, 48), - padding=(64, 64, 64, 64), - pupil=aberrations.GaussianApodization(sigma=0.5), + def setUp(self): + super().setUp() + self.particle = PointParticle( + position=(32, 32), + position_unit="pixel", + intensity=1, ) - aberrated_particle = aberrated_optics(self.particle) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, np.ndarray) - self.assertEqual(im.shape, (64, 48, 1)) - - aberrated_particle.store_properties(True) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, Image) - self.assertEqual(im.shape, (64, 48, 1)) + @property + def array_type(self): + if self.BACKEND == "numpy": + return np.ndarray + if self.BACKEND == "torch": + return torch.Tensor + raise ValueError(f"Unsupported backend: {self.BACKEND}") - def testZernike(self): - aberrated_optics = Fluorescence( + def _make_optics(self, pupil): + return Fluorescence( NA=0.3, resolution=1e-6, magnification=10, wavelength=530e-9, output_region=(0, 0, 64, 48), padding=(64, 64, 64, 64), - pupil=aberrations.Zernike( - n=[2, 3], m=[0, 1], coefficient=[0.5, 0.3], - ), + pupil=pupil, ) + + def _to_numpy(self, x): + if TORCH_AVAILABLE and isinstance(x, torch.Tensor): + return x.detach().cpu().numpy() + return np.asarray(x) + + def _render(self, pupil=None, z=0): + optics = self._make_optics(pupil) + image = optics(self.particle).resolve(z=z) + + self.assertIsInstance(image, self.array_type) + self.assertEqual(image.shape, (64, 48, 1)) + + return self._to_numpy(image[..., 0]) + + def _com(self, img): + y, x = np.indices(img.shape) + s = img.sum() + return (y * img).sum() / s, (x * img).sum() / s + + def _second_moments(self, img): + cy, cx = self._com(img) + y, x = np.indices(img.shape) + s = img.sum() + vy = (img * (y - cy) ** 2).sum() / s + vx = (img * (x - cx) ** 2).sum() / s + return vy, vx + + def _radial_second_moment(self, img): + cy, cx = self._com(img) + y, x = np.indices(img.shape) + r2 = (y - cy) ** 2 + (x - cx) ** 2 + return (img * r2).sum() / img.sum() + + def _assert_resolves(self, pupil): + aberrated_optics = self._make_optics(pupil) aberrated_particle = aberrated_optics(self.particle) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, np.ndarray) - self.assertEqual(im.shape, (64, 48, 1)) - aberrated_particle.store_properties(True) for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, Image) - self.assertEqual(im.shape, (64, 48, 1)) + with self.subTest(z=z, pupil=type(pupil).__name__): + im = aberrated_particle.resolve(z=z) + self.assertIsInstance(im, self.array_type) + self.assertEqual(im.shape, (64, 48, 1)) + + def test___all__(self): + from deeptrack import ( + GaussianApodization, + Zernike, + Piston, + VerticalTilt, + HorizontalTilt, + ObliqueAstigmatism, + Defocus, + Astigmatism, + ObliqueTrefoil, + VerticalComa, + HorizontalComa, + Trefoil, + SphericalAberration, + ) - def testPiston(self): - aberrated_optics = Fluorescence( - NA=0.3, - resolution=1e-6, - magnification=10, - wavelength=530e-9, - output_region=(0, 0, 64, 48), - padding=(64, 64, 64, 64), - pupil=aberrations.Piston(coefficient=1), + def testGaussianApodization_resolves(self): + self._assert_resolves( + aberrations.GaussianApodization(sigma=0.5) ) - aberrated_particle = aberrated_optics(self.particle) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, np.ndarray) - self.assertEqual(im.shape, (64, 48, 1)) - aberrated_particle.store_properties(True) + def testGaussianApodization_reduces_peak(self): for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, Image) - self.assertEqual(im.shape, (64, 48, 1)) + with self.subTest(z=z): + base = self._render(pupil=None, z=z) + out = self._render( + pupil=aberrations.GaussianApodization(sigma=0.5), + z=z, + ) + self.assertLess(out.max(), base.max()) - def testVerticalTilt(self): - aberrated_optics = Fluorescence( - NA=0.3, - resolution=1e-6, - magnification=10, - wavelength=530e-9, - output_region=(0, 0, 64, 48), - padding=(64, 64, 64, 64), - pupil=aberrations.VerticalTilt(coefficient=1), + def testZernike_resolves(self): + self._assert_resolves( + aberrations.Zernike(n=[2, 3], m=[0, 1], coefficient=[0.5, 0.3]) ) - aberrated_particle = aberrated_optics(self.particle) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, np.ndarray) - self.assertEqual(im.shape, (64, 48, 1)) - aberrated_particle.store_properties(True) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, Image) - self.assertEqual(im.shape, (64, 48, 1)) + def testPiston_resolves(self): + self._assert_resolves(aberrations.Piston(coefficient=1)) - def testHorizontalTilt(self): - aberrated_optics = Fluorescence( - NA=0.3, - resolution=1e-6, - magnification=10, - wavelength=530e-9, - output_region=(0, 0, 64, 48), - padding=(64, 64, 64, 64), - pupil=aberrations.HorizontalTilt(coefficient=1), - ) - aberrated_particle = aberrated_optics(self.particle) + def testPiston_image_invariant(self): for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, np.ndarray) - self.assertEqual(im.shape, (64, 48, 1)) + with self.subTest(z=z): + base = self._render(pupil=None, z=z) + out = self._render( + pupil=aberrations.Piston(coefficient=1), + z=z, + ) + np.testing.assert_allclose(out, base, atol=1e-6, rtol=1e-6) - aberrated_particle.store_properties(True) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, Image) - self.assertEqual(im.shape, (64, 48, 1)) + def testVerticalTilt_resolves(self): + self._assert_resolves(aberrations.VerticalTilt(coefficient=1)) - def testObliqueAstigmatism(self): - aberrated_optics = Fluorescence( - NA=0.3, - resolution=1e-6, - magnification=10, - wavelength=530e-9, - output_region=(0, 0, 64, 48), - padding=(64, 64, 64, 64), - pupil=aberrations.ObliqueAstigmatism(coefficient=1), - ) - aberrated_particle = aberrated_optics(self.particle) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, np.ndarray) - self.assertEqual(im.shape, (64, 48, 1)) + def testVerticalTilt_shifts_y(self): + base = self._render(pupil=None, z=0) + out = self._render(pupil=aberrations.VerticalTilt(coefficient=5), z=0) - aberrated_particle.store_properties(True) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, Image) - self.assertEqual(im.shape, (64, 48, 1)) + cy0, cx0 = self._com(base) + cy1, cx1 = self._com(out) - def testDefocus(self): - aberrated_optics = Fluorescence( - NA=0.3, - resolution=1e-6, - magnification=10, - wavelength=530e-9, - output_region=(0, 0, 64, 48), - padding=(64, 64, 64, 64), - pupil=aberrations.Defocus(coefficient=1), - ) - aberrated_particle = aberrated_optics(self.particle) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, np.ndarray) - self.assertEqual(im.shape, (64, 48, 1)) + self.assertGreater(abs(cy1 - cy0), 0.05) + self.assertLess(abs(cx1 - cx0), abs(cy1 - cy0)) - aberrated_particle.store_properties(True) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, Image) - self.assertEqual(im.shape, (64, 48, 1)) + def testHorizontalTilt_resolves(self): + self._assert_resolves(aberrations.HorizontalTilt(coefficient=1)) - def testAstigmatism(self): - aberrated_optics = Fluorescence( - NA=0.3, - resolution=1e-6, - magnification=10, - wavelength=530e-9, - output_region=(0, 0, 64, 48), - padding=(64, 64, 64, 64), - pupil=aberrations.Astigmatism(coefficient=1), - ) - aberrated_particle = aberrated_optics(self.particle) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, np.ndarray) - self.assertEqual(im.shape, (64, 48, 1)) + def testHorizontalTilt_shifts_x(self): + base = self._render(pupil=None, z=0) + out = self._render(pupil=aberrations.HorizontalTilt(coefficient=5), z=0) - aberrated_particle.store_properties(True) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, Image) - self.assertEqual(im.shape, (64, 48, 1)) + cy0, cx0 = self._com(base) + cy1, cx1 = self._com(out) - def testObliqueTrefoil(self): - aberrated_optics = Fluorescence( - NA=0.3, - resolution=1e-6, - magnification=10, - wavelength=530e-9, - output_region=(0, 0, 64, 48), - padding=(64, 64, 64, 64), - pupil=aberrations.ObliqueTrefoil(coefficient=1), - ) - aberrated_particle = aberrated_optics(self.particle) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, np.ndarray) - self.assertEqual(im.shape, (64, 48, 1)) + self.assertGreater(abs(cx1 - cx0), 0.05) + self.assertLess(abs(cy1 - cy0), abs(cx1 - cx0)) - aberrated_particle.store_properties(True) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, Image) - self.assertEqual(im.shape, (64, 48, 1)) + def testObliqueAstigmatism_resolves(self): + self._assert_resolves(aberrations.ObliqueAstigmatism(coefficient=1)) - def testVerticalComa(self): - aberrated_optics = Fluorescence( - NA=0.3, - resolution=1e-6, - magnification=10, - wavelength=530e-9, - output_region=(0, 0, 64, 48), - padding=(64, 64, 64, 64), - pupil=aberrations.VerticalComa(coefficient=1), - ) - aberrated_particle = aberrated_optics(self.particle) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, np.ndarray) - self.assertEqual(im.shape, (64, 48, 1)) + def testDefocus_resolves(self): + self._assert_resolves(aberrations.Defocus(coefficient=1)) - aberrated_particle.store_properties(True) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, Image) - self.assertEqual(im.shape, (64, 48, 1)) + def testDefocus_matches_Zernike(self): + img1 = self._render(pupil=aberrations.Defocus(coefficient=1), z=0) + img2 = self._render(pupil=aberrations.Zernike(n=2, m=0, coefficient=1), z=0) + np.testing.assert_allclose(img1, img2, atol=1e-6, rtol=1e-6) - def testHorizontalComa(self): - aberrated_optics = Fluorescence( - NA=0.3, - resolution=1e-6, - magnification=10, - wavelength=530e-9, - output_region=(0, 0, 64, 48), - padding=(64, 64, 64, 64), - pupil=aberrations.HorizontalComa(coefficient=1), + def testDefocus_broadens_psf(self): + base = self._render(pupil=None, z=0) + out = self._render(pupil=aberrations.Defocus(coefficient=1), z=0) + + self.assertLess(out.max(), base.max()) + self.assertGreater( + self._radial_second_moment(out), + self._radial_second_moment(base), ) - aberrated_particle = aberrated_optics(self.particle) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, np.ndarray) - self.assertEqual(im.shape, (64, 48, 1)) - aberrated_particle.store_properties(True) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, Image) - self.assertEqual(im.shape, (64, 48, 1)) + def testAstigmatism_resolves(self): + self._assert_resolves(aberrations.Astigmatism(coefficient=1)) - def testTrefoil(self): - aberrated_optics = Fluorescence( - NA=0.3, - resolution=1e-6, - magnification=10, - wavelength=530e-9, - output_region=(0, 0, 64, 48), - padding=(64, 64, 64, 64), - pupil=aberrations.Trefoil(coefficient=1), - ) - aberrated_particle = aberrated_optics(self.particle) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, np.ndarray) - self.assertEqual(im.shape, (64, 48, 1)) + def testAstigmatism_breaks_xy_symmetry(self): + base = self._render(pupil=None, z=0) + out = self._render(pupil=aberrations.Astigmatism(coefficient=1), z=0) - aberrated_particle.store_properties(True) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, Image) - self.assertEqual(im.shape, (64, 48, 1)) + vy0, vx0 = self._second_moments(base) + vy1, vx1 = self._second_moments(out) - def testSphericalAberration(self): - aberrated_optics = Fluorescence( - NA=0.3, - resolution=1e-6, - magnification=10, - wavelength=530e-9, - output_region=(0, 0, 64, 48), - padding=(64, 64, 64, 64), - pupil=aberrations.SphericalAberration(coefficient=1), - ) - aberrated_particle = aberrated_optics(self.particle) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, np.ndarray) - self.assertEqual(im.shape, (64, 48, 1)) + base_anisotropy = abs(vy0 - vx0) + out_anisotropy = abs(vy1 - vx1) - aberrated_particle.store_properties(True) - for z in (-100, 0, 100): - im = aberrated_particle.resolve(z=z) - self.assertIsInstance(im, Image) - self.assertEqual(im.shape, (64, 48, 1)) + self.assertGreater(out_anisotropy, base_anisotropy) + + def testObliqueTrefoil_resolves(self): + self._assert_resolves(aberrations.ObliqueTrefoil(coefficient=1)) + + def testVerticalComa_resolves(self): + self._assert_resolves(aberrations.VerticalComa(coefficient=1)) + + def testHorizontalComa_resolves(self): + self._assert_resolves(aberrations.HorizontalComa(coefficient=1)) + + def testTrefoil_resolves(self): + self._assert_resolves(aberrations.Trefoil(coefficient=1)) + + def testSphericalAberration_resolves(self): + self._assert_resolves(aberrations.SphericalAberration(coefficient=1)) + + +@unittest.skipUnless(TORCH_AVAILABLE, "PyTorch is not installed.") +class TestAberrations_PyTorch(TestAberrations_NumPy): + BACKEND = "torch" if __name__ == "__main__":