[Core] Fix engine-use-ray broken (#4105)
This commit is contained in:
parent
37e84a403d
commit
4e7ee664e2
@ -25,21 +25,30 @@ def _query_server_long(prompt: str) -> dict:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def api_server(tokenizer_pool_size: int):
|
def api_server(tokenizer_pool_size: int, engine_use_ray: bool,
|
||||||
|
worker_use_ray: bool):
|
||||||
script_path = Path(__file__).parent.joinpath(
|
script_path = Path(__file__).parent.joinpath(
|
||||||
"api_server_async_engine.py").absolute()
|
"api_server_async_engine.py").absolute()
|
||||||
uvicorn_process = subprocess.Popen([
|
commands = [
|
||||||
sys.executable, "-u",
|
sys.executable, "-u",
|
||||||
str(script_path), "--model", "facebook/opt-125m", "--host",
|
str(script_path), "--model", "facebook/opt-125m", "--host",
|
||||||
"127.0.0.1", "--tokenizer-pool-size",
|
"127.0.0.1", "--tokenizer-pool-size",
|
||||||
str(tokenizer_pool_size)
|
str(tokenizer_pool_size)
|
||||||
])
|
]
|
||||||
|
if engine_use_ray:
|
||||||
|
commands.append("--engine-use-ray")
|
||||||
|
if worker_use_ray:
|
||||||
|
commands.append("--worker-use-ray")
|
||||||
|
uvicorn_process = subprocess.Popen(commands)
|
||||||
yield
|
yield
|
||||||
uvicorn_process.terminate()
|
uvicorn_process.terminate()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
|
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
|
||||||
def test_api_server(api_server, tokenizer_pool_size: int):
|
@pytest.mark.parametrize("worker_use_ray", [False, True])
|
||||||
|
@pytest.mark.parametrize("engine_use_ray", [False, True])
|
||||||
|
def test_api_server(api_server, tokenizer_pool_size: int, worker_use_ray: bool,
|
||||||
|
engine_use_ray: bool):
|
||||||
"""
|
"""
|
||||||
Run the API server and test it.
|
Run the API server and test it.
|
||||||
|
|
||||||
|
@ -333,8 +333,7 @@ class AsyncLLMEngine:
|
|||||||
if engine_config.device_config.device_type == "neuron":
|
if engine_config.device_config.device_type == "neuron":
|
||||||
raise NotImplementedError("Neuron is not supported for "
|
raise NotImplementedError("Neuron is not supported for "
|
||||||
"async engine yet.")
|
"async engine yet.")
|
||||||
elif (engine_config.parallel_config.worker_use_ray
|
elif engine_config.parallel_config.worker_use_ray:
|
||||||
or engine_args.engine_use_ray):
|
|
||||||
initialize_ray_cluster(engine_config.parallel_config)
|
initialize_ray_cluster(engine_config.parallel_config)
|
||||||
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
||||||
executor_class = RayGPUExecutorAsync
|
executor_class = RayGPUExecutorAsync
|
||||||
@ -410,8 +409,8 @@ class AsyncLLMEngine:
|
|||||||
else:
|
else:
|
||||||
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
|
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
|
||||||
# order of the arguments.
|
# order of the arguments.
|
||||||
cache_config = args[1]
|
cache_config = kwargs["cache_config"]
|
||||||
parallel_config = args[2]
|
parallel_config = kwargs["parallel_config"]
|
||||||
if parallel_config.tensor_parallel_size == 1:
|
if parallel_config.tensor_parallel_size == 1:
|
||||||
num_gpus = cache_config.gpu_memory_utilization
|
num_gpus = cache_config.gpu_memory_utilization
|
||||||
else:
|
else:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user