Skip to content

Commit 03864a5

Browse files
committed
Fix typos
1 parent cd7a9b1 commit 03864a5

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

vec_inf/client/_slurm_script_generator.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,10 @@ def __init__(self, params: dict[str, Any]):
3636
self.is_multinode = int(self.params["num_nodes"]) > 1
3737
self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
3838
self.additional_binds = (
39-
{self.params['bind']} if self.params.get("bind") else ""
39+
self.params["bind"] if self.params.get("bind") else ""
4040
)
41-
self.model_weights_path = Path(
42-
self.params["model_weights_parent_dir"], self.params["model_name"]
41+
self.model_weights_path = str(
42+
Path(self.params["model_weights_parent_dir"], self.params["model_name"])
4343
)
4444
self.env_str = self._generate_env_str()
4545

@@ -187,7 +187,7 @@ def _generate_launch_cmd(self) -> str:
187187

188188
launch_cmd.append(
189189
"\n".join(SLURM_SCRIPT_TEMPLATE["launch_cmd"][self.engine]).format( # type: ignore[literal-required]
190-
model_weights_path=self.model_weights_path if not self.params.get("hf_model") else self.params["hf_model"],
190+
model_weights_path=self.params.get("hf_model") or self.model_weights_path,
191191
model_name=self.params["model_name"],
192192
)
193193
)
@@ -217,7 +217,7 @@ def _generate_multinode_sglang_launch_cmd(self) -> str:
217217
SLURM_SCRIPT_TEMPLATE["launch_cmd"]["sglang_multinode"]
218218
).format(
219219
num_nodes=self.params["num_nodes"],
220-
model_weights_path=self.model_weights_path if not self.params.get("hf_model") else self.params["hf_model"],
220+
model_weights_path=self.params.get("hf_model") or self.model_weights_path,
221221
model_name=self.params["model_name"],
222222
)
223223

@@ -277,7 +277,7 @@ def __init__(self, params: dict[str, Any]):
277277
self.use_container = self.params["venv"] == CONTAINER_MODULE_NAME
278278
for model_name in self.params["models"]:
279279
self.params["models"][model_name]["additional_binds"] = (
280-
{self.params['models'][model_name]['bind']}
280+
self.params["models"][model_name]["bind"]
281281
if self.params["models"][model_name].get("bind")
282282
else ""
283283
)
@@ -352,7 +352,7 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
352352
"\n".join(
353353
BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE["launch_cmd"][model_params["engine"]]
354354
).format(
355-
model_weights_path=model_params["model_weights_path"] if not model_params.get("hf_model") else model_params["hf_model"],
355+
model_weights_path=model_params.get("hf_model") or model_params["model_weights_path"],
356356
model_name=model_name,
357357
)
358358
)

0 commit comments

Comments
 (0)