@@ -36,10 +36,10 @@ def __init__(self, params: dict[str, Any]):
3636 self .is_multinode = int (self .params ["num_nodes" ]) > 1
3737 self .use_container = self .params ["venv" ] == CONTAINER_MODULE_NAME
3838 self .additional_binds = (
39- { self .params [' bind' ]} if self .params .get ("bind" ) else ""
39+ self .params [" bind" ] if self .params .get ("bind" ) else ""
4040 )
41- self .model_weights_path = Path (
42- self .params ["model_weights_parent_dir" ], self .params ["model_name" ]
41+ self .model_weights_path = str (
42+ Path ( self .params ["model_weights_parent_dir" ], self .params ["model_name" ])
4343 )
4444 self .env_str = self ._generate_env_str ()
4545
@@ -187,7 +187,7 @@ def _generate_launch_cmd(self) -> str:
187187
188188 launch_cmd .append (
189189 "\n " .join (SLURM_SCRIPT_TEMPLATE ["launch_cmd" ][self .engine ]).format ( # type: ignore[literal-required]
190- model_weights_path = self .model_weights_path if not self . params .get ("hf_model" ) else self .params [ "hf_model" ] ,
190+ model_weights_path = self .params .get ("hf_model" ) or self .model_weights_path ,
191191 model_name = self .params ["model_name" ],
192192 )
193193 )
@@ -217,7 +217,7 @@ def _generate_multinode_sglang_launch_cmd(self) -> str:
217217 SLURM_SCRIPT_TEMPLATE ["launch_cmd" ]["sglang_multinode" ]
218218 ).format (
219219 num_nodes = self .params ["num_nodes" ],
220- model_weights_path = self .model_weights_path if not self . params .get ("hf_model" ) else self .params [ "hf_model" ] ,
220+ model_weights_path = self .params .get ("hf_model" ) or self .model_weights_path ,
221221 model_name = self .params ["model_name" ],
222222 )
223223
@@ -277,7 +277,7 @@ def __init__(self, params: dict[str, Any]):
277277 self .use_container = self .params ["venv" ] == CONTAINER_MODULE_NAME
278278 for model_name in self .params ["models" ]:
279279 self .params ["models" ][model_name ]["additional_binds" ] = (
280- { self .params [' models' ][model_name ][' bind' ]}
280+ self .params [" models" ][model_name ][" bind" ]
281281 if self .params ["models" ][model_name ].get ("bind" )
282282 else ""
283283 )
@@ -352,7 +352,7 @@ def _generate_model_launch_script(self, model_name: str) -> Path:
352352 "\n " .join (
353353 BATCH_MODEL_LAUNCH_SCRIPT_TEMPLATE ["launch_cmd" ][model_params ["engine" ]]
354354 ).format (
355- model_weights_path = model_params [ "model_weights_path" ] if not model_params .get ("hf_model" ) else model_params ["hf_model " ],
355+ model_weights_path = model_params .get ("hf_model" ) or model_params ["model_weights_path " ],
356356 model_name = model_name ,
357357 )
358358 )
0 commit comments