Skip to content

Commit d54f8bc

Browse files
feat(AINode): [Issue-17301] Import PatchTST-FM-R1 architecture and register in model_info
1 parent 9acf1e9 commit d54f8bc

4 files changed

Lines changed: 492 additions & 0 deletions

File tree

iotdb-core/ainode/iotdb/ainode/core/model/model_info.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,4 +158,16 @@ def __repr__(self):
158158
},
159159
transformers_registered=True,
160160
),
161+
"patchtst_fm": ModelInfo(
162+
model_id = "patchtst_fm",
163+
category=ModelCategory.BUILTIN,
164+
state=ModelStates.INACTIVE,
165+
model_type="patchtst_fm",
166+
pipeline_cls="pipeline_patchtst_fm.PatchTSTFMPipeline",
167+
repo_id="ibm-research/patchtst-fm-r1",
168+
auto_map={
169+
"AutoConfig": "configuration_patchtst_fm.PatchTSTFMConfig",
170+
"AutoModelForCausalLM": "modeling_patchtst_fm.PatchTSTFMForPrediction",
171+
},
172+
),
161173
}

iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/__init__.py

Whitespace-only changes.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright contributors to the TSFM project
2+
#
3+
"""PatchTST-FM model configuration"""
4+
5+
from transformers.configuration_utils import PretrainedConfig
6+
from transformers.utils import logging
7+
8+
9+
logger = logging.get_logger(__name__)
10+
11+
PATCHTSTFM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
12+
13+
14+
class PatchTSTFMConfig(PretrainedConfig):
15+
model_type = "patchtst_fm"
16+
attribute_map = {
17+
"hidden_size": "d_model",
18+
"num_hidden_layers": "n_layer",
19+
}
20+
21+
# has_no_defaults_at_init = True
22+
def __init__(
23+
self,
24+
context_length: int = 8192,
25+
prediction_length: int = 64,
26+
d_patch: int = 16,
27+
d_model: int = 384,
28+
n_head: int = 6,
29+
n_layer: int = 6,
30+
norm_first: bool = True,
31+
pretrain_mask_ratio: float = 0.4,
32+
pretrain_mask_cont: int = 8,
33+
num_quantile: int = 99,
34+
**kwargs,
35+
):
36+
self.context_length = context_length
37+
self.prediction_length = prediction_length
38+
self.d_patch = d_patch
39+
self.n_patch = int(context_length // d_patch)
40+
self.d_model = d_model
41+
self.n_head = n_head
42+
self.n_layer = n_layer
43+
self.norm_first = norm_first
44+
self.pretrain_mask_ratio = pretrain_mask_ratio
45+
self.pretrain_mask_cont = pretrain_mask_cont
46+
self.num_quantile = num_quantile
47+
48+
if num_quantile % 9 == 0:
49+
quantiles = [i / (self.num_quantile + 1) for i in range(1, self.num_quantile + 1)]
50+
else:
51+
quantiles = [i / (self.num_quantile - 1) for i in range(1, self.num_quantile - 1)]
52+
quantiles = [0.01] + quantiles + [0.99]
53+
self.quantile_levels = quantiles
54+
super().__init__(**kwargs)

0 commit comments

Comments
 (0)