Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions docs/source/_rst/_code.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,13 +195,31 @@ Equations and Differential Operators
.. toctree::
:titlesonly:

EquationInterface <equation/equation_interface.rst>
Equation Interface <equation/equation_interface.rst>
Base Equation <equation/base_equation.rst>
Equation <equation/equation.rst>
SystemEquation <equation/system_equation.rst>
Equation Factory <equation/equation_factory.rst>
System Equation <equation/system_equation.rst>
Differential Operators <operator.rst>


Equations Zoo
---------------------------------------

.. toctree::
:titlesonly:

Acoustic Wave Equation <equation/zoo/acoustic_wave_equation.rst>
Advection Equation <equation/zoo/advection_equation.rst>
Allen-Cahn Equation <equation/zoo/allen_cahn_equation.rst>
Diffusion-Reaction Equation <equation/zoo/diffusion_reaction_equation.rst>
Fixed Flux <equation/zoo/fixed_flux.rst>
Fixed Gradient <equation/zoo/fixed_gradient.rst>
Fixed Laplacian <equation/zoo/fixed_laplacian.rst>
Fixed Value <equation/zoo/fixed_value.rst>
Helmholtz Equation <equation/zoo/helmholtz_equation.rst>
Poisson Equation <equation/zoo/poisson_equation.rst>


Problems
--------------

Expand Down
7 changes: 7 additions & 0 deletions docs/source/_rst/equation/base_equation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Base Equation
====================

.. currentmodule:: pina.equation.base_equation
.. autoclass:: pina._src.equation.base_equation.BaseEquation
:members:
:show-inheritance:
43 changes: 0 additions & 43 deletions docs/source/_rst/equation/equation_factory.rst

This file was deleted.

7 changes: 7 additions & 0 deletions docs/source/_rst/equation/zoo/acoustic_wave_equation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
AcousticWaveEquation
=====================
.. currentmodule:: pina.equation.zoo.acoustic_wave_equation

.. automodule:: pina._src.equation.zoo.acoustic_wave_equation
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/equation/zoo/advection_equation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Advection Equation
=====================
.. currentmodule:: pina.equation.zoo.advection_equation

.. automodule:: pina._src.equation.zoo.advection_equation
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/equation/zoo/allen_cahn_equation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Allen Cahn Equation
=====================
.. currentmodule:: pina.equation.zoo.allen_cahn_equation

.. automodule:: pina._src.equation.zoo.allen_cahn_equation
:members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Diffusion Reaction Equation
==============================
.. currentmodule:: pina.equation.zoo.diffusion_reaction_equation

.. automodule:: pina._src.equation.zoo.diffusion_reaction_equation
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/equation/zoo/fixed_flux.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Fixed Flux
=====================
.. currentmodule:: pina.equation.zoo.fixed_flux

.. automodule:: pina._src.equation.zoo.fixed_flux
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/equation/zoo/fixed_gradient.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Fixed Gradient
=====================
.. currentmodule:: pina.equation.zoo.fixed_gradient

.. automodule:: pina._src.equation.zoo.fixed_gradient
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/equation/zoo/fixed_laplacian.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Fixed Laplacian
=====================
.. currentmodule:: pina.equation.zoo.fixed_laplacian

.. automodule:: pina._src.equation.zoo.fixed_laplacian
:members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/source/_rst/equation/zoo/fixed_value.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Fixed Value
=====================
.. currentmodule:: pina.equation.zoo.fixed_value

.. automodule:: pina._src.equation.zoo.fixed_value
:members:
:show-inheritance:
9 changes: 9 additions & 0 deletions docs/source/_rst/equation/zoo/helmholtz_equation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Helmholtz Equation
=====================
.. currentmodule:: pina.equation.zoo.helmholtz_equation

.. automodule:: pina._src.equation.zoo.helmholtz_equation

.. autoclass:: pina._src.equation.zoo.helmholtz_equation.HelmholtzEquation
:members:
:show-inheritance:
9 changes: 9 additions & 0 deletions docs/source/_rst/equation/zoo/poisson_equation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Poisson Equation
=====================
.. currentmodule:: pina.equation.zoo.poisson_equation

.. automodule:: pina._src.equation.zoo.poisson_equation

.. autoclass:: pina._src.equation.zoo.poisson_equation.PoissonEquation
:members:
:show-inheritance:
4 changes: 2 additions & 2 deletions pina/_src/condition/data_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch_geometric.data.batch import Batch
from pina import LabelTensor
from pina._src.core.graph import Graph, LabelBatch
from ..equation.equation_interface import EquationInterface
from pina._src.equation.base_equation import BaseEquation
from .batch_manager import _BatchManager


Expand Down Expand Up @@ -39,7 +39,7 @@ def __new__(cls, **kwargs):

# Does the data contain only tensors/LabelTensors/Equations?
is_tensor_only = all(
isinstance(v, (torch.Tensor, LabelTensor, EquationInterface))
isinstance(v, (torch.Tensor, LabelTensor, BaseEquation))
for v in kwargs.values()
)
# Choose the appropriate subclass, GraphDataManager or TensorDataManager
Expand Down
6 changes: 3 additions & 3 deletions pina/_src/condition/domain_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pina._src.condition.condition_base import ConditionBase
from pina._src.domain.domain_interface import DomainInterface
from pina._src.equation.equation_interface import EquationInterface
from pina._src.equation.base_equation import BaseEquation


class DomainEquationCondition(ConditionBase):
Expand Down Expand Up @@ -32,7 +32,7 @@ class DomainEquationCondition(ConditionBase):
__fields__ = ["domain", "equation"]

_avail_domain_cls = (DomainInterface, str)
_avail_equation_cls = EquationInterface
_avail_equation_cls = BaseEquation

def __new__(cls, domain, equation):
"""
Expand All @@ -52,7 +52,7 @@ def __new__(cls, domain, equation):

if not isinstance(equation, cls._avail_equation_cls):
raise ValueError(
"The equation must be an instance of EquationInterface."
"The equation must be an instance of BaseEquation."
)

return super().__new__(cls)
Expand Down
18 changes: 8 additions & 10 deletions pina/_src/condition/input_equation_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pina._src.condition.condition_base import ConditionBase
from pina._src.core.label_tensor import LabelTensor
from pina._src.core.graph import Graph
from pina._src.equation.equation_interface import EquationInterface
from pina._src.equation.base_equation import BaseEquation
from pina._src.condition.data_manager import _DataManager


Expand Down Expand Up @@ -32,7 +32,7 @@ class InputEquationCondition(ConditionBase):
# Available input data types
__fields__ = ["input", "equation"]
_avail_input_cls = (LabelTensor, Graph)
_avail_equation_cls = EquationInterface
_avail_equation_cls = BaseEquation

def __new__(cls, input, equation):
"""
Expand All @@ -41,7 +41,7 @@ def __new__(cls, input, equation):

:param input: The input data for the condition.
:type input: LabelTensor | Graph | list[Graph] | tuple[Graph]
:param EquationInterface equation: The equation to be satisfied over the
:param BaseEquation equation: The equation to be satisfied over the
specified ``input`` data.
:return: The subclass of InputEquationCondition.
:rtype: pina.condition.input_equation_condition.
Expand All @@ -61,7 +61,7 @@ def __new__(cls, input, equation):
# Check equation type
if not isinstance(equation, cls._avail_equation_cls):
raise ValueError(
"The equation must be an instance of EquationInterface."
"The equation must be an instance of BaseEquation."
)

return super().__new__(cls)
Expand Down Expand Up @@ -90,7 +90,7 @@ def equation(self):
Return the equation associated with this condition.

:return: Equation associated with this condition.
:rtype: EquationInterface
:rtype: BaseEquation
"""
return self._equation

Expand All @@ -99,11 +99,9 @@ def equation(self, value):
"""
Set the equation associated with this condition.

:param EquationInterface value: The equation to associate with this
:param BaseEquation value: The equation to associate with this
condition
"""
if not isinstance(value, EquationInterface):
raise TypeError(
"The equation must be an instance of EquationInterface."
)
if not isinstance(value, BaseEquation):
raise TypeError("The equation must be an instance of BaseEquation.")
self._equation = value
67 changes: 67 additions & 0 deletions pina/_src/equation/base_equation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Module for the Base Equation."""

from abc import ABCMeta, abstractmethod
import torch


class BaseEquation(metaclass=ABCMeta):
"""
Base class for all equations, implementing common functionality.

Equations are fundamental components in PINA, representing mathematical
constraints that must be satisfied by the model outputs. They can be passed
to :class:`~pina.condition.condition.Condition` objects to define the
conditions under which the model is trained.

All specific equation types should inherit from this class and implement its
abstract methods.

This class is not meant to be instantiated directly.
"""

@abstractmethod
def residual(self, input_, output_, params_):
"""
Evaluate the equation residual at the given inputs.

:param LabelTensor input_: The input points where the residual is
computed.
:param LabelTensor output_: The output tensor, potentially produced by a
:class:`torch.nn.Module` instance.
:param dict params_: An optional dictionary of unknown parameters, used
in :class:`~pina.problem.inverse_problem.InverseProblem` settings.
If the equation is not related to an inverse problem, this should be
set to ``None``. Default is ``None``.
:return: The residual values of the equation.
:rtype: LabelTensor
"""

def to(self, device):
"""
Move all tensor attributes to the specified device.

:param torch.device device: The target device to move the tensors to.
:return: The instance moved to the specified device.
:rtype: BaseEquation
"""
# Iterate over all attributes of the Equation
for key, val in self.__dict__.items():

# Move tensors in dictionaries to the specified device
if isinstance(val, dict):
self.__dict__[key] = {
k: v.to(device) if torch.is_tensor(v) else v
for k, v in val.items()
}

# Move tensors in lists to the specified device
elif isinstance(val, list):
self.__dict__[key] = [
v.to(device) if torch.is_tensor(v) else v for v in val
]

# Move tensor attributes to the specified device
elif torch.is_tensor(val):
self.__dict__[key] = val.to(device)

return self
Loading
Loading