-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrunpod_scaling_test.py
More file actions
150 lines (128 loc) · 5.85 KB
/
runpod_scaling_test.py
File metadata and controls
150 lines (128 loc) · 5.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import argparse
import logging
import os
import subprocess
import requests
import time
import threading
# Global state for SSH tunnel management
_ssh_tunnel_lock = threading.Lock()
_ssh_tunnel_pid = None
_ssh_tunnel_port = 7337
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def _get_available_port(base_port=7337, max_attempts=50):
"""Find an available port starting from base_port"""
for i in range(max_attempts):
port = base_port + i
try:
# Check if port is in use
result = subprocess.run(
f"lsof -ti:{port}",
shell=True,
capture_output=True,
text=True,
timeout=5
)
if not result.stdout.strip(): # Port is free
return port
except (subprocess.TimeoutExpired, Exception):
continue
raise Exception(f"Could not find available port after {max_attempts} attempts")
def _test_tunnel_connection(port, timeout=10):
"""Test if SSH tunnel is working by checking if port is accessible"""
try:
# In a real scenario, this would be a more specific health check endpoint.
# For this test, we'll assume the presence of the root URL is enough.
url = f"http://localhost:{port}/"
# We don't expect a 200 on /, but a connection should be established
requests.get(url, timeout=timeout)
return True
except requests.ConnectionError:
logger.warning(f"Connection error on port {port}. Tunnel may not be ready.")
return False
except Exception as e:
logger.error(f"Error testing tunnel connection on port {port}: {e}")
return False
def _setup_shared_ssh_tunnel(remote_port, host="runpod_a100_box", max_retries=3):
"""Setup a single shared SSH tunnel that can handle multiple concurrent connections"""
global _ssh_tunnel_pid, _ssh_tunnel_port
with _ssh_tunnel_lock:
if _ssh_tunnel_pid is not None:
try:
os.kill(_ssh_tunnel_pid, 0)
if _test_tunnel_connection(_ssh_tunnel_port):
logger.info(f"Reusing existing SSH tunnel on port {_ssh_tunnel_port} (PID: {_ssh_tunnel_pid})")
return _ssh_tunnel_port
else:
logger.warning(f"SSH tunnel process exists but connection test failed")
except OSError:
logger.info(f"SSH tunnel process {_ssh_tunnel_pid} is dead, creating new tunnel")
_ssh_tunnel_pid = None
for attempt in range(max_retries):
try:
local_port = _get_available_port(_ssh_tunnel_port)
subprocess.run(f"lsof -ti:{local_port} | xargs -r kill -9", shell=True)
time.sleep(1)
cmd = f"ssh -N -L {local_port}:localhost:{remote_port} {host}"
logger.info(f"Starting SSH tunnel: {cmd}")
process = subprocess.Popen(cmd, shell=True)
time.sleep(3) # Give tunnel time to establish
# A simple connection test might be needed here in a real case
# For now, we'll assume it works if the process starts.
# A better test would be `_test_tunnel_connection`
if process.poll() is None: # check if process is running
_ssh_tunnel_pid = process.pid
_ssh_tunnel_port = local_port
logger.info(f"SSH tunnel established on port {local_port} (PID: {process.pid})")
return local_port
else:
logger.warning(f"Tunnel process exited unexpectedly, attempt {attempt + 1}")
except Exception as e:
logger.warning(f"SSH tunnel setup failed, attempt {attempt + 1}: {e}")
time.sleep(2)
raise Exception(f"Failed to establish SSH tunnel after {max_retries} attempts")
def get_vllm_api_base(use_runpod: bool, runpod_endpoint_id: str = "pmave9bk168p0q"):
"""
Determines the vLLM API base URL.
If use_runpod is True, it returns the RunPod API endpoint.
If use_runpod is False, it sets up an SSH tunnel and returns the local URL.
"""
if use_runpod:
vllm_api_base = f"https://api.runpod.ai/v2/{runpod_endpoint_id}/openai"
logger.info(f"Using RunPod endpoint: {vllm_api_base}")
return vllm_api_base
else:
logger.info("Setting up local SSH tunnel for vLLM access.")
try:
# Assuming remote vLLM runs on port 8000
remote_port = 8000
local_port = _setup_shared_ssh_tunnel(remote_port)
vllm_api_base = f"http://localhost:{local_port}/v1"
logger.info(f"Using local vLLM endpoint via SSH tunnel: {vllm_api_base}")
return vllm_api_base
except Exception as e:
logger.error(f"Failed to setup SSH tunnel: {e}")
raise
def main():
parser = argparse.ArgumentParser(description="Test RunPod scaling toggle.")
parser.add_argument("--use-runpod", action="store_true", help="Use RunPod for vLLM backend.")
args = parser.parse_args()
os.environ["VLLM_BACKEND_USE_RUNPOD"] = "True" if args.use_runpod else "False"
logger.info(f"VLLM_BACKEND_USE_RUNPOD set to: {os.environ['VLLM_BACKEND_USE_RUNPOD']}")
try:
api_base_url = get_vllm_api_base(use_runpod=args.use_runpod)
print(f"Successfully determined vLLM API base URL: {api_base_url}")
except Exception as e:
print(f"Error: {e}")
finally:
# Cleanup the SSH tunnel if it was created
global _ssh_tunnel_pid
if _ssh_tunnel_pid is not None:
logger.info(f"Cleaning up SSH tunnel (PID: {_ssh_tunnel_pid})...")
try:
os.kill(_ssh_tunnel_pid, 9)
except OSError:
pass # Process might already be gone
if __name__ == "__main__":
main()