[Core] Refactor executor classes for easier inheritance (#7673)

[Core] Refactor executor classes to make it easier to inherit GPUExecutor (#7673)
This commit is contained in:
Kunshang Ji 2024-08-20 15:56:50 +08:00 committed by GitHub
parent ad28a74beb
commit b6f99a6ffe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 27 additions and 21 deletions

View File

@ -62,6 +62,18 @@ class GPUExecutor(ExecutorBase):
observability_config=self.observability_config,
)
def _get_worker_module_and_class(self) -> Tuple[str, str]:
if self.scheduler_config.is_multi_step:
worker_module_name = "vllm.worker.multi_step_worker"
worker_class_name = "MultiStepWorker"
elif self.speculative_config:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
return (worker_module_name, worker_class_name)
def _get_create_worker_kwargs(
self,
local_rank: int = 0,
@ -70,17 +82,10 @@ class GPUExecutor(ExecutorBase):
worker_kwargs = self._get_worker_kwargs(local_rank, rank,
distributed_init_method)
if self.scheduler_config.is_multi_step:
worker_kwargs.update(
worker_module_name="vllm.worker.multi_step_worker",
worker_class_name="MultiStepWorker")
elif self.speculative_config:
worker_kwargs.update(
worker_module_name="vllm.spec_decode.spec_decode_worker",
worker_class_name="create_spec_worker")
else:
worker_kwargs.update(worker_module_name="vllm.worker.worker",
worker_class_name="Worker")
(worker_module_name,
worker_class_name) = self._get_worker_module_and_class()
worker_kwargs.update(worker_module_name=worker_module_name,
worker_class_name=worker_class_name)
return worker_kwargs

View File

@ -91,15 +91,8 @@ class RayGPUExecutor(DistributedGPUExecutor):
return ray_remote_kwargs
def _get_worker_wrapper_args(self) -> Dict[str, Any]:
if self.speculative_config is not None:
worker_module_name = "vllm.spec_decode.spec_decode_worker"
worker_class_name = "create_spec_worker"
elif self.scheduler_config.is_multi_step:
worker_module_name = "vllm.worker.multi_step_worker"
worker_class_name = "MultiStepWorker"
else:
worker_module_name = "vllm.worker.worker"
worker_class_name = "Worker"
(worker_module_name,
worker_class_name) = self._get_worker_module_and_class()
return dict(
worker_module_name=worker_module_name,
@ -107,6 +100,10 @@ class RayGPUExecutor(DistributedGPUExecutor):
trust_remote_code=self.model_config.trust_remote_code,
)
# child class could overwrite this to return actual env vars.
def _get_env_vars_to_be_updated(self):
return self._env_vars_for_all_workers
def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs):
if (self.parallel_config.tensor_parallel_size == 1
@ -231,8 +228,12 @@ class RayGPUExecutor(DistributedGPUExecutor):
"VLLM_TRACE_FUNCTION":
str(envs.VLLM_TRACE_FUNCTION),
}, ) for (node_id, _) in worker_node_and_gpu_ids]
self._env_vars_for_all_workers = (
all_args_to_update_environment_variables)
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
all_args=self._get_env_vars_to_be_updated())
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.