4949
5050public class AINodeTestUtils {
5151
52- public static final Map <String , FakeModelInfo > BUILTIN_LTSM_MAP =
53- Stream .of (
54- new AbstractMap .SimpleEntry <>(
55- "sundial" , new FakeModelInfo ("sundial" , "sundial" , "builtin" , "active" )),
56- new AbstractMap .SimpleEntry <>(
57- "timer_xl" , new FakeModelInfo ("timer_xl" , "timer" , "builtin" , "active" )),
58- new AbstractMap .SimpleEntry <>(
59- "chronos2" , new FakeModelInfo ("chronos2" , "t5" , "builtin" , "active" )),
60- new AbstractMap .SimpleEntry <>(
61- "moirai2" , new FakeModelInfo ("moirai2" , "moirai" , "builtin" , "active" )))
62- .collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue ));
52+ public static final Map <String , FakeModelInfo > BUILTIN_LTSM_MAP = Stream .of (
53+ new AbstractMap .SimpleEntry <>(
54+ "sundial" , new FakeModelInfo ("sundial" , "sundial" , "builtin" , "active" )),
55+ new AbstractMap .SimpleEntry <>(
56+ "timer_xl" , new FakeModelInfo ("timer_xl" , "timer" , "builtin" , "active" )),
57+ new AbstractMap .SimpleEntry <>(
58+ "chronos2" , new FakeModelInfo ("chronos2" , "t5" , "builtin" , "active" )),
59+ new AbstractMap .SimpleEntry <>(
60+ "moirai2" , new FakeModelInfo ("moirai2" , "moirai" , "builtin" , "active" )),
61+ new AbstractMap .SimpleEntry <>(
62+ "patchtst_fm" , new FakeModelInfo ("patchtst_fm" , "patchtst_fm" , "builtin" , "active" )))
63+ .collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue ));
6364
6465 public static final Map <String , FakeModelInfo > BUILTIN_MODEL_MAP ;
6566
6667 static {
67- Map <String , FakeModelInfo > tmp =
68- Stream .of (
69- new AbstractMap .SimpleEntry <>(
70- "arima" , new FakeModelInfo ("arima" , "sktime" , "builtin" , "active" )),
71- new AbstractMap .SimpleEntry <>(
72- "holtwinters" , new FakeModelInfo ("holtwinters" , "sktime" , "builtin" , "active" )),
73- new AbstractMap .SimpleEntry <>(
74- "exponential_smoothing" ,
75- new FakeModelInfo ("exponential_smoothing" , "sktime" , "builtin" , "active" )),
76- new AbstractMap .SimpleEntry <>(
77- "naive_forecaster" ,
78- new FakeModelInfo ("naive_forecaster" , "sktime" , "builtin" , "active" )),
79- new AbstractMap .SimpleEntry <>(
80- "stl_forecaster" ,
81- new FakeModelInfo ("stl_forecaster" , "sktime" , "builtin" , "active" )),
82- new AbstractMap .SimpleEntry <>(
83- "gaussian_hmm" ,
84- new FakeModelInfo ("gaussian_hmm" , "sktime" , "builtin" , "active" )),
85- new AbstractMap .SimpleEntry <>(
86- "gmm_hmm" , new FakeModelInfo ("gmm_hmm" , "sktime" , "builtin" , "active" )),
87- new AbstractMap .SimpleEntry <>(
88- "stray" , new FakeModelInfo ("stray" , "sktime" , "builtin" , "active" )))
89- .collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue ));
68+ Map <String , FakeModelInfo > tmp = Stream .of (
69+ new AbstractMap .SimpleEntry <>(
70+ "arima" , new FakeModelInfo ("arima" , "sktime" , "builtin" , "active" )),
71+ new AbstractMap .SimpleEntry <>(
72+ "holtwinters" , new FakeModelInfo ("holtwinters" , "sktime" , "builtin" , "active" )),
73+ new AbstractMap .SimpleEntry <>(
74+ "exponential_smoothing" ,
75+ new FakeModelInfo ("exponential_smoothing" , "sktime" , "builtin" , "active" )),
76+ new AbstractMap .SimpleEntry <>(
77+ "naive_forecaster" ,
78+ new FakeModelInfo ("naive_forecaster" , "sktime" , "builtin" , "active" )),
79+ new AbstractMap .SimpleEntry <>(
80+ "stl_forecaster" ,
81+ new FakeModelInfo ("stl_forecaster" , "sktime" , "builtin" , "active" )),
82+ new AbstractMap .SimpleEntry <>(
83+ "gaussian_hmm" ,
84+ new FakeModelInfo ("gaussian_hmm" , "sktime" , "builtin" , "active" )),
85+ new AbstractMap .SimpleEntry <>(
86+ "gmm_hmm" , new FakeModelInfo ("gmm_hmm" , "sktime" , "builtin" , "active" )),
87+ new AbstractMap .SimpleEntry <>(
88+ "stray" , new FakeModelInfo ("stray" , "sktime" , "builtin" , "active" )))
89+ .collect (Collectors .toMap (Map .Entry ::getKey , Map .Entry ::getValue ));
9090 tmp .putAll (BUILTIN_LTSM_MAP );
9191 BUILTIN_MODEL_MAP = Collections .unmodifiableMap (tmp );
9292 }
@@ -115,36 +115,35 @@ public static void concurrentInference(
115115 AtomicBoolean allPass = new AtomicBoolean (true );
116116 Thread [] threads = new Thread [threadCnt ];
117117 for (int i = 0 ; i < threadCnt ; i ++) {
118- threads [i ] =
119- new Thread (
120- () -> {
121- try {
122- for (int j = 0 ; j < loop ; j ++) {
123- try (ResultSet resultSet = statement .executeQuery (sql )) {
124- int outputCnt = 0 ;
125- while (resultSet .next ()) {
126- outputCnt ++;
127- }
128- if (expectedOutputLength != outputCnt ) {
129- allPass .set (false );
130- fail (
131- "Output count mismatch for SQL: "
132- + sql
133- + ". Expected: "
134- + expectedOutputLength
135- + ", but got: "
136- + outputCnt );
137- }
138- } catch (SQLException e ) {
139- allPass .set (false );
140- fail (e .getMessage ());
141- }
118+ threads [i ] = new Thread (
119+ () -> {
120+ try {
121+ for (int j = 0 ; j < loop ; j ++) {
122+ try (ResultSet resultSet = statement .executeQuery (sql )) {
123+ int outputCnt = 0 ;
124+ while (resultSet .next ()) {
125+ outputCnt ++;
142126 }
143- } catch (Exception e ) {
127+ if (expectedOutputLength != outputCnt ) {
128+ allPass .set (false );
129+ fail (
130+ "Output count mismatch for SQL: "
131+ + sql
132+ + ". Expected: "
133+ + expectedOutputLength
134+ + ", but got: "
135+ + outputCnt );
136+ }
137+ } catch (SQLException e ) {
144138 allPass .set (false );
145139 fail (e .getMessage ());
146140 }
147- });
141+ }
142+ } catch (Exception e ) {
143+ allPass .set (false );
144+ fail (e .getMessage ());
145+ }
146+ });
148147 threads [i ].start ();
149148 }
150149 for (Thread thread : threads ) {
@@ -162,8 +161,7 @@ public static void checkModelOnSpecifiedDevice(Statement statement, String model
162161 LOGGER .info ("Checking model: {} on target devices: {}" , modelId , targetDevices );
163162 for (int retry = 0 ; retry < 200 ; retry ++) {
164163 Set <String > foundDevices = new HashSet <>();
165- try (final ResultSet resultSet =
166- statement .executeQuery (String .format ("SHOW LOADED MODELS '%s'" , device ))) {
164+ try (final ResultSet resultSet = statement .executeQuery (String .format ("SHOW LOADED MODELS '%s'" , device ))) {
167165 while (resultSet .next ()) {
168166 String deviceId = resultSet .getString ("DeviceId" );
169167 String loadedModelId = resultSet .getString ("ModelId" );
@@ -191,8 +189,7 @@ public static void checkModelNotOnSpecifiedDevice(
191189 LOGGER .info ("Checking model: {} not on target devices: {}" , modelId , targetDevices );
192190 for (int retry = 0 ; retry < 50 ; retry ++) {
193191 Set <String > foundDevices = new HashSet <>();
194- try (final ResultSet resultSet =
195- statement .executeQuery (String .format ("SHOW LOADED MODELS '%s'" , device ))) {
192+ try (final ResultSet resultSet = statement .executeQuery (String .format ("SHOW LOADED MODELS '%s'" , device ))) {
196193 while (resultSet .next ()) {
197194 String deviceId = resultSet .getString ("DeviceId" );
198195 String loadedModelId = resultSet .getString ("ModelId" );
@@ -213,16 +210,18 @@ public static void checkModelNotOnSpecifiedDevice(
213210 fail ("Model " + modelId + " is still loaded on device " + device );
214211 }
215212
216- private static final String [] WRITE_SQL_IN_TREE =
217- new String [] {
218- "CREATE DATABASE root.AI" ,
219- "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE" ,
220- "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE" ,
221- "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE" ,
222- "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE" ,
223- };
213+ private static final String [] WRITE_SQL_IN_TREE = new String [] {
214+ "CREATE DATABASE root.AI" ,
215+ "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE" ,
216+ "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE" ,
217+ "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE" ,
218+ "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE" ,
219+ };
224220
225- /** Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in tree. */
221+ /**
222+ * Prepare root.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of
223+ * data in tree.
224+ */
226225 public static void prepareDataInTree () throws SQLException {
227226 prepareData (WRITE_SQL_IN_TREE );
228227 try (Connection connection = EnvFactory .getEnv ().getConnection (BaseEnv .TREE_SQL_DIALECT );
@@ -236,7 +235,10 @@ public static void prepareDataInTree() throws SQLException {
236235 }
237236 }
238237
239- /** Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data in table. */
238+ /**
239+ * Prepare db.AI(s0 FLOAT, s1 DOUBLE, s2 INT32, s3 INT64) with 5760 rows of data
240+ * in table.
241+ */
240242 public static void prepareDataInTable () throws SQLException {
241243 try (Connection connection = EnvFactory .getEnv ().getConnection (BaseEnv .TABLE_SQL_DIALECT );
242244 Statement statement = connection .createStatement ()) {
0 commit comments