Skip to content

Commit e8af596

Browse files
authored
[AINode] Upgrade torch version (#17323)
1 parent c9066c7 commit e8af596

File tree

4 files changed

+83
-16
lines changed

4 files changed

+83
-16
lines changed

integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import java.util.ArrayList;
3939
import java.util.List;
4040
import java.util.Properties;
41+
import java.util.concurrent.TimeUnit;
4142
import java.util.stream.Stream;
4243

4344
import static org.apache.iotdb.it.env.cluster.ClusterConstant.AI_NODE_NAME;
@@ -52,7 +53,8 @@ public class AINodeWrapper extends AbstractNodeWrapper {
5253
private final String seedConfigNode;
5354
private final int clusterIngressPort;
5455

55-
private static final String SCRIPT_FILE = "start-ainode.sh";
56+
private static final String START_SCRIPT_FILE = "start-ainode.sh";
57+
private static final String STOP_SCRIPT_FILE = "stop-ainode.sh";
5658

5759
private static final String SHELL_COMMAND = "bash";
5860

@@ -165,8 +167,8 @@ public void start() {
165167
// start AINode
166168
List<String> startCommand = new ArrayList<>();
167169
startCommand.add(SHELL_COMMAND);
168-
startCommand.add(filePrefix + File.separator + SCRIPT_PATH + File.separator + SCRIPT_FILE);
169-
startCommand.add("-r");
170+
startCommand.add(
171+
filePrefix + File.separator + SCRIPT_PATH + File.separator + START_SCRIPT_FILE);
170172

171173
ProcessBuilder processBuilder =
172174
new ProcessBuilder(startCommand)
@@ -179,6 +181,48 @@ public void start() {
179181
}
180182
}
181183

184+
@Override
185+
public void stop() {
186+
if (this.instance == null) {
187+
return;
188+
}
189+
try {
190+
// stop AINode
191+
File stdoutFile = new File(getLogPath());
192+
String filePrefix = getNodePath();
193+
List<String> stopCommand = new ArrayList<>();
194+
stopCommand.add(SHELL_COMMAND);
195+
stopCommand.add(
196+
filePrefix + File.separator + SCRIPT_PATH + File.separator + STOP_SCRIPT_FILE);
197+
ProcessBuilder processBuilder =
198+
new ProcessBuilder(stopCommand)
199+
.redirectOutput(ProcessBuilder.Redirect.appendTo(stdoutFile))
200+
.redirectError(ProcessBuilder.Redirect.appendTo(stdoutFile));
201+
Process stopProcess = processBuilder.inheritIO().start();
202+
if (!stopProcess.waitFor(20, TimeUnit.SECONDS)) {
203+
logger.warn("Node {} does not exit within 20s, killing it", getId());
204+
if (!this.instance.destroyForcibly().waitFor(10, TimeUnit.SECONDS)) {
205+
logger.error("Cannot forcibly stop node {}", getId());
206+
}
207+
}
208+
int exitCode = stopProcess.exitValue();
209+
if (exitCode != 0) {
210+
logger.warn("Node {}'s stop script exited with code {}", getId(), exitCode);
211+
}
212+
} catch (InterruptedException e) {
213+
Thread.currentThread().interrupt();
214+
logger.error("Waiting node to shutdown error.", e);
215+
} catch (IOException e) {
216+
logger.error("Waiting node to shutdown error.", e);
217+
}
218+
logger.info("In test {} {} stopped.", getTestLogDirName(), getId());
219+
}
220+
221+
@Override
222+
public void stopForcibly() {
223+
this.stop();
224+
}
225+
182226
@Override
183227
public int getMetricPort() {
184228
// no metric currently

iotdb-core/ainode/iotdb/ainode/core/device/backend/cuda_backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
#
18+
import time
1819

1920
import torch
2021

@@ -24,6 +25,9 @@
2425
class CUDABackend(BackendAdapter):
2526
type = BackendType.CUDA
2627

28+
def __init__(self) -> None:
29+
self._safe_cuda_init()
30+
2731
def is_available(self) -> bool:
2832
return torch.cuda.is_available()
2933

@@ -37,3 +41,19 @@ def make_device(self, index: int | None) -> torch.device:
3741

3842
def set_device(self, index: int) -> None:
3943
torch.cuda.set_device(index)
44+
45+
def _safe_cuda_init(self) -> None:
46+
# Safe CUDA initialization to avoid potential deadlocks
47+
# This is a workaround for certain PyTorch versions where the first CUDA call can cause a long delay
48+
# By calling a simple CUDA operation at startup, we can ensure that the CUDA context is initialized early
49+
# and avoid unexpected delays during actual model loading or inference.
50+
attempt_cnt = 3
51+
for attempt in range(attempt_cnt):
52+
try:
53+
if self.is_available():
54+
return
55+
raise RuntimeError("CUDA not available")
56+
except Exception as e:
57+
print(f"CUDA init attempt {attempt + 1} failed: {e}")
58+
if attempt < attempt_cnt:
59+
time.sleep(1.5)

iotdb-core/ainode/iotdb/ainode/core/script.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,26 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
#
18+
1819
import multiprocessing
1920
import sys
2021

22+
# PyInstaller multiprocessing support
23+
# freeze_support() is essential for PyInstaller frozen executables on all platforms
24+
# It detects if the current process is a multiprocessing child process
25+
# If it is, it executes the child process target function and exits
26+
# If it's not, it returns immediately and continues with main() execution
27+
# This prevents child processes from executing the main application logic
28+
if getattr(sys, "frozen", False):
29+
# Call freeze_support() for both standard multiprocessing and torch.multiprocessing
30+
multiprocessing.freeze_support()
31+
multiprocessing.set_start_method("spawn", force=True)
32+
2133
import torch.multiprocessing as mp
2234

35+
mp.freeze_support()
36+
mp.set_start_method("spawn", force=True)
37+
2338
from iotdb.ainode.core.ai_node import AINode
2439
from iotdb.ainode.core.log import Logger
2540

@@ -42,7 +57,6 @@ def main():
4257
command = arguments[1]
4358
if command == "start":
4459
try:
45-
mp.set_start_method("spawn", force=True)
4660
logger.info(f"Current multiprocess start method: {mp.get_start_method()}")
4761
logger.info("IoTDB-AINode is starting...")
4862
ai_node = AINode()
@@ -55,15 +69,4 @@ def main():
5569

5670

5771
if __name__ == "__main__":
58-
# PyInstaller multiprocessing support
59-
# freeze_support() is essential for PyInstaller frozen executables on all platforms
60-
# It detects if the current process is a multiprocessing child process
61-
# If it is, it executes the child process target function and exits
62-
# If it's not, it returns immediately and continues with main() execution
63-
# This prevents child processes from executing the main application logic
64-
if getattr(sys, "frozen", False):
65-
# Call freeze_support() for both standard multiprocessing and torch.multiprocessing
66-
multiprocessing.freeze_support()
67-
mp.freeze_support()
68-
6972
main()

iotdb-core/ainode/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ exclude = [
7979
python = ">=3.11.0,<3.12.0"
8080

8181
# ---- DL / HF stack ----
82-
torch = "^2.8.0,<2.9.0"
82+
torch = "^2.9.0,<2.10.0"
8383
torchmetrics = "^1.8.0"
8484
transformers = "==4.56.2"
8585
tokenizers = ">=0.22.0,<=0.23.0"

0 commit comments

Comments
 (0)