|
| 1 | +import torch |
| 2 | + |
| 3 | +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline |
| 4 | + |
| 5 | +class PatchTSTFMPipeline(ForecastPipeline): |
| 6 | + def __init__(self, model_info, **model_kwargs): |
| 7 | + super().__init__(model_info, **model_kwargs) |
| 8 | + |
| 9 | + def preprocess(self, inputs, **infer_kwargs): |
| 10 | + inputs = super().preprocess(inputs, **infer_kwargs) |
| 11 | + for idx, item in enumerate(inputs): |
| 12 | + # Model expects float32 |
| 13 | + target_tensor = item["targets"].to(torch.float32) |
| 14 | + |
| 15 | + # Expand 1D tensor [length] to [batch=1, length] |
| 16 | + if target_tensor.ndim == 1: |
| 17 | + target_tensor = target_tensor.unsqueeze(0) |
| 18 | + |
| 19 | + item["targets"] = target_tensor |
| 20 | + return inputs |
| 21 | + |
| 22 | + def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]: |
| 23 | + """ |
| 24 | + TODO: YOU WRITE THIS. |
| 25 | + 1. Create an empty list called `forecasts`. |
| 26 | + 2. Iterate through the [inputs](cci:1://file:///Users/karthik/Documents/projects/iotdb/iotdb-core/ainode/iotdb/ainode/core/model/chronos2/pipeline_chronos2.py:143:4-178:77) list. |
| 27 | + 3. For each input dictionary, extract the "targets" tensor. |
| 28 | + 4. Extract the prediction length using `infer_kwargs.get("output_length", 96)`. |
| 29 | + 5. Move the tensor to the model's device: `tensor = tensor.to(self.device)` |
| 30 | + 6. IBM's PatchTST expects "past_values" and "prediction_length" as arguments. |
| 31 | + Run the forward pass inside a `with torch.no_grad():` block natively using: |
| 32 | + output = self.model(past_values=tensor, prediction_length=pred_len) |
| 33 | + 7. Extract the forecast using `output.prediction_outputs` and append it to your list. |
| 34 | + 8. Return the list. |
| 35 | + """ |
| 36 | + forecasts = [] |
| 37 | + for input in inputs: |
| 38 | + targets = input['targets'] |
| 39 | + pred_length = infer_kwargs.get("output_length", 96) |
| 40 | + tensor = targets.to(self.device) |
| 41 | + with torch.no_grad(): |
| 42 | + output = self.model(past_values = tensor, prediction_length = pred_length) |
| 43 | + |
| 44 | + forecasts.append(output.prediction_outputs) |
| 45 | + return forecasts |
| 46 | + |
| 47 | + |
| 48 | + |
| 49 | + |
| 50 | + |
| 51 | + |
| 52 | + |
| 53 | + |
| 54 | + |
| 55 | + def postprocess(self, outputs: list[torch.Tensor], **infer_kwargs) -> list[torch.Tensor]: |
| 56 | + """ |
| 57 | + The IBM Model returns quantiles [batch, variates, prediction_length, quantiles]. |
| 58 | + We reduce this to [variates, prediction_length] by taking the median or mean. |
| 59 | + """ |
| 60 | + final_outputs = [] |
| 61 | + for output in outputs: |
| 62 | + # Remove batch dimension if it is just a single batch |
| 63 | + if output.ndim == 4: |
| 64 | + output = output.squeeze(0) |
| 65 | + |
| 66 | + # Average out the quantiles to get a point forecast |
| 67 | + point_forecast = output.mean(dim=-1) |
| 68 | + final_outputs.append(point_forecast) |
| 69 | + |
| 70 | + return super().postprocess(final_outputs, **infer_kwargs) |
0 commit comments