Skip to content

Commit cec7ace

Browse files
[AINode]: PatchTST-FM integration with license compliance and formatting
1 parent d1a348a commit cec7ace

11 files changed

Lines changed: 393 additions & 148 deletions

File tree

LICENSE

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,3 +360,13 @@ Project page: https://github.com/SalesforceAIResearch/uni2ts
360360
License: https://github.com/SalesforceAIResearch/uni2ts/blob/main/LICENSE.txt
361361

362362
--------------------------------------------------------------------------------
363+
364+
The following files include code modified from PatchTST project.
365+
366+
./iotdb-core/ainode/iotdb/ainode/core/model/patchtst_fm/*
367+
368+
PatchTST is open source software licensed under the Apache License 2.0
369+
Project page: https://github.com/ibm-research/patchtst
370+
License: https://github.com/ibm-research/patchtst/blob/main/LICENSE
371+
372+
--------------------------------------------------------------------------------

integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java

Lines changed: 76 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -49,44 +49,44 @@
4949

5050
public class AINodeTestUtils {
5151

52-
public static final Map<String, FakeModelInfo> BUILTIN_LTSM_MAP =
53-
Stream.of(
54-
new AbstractMap.SimpleEntry<>(
55-
"sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")),
56-
new AbstractMap.SimpleEntry<>(
57-
"timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")),
58-
new AbstractMap.SimpleEntry<>(
59-
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")),
60-
new AbstractMap.SimpleEntry<>(
61-
"moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")))
62-
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
52+
public static final Map<String, FakeModelInfo> BUILTIN_LTSM_MAP = Stream.of(
53+
new AbstractMap.SimpleEntry<>(
54+
"sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")),
55+
new AbstractMap.SimpleEntry<>(
56+
"timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")),
57+
new AbstractMap.SimpleEntry<>(
58+
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")),
59+
new AbstractMap.SimpleEntry<>(
60+
"moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", "active")),
61+
new AbstractMap.SimpleEntry<>(
62+
"patchtst_fm", new FakeModelInfo("patchtst_fm", "patchtst_fm", "builtin", "active")))
63+
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
6364

6465
public static final Map<String, FakeModelInfo> BUILTIN_MODEL_MAP;
6566

6667
static {
67-
Map<String, FakeModelInfo> tmp =
68-
Stream.of(
69-
new AbstractMap.SimpleEntry<>(
70-
"arima", new FakeModelInfo("arima", "sktime", "builtin", "active")),
71-
new AbstractMap.SimpleEntry<>(
72-
"holtwinters", new FakeModelInfo("holtwinters", "sktime", "builtin", "active")),
73-
new AbstractMap.SimpleEntry<>(
74-
"exponential_smoothing",
75-
new FakeModelInfo("exponential_smoothing", "sktime", "builtin", "active")),
76-
new AbstractMap.SimpleEntry<>(
77-
"naive_forecaster",
78-
new FakeModelInfo("naive_forecaster", "sktime", "builtin", "active")),
79-
new AbstractMap.SimpleEntry<>(
80-
"stl_forecaster",
81-
new FakeModelInfo("stl_forecaster", "sktime", "builtin", "active")),
82-
new AbstractMap.SimpleEntry<>(
83-
"gaussian_hmm",
84-
new FakeModelInfo("gaussian_hmm", "sktime", "builtin", "active")),
85-
new AbstractMap.SimpleEntry<>(
86-
"gmm_hmm", new FakeModelInfo("gmm_hmm", "sktime", "builtin", "active")),
87-
new AbstractMap.SimpleEntry<>(
88-
"stray", new FakeModelInfo("stray", "sktime", "builtin", "active")))
89-
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
68+
Map<String, FakeModelInfo> tmp = Stream.of(
69+
new AbstractMap.SimpleEntry<>(
70+
"arima", new FakeModelInfo("arima", "sktime", "builtin", "active")),
71+
new AbstractMap.SimpleEntry<>(
72+
"holtwinters", new FakeModelInfo("holtwinters", "sktime", "builtin", "active")),
73+
new AbstractMap.SimpleEntry<>(
74+
"exponential_smoothing",
75+
new FakeModelInfo("exponential_smoothing", "sktime", "builtin", "active")),
76+
new AbstractMap.SimpleEntry<>(
77+
"naive_forecaster",
78+
new FakeModelInfo("naive_forecaster", "sktime", "builtin", "active")),
79+
new AbstractMap.SimpleEntry<>(
80+
"stl_forecaster",
81+
new FakeModelInfo("stl_forecaster", "sktime", "builtin", "active")),
82+
new AbstractMap.SimpleEntry<>(
83+
"gaussian_hmm",
84+
new FakeModelInfo("gaussian_hmm", "sktime", "builtin", "active")),
85+
new AbstractMap.SimpleEntry<>(
86+
"gmm_hmm", new FakeModelInfo("gmm_hmm", "sktime", "builtin", "active")),
87+
new AbstractMap.SimpleEntry<>(
88+
"stray", new FakeModelInfo("stray", "sktime", "builtin", "active")))
89+
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
9090
tmp.putAll(BUILTIN_LTSM_MAP);
9191
BUILTIN_MODEL_MAP = Collections.unmodifiableMap(tmp);
9292
}
@@ -115,36 +115,35 @@ public static void concurrentInference(
115115
AtomicBoolean allPass = new AtomicBoolean(true);
116116
Thread[] threads = new Thread[threadCnt];
117117
for (int i = 0; i < threadCnt; i++) {
118-
threads[i] =
119-
new Thread(
120-
() -> {
121-
try {
122-
for (int j = 0; j < loop; j++) {
123-
try (ResultSet resultSet = statement.executeQuery(sql)) {
124-
int outputCnt = 0;
125-
while (resultSet.next()) {
126-
outputCnt++;
127-
}
128-
if (expectedOutputLength != outputCnt) {
129-
allPass.set(false);
130-
fail(
131-
"Output count mismatch for SQL: "
132-
+ sql
133-
+ ". Expected: "
134-
+ expectedOutputLength
135-
+ ", but got: "
136-
+ outputCnt);
137-
}
138-
} catch (SQLException e) {
139-
allPass.set(false);
140-
fail(e.getMessage());
141-
}
118+
threads[i] = new Thread(
119+
() -> {
120+
try {
121+
for (int j = 0; j < loop; j++) {
122+
try (ResultSet resultSet = statement.executeQuery(sql)) {
123+
int outputCnt = 0;
124+
while (resultSet.next()) {
125+
outputCnt++;
142126
}
143-
} catch (Exception e) {
127+
if (expectedOutputLength != outputCnt) {
128+
allPass.set(false);
129+
fail(
130+
"Output count mismatch for SQL: "
131+
+ sql
132+
+ ". Expected: "
133+
+ expectedOutputLength
134+
+ ", but got: "
135+
+ outputCnt);
136+
}
137+
} catch (SQLException e) {
144138
allPass.set(false);
145139
fail(e.getMessage());
146140
}
147-
});
141+
}
142+
} catch (Exception e) {
143+
allPass.set(false);
144+
fail(e.getMessage());
145+
}
146+
});
148147
threads[i].start();
149148
}
150149
for (Thread thread : threads) {
@@ -162,8 +161,7 @@ public static void checkModelOnSpecifiedDevice(Statement statement, String model
162161
LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices);
163162
for (int retry = 0; retry < 200; retry++) {
164163
Set<String> foundDevices = new HashSet<>();
165-
try (final ResultSet resultSet =
166-
statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
164+
try (final ResultSet resultSet = statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
167165
while (resultSet.next()) {
168166
String deviceId = resultSet.getString("DeviceId");
169167
String loadedModelId = resultSet.getString("ModelId");
@@ -191,8 +189,7 @@ public static void checkModelNotOnSpecifiedDevice(
191189
LOGGER.info("Checking model: {} not on target devices: {}", modelId, targetDevices);
192190
for (int retry = 0; retry < 50; retry++) {
193191
Set<String> foundDevices = new HashSet<>();
194-
try (final ResultSet resultSet =
195-
statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
192+
try (final ResultSet resultSet = statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
196193
while (resultSet.next()) {
197194
String deviceId = resultSet.getString("DeviceId");
198195
String loadedModelId = resultSet.getString("ModelId");
@@ -213,16 +210,18 @@ public static void checkModelNotOnSpecifiedDevice(
213210
fail("Model " + modelId + " is still loaded on device " + device);
214211
}
215212

216-
private static final String[] WRITE_SQL_IN_TREE =
217-
new String[] {
218-
"CREATE DATABASE root.AI",
219-
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
220-
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
221-
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
222-
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
223-
};
213+
private static final String[] WRITE_SQL_IN_TREE = new String[] {
214+
"CREATE DATABASE root.AI",
215+
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
216+
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
217+
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
218+
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
219+
};
224220

225-
/** Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in tree. */
221+
/**
222+
* Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of
223+
* data in tree.
224+
*/
226225
public static void prepareDataInTree() throws SQLException {
227226
prepareData(WRITE_SQL_IN_TREE);
228227
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
@@ -236,7 +235,10 @@ public static void prepareDataInTree() throws SQLException {
236235
}
237236
}
238237

239-
/** Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in table. */
238+
/**
239+
* Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data
240+
* in table.
241+
*/
240242
public static void prepareDataInTable() throws SQLException {
241243
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
242244
Statement statement = connection.createStatement()) {

iotdb-core/ainode/iotdb/ainode/core/model/model_info.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def __init__(
3030
model_type: str = "",
3131
pipeline_cls: str = "",
3232
repo_id: str = "",
33+
download_weights: bool = True,
3334
auto_map: Optional[Dict] = None,
3435
transformers_registered: bool = False,
3536
):
@@ -39,6 +40,7 @@ def __init__(
3940
self.state = state
4041
self.pipeline_cls = pipeline_cls
4142
self.repo_id = repo_id
43+
self.download_weights = download_weights
4244
self.auto_map = auto_map # If exists, indicates it's a Transformers model
4345
self.transformers_registered = (
4446
transformers_registered # Internal flag: whether registered to Transformers
@@ -159,15 +161,17 @@ def __repr__(self):
159161
transformers_registered=True,
160162
),
161163
"patchtst_fm": ModelInfo(
162-
model_id = "patchtst_fm",
164+
model_id="patchtst_fm",
163165
category=ModelCategory.BUILTIN,
164166
state=ModelStates.INACTIVE,
165167
model_type="patchtst_fm",
166168
pipeline_cls="pipeline_patchtst_fm.PatchTSTFMPipeline",
167169
repo_id="ibm-research/patchtst-fm-r1",
170+
download_weights=False,
168171
auto_map={
169-
"AutoConfig": "configuration_patchtst_fm.PatchTSTFMConfig",
170-
"AutoModelForCausalLM": "modeling_patchtst_fm.PatchTSTFMForPrediction",
171-
},
172+
"AutoConfig": "configuration_patchtst_fm.PatchTSTFMConfig",
173+
"AutoModelForCausalLM": "modeling_patchtst_fm.PatchTSTFMForPrediction",
174+
},
175+
transformers_registered=True,
172176
),
173177
}

iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -146,21 +146,28 @@ def _process_builtin_model_directory(self, model_dir: str, model_id: str):
146146

147147
def _download_model_if_necessary() -> bool:
148148
"""Returns: True if the model is existed or downloaded successfully, False otherwise."""
149-
repo_id = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id].repo_id
149+
model_info = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id]
150+
repo_id = model_info.repo_id
150151
weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)
151152
config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON)
152-
if not os.path.exists(weights_path):
153-
try:
154-
hf_hub_download(
155-
repo_id=repo_id,
156-
filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
157-
local_dir=model_dir,
158-
)
159-
except Exception as e:
160-
logger.error(
161-
f"Failed to download model weights from HuggingFace: {e}"
162-
)
163-
return False
153+
if model_info.download_weights:
154+
if not os.path.exists(weights_path):
155+
try:
156+
hf_hub_download(
157+
repo_id=repo_id,
158+
filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
159+
local_dir=model_dir,
160+
)
161+
except Exception as e:
162+
logger.error(
163+
f"Failed to download model weights from HuggingFace: {e}"
164+
)
165+
return False
166+
167+
else:
168+
logger.info(
169+
f"Skipping weight download for {model_id} due to configuration."
170+
)
164171
if not os.path.exists(config_path):
165172
try:
166173
hf_hub_download(
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.

0 commit comments

Comments
 (0)