[Core] Refactor executor classes for easier inheritance (#7673)
[Core] Refactor executor classes to make it easier to inherit GPUExecutor (#7673)
This commit is contained in:
parent
ad28a74beb
commit
b6f99a6ffe
@ -62,6 +62,18 @@ class GPUExecutor(ExecutorBase):
|
|||||||
observability_config=self.observability_config,
|
observability_config=self.observability_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_worker_module_and_class(self) -> Tuple[str, str]:
|
||||||
|
if self.scheduler_config.is_multi_step:
|
||||||
|
worker_module_name = "vllm.worker.multi_step_worker"
|
||||||
|
worker_class_name = "MultiStepWorker"
|
||||||
|
elif self.speculative_config:
|
||||||
|
worker_module_name = "vllm.spec_decode.spec_decode_worker"
|
||||||
|
worker_class_name = "create_spec_worker"
|
||||||
|
else:
|
||||||
|
worker_module_name = "vllm.worker.worker"
|
||||||
|
worker_class_name = "Worker"
|
||||||
|
return (worker_module_name, worker_class_name)
|
||||||
|
|
||||||
def _get_create_worker_kwargs(
|
def _get_create_worker_kwargs(
|
||||||
self,
|
self,
|
||||||
local_rank: int = 0,
|
local_rank: int = 0,
|
||||||
@ -70,17 +82,10 @@ class GPUExecutor(ExecutorBase):
|
|||||||
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
|
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
|
||||||
distributed_init_method)
|
distributed_init_method)
|
||||||
|
|
||||||
if self.scheduler_config.is_multi_step:
|
(worker_module_name,
|
||||||
worker_kwargs.update(
|
worker_class_name) = self._get_worker_module_and_class()
|
||||||
worker_module_name="vllm.worker.multi_step_worker",
|
worker_kwargs.update(worker_module_name=worker_module_name,
|
||||||
worker_class_name="MultiStepWorker")
|
worker_class_name=worker_class_name)
|
||||||
elif self.speculative_config:
|
|
||||||
worker_kwargs.update(
|
|
||||||
worker_module_name="vllm.spec_decode.spec_decode_worker",
|
|
||||||
worker_class_name="create_spec_worker")
|
|
||||||
else:
|
|
||||||
worker_kwargs.update(worker_module_name="vllm.worker.worker",
|
|
||||||
worker_class_name="Worker")
|
|
||||||
|
|
||||||
return worker_kwargs
|
return worker_kwargs
|
||||||
|
|
||||||
|
@ -91,15 +91,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
return ray_remote_kwargs
|
return ray_remote_kwargs
|
||||||
|
|
||||||
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
|
||||||
if self.speculative_config is not None:
|
(worker_module_name,
|
||||||
worker_module_name = "vllm.spec_decode.spec_decode_worker"
|
worker_class_name) = self._get_worker_module_and_class()
|
||||||
worker_class_name = "create_spec_worker"
|
|
||||||
elif self.scheduler_config.is_multi_step:
|
|
||||||
worker_module_name = "vllm.worker.multi_step_worker"
|
|
||||||
worker_class_name = "MultiStepWorker"
|
|
||||||
else:
|
|
||||||
worker_module_name = "vllm.worker.worker"
|
|
||||||
worker_class_name = "Worker"
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
worker_module_name=worker_module_name,
|
worker_module_name=worker_module_name,
|
||||||
@ -107,6 +100,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
trust_remote_code=self.model_config.trust_remote_code,
|
trust_remote_code=self.model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# child class could overwrite this to return actual env vars.
|
||||||
|
def _get_env_vars_to_be_updated(self):
|
||||||
|
return self._env_vars_for_all_workers
|
||||||
|
|
||||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||||
**ray_remote_kwargs):
|
**ray_remote_kwargs):
|
||||||
if (self.parallel_config.tensor_parallel_size == 1
|
if (self.parallel_config.tensor_parallel_size == 1
|
||||||
@ -231,8 +228,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
|
|||||||
"VLLM_TRACE_FUNCTION":
|
"VLLM_TRACE_FUNCTION":
|
||||||
str(envs.VLLM_TRACE_FUNCTION),
|
str(envs.VLLM_TRACE_FUNCTION),
|
||||||
}, ) for (node_id, _) in worker_node_and_gpu_ids]
|
}, ) for (node_id, _) in worker_node_and_gpu_ids]
|
||||||
|
|
||||||
|
self._env_vars_for_all_workers = (
|
||||||
|
all_args_to_update_environment_variables)
|
||||||
|
|
||||||
self._run_workers("update_environment_variables",
|
self._run_workers("update_environment_variables",
|
||||||
all_args=all_args_to_update_environment_variables)
|
all_args=self._get_env_vars_to_be_updated())
|
||||||
|
|
||||||
if len(node_gpus) == 1:
|
if len(node_gpus) == 1:
|
||||||
# in single node case, we don't need to get the IP address.
|
# in single node case, we don't need to get the IP address.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user