@@ -15,18 +15,18 @@ class EquationConditionBase(ConditionBase):
1515 :class:`~pina.condition.InputEquationCondition`.
1616 """
1717
18- def evaluate (self , samples , solver ):
18+ def evaluate (self , batch , solver ):
1919 """
20- Evaluate the equation residual on the given samples using the solver.
20+ Evaluate the equation residual on the given batch using the solver.
2121
2222 This method computes the non-aggregated, element-wise residual of the
2323 equation. It performs a forward pass of the solver's model on the
2424 input samples and then evaluates the equation residual. The returned
2525 tensor is **not** reduced (i.e., no mean, sum, etc.), preserving the
2626 per-sample residual values.
2727
28- :param samples : The input samples on which to evaluate the residual .
29- :type samples: ~pina.label_tensor.LabelTensor
28+ :param batch : The batch containing the ``input`` entry .
29+ :type batch: dict | _DataManager
3030 :param solver: The solver containing the model and any additional
3131 parameters (e.g., unknown parameters for inverse problems).
3232 :type solver: ~pina.solver.solver.SolverInterface
@@ -35,9 +35,12 @@ def evaluate(self, samples, solver):
3535
3636 :Example:
3737
38- >>> residuals = condition.evaluate(input_samples, solver)
38+ >>> residuals = condition.evaluate(
39+ ... {"input": input_samples}, solver
40+ ... )
3941 >>> # residuals is a non-reduced tensor of shape (n_samples, ...)
4042 """
43+ samples = batch ["input" ]
4144 return self .equation .residual (
4245 samples , solver .forward (samples ), solver ._params
4346 )
0 commit comments