Skip to content

Commit 2e4a5e3

Browse files
AINode: [Issue-17301] Implement PatchTSTFMPipeline forecast workflow
1 parent 378a8fd commit 2e4a5e3

1 file changed

Lines changed: 70 additions & 0 deletions

File tree

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
3+
from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline
4+
5+
class PatchTSTFMPipeline(ForecastPipeline):
6+
def __init__(self, model_info, **model_kwargs):
7+
super().__init__(model_info, **model_kwargs)
8+
9+
def preprocess(self, inputs, **infer_kwargs):
10+
inputs = super().preprocess(inputs, **infer_kwargs)
11+
for idx, item in enumerate(inputs):
12+
# Model expects float32
13+
target_tensor = item["targets"].to(torch.float32)
14+
15+
# Expand 1D tensor [length] to [batch=1, length]
16+
if target_tensor.ndim == 1:
17+
target_tensor = target_tensor.unsqueeze(0)
18+
19+
item["targets"] = target_tensor
20+
return inputs
21+
22+
def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]:
23+
"""
24+
TODO: YOU WRITE THIS.
25+
1. Create an empty list called `forecasts`.
26+
2. Iterate through the [inputs](cci:1://file:///Users/karthik/Documents/projects/iotdb/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py:143:4-178:77) list.
27+
3. For each input dictionary, extract the "targets" tensor.
28+
4. Extract the prediction length using `infer_kwargs.get("output_length", 96)`.
29+
5. Move the tensor to the model's device: `tensor = tensor.to(self.device)`
30+
6. IBM's PatchTST expects "past_values" and "prediction_length" as arguments.
31+
Run the forward pass inside a `with torch.no_grad():` block natively using:
32+
output = self.model(past_values=tensor, prediction_length=pred_len)
33+
7. Extract the forecast using `output.prediction_outputs` and append it to your list.
34+
8. Return the list.
35+
"""
36+
forecasts = []
37+
for input in inputs:
38+
targets = input['targets']
39+
pred_length = infer_kwargs.get("output_length", 96)
40+
tensor = targets.to(self.device)
41+
with torch.no_grad():
42+
output = self.model(past_values = tensor, prediction_length = pred_length)
43+
44+
forecasts.append(output.prediction_outputs)
45+
return forecasts
46+
47+
48+
49+
50+
51+
52+
53+
54+
55+
def postprocess(self, outputs: list[torch.Tensor], **infer_kwargs) -> list[torch.Tensor]:
56+
"""
57+
The IBM Model returns quantiles [batch, variates, prediction_length, quantiles].
58+
We reduce this to [variates, prediction_length] by taking the median or mean.
59+
"""
60+
final_outputs = []
61+
for output in outputs:
62+
# Remove batch dimension if it is just a single batch
63+
if output.ndim == 4:
64+
output = output.squeeze(0)
65+
66+
# Average out the quantiles to get a point forecast
67+
point_forecast = output.mean(dim=-1)
68+
final_outputs.append(point_forecast)
69+
70+
return super().postprocess(final_outputs, **infer_kwargs)

0 commit comments

Comments
 (0)