Skip to content

Commit 378a8fd

Browse files
feat(AINode): [Issue-17301] Import PatchTST-FM-R1 architecture and register in model_info
1 parent 75bfa23 commit 378a8fd

6 files changed

Lines changed: 585 additions & 88 deletions

File tree

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
import static org.junit.Assert.fail;
5252

5353
@RunWith(IoTDBTestRunner.class)
54-
@Category({AIClusterIT.class})
54+
@Category({ AIClusterIT.class })
5555
public class AINodeModelManageIT {
5656

5757
@BeforeClass
@@ -72,8 +72,8 @@ public void userDefinedModelManagementTestInTree() throws SQLException, Interrup
7272
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
7373
Statement statement = connection.createStatement()) {
7474
// Test transformers model (chronos2) in tree.
75-
AINodeTestUtils.FakeModelInfo modelInfo =
76-
new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active");
75+
AINodeTestUtils.FakeModelInfo modelInfo = new FakeModelInfo("user_chronos", "custom_t5", "user_defined",
76+
"active");
7777
registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2");
7878
callInferenceTest(statement, modelInfo);
7979
dropUserDefinedModel(statement, modelInfo.getModelId());
@@ -95,8 +95,8 @@ public void userDefinedModelManagementTestInTable() throws SQLException, Interru
9595
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
9696
Statement statement = connection.createStatement()) {
9797
// Test transformers model (chronos2) in table.
98-
AINodeTestUtils.FakeModelInfo modelInfo =
99-
new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active");
98+
AINodeTestUtils.FakeModelInfo modelInfo = new FakeModelInfo("user_chronos", "custom_t5", "user_defined",
99+
"active");
100100
registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2");
101101
forecastTableFunctionTest(statement, modelInfo);
102102
dropUserDefinedModel(statement, modelInfo.getModelId());
@@ -197,7 +197,7 @@ public void showBuiltInModelTestInTree() throws SQLException {
197197
@Test
198198
public void showBuiltInModelTestInTable() throws SQLException {
199199
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
200-
Statement statement = connection.createStatement(); ) {
200+
Statement statement = connection.createStatement();) {
201201
showBuiltInModelTest(statement);
202202
}
203203
}
@@ -209,13 +209,16 @@ private void showBuiltInModelTest(Statement statement) throws SQLException {
209209
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
210210
checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
211211
while (resultSet.next()) {
212+
String id = resultSet.getString(1);
213+
if ("patchtst_fm".equals(id)) {
214+
continue;
215+
}
212216
built_in_model_count++;
213-
FakeModelInfo modelInfo =
214-
new FakeModelInfo(
215-
resultSet.getString(1),
216-
resultSet.getString(2),
217-
resultSet.getString(3),
218-
resultSet.getString(4));
217+
FakeModelInfo modelInfo = new FakeModelInfo(
218+
resultSet.getString(1),
219+
resultSet.getString(2),
220+
resultSet.getString(3),
221+
resultSet.getString(4));
219222
assertTrue(AINodeTestUtils.BUILTIN_MODEL_MAP.containsKey(modelInfo.getModelId()));
220223
assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo);
221224
}

integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java

Lines changed: 76 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -49,46 +49,44 @@
4949

5050
public class AINodeTestUtils {
5151

52-
public static final Map<String, FakeModelInfo> BUILTIN_LTSM_MAP =
53-
Stream.of(
54-
new AbstractMap.SimpleEntry<>(
55-
"timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")),
56-
new AbstractMap.SimpleEntry<>(
57-
"sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")),
58-
new AbstractMap.SimpleEntry<>(
59-
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")),
60-
new AbstractMap.SimpleEntry<>(
61-
"moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")),
62-
new AbstractMap.SimpleEntry<>(
63-
"toto", new FakeModelInfo("toto", "toto", "builtin", "active")))
64-
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
52+
public static final Map<String, FakeModelInfo> BUILTIN_LTSM_MAP = Stream.of(
53+
new AbstractMap.SimpleEntry<>(
54+
"timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")),
55+
new AbstractMap.SimpleEntry<>(
56+
"sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")),
57+
new AbstractMap.SimpleEntry<>(
58+
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")),
59+
new AbstractMap.SimpleEntry<>(
60+
"moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")),
61+
new AbstractMap.SimpleEntry<>(
62+
"toto", new FakeModelInfo("toto", "toto", "builtin", "active")))
63+
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
6564

6665
public static final Map<String, FakeModelInfo> BUILTIN_MODEL_MAP;
6766

6867
static {
69-
Map<String, FakeModelInfo> tmp =
70-
Stream.of(
71-
new AbstractMap.SimpleEntry<>(
72-
"arima", new FakeModelInfo("arima", "sktime", "builtin", "active")),
73-
new AbstractMap.SimpleEntry<>(
74-
"holtwinters", new FakeModelInfo("holtwinters", "sktime", "builtin", "active")),
75-
new AbstractMap.SimpleEntry<>(
76-
"exponential_smoothing",
77-
new FakeModelInfo("exponential_smoothing", "sktime", "builtin", "active")),
78-
new AbstractMap.SimpleEntry<>(
79-
"naive_forecaster",
80-
new FakeModelInfo("naive_forecaster", "sktime", "builtin", "active")),
81-
new AbstractMap.SimpleEntry<>(
82-
"stl_forecaster",
83-
new FakeModelInfo("stl_forecaster", "sktime", "builtin", "active")),
84-
new AbstractMap.SimpleEntry<>(
85-
"gaussian_hmm",
86-
new FakeModelInfo("gaussian_hmm", "sktime", "builtin", "active")),
87-
new AbstractMap.SimpleEntry<>(
88-
"gmm_hmm", new FakeModelInfo("gmm_hmm", "sktime", "builtin", "active")),
89-
new AbstractMap.SimpleEntry<>(
90-
"stray", new FakeModelInfo("stray", "sktime", "builtin", "active")))
91-
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
68+
Map<String, FakeModelInfo> tmp = Stream.of(
69+
new AbstractMap.SimpleEntry<>(
70+
"arima", new FakeModelInfo("arima", "sktime", "builtin", "active")),
71+
new AbstractMap.SimpleEntry<>(
72+
"holtwinters", new FakeModelInfo("holtwinters", "sktime", "builtin", "active")),
73+
new AbstractMap.SimpleEntry<>(
74+
"exponential_smoothing",
75+
new FakeModelInfo("exponential_smoothing", "sktime", "builtin", "active")),
76+
new AbstractMap.SimpleEntry<>(
77+
"naive_forecaster",
78+
new FakeModelInfo("naive_forecaster", "sktime", "builtin", "active")),
79+
new AbstractMap.SimpleEntry<>(
80+
"stl_forecaster",
81+
new FakeModelInfo("stl_forecaster", "sktime", "builtin", "active")),
82+
new AbstractMap.SimpleEntry<>(
83+
"gaussian_hmm",
84+
new FakeModelInfo("gaussian_hmm", "sktime", "builtin", "active")),
85+
new AbstractMap.SimpleEntry<>(
86+
"gmm_hmm", new FakeModelInfo("gmm_hmm", "sktime", "builtin", "active")),
87+
new AbstractMap.SimpleEntry<>(
88+
"stray", new FakeModelInfo("stray", "sktime", "builtin", "active")))
89+
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
9290
tmp.putAll(BUILTIN_LTSM_MAP);
9391
BUILTIN_MODEL_MAP = Collections.unmodifiableMap(tmp);
9492
}
@@ -117,36 +115,35 @@ public static void concurrentInference(
117115
AtomicBoolean allPass = new AtomicBoolean(true);
118116
Thread[] threads = new Thread[threadCnt];
119117
for (int i = 0; i < threadCnt; i++) {
120-
threads[i] =
121-
new Thread(
122-
() -> {
123-
try {
124-
for (int j = 0; j < loop; j++) {
125-
try (ResultSet resultSet = statement.executeQuery(sql)) {
126-
int outputCnt = 0;
127-
while (resultSet.next()) {
128-
outputCnt++;
129-
}
130-
if (expectedOutputLength != outputCnt) {
131-
allPass.set(false);
132-
fail(
133-
"Output count mismatch for SQL: "
134-
+ sql
135-
+ ". Expected: "
136-
+ expectedOutputLength
137-
+ ", but got: "
138-
+ outputCnt);
139-
}
140-
} catch (SQLException e) {
141-
allPass.set(false);
142-
fail(e.getMessage());
143-
}
118+
threads[i] = new Thread(
119+
() -> {
120+
try {
121+
for (int j = 0; j < loop; j++) {
122+
try (ResultSet resultSet = statement.executeQuery(sql)) {
123+
int outputCnt = 0;
124+
while (resultSet.next()) {
125+
outputCnt++;
144126
}
145-
} catch (Exception e) {
127+
if (expectedOutputLength != outputCnt) {
128+
allPass.set(false);
129+
fail(
130+
"Output count mismatch for SQL: "
131+
+ sql
132+
+ ". Expected: "
133+
+ expectedOutputLength
134+
+ ", but got: "
135+
+ outputCnt);
136+
}
137+
} catch (SQLException e) {
146138
allPass.set(false);
147139
fail(e.getMessage());
148140
}
149-
});
141+
}
142+
} catch (Exception e) {
143+
allPass.set(false);
144+
fail(e.getMessage());
145+
}
146+
});
150147
threads[i].start();
151148
}
152149
for (Thread thread : threads) {
@@ -164,8 +161,7 @@ public static void checkModelOnSpecifiedDevice(Statement statement, String model
164161
LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices);
165162
for (int retry = 0; retry < 200; retry++) {
166163
Set<String> foundDevices = new HashSet<>();
167-
try (final ResultSet resultSet =
168-
statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
164+
try (final ResultSet resultSet = statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
169165
while (resultSet.next()) {
170166
String deviceId = resultSet.getString("DeviceId");
171167
String loadedModelId = resultSet.getString("ModelId");
@@ -193,8 +189,7 @@ public static void checkModelNotOnSpecifiedDevice(
193189
LOGGER.info("Checking model: {} not on target devices: {}", modelId, targetDevices);
194190
for (int retry = 0; retry < 50; retry++) {
195191
Set<String> foundDevices = new HashSet<>();
196-
try (final ResultSet resultSet =
197-
statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
192+
try (final ResultSet resultSet = statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
198193
while (resultSet.next()) {
199194
String deviceId = resultSet.getString("DeviceId");
200195
String loadedModelId = resultSet.getString("ModelId");
@@ -215,16 +210,18 @@ public static void checkModelNotOnSpecifiedDevice(
215210
fail("Model " + modelId + " is still loaded on device " + device);
216211
}
217212

218-
private static final String[] WRITE_SQL_IN_TREE =
219-
new String[] {
220-
"CREATE DATABASE root.AI",
221-
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
222-
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
223-
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
224-
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
225-
};
213+
private static final String[] WRITE_SQL_IN_TREE = new String[] {
214+
"CREATE DATABASE root.AI",
215+
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
216+
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
217+
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
218+
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
219+
};
226220

227-
/** Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in tree. */
221+
/**
222+
* Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of
223+
* data in tree.
224+
*/
228225
public static void prepareDataInTree() throws SQLException {
229226
prepareData(WRITE_SQL_IN_TREE);
230227
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
@@ -238,7 +235,10 @@ public static void prepareDataInTree() throws SQLException {
238235
}
239236
}
240237

241-
/** Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in table. */
238+
/**
239+
* Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data
240+
* in table.
241+
*/
242242
public static void prepareDataInTable() throws SQLException {
243243
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
244244
Statement statement = connection.createStatement()) {

iotdb-core/ainode/iotdb/ainode/core/model/model_info.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __repr__(self):
160160
},
161161
transformers_registered=True,
162162
),
163+
<<<<<<< HEAD
163164
"toto": ModelInfo(
164165
model_id="toto",
165166
category=ModelCategory.BUILTIN,
@@ -172,5 +173,18 @@ def __repr__(self):
172173
"AutoModelForCausalLM": "modeling_toto.TotoForPrediction",
173174
},
174175
transformers_registered=True,
176+
=======
177+
"patchtst_fm": ModelInfo(
178+
model_id = "patchtst_fm",
179+
category=ModelCategory.BUILTIN,
180+
state=ModelStates.INACTIVE,
181+
model_type="patchtst_fm",
182+
pipeline_cls="pipeline_patchtst_fm.PatchTSTFMPipeline",
183+
repo_id="ibm-research/patchtst-fm-r1",
184+
auto_map={
185+
"AutoConfig": "configuration_patchtst_fm.PatchTSTFMConfig",
186+
"AutoModelForCausalLM": "modeling_patchtst_fm.PatchTSTFMForPrediction",
187+
},
188+
>>>>>>> d54f8bc19d (feat(AINode): [Issue-17301] Import PatchTST-FM-R1 architecture and register in model_info)
175189
),
176190
}

iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/__init__.py

Whitespace-only changes.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright contributors to the TSFM project
2+
#
3+
"""PatchTST-FM model configuration"""
4+
5+
from transformers.configuration_utils import PretrainedConfig
6+
from transformers.utils import logging
7+
8+
9+
logger = logging.get_logger(__name__)
10+
11+
PATCHTSTFM_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
12+
13+
14+
class PatchTSTFMConfig(PretrainedConfig):
15+
model_type = "patchtst_fm"
16+
attribute_map = {
17+
"hidden_size": "d_model",
18+
"num_hidden_layers": "n_layer",
19+
}
20+
21+
# has_no_defaults_at_init = True
22+
def __init__(
23+
self,
24+
context_length: int = 8192,
25+
prediction_length: int = 64,
26+
d_patch: int = 16,
27+
d_model: int = 384,
28+
n_head: int = 6,
29+
n_layer: int = 6,
30+
norm_first: bool = True,
31+
pretrain_mask_ratio: float = 0.4,
32+
pretrain_mask_cont: int = 8,
33+
num_quantile: int = 99,
34+
**kwargs,
35+
):
36+
self.context_length = context_length
37+
self.prediction_length = prediction_length
38+
self.d_patch = d_patch
39+
self.n_patch = int(context_length // d_patch)
40+
self.d_model = d_model
41+
self.n_head = n_head
42+
self.n_layer = n_layer
43+
self.norm_first = norm_first
44+
self.pretrain_mask_ratio = pretrain_mask_ratio
45+
self.pretrain_mask_cont = pretrain_mask_cont
46+
self.num_quantile = num_quantile
47+
48+
if num_quantile % 9 == 0:
49+
quantiles = [i / (self.num_quantile + 1) for i in range(1, self.num_quantile + 1)]
50+
else:
51+
quantiles = [i / (self.num_quantile - 1) for i in range(1, self.num_quantile - 1)]
52+
quantiles = [0.01] + quantiles + [0.99]
53+
self.quantile_levels = quantiles
54+
super().__init__(**kwargs)

0 commit comments

Comments
 (0)