Skip to content

Commit 992a688

Browse files
committed
first new generation solver
1 parent c34ca05 commit 992a688

8 files changed

Lines changed: 269 additions & 15 deletions

pina/_src/condition/equation_condition_base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ class EquationConditionBase(ConditionBase):
1515
:class:`~pina.condition.InputEquationCondition`.
1616
"""
1717

18-
def evaluate(self, batch, solver):
18+
def evaluate(self, batch, solver, loss):
1919
"""
2020
Evaluate the equation residual on the given batch using the solver.
2121
@@ -30,17 +30,21 @@ def evaluate(self, batch, solver):
3030
:param solver: The solver containing the model and any additional
3131
parameters (e.g., unknown parameters for inverse problems).
3232
:type solver: ~pina.solver.solver.SolverInterface
33-
:return: The non-aggregated residual tensor.
33+
:param loss: The non-aggregating loss function to apply to the
34+
computed residual against zero.
35+
:type loss: torch.nn.Module
36+
:return: The non-aggregated loss tensor.
3437
:rtype: ~pina.label_tensor.LabelTensor
3538
3639
:Example:
3740
3841
>>> residuals = condition.evaluate(
39-
... {"input": input_samples}, solver
42+
... {"input": input_samples}, solver, loss
4043
... )
4144
>>> # residuals is a non-reduced tensor of shape (n_samples, ...)
4245
"""
4346
samples = batch["input"]
44-
return self.equation.residual(
47+
residual = self.equation.residual(
4548
samples, solver.forward(samples), solver._params
4649
)
50+
return residual**2

pina/_src/condition/input_target_condition.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,20 @@ def target(self):
113113
"""
114114
return self.data.target
115115

116-
def evaluate(self, batch, solver):
116+
def evaluate(self, batch, solver, loss):
117117
"""
118118
Evaluate the supervised condition on the given batch using the solver.
119119
120-
This method computes the element-wise prediction error associated with
121-
the condition using the input and target stored in the provided batch.
120+
This method computes the element-wise loss associated with the
121+
condition using the input and target stored in the provided batch.
122122
123123
:param batch: The batch containing ``input`` and ``target`` entries.
124124
:type batch: dict | _DataManager
125125
:param solver: The solver containing the model.
126126
:type solver: ~pina.solver.solver.SolverInterface
127-
:return: The non-aggregated prediction error.
127+
:param loss: The non-aggregating loss function to apply.
128+
:type loss: torch.nn.Module
129+
:return: The non-aggregated loss tensor.
128130
:rtype: LabelTensor | torch.Tensor | Graph | Data
129131
"""
130-
return solver.forward(batch["input"]) - batch["target"]
132+
return loss(solver.forward(batch["input"]), batch["target"])
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""Module for the SingleModelSimpleSolver."""
2+
3+
import torch
4+
from torch.nn.modules.loss import _Loss
5+
6+
from pina._src.condition.domain_equation_condition import (
7+
DomainEquationCondition,
8+
)
9+
from pina._src.condition.input_equation_condition import (
10+
InputEquationCondition,
11+
)
12+
from pina._src.condition.input_target_condition import InputTargetCondition
13+
from pina._src.core.utils import check_consistency
14+
from pina._src.loss.loss_interface import LossInterface
15+
from pina._src.solver.solver import SingleSolverInterface
16+
17+
18+
class SingleModelSimpleSolver(SingleSolverInterface):
19+
"""
20+
Minimal single-model solver with explicit residual evaluation, reduction,
21+
and loss aggregation across conditions.
22+
23+
The solver orchestrates a uniform workflow for all conditions in the batch:
24+
25+
1. evaluate the condition and obtain a non-aggregated loss tensor;
26+
2. apply a reduction to obtain a scalar loss for that condition;
27+
4. return the per-condition losses, which are aggregated by the inherited
28+
solver machinery through the configured weighting.
29+
"""
30+
31+
accepted_conditions_types = (
32+
InputTargetCondition,
33+
InputEquationCondition,
34+
DomainEquationCondition,
35+
)
36+
37+
def __init__(
38+
self,
39+
problem,
40+
model,
41+
optimizer=None,
42+
scheduler=None,
43+
weighting=None,
44+
loss=None,
45+
use_lt=True,
46+
):
47+
"""
48+
Initialize the single-model simple solver.
49+
50+
:param AbstractProblem problem: The problem to be solved.
51+
:param torch.nn.Module model: The neural network model to be used.
52+
:param Optimizer optimizer: The optimizer to be used.
53+
:param Scheduler scheduler: Learning rate scheduler.
54+
:param WeightingInterface weighting: The weighting schema to be used.
55+
:param torch.nn.Module loss: The element-wise loss module whose
56+
reduction strategy is reused by the solver. If ``None``,
57+
:class:`torch.nn.MSELoss` is used.
58+
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
59+
"""
60+
if loss is None:
61+
loss = torch.nn.MSELoss()
62+
63+
check_consistency(loss, (LossInterface, _Loss), subclass=False)
64+
65+
super().__init__(
66+
model=model,
67+
problem=problem,
68+
optimizer=optimizer,
69+
scheduler=scheduler,
70+
weighting=weighting,
71+
use_lt=use_lt,
72+
)
73+
74+
self._loss_fn = loss
75+
self._reduction = getattr(loss, "reduction", "mean")
76+
77+
if hasattr(self._loss_fn, "reduction"):
78+
self._loss_fn.reduction = "none"
79+
80+
def optimization_cycle(self, batch):
81+
"""
82+
Compute one reduced loss per condition in the batch.
83+
84+
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
85+
tuple containing a condition name and a dictionary of points.
86+
:return: The reduced losses for all conditions.
87+
:rtype: dict[str, torch.Tensor]
88+
"""
89+
condition_losses = {}
90+
91+
for condition_name, data in batch:
92+
condition = self.problem.conditions[condition_name]
93+
condition_data = dict(data)
94+
95+
if hasattr(condition_data.get("input"), "requires_grad_"):
96+
condition_data["input"] = condition_data[
97+
"input"
98+
].requires_grad_()
99+
100+
condition_loss_tensor = condition.evaluate(
101+
condition_data, self, self._loss_fn
102+
)
103+
condition_losses[condition_name] = self._apply_reduction(
104+
condition_loss_tensor
105+
)
106+
107+
return condition_losses
108+
109+
def _apply_reduction(self, value):
110+
"""
111+
Apply the configured reduction to a non-aggregated condition tensor.
112+
113+
:param value: The non-aggregated tensor returned by a condition.
114+
:type value: torch.Tensor
115+
:return: The reduced scalar tensor.
116+
:rtype: torch.Tensor
117+
:raises ValueError: If the reduction is not supported.
118+
"""
119+
if self._reduction == "none":
120+
return value
121+
if self._reduction == "mean":
122+
return value.mean()
123+
if self._reduction == "sum":
124+
return value.sum()
125+
raise ValueError(f"Unsupported reduction '{self._reduction}'.")
126+
127+
@property
128+
def loss(self):
129+
"""
130+
The underlying element-wise loss module.
131+
132+
:return: The stored loss module.
133+
:rtype: torch.nn.Module
134+
"""
135+
return self._loss_fn

pina/solver/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"SolverInterface",
1414
"SingleSolverInterface",
1515
"MultiSolverInterface",
16+
"SingleModelSimpleSolver",
1617
"PINNInterface",
1718
"PINN",
1819
"GradientPINN",
@@ -36,6 +37,9 @@
3637
SingleSolverInterface,
3738
MultiSolverInterface,
3839
)
40+
from pina._src.solver.single_model_simple_solver import (
41+
SingleModelSimpleSolver,
42+
)
3943
from pina._src.solver.physics_informed_solver.pinn import PINNInterface, PINN
4044
from pina._src.solver.physics_informed_solver.gradient_pinn import GradientPINN
4145
from pina._src.solver.physics_informed_solver.causal_pinn import CausalPINN

tests/test_condition/test_domain_equation_condition.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,12 @@ def equation_func(input_, output_, params_):
4848
cond = Condition(domain=example_domain, equation=Equation(equation_func))
4949
solver = DummySolver()
5050
batch = {"input": samples}
51+
loss = torch.nn.MSELoss(reduction="none")
5152

52-
residual = cond.evaluate(batch, solver)
53-
expected = samples.extract(["x"]) - solver._params["shift"]
53+
residual = cond.evaluate(batch, solver, loss)
54+
expected = loss(
55+
samples.extract(["x"]) - solver._params["shift"],
56+
torch.zeros_like(samples.extract(["x"]) - solver._params["shift"]),
57+
)
5458

5559
torch.testing.assert_close(residual, expected)

tests/test_condition/test_input_equation_condition.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,12 @@ def equation_func(input_, output_, params_):
9595
condition = Condition(input=pts, equation=Equation(equation_func))
9696
solver = DummySolver()
9797
batch = {"input": pts}
98+
loss = torch.nn.MSELoss(reduction="none")
9899

99-
residual = condition.evaluate(batch, solver)
100-
expected = pts.extract(["y"]) - solver._params["shift"]
100+
residual = condition.evaluate(batch, solver, loss)
101+
expected = loss(
102+
pts.extract(["y"]) - solver._params["shift"],
103+
torch.zeros_like(pts.extract(["y"]) - solver._params["shift"]),
104+
)
101105

102106
torch.testing.assert_close(residual, expected)

tests/test_condition/test_input_target_condition.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,11 @@ def test_evaluate_tensor_input_target_condition():
220220
target_tensor = torch.tensor([[1.5, 3.5], [5.5, 7.5]])
221221
condition = Condition(input=input_tensor, target=target_tensor)
222222
solver = DummySolver()
223+
loss_fn = torch.nn.MSELoss(reduction="none")
223224

224225
batch = {"input": condition.input, "target": condition.target}
225-
loss = condition.evaluate(batch, solver)
226-
expected = solver.forward(input_tensor) - target_tensor
226+
loss = condition.evaluate(batch, solver, loss_fn)
227+
expected = loss_fn(solver.forward(input_tensor), target_tensor)
227228

228229
torch.testing.assert_close(loss, expected)
229230

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import pytest
2+
import torch
3+
4+
from pina import LabelTensor, Condition
5+
from pina.model import FeedForward
6+
from pina.trainer import Trainer
7+
from pina.solver import SingleModelSimpleSolver
8+
from pina.condition import (
9+
InputTargetCondition,
10+
InputEquationCondition,
11+
DomainEquationCondition,
12+
)
13+
from pina.problem.zoo import (
14+
Poisson2DSquareProblem as Poisson,
15+
InversePoisson2DSquareProblem as InversePoisson,
16+
)
17+
from torch._dynamo.eval_frame import OptimizedModule
18+
19+
20+
problem = Poisson()
21+
problem.discretise_domain(10)
22+
inverse_problem = InversePoisson(load=True, data_size=0.01)
23+
inverse_problem.discretise_domain(10)
24+
25+
input_pts = torch.rand(10, len(problem.input_variables))
26+
input_pts = LabelTensor(input_pts, problem.input_variables)
27+
output_pts = torch.rand(10, len(problem.output_variables))
28+
output_pts = LabelTensor(output_pts, problem.output_variables)
29+
problem.conditions["data"] = Condition(input=input_pts, target=output_pts)
30+
31+
model = FeedForward(len(problem.input_variables), len(problem.output_variables))
32+
33+
34+
@pytest.mark.parametrize("problem", [problem, inverse_problem])
35+
def test_constructor(problem):
36+
solver = SingleModelSimpleSolver(problem=problem, model=model)
37+
38+
assert solver.accepted_conditions_types == (
39+
InputTargetCondition,
40+
InputEquationCondition,
41+
DomainEquationCondition,
42+
)
43+
44+
45+
@pytest.mark.parametrize("problem", [problem])
46+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
47+
@pytest.mark.parametrize("compile", [True, False])
48+
def test_solver_train(problem, batch_size, compile):
49+
solver = SingleModelSimpleSolver(model=model, problem=problem)
50+
trainer = Trainer(
51+
solver=solver,
52+
max_epochs=2,
53+
accelerator="cpu",
54+
batch_size=batch_size,
55+
train_size=1.0,
56+
val_size=0.0,
57+
test_size=0.0,
58+
#compile=compile,
59+
)
60+
trainer.train()
61+
if trainer.compile:
62+
assert isinstance(solver.model, OptimizedModule)
63+
64+
65+
@pytest.mark.parametrize("problem", [problem, inverse_problem])
66+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
67+
@pytest.mark.parametrize("compile", [True, False])
68+
def test_solver_validation(problem, batch_size, compile):
69+
solver = SingleModelSimpleSolver(model=model, problem=problem)
70+
trainer = Trainer(
71+
solver=solver,
72+
max_epochs=2,
73+
accelerator="cpu",
74+
batch_size=batch_size,
75+
train_size=0.9,
76+
val_size=0.1,
77+
test_size=0.0,
78+
compile=compile,
79+
)
80+
trainer.train()
81+
if trainer.compile:
82+
assert isinstance(solver.model, OptimizedModule)
83+
84+
85+
@pytest.mark.parametrize("problem", [problem, inverse_problem])
86+
@pytest.mark.parametrize("batch_size", [None, 1, 5, 20])
87+
@pytest.mark.parametrize("compile", [True, False])
88+
def test_solver_test(problem, batch_size, compile):
89+
solver = SingleModelSimpleSolver(model=model, problem=problem)
90+
trainer = Trainer(
91+
solver=solver,
92+
max_epochs=2,
93+
accelerator="cpu",
94+
batch_size=batch_size,
95+
train_size=0.7,
96+
val_size=0.2,
97+
test_size=0.1,
98+
#compile=compile,
99+
)
100+
trainer.test()

0 commit comments

Comments
 (0)