[Frontend] API support for beam search (#9087)

Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
Brendan Wong 2024-10-05 23:39:03 -07:00 committed by GitHub
parent 23fea8714a
commit 168cab6bbf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 275 additions and 68 deletions

View File

@ -15,6 +15,7 @@ from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import ( from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args) build_async_engine_client_from_engine_args)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators from vllm.utils import FlexibleArgumentParser, merge_async_iterators
@ -145,10 +146,13 @@ def run_vllm(
for prompt, input_len, _output_len in requests: for prompt, input_len, _output_len in requests:
assert _output_len == output_len assert _output_len == output_len
start = time.perf_counter() start = time.perf_counter()
llm.beam_search(prompts, llm.beam_search(
beam_width=n, prompts,
max_tokens=output_len, BeamSearchParams(
ignore_eos=True) beam_width=n,
max_tokens=output_len,
ignore_eos=True,
))
end = time.perf_counter() end = time.perf_counter()
return end - start return end - start

View File

@ -35,6 +35,7 @@ from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
to_enc_dec_tuple_list, zip_enc_dec_prompts) to_enc_dec_tuple_list, zip_enc_dec_prompts)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
identity, is_cpu) identity, is_cpu)
@ -812,7 +813,9 @@ class VllmRunner:
beam_width: int, beam_width: int,
max_tokens: int, max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]: ) -> List[Tuple[List[List[int]], List[str]]]:
outputs = self.model.beam_search(prompts, beam_width, max_tokens) outputs = self.model.beam_search(
prompts,
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
returned_outputs = [] returned_outputs = []
for output in outputs: for output in outputs:
token_ids = [x.tokens for x in output.sequences] token_ids = [x.tokens for x in output.sequences]

View File

@ -495,25 +495,30 @@ async def test_batch_completions(client: openai.AsyncOpenAI, model_name: str):
assert len(batch.choices) == 2 assert len(batch.choices) == 2
assert batch.choices[0].text == batch.choices[1].text assert batch.choices[0].text == batch.choices[1].text
# test n = 2 try:
batch = await client.completions.create( # test n = 2
model=model_name, batch = await client.completions.create(
prompt=prompts, model=model_name,
n=2, prompt=prompts,
max_tokens=5, n=2,
temperature=0.0, max_tokens=5,
extra_body=dict( temperature=0.0,
# NOTE: this has to be true for n > 1 in vLLM, but not necessary extra_body=dict(
# for official client. # NOTE: this has to be true for n > 1 in vLLM, but
use_beam_search=True), # not necessary for official client.
) use_beam_search=True),
assert len(batch.choices) == 4 )
assert batch.choices[0].text != batch.choices[ assert len(batch.choices) == 4
1].text, "beam search should be different" assert batch.choices[0].text != batch.choices[
assert batch.choices[0].text == batch.choices[ 1].text, "beam search should be different"
2].text, "two copies of the same prompt should be the same" assert batch.choices[0].text == batch.choices[
assert batch.choices[1].text == batch.choices[ 2].text, "two copies of the same prompt should be the same"
3].text, "two copies of the same prompt should be the same" assert batch.choices[1].text == batch.choices[
3].text, "two copies of the same prompt should be the same"
except BadRequestError as e:
# the only allowed exception is when beam search is not supported
# in the default mqllmengine
assert "--disable-frontend-multiprocessing" in str(e)
# test streaming # test streaming
batch = await client.completions.create( batch = await client.completions.create(

View File

@ -14,23 +14,26 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.engine.metrics_types import StatLoggerBase from vllm.engine.metrics_types import StatLoggerBase
from vllm.entrypoints.llm import BeamSearchSequence
from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutorAsync from vllm.executor.gpu_executor import GPUExecutorAsync
from vllm.executor.ray_utils import initialize_ray_cluster from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType from vllm.inputs import PromptType, TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding import ( from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor) get_guided_decoding_logits_processor)
from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import (CompletionOutput, EmbeddingRequestOutput,
RequestOutput)
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import deprecate_kwargs, weak_bind from vllm.utils import (collect_from_async_generator, deprecate_kwargs,
random_uuid, weak_bind)
logger = init_logger(__name__) logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@ -1036,6 +1039,102 @@ class AsyncLLMEngine:
): ):
yield LLMEngine.validate_output(output, RequestOutput) yield LLMEngine.validate_output(output, RequestOutput)
async def beam_search(
self,
prompt: Union[PromptType, List[int]],
request_id: str,
params: BeamSearchParams,
) -> AsyncGenerator[RequestOutput, None]:
beam_width = params.beam_width
max_tokens = params.max_tokens
ignore_eos = params.ignore_eos
temperature = params.temperature
tokenizer = await self.get_tokenizer()
tokenizedPrompt = prompt if isinstance(
prompt, list) else tokenizer.encode(prompt)
tokenizedLength = len(tokenizedPrompt)
beam_search_params = SamplingParams(logprobs=2 * beam_width,
max_tokens=1,
temperature=temperature)
all_beams = [BeamSearchSequence(tokens=tokenizedPrompt, cum_logprob=0)]
completed = []
for _ in range(max_tokens):
prompts_batch = [
TokensPrompt(prompt_token_ids=beam.tokens)
for beam in all_beams
]
tasks = []
request_id = f"beam_search-{random_uuid()}"
for i, individual_prompt in enumerate(prompts_batch):
request_id_item = f"{request_id}-{i}"
task = asyncio.create_task(
collect_from_async_generator(
self.generate(individual_prompt, beam_search_params,
request_id_item)))
tasks.append(task)
output = await asyncio.gather(*tasks)
output = [x[0] for x in output]
logger.info(output)
new_beams = []
for i, current_beam in enumerate(all_beams):
result = output[i]
if result.outputs[0].logprobs is not None:
logprobs = result.outputs[0].logprobs[0]
for token_id, logprob_obj in logprobs.items():
new_beam = BeamSearchSequence(
tokens=current_beam.tokens + [token_id],
cum_logprob=current_beam.cum_logprob +
logprob_obj.logprob)
if token_id == tokenizer.eos_token_id and \
not ignore_eos:
completed.append(new_beam)
else:
new_beams.append(new_beam)
sorted_beams = sorted(new_beams,
key=lambda x: x.cum_logprob,
reverse=True)
all_beams = sorted_beams[:beam_width]
completed.extend(all_beams)
sorted_completed = sorted(completed,
key=lambda x: x.cum_logprob,
reverse=True)
best_beams = sorted_completed[:beam_width]
for beam in best_beams:
beam.text = tokenizer.decode(beam.tokens[tokenizedLength:])
beam_search_output = RequestOutput(
request_id=request_id,
prompt=prompt,
outputs=[
CompletionOutput(
text=beam.text,
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
index=i,
logprobs=beam.cum_logprob,
) for (i, beam) in enumerate(best_beams)
],
finished=True,
prompt_token_ids=tokenizedPrompt,
prompt_logprobs=None)
yield LLMEngine.validate_output(beam_search_output, RequestOutput)
async def encode( async def encode(
self, self,
prompt: PromptType, prompt: PromptType,

View File

@ -22,8 +22,8 @@ from vllm.model_executor.guided_decoding.guided_fields import (
from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind, from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
SamplingParams) RequestOutputKind, SamplingParams)
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer, from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
get_cached_tokenizer) get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
@ -394,10 +394,7 @@ class LLM:
def beam_search( def beam_search(
self, self,
prompts: List[Union[str, List[int]]], prompts: List[Union[str, List[int]]],
beam_width: int, params: BeamSearchParams,
max_tokens: int,
ignore_eos: bool = False,
temperature: float = 0.0,
) -> List[BeamSearchOutput]: ) -> List[BeamSearchOutput]:
""" """
Generate sequences using beam search. Generate sequences using beam search.
@ -405,14 +402,17 @@ class LLM:
Args: Args:
prompts: A list of prompts. Each prompt can be a string or a list prompts: A list of prompts. Each prompt can be a string or a list
of token IDs. of token IDs.
beam_width: The number of beams to keep at each step. params: The beam search parameters.
max_tokens: The max number of tokens to generate for each prompt.
temperature: The temperature to use for generation.
TODO: how does beam search work together with length penalty, frequency TODO: how does beam search work together with length penalty, frequency
penalty, and stopping criteria, etc.? penalty, and stopping criteria, etc.?
""" """
beam_width = params.beam_width
max_tokens = params.max_tokens
temperature = params.temperature
ignore_eos = params.ignore_eos
tokenizer = self.get_tokenizer() tokenizer = self.get_tokenizer()
# generate 2 * beam_width candidates at each step # generate 2 * beam_width candidates at each step
# following the huggingface transformers implementation # following the huggingface transformers implementation

View File

@ -4,7 +4,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
logger = init_logger(__name__) logger = init_logger(__name__)
@ -21,7 +21,8 @@ class RequestLogger:
request_id: str, request_id: str,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: Optional[List[int]], prompt_token_ids: Optional[List[int]],
params: Optional[Union[SamplingParams, PoolingParams]], params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None: ) -> None:

View File

@ -11,8 +11,8 @@ from typing_extensions import Annotated, Required, TypedDict
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.sampling_params import (GuidedDecodingParams, RequestOutputKind, from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
SamplingParams) RequestOutputKind, SamplingParams)
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.utils import random_uuid from vllm.utils import random_uuid
@ -288,6 +288,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: end-chat-completion-extra-params # doc: end-chat-completion-extra-params
def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:
@ -567,6 +583,22 @@ class CompletionRequest(OpenAIBaseModel):
# doc: end-completion-extra-params # doc: end-completion-extra-params
def to_beam_search_params(self,
default_max_tokens: int) -> BeamSearchParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens
n = self.n if self.n is not None else 1
temperature = self.temperature if self.temperature is not None else 0.0
return BeamSearchParams(
beam_width=n,
max_tokens=max_tokens,
ignore_eos=self.ignore_eos,
temperature=temperature,
)
def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: def to_sampling_params(self, default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens max_tokens = self.max_tokens
if max_tokens is None: if max_tokens is None:

View File

@ -9,6 +9,7 @@ from typing import Union
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import (ConversationMessage, from vllm.entrypoints.chat_utils import (ConversationMessage,
apply_hf_chat_template, apply_hf_chat_template,
@ -33,6 +34,7 @@ from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
@ -203,9 +205,15 @@ class OpenAIServingChat(OpenAIServing):
assert prompt_inputs is not None assert prompt_inputs is not None
sampling_params = request.to_sampling_params( sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens=self.max_model_len - default_max_tokens = self.max_model_len - len(
len(prompt_inputs["prompt_token_ids"])) prompt_inputs["prompt_token_ids"])
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
self._log_inputs(request_id, self._log_inputs(request_id,
prompt_inputs, prompt_inputs,
@ -227,15 +235,26 @@ class OpenAIServingChat(OpenAIServing):
and contains_trace_headers(raw_request.headers)): and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning() log_tracing_disabled_warning()
result_generator = self.engine_client.generate( if isinstance(sampling_params, BeamSearchParams):
engine_inputs, if not isinstance(self.engine_client, AsyncLLMEngine):
sampling_params, raise ValueError(
request_id, "Beam search in the API server is only supported with"
lora_request=lora_request, " AsyncLLMEngine. please add "
trace_headers=trace_headers, "`--disable-frontend-multiprocessing` to "
prompt_adapter_request=prompt_adapter_request, "use beam search.")
priority=request.priority, result_generator = self.engine_client.beam_search(
) engine_inputs['prompt_token_ids'], request_id,
sampling_params)
else:
result_generator = self.engine_client.generate(
engine_inputs,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
)
except ValueError as e: except ValueError as e:
# TODO: Use a vllm-specific Validation Error # TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e)) return self.create_error_response(str(e))

View File

@ -8,6 +8,7 @@ from typing import Tuple, Union, cast
from fastapi import Request from fastapi import Request
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block # yapf conflicts with isort for this block
@ -28,6 +29,7 @@ from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
PromptAdapterPath) PromptAdapterPath)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers, from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning) log_tracing_disabled_warning)
@ -120,9 +122,15 @@ class OpenAIServingCompletion(OpenAIServing):
)) ))
for i, prompt_inputs in enumerate(prompts): for i, prompt_inputs in enumerate(prompts):
sampling_params = request.to_sampling_params( sampling_params: Union[SamplingParams, BeamSearchParams]
default_max_tokens=self.max_model_len - default_max_tokens = self.max_model_len - len(
len(prompt_inputs["prompt_token_ids"])) prompt_inputs["prompt_token_ids"])
if request.use_beam_search:
sampling_params = request.to_beam_search_params(
default_max_tokens)
else:
sampling_params = request.to_sampling_params(
default_max_tokens)
request_id_item = f"{request_id}-{i}" request_id_item = f"{request_id}-{i}"
@ -141,15 +149,29 @@ class OpenAIServingCompletion(OpenAIServing):
raw_request.headers): raw_request.headers):
log_tracing_disabled_warning() log_tracing_disabled_warning()
generator = self.engine_client.generate( if isinstance(sampling_params, BeamSearchParams):
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]}, if not isinstance(self.engine_client, AsyncLLMEngine):
sampling_params, raise ValueError(
request_id_item, "Beam search in the API server is only supported"
lora_request=lora_request, " with AsyncLLMEngine. please add "
prompt_adapter_request=prompt_adapter_request, "`--disable-frontend-multiprocessing` to "
trace_headers=trace_headers, "use beam search.")
priority=request.priority, generator = self.engine_client.beam_search(
) prompt_inputs["prompt_token_ids"], request_id_item,
sampling_params)
else:
generator = self.engine_client.generate(
{
"prompt_token_ids":
prompt_inputs["prompt_token_ids"]
},
sampling_params,
request_id_item,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=request.priority,
)
generators.append(generator) generators.append(generator)
except ValueError as e: except ValueError as e:

View File

@ -29,7 +29,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import AtomicCounter from vllm.utils import AtomicCounter
@ -371,7 +371,8 @@ class OpenAIServing:
self, self,
request_id: str, request_id: str,
inputs: Union[str, List[int], TextTokensPrompt], inputs: Union[str, List[int], TextTokensPrompt],
params: Optional[Union[SamplingParams, PoolingParams]], params: Optional[Union[SamplingParams, PoolingParams,
BeamSearchParams]],
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None: ) -> None:

View File

@ -530,3 +530,15 @@ class SamplingParams(
f"{self.spaces_between_special_tokens}, " f"{self.spaces_between_special_tokens}, "
f"truncate_prompt_tokens={self.truncate_prompt_tokens}), " f"truncate_prompt_tokens={self.truncate_prompt_tokens}), "
f"guided_decoding={self.guided_decoding}") f"guided_decoding={self.guided_decoding}")
class BeamSearchParams(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True): # type: ignore[call-arg]
"""Beam search parameters for text generation."""
beam_width: int
max_tokens: int
ignore_eos: bool = False
temperature: float = 0.0

View File

@ -504,6 +504,15 @@ async def merge_async_iterators(
await it.aclose() await it.aclose()
async def collect_from_async_generator(
iterator: AsyncGenerator[T, None]) -> List[T]:
"""Collect all items from an async generator into a list."""
items = []
async for item in iterator:
items.append(item)
return items
def get_ip() -> str: def get_ip() -> str:
host_ip = envs.VLLM_HOST_IP host_ip = envs.VLLM_HOST_IP
if host_ip: if host_ip: