Skip to content

Commit 5a83067

Browse files
authored
[AINode] Support hubmixin models and modify pipeline (#17334)
1 parent 6f54691 commit 5a83067

File tree

21 files changed

+709
-298
lines changed

21 files changed

+709
-298
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,16 @@
2020
package org.apache.iotdb.ainode.it;
2121

2222
import org.apache.iotdb.it.env.EnvFactory;
23+
import org.apache.iotdb.it.framework.IoTDBTestRunner;
24+
import org.apache.iotdb.itbase.category.AIClusterIT;
2325
import org.apache.iotdb.itbase.env.BaseEnv;
2426

2527
import org.junit.AfterClass;
2628
import org.junit.Assert;
2729
import org.junit.BeforeClass;
2830
import org.junit.Test;
31+
import org.junit.experimental.categories.Category;
32+
import org.junit.runner.RunWith;
2933

3034
import java.sql.Connection;
3135
import java.sql.ResultSet;
@@ -41,9 +45,13 @@
4145
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice;
4246
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
4347

48+
@RunWith(IoTDBTestRunner.class)
49+
@Category({AIClusterIT.class})
4450
public class AINodeInstanceManagementIT {
4551

46-
private static final Set<String> TARGET_DEVICES = new HashSet<>(Arrays.asList("cpu", "0", "1"));
52+
private static final String TARGET_DEVICES_STR = "0,1";
53+
private static final Set<String> TARGET_DEVICES =
54+
new HashSet<>(Arrays.asList(TARGET_DEVICES_STR.split(",")));
4755

4856
@BeforeClass
4957
public static void setUp() throws Exception {
@@ -76,53 +84,57 @@ private void basicManagementTest(Statement statement) throws SQLException, Inter
7684
// Ensure resources
7785
try (ResultSet resultSet = statement.executeQuery("SHOW AI_DEVICES")) {
7886
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
79-
checkHeader(resultSetMetaData, "DeviceID");
87+
checkHeader(resultSetMetaData, "DeviceId,DeviceType");
8088
final Set<String> resultDevices = new HashSet<>();
8189
while (resultSet.next()) {
82-
resultDevices.add(resultSet.getString("DeviceID"));
90+
resultDevices.add(resultSet.getString("DeviceId"));
8391
}
84-
Assert.assertEquals(TARGET_DEVICES, resultDevices);
92+
Set<String> expected = new HashSet<>(TARGET_DEVICES);
93+
expected.add("cpu");
94+
Assert.assertEquals(expected, resultDevices);
8595
}
8696

8797
// Load sundial to each device
88-
statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES));
89-
checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
98+
statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES_STR));
99+
checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);
90100
// Unload sundial from each device
91-
statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES));
92-
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
101+
statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES_STR));
102+
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);
93103

94104
// Load timer_xl to each device
95-
statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES));
96-
checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString());
105+
statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES_STR));
106+
checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES_STR);
97107
// Unload timer_xl from each device
98-
statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES));
99-
checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString());
108+
statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES_STR));
109+
checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES_STR);
100110
}
101111

102112
private static final int LOOP_CNT = 10;
103113

104-
@Test
114+
// @Test
105115
public void repeatLoadAndUnloadTest() throws SQLException, InterruptedException {
106116
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
107117
Statement statement = connection.createStatement()) {
108118
for (int i = 0; i < LOOP_CNT; i++) {
109-
statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\"");
110-
checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
111-
statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\"");
112-
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
119+
statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES_STR));
120+
checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);
121+
statement.execute(
122+
String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES_STR));
123+
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);
113124
}
114125
}
115126
}
116127

117-
@Test
128+
// @Test
118129
public void concurrentLoadAndUnloadTest() throws SQLException, InterruptedException {
119130
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
120131
Statement statement = connection.createStatement()) {
121132
for (int i = 0; i < LOOP_CNT; i++) {
122-
statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\"");
123-
statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\"");
133+
statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES_STR));
134+
statement.execute(
135+
String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES_STR));
124136
}
125-
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
137+
checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR);
126138
}
127139
}
128140

@@ -145,23 +157,23 @@ public void failTestInTableModel() throws SQLException {
145157
private void failTest(Statement statement) {
146158
errorTest(
147159
statement,
148-
"LOAD MODEL unknown TO DEVICES \"cpu,0,1\"",
149-
"1505: Cannot load model [unknown], because it is neither a built-in nor a fine-tuned model. You can use 'SHOW MODELS' to retrieve the available models.");
160+
"LOAD MODEL unknown TO DEVICES 'cpu,0,1'",
161+
"1504: Model [unknown] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models.");
150162
errorTest(
151163
statement,
152-
"LOAD MODEL sundial TO DEVICES \"unknown\"",
153-
"1507: Device ID [unknown] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.");
164+
"LOAD MODEL sundial TO DEVICES '999'",
165+
"1508: AIDevice ID [999] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.");
154166
errorTest(
155167
statement,
156-
"UNLOAD MODEL sundial FROM DEVICES \"unknown\"",
157-
"1507: Device ID [unknown] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.");
168+
"UNLOAD MODEL sundial FROM DEVICES '999'",
169+
"1508: AIDevice ID [999] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.");
158170
errorTest(
159171
statement,
160-
"LOAD MODEL sundial TO DEVICES \"0,0\"",
172+
"LOAD MODEL sundial TO DEVICES '0,0'",
161173
"1509: Device ID list contains duplicate entries.");
162174
errorTest(
163175
statement,
164-
"UNLOAD MODEL sundial FROM DEVICES \"0,0\"",
176+
"UNLOAD MODEL sundial FROM DEVICES '0,0'",
165177
"1510: Device ID list contains duplicate entries.");
166178
}
167179
}

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -71,39 +71,58 @@ public static void tearDown() throws Exception {
7171
public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException {
7272
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
7373
Statement statement = connection.createStatement()) {
74-
registerUserDefinedModel(statement);
75-
callInferenceTest(
76-
statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"));
77-
dropUserDefinedModel(statement);
74+
// Test transformers model (chronos2) in tree.
75+
AINodeTestUtils.FakeModelInfo modelInfo =
76+
new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active");
77+
registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2");
78+
callInferenceTest(statement, modelInfo);
79+
dropUserDefinedModel(statement, modelInfo.getModelId());
7880
errorTest(
7981
statement,
8082
"create model origin_chronos using uri \"file:///data/chronos2_origin\"",
8183
"1505: 't5' is already used by a Transformers config, pick another name.");
8284
statement.execute("drop model origin_chronos");
85+
86+
// Test PytorchModelHubMixin model (mantis) in tree.
87+
modelInfo = new FakeModelInfo("user_mantis", "custom_mantis", "user_defined", "active");
88+
registerUserDefinedModel(statement, modelInfo, "file:///data/mantis");
89+
dropUserDefinedModel(statement, modelInfo.getModelId());
8390
}
8491
}
8592

8693
@Test
8794
public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException {
8895
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
8996
Statement statement = connection.createStatement()) {
90-
registerUserDefinedModel(statement);
91-
forecastTableFunctionTest(
92-
statement, new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"));
93-
dropUserDefinedModel(statement);
97+
// Test transformers model (chronos2) in table.
98+
AINodeTestUtils.FakeModelInfo modelInfo =
99+
new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active");
100+
registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2");
101+
forecastTableFunctionTest(statement, modelInfo);
102+
dropUserDefinedModel(statement, modelInfo.getModelId());
94103
errorTest(
95104
statement,
96105
"create model origin_chronos using uri \"file:///data/chronos2_origin\"",
97106
"1505: 't5' is already used by a Transformers config, pick another name.");
98107
statement.execute("drop model origin_chronos");
108+
109+
// Test PytorchModelHubMixin model (mantis) in table.
110+
modelInfo = new FakeModelInfo("user_mantis", "custom_mantis", "user_defined", "active");
111+
registerUserDefinedModel(statement, modelInfo, "file:///data/mantis");
112+
dropUserDefinedModel(statement, modelInfo.getModelId());
99113
}
100114
}
101115

102-
private void registerUserDefinedModel(Statement statement)
116+
public static void registerUserDefinedModel(
117+
Statement statement, AINodeTestUtils.FakeModelInfo modelInfo, String uri)
103118
throws SQLException, InterruptedException {
119+
String modelId = modelInfo.getModelId();
120+
String modelType = modelInfo.getModelType();
121+
String category = modelInfo.getCategory();
122+
final String CREATE_MODEL_TEMPLATE = "create model %s using uri \"%s\"";
104123
final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'";
105-
final String registerSql = "create model user_chronos using uri \"file:///data/chronos2\"";
106-
final String showSql = "SHOW MODELS user_chronos";
124+
final String registerSql = String.format(CREATE_MODEL_TEMPLATE, modelId, uri);
125+
final String showSql = String.format("SHOW MODELS %s", modelId);
107126
statement.execute(alterConfigSQL);
108127
statement.execute(registerSql);
109128
boolean loading = true;
@@ -112,13 +131,13 @@ private void registerUserDefinedModel(Statement statement)
112131
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
113132
checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State");
114133
while (resultSet.next()) {
115-
String modelId = resultSet.getString(1);
116-
String modelType = resultSet.getString(2);
117-
String category = resultSet.getString(3);
134+
String resultModelId = resultSet.getString(1);
135+
String resultModelType = resultSet.getString(2);
136+
String resultCategory = resultSet.getString(3);
118137
String state = resultSet.getString(4);
119-
assertEquals("user_chronos", modelId);
120-
assertEquals("custom_t5", modelType);
121-
assertEquals("user_defined", category);
138+
assertEquals(modelId, resultModelId);
139+
assertEquals(modelType, resultModelType);
140+
assertEquals(category, resultCategory);
122141
if (state.equals("active")) {
123142
loading = false;
124143
} else if (state.equals("loading")) {
@@ -136,9 +155,9 @@ private void registerUserDefinedModel(Statement statement)
136155
assertFalse(loading);
137156
}
138157

139-
private void dropUserDefinedModel(Statement statement) throws SQLException {
140-
final String showSql = "SHOW MODELS user_chronos";
141-
final String dropSql = "DROP MODEL user_chronos";
158+
public static void dropUserDefinedModel(Statement statement, String modelId) throws SQLException {
159+
final String showSql = String.format("SHOW MODELS %s", modelId);
160+
final String dropSql = String.format("DROP MODEL %s", modelId);
142161
statement.execute(dropSql);
143162
try (ResultSet resultSet = statement.executeQuery(showSql)) {
144163
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();

integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ public class AINodeTestUtils {
5151

5252
public static final Map<String, FakeModelInfo> BUILTIN_LTSM_MAP =
5353
Stream.of(
54-
new AbstractMap.SimpleEntry<>(
55-
"sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")),
5654
new AbstractMap.SimpleEntry<>(
5755
"timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")),
56+
new AbstractMap.SimpleEntry<>(
57+
"sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")),
5858
new AbstractMap.SimpleEntry<>(
5959
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin", "active")),
6060
new AbstractMap.SimpleEntry<>(
@@ -171,7 +171,7 @@ public static void checkModelOnSpecifiedDevice(Statement statement, String model
171171
LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count);
172172
if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) {
173173
foundDevices.add(deviceId);
174-
LOGGER.info("Model {} is loaded to device {}", modelId, device);
174+
LOGGER.info("Model {} is loaded to device {}", modelId, deviceId);
175175
}
176176
}
177177
if (foundDevices.containsAll(targetDevices)) {
@@ -252,6 +252,32 @@ public static void prepareDataInTable() throws SQLException {
252252
}
253253
}
254254

255+
/** Prepare db.AI2(s0 FLOAT,...) with 2880 rows of data in table. */
256+
public static void prepareDataInTable2() throws SQLException {
257+
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
258+
Statement statement = connection.createStatement()) {
259+
statement.execute("CREATE DATABASE db");
260+
statement.execute(
261+
"CREATE TABLE db.AI2 (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD, s4 FLOAT FIELD, s5 DOUBLE FIELD, s6 INT32 FIELD, s7 INT64 FIELD, s8 FLOAT FIELD, s9 DOUBLE FIELD)");
262+
for (int i = 0; i < 2880; i++) {
263+
statement.execute(
264+
String.format(
265+
"INSERT INTO db.AI2(time,s0,s1,s2,s3,s4,s5,s6,s7,s8,s9) VALUES(%d,%f,%f,%d,%d,%f,%f,%d,%d,%f,%f)",
266+
i,
267+
(float) i,
268+
(double) i,
269+
i,
270+
i,
271+
(float) (i * 2),
272+
(double) (i * 2),
273+
i * 2,
274+
i * 2,
275+
(float) (i * 3),
276+
(double) (i * 3)));
277+
}
278+
}
279+
}
280+
255281
public static class FakeModelInfo {
256282

257283
private final String modelId;

iotdb-core/ainode/iotdb/ainode/core/exception.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
#
18-
import re
19-
20-
from iotdb.ainode.core.model.model_constants import (
21-
MODEL_CONFIG_FILE_IN_YAML,
22-
MODEL_WEIGHTS_FILE_IN_PT,
23-
)
2418

2519

2620
class _BaseException(Exception):

0 commit comments

Comments
 (0)