Skip to content

Commit ba3303c

Browse files
authored
Merge pull request #160 from VectorInstitute/feature/misc-fixes
* Move all bindings into BINDPATH env var * Fix required field checking for batch launch
2 parents 6d21da4 + aab1eef commit ba3303c

File tree

4 files changed

+43
-34
lines changed

4 files changed

+43
-34
lines changed

tests/vec_inf/client/test_slurm_script_generator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_init_singularity(self, singularity_params):
9999
assert generator.params == singularity_params
100100
assert generator.use_container
101101
assert not generator.is_multinode
102-
assert generator.additional_binds == " --bind /scratch:/scratch,/data:/data"
102+
assert generator.additional_binds == ",/scratch:/scratch,/data:/data"
103103
assert generator.model_weights_path == "/path/to/model_weights/test-model"
104104
assert (
105105
generator.env_str
@@ -185,8 +185,6 @@ def test_generate_launch_cmd_singularity(self, singularity_params):
185185
launch_cmd = generator._generate_launch_cmd()
186186

187187
assert "apptainer exec --nv" in launch_cmd
188-
assert "--bind /path/to/model_weights/test-model" in launch_cmd
189-
assert "--bind /scratch:/scratch,/data:/data" in launch_cmd
190188
assert "source" not in launch_cmd
191189

192190
def test_generate_launch_cmd_boolean_args(self, basic_params):
@@ -327,14 +325,17 @@ def test_init_singularity(self, batch_singularity_params):
327325
"""Test initialization with Singularity configuration."""
328326
generator = BatchSlurmScriptGenerator(batch_singularity_params)
329327

328+
print(generator.params["models"]["model1"]["additional_binds"])
329+
print(generator.params["models"]["model2"]["additional_binds"])
330+
330331
assert generator.use_container
331332
assert (
332333
generator.params["models"]["model1"]["additional_binds"]
333-
== " --bind /scratch:/scratch,/data:/data"
334+
== ",/scratch:/scratch,/data:/data"
334335
)
335336
assert (
336337
generator.params["models"]["model2"]["additional_binds"]
337-
== " --bind /scratch:/scratch,/data:/data"
338+
== ",/scratch:/scratch,/data:/data"
338339
)
339340

340341
def test_init_singularity_no_bind(self, batch_params):

vec_inf/client/_helper.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -469,16 +469,15 @@ def _get_launch_params(
469469
If required fields are missing or tensor parallel size is not specified
470470
when using multiple GPUs
471471
"""
472-
params: dict[str, Any] = {
473-
"models": {},
472+
common_params: dict[str, Any] = {
474473
"slurm_job_name": self.slurm_job_name,
475474
"src_dir": str(SRC_DIR),
476475
"account": account,
477476
"work_dir": work_dir,
478477
}
479478

480-
# Check for required fields without default vals, will raise an error if missing
481-
utils.check_required_fields(params)
479+
params: dict[str, Any] = common_params.copy()
480+
params["models"] = {}
482481

483482
for i, (model_name, config) in enumerate(self.model_configs.items()):
484483
params["models"][model_name] = config.model_dump(exclude_none=True)
@@ -555,6 +554,10 @@ def _get_launch_params(
555554
raise ValueError(
556555
f"Mismatch found for {arg}: {params[arg]} != {params['models'][model_name][arg]}, check your configuration"
557556
)
557+
# Check for required fields, will raise an error if missing any
558+
utils.check_required_fields(
559+
{**params["models"][model_name], **common_params}
560+
)
558561

559562
return params
560563

vec_inf/client/_slurm_script_generator.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def __init__(self, params: dict[str, Any]):
3434
self.params = params
3535
self.is_multinode = int(self.params["num_nodes"]) > 1
3636
self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
37-
self.additional_binds = self.params.get("bind", "")
38-
if self.additional_binds:
39-
self.additional_binds = f" --bind {self.additional_binds}"
37+
self.additional_binds = (
38+
f",{self.params['bind']}" if self.params.get("bind") else ""
39+
)
4040
self.model_weights_path = str(
4141
Path(self.params["model_weights_parent_dir"], self.params["model_name"])
4242
)
@@ -107,7 +107,12 @@ def _generate_server_setup(self) -> str:
107107
server_script = ["\n"]
108108
if self.use_container:
109109
server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_setup"]))
110-
server_script.append("\n".join(SLURM_SCRIPT_TEMPLATE["container_env_vars"]))
110+
server_script.append(
111+
SLURM_SCRIPT_TEMPLATE["bind_path"].format(
112+
model_weights_path=self.model_weights_path,
113+
additional_binds=self.additional_binds,
114+
)
115+
)
111116
else:
112117
server_script.append(
113118
SLURM_SCRIPT_TEMPLATE["activate_venv"].format(venv=self.params["venv"])
@@ -125,7 +130,6 @@ def _generate_server_setup(self) -> str:
125130
"CONTAINER_PLACEHOLDER",
126131
SLURM_SCRIPT_TEMPLATE["container_command"].format(
127132
model_weights_path=self.model_weights_path,
128-
additional_binds=self.additional_binds,
129133
env_str=self.env_str,
130134
),
131135
)
@@ -163,7 +167,6 @@ def _generate_launch_cmd(self) -> str:
163167
launcher_script.append(
164168
SLURM_SCRIPT_TEMPLATE["container_command"].format(
165169
model_weights_path=self.model_weights_path,
166-
additional_binds=self.additional_binds,
167170
env_str=self.env_str,
168171
)
169172
)
@@ -215,11 +218,11 @@ def __init__(self, params: dict[str, Any]):
215218
self.script_paths: list[Path] = []
216219
self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
217220
for model_name in self.params["models"]:
218-
self.params["models"][model_name]["additional_binds"] = ""
219-
if self.params["models"][model_name].get("bind"):
220-
self.params["models"][model_name]["additional_binds"] = (
221-
f" --bind {self.params['models'][model_name]['bind']}"
222-
)
221+
self.params["models"][model_name]["additional_binds"] = (
222+
f",{self.params['models'][model_name]['bind']}"
223+
if self.params["models"][model_name].get("bind")
224+
else ""
225+
)
223226
self.params["models"][model_name]["model_weights_path"] = str(
224227
Path(
225228
self.params["models"][model_name]["model_weights_parent_dir"],
@@ -259,7 +262,12 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
259262
script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["shebang"])
260263
if self.use_container:
261264
script_content.append(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_setup"])
262-
script_content.append("\n".join(BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["env_vars"]))
265+
script_content.append(
266+
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["bind_path"].format(
267+
model_weights_path=model_params["model_weights_path"],
268+
additional_binds=model_params["additional_binds"],
269+
)
270+
)
263271
script_content.append(
264272
"\n".join(
265273
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["server_address_setup"]
@@ -277,7 +285,6 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
277285
script_content.append(
278286
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["container_command"].format(
279287
model_weights_path=model_params["model_weights_path"],
280-
additional_binds=model_params["additional_binds"],
281288
)
282289
)
283290
script_content.append(

vec_inf/client/_slurm_templates.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class SlurmScriptTemplate(TypedDict):
5757
Commands for container setup
5858
imports : str
5959
Import statements and source commands
60+
bind_path : str
61+
Bind path environment variable for the container
6062
container_command : str
6163
Template for container execution command
6264
activate_venv : str
@@ -74,7 +76,7 @@ class SlurmScriptTemplate(TypedDict):
7476
shebang: ShebangConfig
7577
container_setup: list[str]
7678
imports: str
77-
container_env_vars: list[str]
79+
bind_path: str
7880
container_command: str
7981
activate_venv: str
8082
server_setup: ServerSetupConfig
@@ -96,10 +98,8 @@ class SlurmScriptTemplate(TypedDict):
9698
f"{CONTAINER_MODULE_NAME} exec {IMAGE_PATH} ray stop",
9799
],
98100
"imports": "source {src_dir}/find_port.sh",
99-
"container_env_vars": [
100-
f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp"
101-
],
102-
"container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --bind {{model_weights_path}}{{additional_binds}} --containall {IMAGE_PATH} \\",
101+
"bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp,{{model_weights_path}}{{additional_binds}}",
102+
"container_command": f"{CONTAINER_MODULE_NAME} exec --nv {{env_str}} --containall {IMAGE_PATH} \\",
103103
"activate_venv": "source {venv}/bin/activate",
104104
"server_setup": {
105105
"single_node": [
@@ -215,8 +215,8 @@ class BatchModelLaunchScriptTemplate(TypedDict):
215215
Shebang line for the script
216216
container_setup : list[str]
217217
Commands for container setup
218-
env_vars : list[str]
219-
Environment variables to set
218+
bind_path : str
219+
Bind path environment variable for the container
220220
server_address_setup : list[str]
221221
Commands to setup the server address
222222
launch_cmd : list[str]
@@ -227,7 +227,7 @@ class BatchModelLaunchScriptTemplate(TypedDict):
227227

228228
shebang: str
229229
container_setup: str
230-
env_vars: list[str]
230+
bind_path: str
231231
server_address_setup: list[str]
232232
write_to_json: list[str]
233233
launch_cmd: list[str]
@@ -237,9 +237,7 @@ class BatchModelLaunchScriptTemplate(TypedDict):
237237
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE: BatchModelLaunchScriptTemplate = {
238238
"shebang": "#!/bin/bash\n",
239239
"container_setup": f"{CONTAINER_LOAD_CMD}\n",
240-
"env_vars": [
241-
f"export {CONTAINER_MODULE_NAME}_BINDPATH=${CONTAINER_MODULE_NAME}_BINDPATH,$(echo /dev/infiniband* | sed -e 's/ /,/g')"
242-
],
240+
"bind_path": f"export {CONTAINER_MODULE_NAME.upper()}_BINDPATH=${CONTAINER_MODULE_NAME.upper()}_BINDPATH,/dev,/tmp,{{model_weights_path}}{{additional_binds}}",
243241
"server_address_setup": [
244242
"source {src_dir}/find_port.sh",
245243
"head_node_ip=${{SLURMD_NODENAME}}",
@@ -255,7 +253,7 @@ class BatchModelLaunchScriptTemplate(TypedDict):
255253
' "$json_path" > temp_{model_name}.json \\',
256254
' && mv temp_{model_name}.json "$json_path"\n',
257255
],
258-
"container_command": f"{CONTAINER_MODULE_NAME} exec --nv --bind {{model_weights_path}}{{additional_binds}} --containall {IMAGE_PATH} \\",
256+
"container_command": f"{CONTAINER_MODULE_NAME} exec --nv --containall {IMAGE_PATH} \\",
259257
"launch_cmd": [
260258
"vllm serve {model_weights_path} \\",
261259
" --served-model-name {model_name} \\",

0 commit comments

Comments
 (0)