|
| 1 | +"""Module for the EquationConditionBase class.""" |
| 2 | + |
| 3 | +from pina._src.condition.condition_base import ConditionBase |
| 4 | + |
| 5 | + |
| 6 | +class EquationConditionBase(ConditionBase): |
| 7 | + """ |
| 8 | + Base class for conditions that involve an equation. |
| 9 | +
|
| 10 | + This class provides the :meth:`evaluate` method, which computes the |
| 11 | + non-aggregated residual of the equation given the input samples and a |
| 12 | + solver. It is intended to be subclassed by conditions that define an |
| 13 | + ``equation`` attribute, such as |
| 14 | + :class:`~pina.condition.DomainEquationCondition` and |
| 15 | + :class:`~pina.condition.InputEquationCondition`. |
| 16 | + """ |
| 17 | + |
| 18 | + def evaluate(self, samples, solver): |
| 19 | + """ |
| 20 | + Evaluate the equation residual on the given samples using the solver. |
| 21 | +
|
| 22 | + This method computes the non-aggregated, element-wise residual of the |
| 23 | + equation. It performs a forward pass of the solver's model on the |
| 24 | + input samples and then evaluates the equation residual. The returned |
| 25 | + tensor is **not** reduced (i.e., no mean, sum, etc.), preserving the |
| 26 | + per-sample residual values. |
| 27 | +
|
| 28 | + :param samples: The input samples on which to evaluate the residual. |
| 29 | + :type samples: ~pina.label_tensor.LabelTensor |
| 30 | + :param solver: The solver containing the model and any additional |
| 31 | + parameters (e.g., unknown parameters for inverse problems). |
| 32 | + :type solver: ~pina.solver.solver.SolverInterface |
| 33 | + :return: The non-aggregated residual tensor. |
| 34 | + :rtype: ~pina.label_tensor.LabelTensor |
| 35 | +
|
| 36 | + :Example: |
| 37 | +
|
| 38 | + >>> residuals = condition.evaluate(input_samples, solver) |
| 39 | + >>> # residuals is a non-reduced tensor of shape (n_samples, ...) |
| 40 | + """ |
| 41 | + return self.equation.residual( |
| 42 | + samples, solver.forward(samples), solver._params |
| 43 | + ) |
0 commit comments