diff --git a/docs/docs/tutorials/components.ipynb b/docs/docs/tutorials/components.ipynb index 55570c35..4934347b 100644 --- a/docs/docs/tutorials/components.ipynb +++ b/docs/docs/tutorials/components.ipynb @@ -150,6 +150,30 @@ "plt.legend()\n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "be79f3fc", + "metadata": {}, + "outputs": [], + "source": [ + "expr = ExpressionComponent(\n", + " 'A*erf(B*x)',\n", + ")\n", + "\n", + "expr.A = 1.0\n", + "expr.B = 0.5\n", + "\n", + "\n", + "x = np.linspace(-5, 5, 100)\n", + "y = expr.evaluate(x)\n", + "\n", + "plt.figure()\n", + "plt.plot(x, y, label='erf')\n", + "plt.legend()\n", + "plt.show()" + ] } ], "metadata": { @@ -168,7 +192,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.13" + "version": "3.14.4" } }, "nbformat": 4, diff --git a/src/easydynamics/sample_model/components/expression_component.py b/src/easydynamics/sample_model/components/expression_component.py index 79ce73ef..14f75a63 100644 --- a/src/easydynamics/sample_model/components/expression_component.py +++ b/src/easydynamics/sample_model/components/expression_component.py @@ -6,6 +6,7 @@ from typing import TYPE_CHECKING from typing import ClassVar +import scipy import sympy as sp from easyscience.variable import Parameter @@ -20,11 +21,6 @@ class ExpressionComponent(ModelComponent): """ Model component defined by a symbolic expression. - - Example: expr = ExpressionComponent( "A * exp(-(x - x0)**2 / (2*sigma**2))", parameters={"A": - 10, "x0": 0, "sigma": 1}, ) - - expr.A = 5 y = expr.evaluate(x) """ # ------------------------- @@ -87,9 +83,9 @@ def __init__( The symbolic expression as a string. Must contain 'x' as the independent variable. parameters : dict[str, Numeric] | None, default=None Dictionary of parameter names and their initial values. - unit : str | sc.Unit, default='meV' + unit : str | sc.Unit, default="meV" Unit of the output. - display_name : str | None, default='Expression' + display_name : str | None, default="Expression" Display name for the component. unique_name : str | None, default=None Unique name for the component. @@ -100,6 +96,18 @@ def __init__( If the expression is invalid or does not contain 'x'. TypeError If any parameter value is not numeric. + + Examples + -------- + >>> expr = ExpressionComponent( + ... 'A * exp(-(x - x0)**2 / (2*sigma**2))', + ... parameters={'A': 10, 'x0': 0, 'sigma': 1}, + ... unit='meV', + ... display_name='Gaussian Peak', + ... ) + + >>> expr.A = 5 + >>> y = expr.evaluate(x) """ super().__init__(unit=unit, display_name=display_name, unique_name=unique_name) @@ -157,8 +165,11 @@ def __init__( if parameters is not None: for name, value in parameters.items(): - if not isinstance(value, Numeric): - raise TypeError(f"Parameter '{name}' must be numeric") + if not isinstance(value, (Numeric, Parameter, dict)): + raise TypeError( + f"Parameter '{name}' must be numeric, " + f'a Parameter instance, or a dictionary, got {type(value).__name__}' + ) parameters = parameters or {} self._parameters: dict[str, Parameter] = {} @@ -168,12 +179,17 @@ def __init__( continue value = parameters.get(name, 1.0) + if isinstance(value, Parameter): + self._parameters[name] = value - self._parameters[name] = Parameter( - name=name, - value=value, - unit=self._unit, - ) + elif isinstance(value, dict) and value.get('@class') == 'Parameter': + self._parameters[name] = Parameter.from_dict(value) + else: + self._parameters[name] = Parameter( + name=name, + value=value, + unit=self._unit, + ) # Create numerical function ordered_symbols = [sp.Symbol(name) for name in self._symbol_names] @@ -181,7 +197,7 @@ def __init__( self._func = sp.lambdify( ordered_symbols, self._expr, - modules=['numpy'], + modules=[{'erf': scipy.special.erf}, 'numpy'], ) # ------------------------- @@ -190,7 +206,14 @@ def __init__( @property def expression(self) -> str: - """Return the original expression string.""" + """ + Return the original expression string. + + Returns + ------- + str + The original expression string provided at initialization. + """ return self._expression_str @expression.setter @@ -334,7 +357,14 @@ def __dir__(self) -> list[str]: return super().__dir__() + list(self._parameters.keys()) def __repr__(self) -> str: - """Repr function.""" + """ + Return a string representation of the ExpressionComponent. + + Returns + ------- + str + String representation of the ExpressionComponent. + """ param_str = ', '.join(f'{k}={v.value}' for k, v in self._parameters.items()) return ( f'{self.__class__.__name__}(\n' diff --git a/tests/unit/easydynamics/sample_model/components/test_expression_component.py b/tests/unit/easydynamics/sample_model/components/test_expression_component.py index 27399844..48fbafe2 100644 --- a/tests/unit/easydynamics/sample_model/components/test_expression_component.py +++ b/tests/unit/easydynamics/sample_model/components/test_expression_component.py @@ -1,6 +1,8 @@ # SPDX-FileCopyrightText: 2026 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause +from copy import copy + import numpy as np import pytest from easyscience.variable import Parameter @@ -34,6 +36,14 @@ def test_init_without_parameters(self): # EXPECT assert expr.A.value == pytest.approx(1.0) # default + def test_init_with_parameter(self): + # WHEN THEN + A = Parameter('A', 3.0) + expr = ExpressionComponent('A * x', parameters={'A': A}) + + # EXPECT + assert expr.A.value == pytest.approx(3.0) + def test_invalid_expression_raises(self): # WHEN THEN EXPECT with pytest.raises(ValueError, match='Invalid expression'): @@ -172,3 +182,30 @@ def test_reserved_name_not_parameter(self): assert 'A' in names assert 'x' not in names # x is reserved + + def test_copy(self, expr: ExpressionComponent): + # WHEN THEN + expr_copy = copy(expr) + + # EXPECT the copy is a new instance with the same properties + assert expr_copy is not expr + assert isinstance(expr_copy, ExpressionComponent) + assert expr_copy.expression == expr.expression + assert expr_copy.unit == expr.unit + assert expr_copy.display_name == expr.display_name + + assert expr_copy.A.value == pytest.approx(expr.A.value) + assert expr_copy.x0.value == pytest.approx(expr.x0.value) + assert expr_copy.sigma.value == pytest.approx(expr.sigma.value) + + def test_erf(self): + # WHEN + expr = ExpressionComponent('erf(x)') + x = np.array([-1.0, 0.0, 1.0]) + + # THEN + result = expr.evaluate(x) + + # EXPECT + expected = np.array([-0.84270079, 0.0, 0.84270079]) # erf(-1), erf(0), erf(1) + np.testing.assert_allclose(result, expected, rtol=1e-5)