Skip to content

Commit c34ca05

Browse files
committed
input target evaluate
1 parent 8fdfcb7 commit c34ca05

5 files changed

Lines changed: 46 additions & 7 deletions

File tree

pina/_src/condition/equation_condition_base.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,18 @@ class EquationConditionBase(ConditionBase):
1515
:class:`~pina.condition.InputEquationCondition`.
1616
"""
1717

18-
def evaluate(self, samples, solver):
18+
def evaluate(self, batch, solver):
1919
"""
20-
Evaluate the equation residual on the given samples using the solver.
20+
Evaluate the equation residual on the given batch using the solver.
2121
2222
This method computes the non-aggregated, element-wise residual of the
2323
equation. It performs a forward pass of the solver's model on the
2424
input samples and then evaluates the equation residual. The returned
2525
tensor is **not** reduced (i.e., no mean, sum, etc.), preserving the
2626
per-sample residual values.
2727
28-
:param samples: The input samples on which to evaluate the residual.
29-
:type samples: ~pina.label_tensor.LabelTensor
28+
:param batch: The batch containing the ``input`` entry.
29+
:type batch: dict | _DataManager
3030
:param solver: The solver containing the model and any additional
3131
parameters (e.g., unknown parameters for inverse problems).
3232
:type solver: ~pina.solver.solver.SolverInterface
@@ -35,9 +35,12 @@ def evaluate(self, samples, solver):
3535
3636
:Example:
3737
38-
>>> residuals = condition.evaluate(input_samples, solver)
38+
>>> residuals = condition.evaluate(
39+
... {"input": input_samples}, solver
40+
... )
3941
>>> # residuals is a non-reduced tensor of shape (n_samples, ...)
4042
"""
43+
samples = batch["input"]
4144
return self.equation.residual(
4245
samples, solver.forward(samples), solver._params
4346
)

pina/_src/condition/input_target_condition.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,19 @@ def target(self):
112112
list[Data] | tuple[Graph] | tuple[Data]
113113
"""
114114
return self.data.target
115+
116+
def evaluate(self, batch, solver):
117+
"""
118+
Evaluate the supervised condition on the given batch using the solver.
119+
120+
This method computes the element-wise prediction error associated with
121+
the condition using the input and target stored in the provided batch.
122+
123+
:param batch: The batch containing ``input`` and ``target`` entries.
124+
:type batch: dict | _DataManager
125+
:param solver: The solver containing the model.
126+
:type solver: ~pina.solver.solver.SolverInterface
127+
:return: The non-aggregated prediction error.
128+
:rtype: LabelTensor | torch.Tensor | Graph | Data
129+
"""
130+
return solver.forward(batch["input"]) - batch["target"]

tests/test_condition/test_domain_equation_condition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ def equation_func(input_, output_, params_):
4747
samples = LabelTensor(torch.randn(12, 2), labels=["x", "y"])
4848
cond = Condition(domain=example_domain, equation=Equation(equation_func))
4949
solver = DummySolver()
50+
batch = {"input": samples}
5051

51-
residual = cond.evaluate(samples, solver)
52+
residual = cond.evaluate(batch, solver)
5253
expected = samples.extract(["x"]) - solver._params["shift"]
5354

5455
torch.testing.assert_close(residual, expected)

tests/test_condition/test_input_equation_condition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,9 @@ def equation_func(input_, output_, params_):
9494
pts = LabelTensor(torch.randn(10, 2), labels=["x", "y"])
9595
condition = Condition(input=pts, equation=Equation(equation_func))
9696
solver = DummySolver()
97+
batch = {"input": pts}
9798

98-
residual = condition.evaluate(pts, solver)
99+
residual = condition.evaluate(batch, solver)
99100
expected = pts.extract(["y"]) - solver._params["shift"]
100101

101102
torch.testing.assert_close(residual, expected)

tests/test_condition/test_input_target_condition.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
from pina._src.condition.batch_manager import _BatchManager
66

77

8+
class DummySolver:
9+
def forward(self, samples):
10+
return 2 * samples
11+
12+
813
def _create_tensor_data(use_lt=False):
914
if use_lt:
1015
input_tensor = LabelTensor(torch.rand((10, 3)), ["x", "y", "z"])
@@ -210,6 +215,19 @@ def test_getitem_tensor_input_tensor_target_condition_label_tensor():
210215
assert torch.allclose(item.target, target_tensor[index])
211216

212217

218+
def test_evaluate_tensor_input_target_condition():
219+
input_tensor = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
220+
target_tensor = torch.tensor([[1.5, 3.5], [5.5, 7.5]])
221+
condition = Condition(input=input_tensor, target=target_tensor)
222+
solver = DummySolver()
223+
224+
batch = {"input": condition.input, "target": condition.target}
225+
loss = condition.evaluate(batch, solver)
226+
expected = solver.forward(input_tensor) - target_tensor
227+
228+
torch.testing.assert_close(loss, expected)
229+
230+
213231
@pytest.mark.parametrize("use_lt", [True, False])
214232
def test_getitem_graph_input_tensor_target_condition(use_lt):
215233
input_graph, target_tensor = _create_graph_data(False, use_lt=use_lt)

0 commit comments

Comments
 (0)