Skip to content

Commit 8fdfcb7

Browse files
committed
evaluate for physics condition
1 parent fce15bc commit 8fdfcb7

6 files changed

Lines changed: 128 additions & 5 deletions

File tree

pina/_src/condition/domain_equation_condition.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Module for the DomainEquationCondition class."""
22

3-
from pina._src.condition.condition_base import ConditionBase
3+
from pina._src.condition.equation_condition_base import (
4+
EquationConditionBase,
5+
)
46
from pina._src.domain.domain_interface import DomainInterface
57
from pina._src.equation.equation_interface import EquationInterface
68

79

8-
class DomainEquationCondition(ConditionBase):
10+
class DomainEquationCondition(EquationConditionBase):
911
"""
1012
The class :class:`DomainEquationCondition` defines a condition based on a
1113
``domain`` and an ``equation``. This condition is typically used in
@@ -92,4 +94,29 @@ def store_data(self, **kwargs):
9294
:rtype: dict
9395
"""
9496
setattr(self, "domain", kwargs.get("domain"))
95-
setattr(self, "equation", kwargs.get("equation"))
97+
setattr(self, "_equation", kwargs.get("equation"))
98+
99+
@property
100+
def equation(self):
101+
"""
102+
Return the equation associated with this condition.
103+
104+
:return: Equation associated with this condition.
105+
:rtype: EquationInterface
106+
"""
107+
return self._equation
108+
109+
@equation.setter
110+
def equation(self, value):
111+
"""
112+
Set the equation associated with this condition.
113+
114+
:param EquationInterface value: The equation to associate
115+
with this condition.
116+
"""
117+
if not isinstance(value, EquationInterface):
118+
raise TypeError(
119+
"The equation must be an instance of "
120+
"EquationInterface."
121+
)
122+
self._equation = value
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Module for the EquationConditionBase class."""
2+
3+
from pina._src.condition.condition_base import ConditionBase
4+
5+
6+
class EquationConditionBase(ConditionBase):
7+
"""
8+
Base class for conditions that involve an equation.
9+
10+
This class provides the :meth:`evaluate` method, which computes the
11+
non-aggregated residual of the equation given the input samples and a
12+
solver. It is intended to be subclassed by conditions that define an
13+
``equation`` attribute, such as
14+
:class:`~pina.condition.DomainEquationCondition` and
15+
:class:`~pina.condition.InputEquationCondition`.
16+
"""
17+
18+
def evaluate(self, samples, solver):
19+
"""
20+
Evaluate the equation residual on the given samples using the solver.
21+
22+
This method computes the non-aggregated, element-wise residual of the
23+
equation. It performs a forward pass of the solver's model on the
24+
input samples and then evaluates the equation residual. The returned
25+
tensor is **not** reduced (i.e., no mean, sum, etc.), preserving the
26+
per-sample residual values.
27+
28+
:param samples: The input samples on which to evaluate the residual.
29+
:type samples: ~pina.label_tensor.LabelTensor
30+
:param solver: The solver containing the model and any additional
31+
parameters (e.g., unknown parameters for inverse problems).
32+
:type solver: ~pina.solver.solver.SolverInterface
33+
:return: The non-aggregated residual tensor.
34+
:rtype: ~pina.label_tensor.LabelTensor
35+
36+
:Example:
37+
38+
>>> residuals = condition.evaluate(input_samples, solver)
39+
>>> # residuals is a non-reduced tensor of shape (n_samples, ...)
40+
"""
41+
return self.equation.residual(
42+
samples, solver.forward(samples), solver._params
43+
)

pina/_src/condition/input_equation_condition.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
"""Module for the InputEquationCondition class and its subclasses."""
22

3-
from pina._src.condition.condition_base import ConditionBase
3+
from pina._src.condition.equation_condition_base import (
4+
EquationConditionBase,
5+
)
46
from pina._src.core.label_tensor import LabelTensor
57
from pina._src.core.graph import Graph
68
from pina._src.equation.equation_interface import EquationInterface
79
from pina._src.condition.data_manager import _DataManager
810

911

10-
class InputEquationCondition(ConditionBase):
12+
class InputEquationCondition(EquationConditionBase):
1113
"""
1214
The class :class:`InputEquationCondition` defines a condition based on
1315
``input`` data and an ``equation``. This condition is typically used in

pina/condition/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
"Condition",
1111
"ConditionInterface",
1212
"ConditionBase",
13+
"EquationConditionBase",
1314
"DomainEquationCondition",
1415
"InputTargetCondition",
1516
"InputEquationCondition",
@@ -18,6 +19,9 @@
1819

1920
from pina._src.condition.condition_interface import ConditionInterface
2021
from pina._src.condition.condition_base import ConditionBase
22+
from pina._src.condition.equation_condition_base import (
23+
EquationConditionBase,
24+
)
2125
from pina._src.condition.condition import Condition
2226
from pina._src.condition.domain_equation_condition import (
2327
DomainEquationCondition,

tests/test_condition/test_domain_equation_condition.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
import pytest
2+
import torch
23
from pina import Condition
4+
from pina import LabelTensor
35
from pina.domain import CartesianDomain
46
from pina._src.equation.equation_factory import FixedValue
7+
from pina.equation import Equation
58
from pina.condition import DomainEquationCondition
69

10+
11+
class DummySolver:
12+
def __init__(self):
13+
self._params = {"shift": torch.tensor(0.25)}
14+
15+
def forward(self, samples):
16+
return samples.extract(["x"]) - samples.extract(["y"])
17+
718
example_domain = CartesianDomain({"x": [0, 1], "y": [0, 1]})
819
example_equation = FixedValue(0.0)
920

@@ -27,3 +38,17 @@ def test_getitem_not_implemented():
2738
cond = Condition(domain=example_domain, equation=FixedValue(0.0))
2839
with pytest.raises(NotImplementedError):
2940
cond[0]
41+
42+
43+
def test_evaluate_domain_equation_condition():
44+
def equation_func(input_, output_, params_):
45+
return output_ + input_.extract(["y"]) - params_["shift"]
46+
47+
samples = LabelTensor(torch.randn(12, 2), labels=["x", "y"])
48+
cond = Condition(domain=example_domain, equation=Equation(equation_func))
49+
solver = DummySolver()
50+
51+
residual = cond.evaluate(samples, solver)
52+
expected = samples.extract(["x"]) - solver._params["shift"]
53+
54+
torch.testing.assert_close(residual, expected)

tests/test_condition/test_input_equation_condition.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@
88
from pina._src.condition.data_manager import _DataManager
99

1010

11+
class DummySolver:
12+
def __init__(self):
13+
self._params = {"shift": torch.tensor(1.5)}
14+
15+
def forward(self, samples):
16+
return samples.extract(["x"]) + samples.extract(["y"])
17+
18+
1119
def _create_pts_and_equation():
1220
def dummy_equation(pts):
1321
return pts["x"] ** 2 + pts["y"] ** 2 - 1
@@ -77,3 +85,17 @@ def test_getitems_tensor_equation_condition():
7785
assert isinstance(item, _DataManager)
7886
assert hasattr(item, "input")
7987
assert item.input.shape == (3, 2)
88+
89+
90+
def test_evaluate_tensor_equation_condition():
91+
def equation_func(input_, output_, params_):
92+
return output_ - input_.extract(["x"]) - params_["shift"]
93+
94+
pts = LabelTensor(torch.randn(10, 2), labels=["x", "y"])
95+
condition = Condition(input=pts, equation=Equation(equation_func))
96+
solver = DummySolver()
97+
98+
residual = condition.evaluate(pts, solver)
99+
expected = pts.extract(["y"]) - solver._params["shift"]
100+
101+
torch.testing.assert_close(residual, expected)

0 commit comments

Comments
 (0)