Skip to content

Commit 31e2c03

Browse files
committed
autoregressive
1 parent 563d709 commit 31e2c03

10 files changed

Lines changed: 310 additions & 266 deletions

File tree

pina/_src/condition/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from pina._src.condition.condition_base import ConditionBase
2+
from pina._src.condition.data_condition import DataCondition
3+
from pina._src.condition.domain_equation_condition import (
4+
DomainEquationCondition,
5+
)
6+
from pina._src.condition.equation_condition_base import (
7+
EquationConditionBase,
8+
)
9+
from pina._src.condition.input_equation_condition import InputEquationCondition
10+
from pina._src.condition.input_target_condition import InputTargetCondition
11+
from pina._src.condition.time_series_condition import TimeSeriesCondition
12+
13+
__all__ = [
14+
"ConditionBase",
15+
"DataCondition",
16+
"DomainEquationCondition",
17+
"EquationConditionBase",
18+
"InputEquationCondition",
19+
"InputTargetCondition",
20+
"TimeSeriesCondition",
21+
]

pina/_src/condition/equation_condition_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,8 @@ def evaluate(self, batch, solver, loss):
4444
>>> # residuals is a non-reduced tensor of shape (n_samples, ...)
4545
"""
4646
samples = batch["input"].requires_grad_(True)
47-
print("samples", samples)
4847
residual = self.equation.residual(
4948
samples, solver.forward(samples), solver._params
5049
)
5150
# assert False
52-
return residual**2
51+
return residual
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""Module for the TimeSeriesCondition class."""
2+
3+
import torch
4+
5+
from pina._src.condition.condition_base import ConditionBase
6+
from pina._src.condition.data_manager import _DataManager
7+
from pina._src.core.label_tensor import LabelTensor
8+
9+
10+
class TimeSeriesCondition(ConditionBase):
11+
"""
12+
Condition for autoregressive time-series training.
13+
14+
The condition stores an input tensor containing unroll windows with shape
15+
``[trajectories, windows, time_steps, *features]`` and computes the
16+
autoregressive non-aggregated/aggregated temporal loss inside
17+
:meth:`evaluate` by recursively applying the solver model over time.
18+
"""
19+
20+
__fields__ = ["input", "eps", "aggregation_strategy", "kwargs"]
21+
_avail_input_cls = (torch.Tensor, LabelTensor)
22+
23+
def __new__(cls, input, eps=None, aggregation_strategy=None, kwargs=None):
24+
if cls != TimeSeriesCondition:
25+
return super().__new__(cls)
26+
27+
if not isinstance(input, cls._avail_input_cls):
28+
raise ValueError(
29+
"Invalid input type. Expected one of the following: "
30+
"torch.Tensor, LabelTensor."
31+
)
32+
33+
return super().__new__(cls)
34+
35+
def store_data(self, **kwargs):
36+
return _DataManager(input=kwargs.get("input"))
37+
38+
@property
39+
def input(self):
40+
return self.data.input
41+
42+
@property
43+
def settings(self):
44+
return {
45+
"eps": getattr(self, "_eps", None),
46+
"aggregation_strategy": getattr(
47+
self, "_aggregation_strategy", None
48+
),
49+
"kwargs": getattr(self, "_kwargs", {}),
50+
}
51+
52+
def __init__(
53+
self, input, eps=None, aggregation_strategy=None, kwargs=None
54+
):
55+
super().__init__(input=input)
56+
self._eps = eps
57+
self._aggregation_strategy = aggregation_strategy
58+
self._kwargs = kwargs or {}
59+
60+
def evaluate(self, batch, solver, loss, condition_name=None):
61+
input_tensor = batch["input"]
62+
63+
if input_tensor.dim() < 4:
64+
raise ValueError(
65+
"The provided input tensor must have at least 4 dimensions:"
66+
" [trajectories, windows, time_steps, *features]."
67+
f" Got shape {input_tensor.shape}."
68+
)
69+
70+
current_state = input_tensor[:, :, 0]
71+
losses = []
72+
step_kwargs = self._kwargs.copy()
73+
74+
for step in range(1, input_tensor.shape[2]):
75+
processed_input = solver.preprocess_step(current_state, **step_kwargs)
76+
output = solver.forward(processed_input)
77+
predicted_state = solver.postprocess_step(output, **step_kwargs)
78+
79+
target_state = input_tensor[:, :, step]
80+
step_loss = loss(predicted_state, target_state, **step_kwargs)
81+
losses.append(step_loss)
82+
current_state = predicted_state
83+
84+
step_losses = torch.stack(losses).as_subclass(torch.Tensor)
85+
86+
with torch.no_grad():
87+
name = condition_name or getattr(self, "name", None) or "default"
88+
#weights = solver._get_weights(name, step_losses, self._eps)
89+
90+
aggregation_strategy = self._aggregation_strategy or torch.mean
91+
return aggregation_strategy(step_losses)# * weights)
92+
93+
@staticmethod
94+
def unroll(data, unroll_length, n_unrolls=None, randomize=True):
95+
"""
96+
Create unrolling time windows from temporal data.
97+
98+
This function takes as input a tensor of shape
99+
``[trajectories, time_steps, *features]`` and produces a tensor of
100+
shape ``[trajectories, windows, unroll_length, *features]``.
101+
Each window contains a sequence of subsequent states used for
102+
computing the multi-step loss during training.
103+
104+
:param data: The temporal data tensor to be unrolled.
105+
:type data: torch.Tensor | LabelTensor
106+
:param int unroll_length: The number of time steps in each window.
107+
:param int n_unrolls: The maximum number of windows to return.
108+
If ``None``, all valid windows are returned. Default is ``None``.
109+
:param bool randomize: If ``True``, starting indices are randomly
110+
permuted before applying ``n_unrolls``. Default is ``True``.
111+
:raise ValueError: If the input ``data`` has less than 3 dimensions.
112+
:raise ValueError: If ``unroll_length`` is greater or equal to the
113+
number of time steps in ``data``.
114+
:return: A tensor of unrolled windows.
115+
:rtype: torch.Tensor | LabelTensor
116+
"""
117+
if data.dim() < 3:
118+
raise ValueError(
119+
"The provided data tensor must have at least 3 dimensions:"
120+
" [trajectories, time_steps, *features]."
121+
f" Got shape {data.shape}."
122+
)
123+
124+
start_idx = TimeSeriesCondition._get_start_idx(
125+
n_steps=data.shape[1],
126+
unroll_length=unroll_length,
127+
n_unrolls=n_unrolls,
128+
randomize=randomize,
129+
)
130+
131+
windows = [data[:, s : s + unroll_length] for s in start_idx]
132+
return torch.stack(windows, dim=1)
133+
134+
@staticmethod
135+
def _get_start_idx(n_steps, unroll_length, n_unrolls=None, randomize=True):
136+
"""
137+
Determine starting indices for unroll windows.
138+
139+
:param int n_steps: The total number of time steps in the data.
140+
:param int unroll_length: The number of time steps in each window.
141+
:param int n_unrolls: The maximum number of windows to return.
142+
If ``None``, all valid windows are returned. Default is ``None``.
143+
:param bool randomize: If ``True``, starting indices are randomly
144+
permuted before applying ``n_unrolls``. Default is ``True``.
145+
:raise ValueError: If ``unroll_length`` is greater or equal to the
146+
number of time steps in ``data``.
147+
:return: A tensor of starting indices for unroll windows.
148+
:rtype: torch.Tensor
149+
"""
150+
last_idx = n_steps - unroll_length
151+
152+
if last_idx < 0:
153+
raise ValueError(
154+
"Cannot create unroll windows: "
155+
f"unroll_length ({unroll_length})"
156+
" cannot be greater or equal to the number of time_steps"
157+
f" ({n_steps})."
158+
)
159+
160+
indices = torch.arange(last_idx + 1)
161+
162+
if randomize:
163+
indices = indices[torch.randperm(len(indices))]
164+
165+
if n_unrolls is not None and n_unrolls < len(indices):
166+
indices = indices[:n_unrolls]
167+
168+
return indices

0 commit comments

Comments
 (0)