Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion docs/docs/tutorials/components.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,30 @@
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "be79f3fc",
"metadata": {},
"outputs": [],
"source": [
"expr = ExpressionComponent(\n",
" 'A*erf(B*x)',\n",
")\n",
"\n",
"expr.A = 1.0\n",
"expr.B = 0.5\n",
"\n",
"\n",
"x = np.linspace(-5, 5, 100)\n",
"y = expr.evaluate(x)\n",
"\n",
"plt.figure()\n",
"plt.plot(x, y, label='erf')\n",
"plt.legend()\n",
"plt.show()"
]
}
],
"metadata": {
Expand All @@ -168,7 +192,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.13"
"version": "3.14.4"
}
},
"nbformat": 4,
Expand Down
64 changes: 47 additions & 17 deletions src/easydynamics/sample_model/components/expression_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import TYPE_CHECKING
from typing import ClassVar

import scipy
import sympy as sp
from easyscience.variable import Parameter

Expand All @@ -20,11 +21,6 @@
class ExpressionComponent(ModelComponent):
"""
Model component defined by a symbolic expression.

Example: expr = ExpressionComponent( "A * exp(-(x - x0)**2 / (2*sigma**2))", parameters={"A":
10, "x0": 0, "sigma": 1}, )

expr.A = 5 y = expr.evaluate(x)
"""

# -------------------------
Expand Down Expand Up @@ -87,9 +83,9 @@ def __init__(
The symbolic expression as a string. Must contain 'x' as the independent variable.
parameters : dict[str, Numeric] | None, default=None
Dictionary of parameter names and their initial values.
unit : str | sc.Unit, default='meV'
unit : str | sc.Unit, default="meV"
Unit of the output.
display_name : str | None, default='Expression'
display_name : str | None, default="Expression"
Display name for the component.
unique_name : str | None, default=None
Unique name for the component.
Expand All @@ -100,6 +96,18 @@ def __init__(
If the expression is invalid or does not contain 'x'.
TypeError
If any parameter value is not numeric.

Examples
--------
>>> expr = ExpressionComponent(
... 'A * exp(-(x - x0)**2 / (2*sigma**2))',
... parameters={'A': 10, 'x0': 0, 'sigma': 1},
... unit='meV',
... display_name='Gaussian Peak',
... )

>>> expr.A = 5
>>> y = expr.evaluate(x)
"""
super().__init__(unit=unit, display_name=display_name, unique_name=unique_name)

Expand Down Expand Up @@ -157,8 +165,11 @@ def __init__(

if parameters is not None:
for name, value in parameters.items():
if not isinstance(value, Numeric):
raise TypeError(f"Parameter '{name}' must be numeric")
if not isinstance(value, (Numeric, Parameter, dict)):
raise TypeError(
f"Parameter '{name}' must be numeric, "
f'a Parameter instance, or a dictionary, got {type(value).__name__}'
)
parameters = parameters or {}
self._parameters: dict[str, Parameter] = {}

Expand All @@ -168,20 +179,25 @@ def __init__(
continue

value = parameters.get(name, 1.0)
if isinstance(value, Parameter):
self._parameters[name] = value

self._parameters[name] = Parameter(
name=name,
value=value,
unit=self._unit,
)
elif isinstance(value, dict) and value.get('@class') == 'Parameter':
self._parameters[name] = Parameter.from_dict(value)
else:
self._parameters[name] = Parameter(
name=name,
value=value,
unit=self._unit,
)

# Create numerical function
ordered_symbols = [sp.Symbol(name) for name in self._symbol_names]

self._func = sp.lambdify(
ordered_symbols,
self._expr,
modules=['numpy'],
modules=[{'erf': scipy.special.erf}, 'numpy'],
)

# -------------------------
Expand All @@ -190,7 +206,14 @@ def __init__(

@property
def expression(self) -> str:
"""Return the original expression string."""
"""
Return the original expression string.

Returns
-------
str
The original expression string provided at initialization.
"""
return self._expression_str

@expression.setter
Expand Down Expand Up @@ -334,7 +357,14 @@ def __dir__(self) -> list[str]:
return super().__dir__() + list(self._parameters.keys())

def __repr__(self) -> str:
"""Repr function."""
"""
Return a string representation of the ExpressionComponent.

Returns
-------
str
String representation of the ExpressionComponent.
"""
param_str = ', '.join(f'{k}={v.value}' for k, v in self._parameters.items())
return (
f'{self.__class__.__name__}(\n'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: 2026 EasyScience contributors
# SPDX-License-Identifier: BSD-3-Clause

from copy import copy

import numpy as np
import pytest
from easyscience.variable import Parameter
Expand Down Expand Up @@ -34,6 +36,14 @@ def test_init_without_parameters(self):
# EXPECT
assert expr.A.value == pytest.approx(1.0) # default

def test_init_with_parameter(self):
# WHEN THEN
A = Parameter('A', 3.0)
expr = ExpressionComponent('A * x', parameters={'A': A})

# EXPECT
assert expr.A.value == pytest.approx(3.0)

def test_invalid_expression_raises(self):
# WHEN THEN EXPECT
with pytest.raises(ValueError, match='Invalid expression'):
Expand Down Expand Up @@ -172,3 +182,30 @@ def test_reserved_name_not_parameter(self):

assert 'A' in names
assert 'x' not in names # x is reserved

def test_copy(self, expr: ExpressionComponent):
# WHEN THEN
expr_copy = copy(expr)

# EXPECT the copy is a new instance with the same properties
assert expr_copy is not expr
assert isinstance(expr_copy, ExpressionComponent)
assert expr_copy.expression == expr.expression
assert expr_copy.unit == expr.unit
assert expr_copy.display_name == expr.display_name

assert expr_copy.A.value == pytest.approx(expr.A.value)
assert expr_copy.x0.value == pytest.approx(expr.x0.value)
assert expr_copy.sigma.value == pytest.approx(expr.sigma.value)

def test_erf(self):
# WHEN
expr = ExpressionComponent('erf(x)')
x = np.array([-1.0, 0.0, 1.0])

# THEN
result = expr.evaluate(x)

# EXPECT
expected = np.array([-0.84270079, 0.0, 0.84270079]) # erf(-1), erf(0), erf(1)
np.testing.assert_allclose(result, expected, rtol=1e-5)
Loading