Compare commits
10 Commits
c16fb5dae8
...
26507f8973
Author | SHA1 | Date | |
---|---|---|---|
![]() |
26507f8973 | ||
![]() |
9c1d5b456d | ||
![]() |
e31045f95c | ||
![]() |
aaec845f8e | ||
![]() |
7bdfd29a35 | ||
![]() |
e78587a64c | ||
![]() |
7eb4255628 | ||
![]() |
6a0f547561 | ||
![]() |
30ed81b7ca | ||
![]() |
7a4a5de729 |
@ -17,10 +17,12 @@ source /etc/environment
|
|||||||
docker run --privileged --net host --shm-size=16G -it \
|
docker run --privileged --net host --shm-size=16G -it \
|
||||||
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
-e "HF_TOKEN=$HF_TOKEN" --name tpu-test \
|
||||||
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git \
|
||||||
&& python3 -m pip install pytest \
|
&& python3 -m pip install pytest tpu-info \
|
||||||
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
&& python3 -m pip install lm_eval[api]==0.4.4 \
|
||||||
&& export VLLM_USE_V1=1 \
|
&& export VLLM_USE_V1=1 \
|
||||||
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
&& export VLLM_XLA_CHECK_RECOMPILATION=1 \
|
||||||
|
&& echo HARDWARE \
|
||||||
|
&& tpu-info \
|
||||||
&& echo TEST_0 \
|
&& echo TEST_0 \
|
||||||
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \
|
&& pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \
|
||||||
&& echo TEST_1 \
|
&& echo TEST_1 \
|
||||||
|
2
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
2
.github/ISSUE_TEMPLATE/200-installation.yml
vendored
@ -14,7 +14,7 @@ body:
|
|||||||
description: |
|
description: |
|
||||||
Please run the following and paste the output below.
|
Please run the following and paste the output below.
|
||||||
```sh
|
```sh
|
||||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||||
python collect_env.py
|
python collect_env.py
|
||||||
```
|
```
|
||||||
|
2
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
2
.github/ISSUE_TEMPLATE/300-usage.yml
vendored
@ -14,7 +14,7 @@ body:
|
|||||||
description: |
|
description: |
|
||||||
Please run the following and paste the output below.
|
Please run the following and paste the output below.
|
||||||
```sh
|
```sh
|
||||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||||
python collect_env.py
|
python collect_env.py
|
||||||
```
|
```
|
||||||
|
2
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
2
.github/ISSUE_TEMPLATE/400-bug-report.yml
vendored
@ -14,7 +14,7 @@ body:
|
|||||||
description: |
|
description: |
|
||||||
Please run the following and paste the output below.
|
Please run the following and paste the output below.
|
||||||
```sh
|
```sh
|
||||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||||
python collect_env.py
|
python collect_env.py
|
||||||
```
|
```
|
||||||
|
@ -35,7 +35,7 @@ body:
|
|||||||
description: |
|
description: |
|
||||||
Please run the following and paste the output below.
|
Please run the following and paste the output below.
|
||||||
```sh
|
```sh
|
||||||
wget https://raw.githubusercontent.com/vllm-project/vllm/main/collect_env.py
|
wget https://raw.githubusercontent.com/vllm-project/vllm/main/vllm/collect_env.py
|
||||||
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
# For security purposes, please feel free to check the contents of collect_env.py before running it.
|
||||||
python collect_env.py
|
python collect_env.py
|
||||||
```
|
```
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
template <typename scalar_t, int bit, int GROUPS>
|
template <typename scalar_t, int bit, int GROUPS>
|
||||||
__global__ void moe_wna16_gemm_kernel(
|
__global__ void moe_wna16_gemm_kernel(
|
||||||
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
const scalar_t* __restrict__ input, scalar_t* __restrict__ output,
|
||||||
|
|
||||||
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
const uint32_t* __restrict__ qweight, const scalar_t* __restrict__ scales,
|
||||||
const uint32_t* __restrict__ qzeros,
|
const uint32_t* __restrict__ qzeros,
|
||||||
|
|
||||||
@ -54,8 +53,6 @@ __global__ void moe_wna16_gemm_kernel(
|
|||||||
if (token_index / top_k >= size_m) break;
|
if (token_index / top_k >= size_m) break;
|
||||||
|
|
||||||
num_valid_tokens = m + 1;
|
num_valid_tokens = m + 1;
|
||||||
if (blockIdx.z == 0 && offset_n < size_n)
|
|
||||||
output[token_index * size_n + offset_n] = Dtype::int2num(0);
|
|
||||||
|
|
||||||
if (expert_id != -1) {
|
if (expert_id != -1) {
|
||||||
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
int k_per_thread = DIVIDE(BLOCK_SIZE_K, BLOCK_SIZE_N);
|
||||||
@ -284,8 +281,7 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
|||||||
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
|
||||||
int64_t BLOCK_SIZE_K, int64_t bit) {
|
int64_t BLOCK_SIZE_K, int64_t bit) {
|
||||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||||
auto options =
|
output.zero_();
|
||||||
torch::TensorOptions().dtype(input.dtype()).device(input.device());
|
|
||||||
|
|
||||||
const int num_experts = b_qweight.size(0);
|
const int num_experts = b_qweight.size(0);
|
||||||
const int size_m = input.size(0);
|
const int size_m = input.size(0);
|
||||||
@ -302,9 +298,9 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
|
|||||||
const uint32_t* b_qzeros_ptr;
|
const uint32_t* b_qzeros_ptr;
|
||||||
if (b_qzeros.has_value())
|
if (b_qzeros.has_value())
|
||||||
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
b_qzeros_ptr = (const uint32_t*)b_qzeros.value().data_ptr<uint8_t>();
|
||||||
const float* topk_weights_ptr;
|
const float* topk_weights_ptr = nullptr;
|
||||||
if (topk_weights.has_value())
|
if (topk_weights.has_value())
|
||||||
topk_weights_ptr = (const float*)topk_weights.value().data_ptr();
|
topk_weights_ptr = (const float*)topk_weights.value().data_ptr<float>();
|
||||||
|
|
||||||
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
int groups_per_block_row = BLOCK_SIZE_K / group_size;
|
||||||
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
TORCH_CHECK(bit == 4 || bit == 8, "bit must be 4 or 8");
|
||||||
|
@ -241,6 +241,7 @@ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
|||||||
fi
|
fi
|
||||||
COPY examples examples
|
COPY examples examples
|
||||||
COPY benchmarks benchmarks
|
COPY benchmarks benchmarks
|
||||||
|
COPY ./vllm/collect_env.py .
|
||||||
|
|
||||||
# Although we build Flashinfer with AOT mode, there's still
|
# Although we build Flashinfer with AOT mode, there's still
|
||||||
# some issues w.r.t. JIT compilation. Therefore we need to
|
# some issues w.r.t. JIT compilation. Therefore we need to
|
||||||
|
@ -121,6 +121,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
|||||||
ADD ./tests/ ./tests/
|
ADD ./tests/ ./tests/
|
||||||
ADD ./examples/ ./examples/
|
ADD ./examples/ ./examples/
|
||||||
ADD ./benchmarks/ ./benchmarks/
|
ADD ./benchmarks/ ./benchmarks/
|
||||||
|
ADD ./vllm/collect_env.py .
|
||||||
|
|
||||||
# install development dependencies (for testing)
|
# install development dependencies (for testing)
|
||||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||||
|
@ -19,6 +19,18 @@ $ docker run --runtime nvidia --gpus all \
|
|||||||
--model mistralai/Mistral-7B-v0.1
|
--model mistralai/Mistral-7B-v0.1
|
||||||
```
|
```
|
||||||
|
|
||||||
|
This image can also be used with other container engines such as [Podman](https://podman.io/).
|
||||||
|
|
||||||
|
```console
|
||||||
|
$ podman run --gpus all \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||||
|
-p 8000:8000 \
|
||||||
|
--ipc=host \
|
||||||
|
vllm/vllm-openai:latest \
|
||||||
|
--model mistralai/Mistral-7B-v0.1
|
||||||
|
```
|
||||||
|
|
||||||
You can add any other <project:#engine-args> you need after the image tag (`vllm/vllm-openai:latest`).
|
You can add any other <project:#engine-args> you need after the image tag (`vllm/vllm-openai:latest`).
|
||||||
|
|
||||||
:::{note}
|
:::{note}
|
||||||
|
@ -16,7 +16,7 @@ Ensure that you have a running Kubernetes environment with GPU (you can follow [
|
|||||||
|
|
||||||
## Deployment using vLLM production stack
|
## Deployment using vLLM production stack
|
||||||
|
|
||||||
The standard vLLM production stack install uses a Helm chart. You can run this [bash script](https://github.com/vllm-project/production-stack/blob/main/tutorials/install-helm.sh) to install Helm on your GPU server.
|
The standard vLLM production stack is installed using a Helm chart. You can run this [bash script](https://github.com/vllm-project/production-stack/blob/main/utils/install-helm.sh) to install Helm on your GPU server.
|
||||||
|
|
||||||
To install the vLLM production stack, run the following commands on your desktop:
|
To install the vLLM production stack, run the following commands on your desktop:
|
||||||
|
|
||||||
|
@ -788,7 +788,7 @@ llm = LLM(
|
|||||||
Online serving:
|
Online serving:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4
|
vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt '{"image":4}'
|
||||||
```
|
```
|
||||||
|
|
||||||
**This is no longer required if you are using vLLM V1.**
|
**This is no longer required if you are using vLLM V1.**
|
||||||
|
@ -228,7 +228,7 @@ First, launch the OpenAI-compatible server:
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
||||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}'
|
||||||
```
|
```
|
||||||
|
|
||||||
Then, you can use the OpenAI client as follows:
|
Then, you can use the OpenAI client as follows:
|
||||||
|
@ -37,11 +37,11 @@ def build_llm_with_lmcache():
|
|||||||
'{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}')
|
'{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}')
|
||||||
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB
|
||||||
# memory. Reduce the value if your GPU has less memory.
|
# memory. Reduce the value if your GPU has less memory.
|
||||||
# Note that LMCache is not compatible with chunked prefill for now.
|
# Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392).
|
||||||
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
|
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||||
kv_transfer_config=ktc,
|
kv_transfer_config=ktc,
|
||||||
max_model_len=8000,
|
max_model_len=8000,
|
||||||
enable_chunked_prefill=False,
|
enable_chunked_prefill=True,
|
||||||
gpu_memory_utilization=0.8)
|
gpu_memory_utilization=0.8)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -16,11 +16,11 @@ from vllm.sampling_params import SamplingParams
|
|||||||
# # Mistral format
|
# # Mistral format
|
||||||
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
|
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
|
||||||
# --tokenizer-mode mistral --config-format mistral --load-format mistral \
|
# --tokenizer-mode mistral --config-format mistral --load-format mistral \
|
||||||
# --limit-mm-per-prompt 'image=4' --max-model-len 16384
|
# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
|
||||||
#
|
#
|
||||||
# # HF format
|
# # HF format
|
||||||
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
|
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
|
||||||
# --limit-mm-per-prompt 'image=4' --max-model-len 16384
|
# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
|
||||||
# ```
|
# ```
|
||||||
#
|
#
|
||||||
# - Client:
|
# - Client:
|
||||||
|
@ -9,7 +9,7 @@ vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja
|
|||||||
|
|
||||||
(multi-image inference with Phi-3.5-vision-instruct)
|
(multi-image inference with Phi-3.5-vision-instruct)
|
||||||
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
|
||||||
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
|
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}'
|
||||||
|
|
||||||
(audio inference with Ultravox)
|
(audio inference with Ultravox)
|
||||||
vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096
|
vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096
|
||||||
|
@ -24,6 +24,10 @@ from vllm.utils import FlexibleArgumentParser
|
|||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
def test_limit_mm_per_prompt_parser(arg, expected):
|
def test_limit_mm_per_prompt_parser(arg, expected):
|
||||||
|
"""This functionality is deprecated and will be removed in the future.
|
||||||
|
This argument should be passed as JSON string instead.
|
||||||
|
|
||||||
|
TODO: Remove with nullable_kvs."""
|
||||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||||
if arg is None:
|
if arg is None:
|
||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
|
@ -27,7 +27,7 @@ def server():
|
|||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
"--limit-mm-per-prompt",
|
"--limit-mm-per-prompt",
|
||||||
f"audio={MAXIMUM_AUDIOS}",
|
str({"audio": MAXIMUM_AUDIOS}),
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
|
@ -31,7 +31,7 @@ def server():
|
|||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
"--limit-mm-per-prompt",
|
"--limit-mm-per-prompt",
|
||||||
f"video={MAXIMUM_VIDEOS}",
|
str({"video": MAXIMUM_VIDEOS}),
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
|
@ -35,7 +35,7 @@ def server():
|
|||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
"--limit-mm-per-prompt",
|
"--limit-mm-per-prompt",
|
||||||
f"image={MAXIMUM_IMAGES}",
|
str({"image": MAXIMUM_IMAGES}),
|
||||||
]
|
]
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||||
|
@ -37,7 +37,7 @@ def server():
|
|||||||
"--enforce-eager",
|
"--enforce-eager",
|
||||||
"--trust-remote-code",
|
"--trust-remote-code",
|
||||||
"--limit-mm-per-prompt",
|
"--limit-mm-per-prompt",
|
||||||
f"image={MAXIMUM_IMAGES}",
|
str({"image": MAXIMUM_IMAGES}),
|
||||||
"--chat-template",
|
"--chat-template",
|
||||||
str(vlm2vec_jinja_path),
|
str(vlm2vec_jinja_path),
|
||||||
]
|
]
|
||||||
|
@ -48,9 +48,9 @@ def audio(request):
|
|||||||
])
|
])
|
||||||
def server(request, audio_assets):
|
def server(request, audio_assets):
|
||||||
args = [
|
args = [
|
||||||
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
|
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
|
||||||
f"--limit-mm-per-prompt=audio={len(audio_assets)}",
|
"--limit-mm-per-prompt",
|
||||||
"--trust-remote-code"
|
str({"audio": len(audio_assets)}), "--trust-remote-code"
|
||||||
] + [
|
] + [
|
||||||
f"--{key.replace('_','-')}={value}"
|
f"--{key.replace('_','-')}={value}"
|
||||||
for key, value in request.param.items()
|
for key, value in request.param.items()
|
||||||
|
@ -27,6 +27,7 @@ _PARTITION_SIZE_ROCM = 256
|
|||||||
|
|
||||||
|
|
||||||
class ROCmFlashAttentionBackend(AttentionBackend):
|
class ROCmFlashAttentionBackend(AttentionBackend):
|
||||||
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_name() -> str:
|
def get_name() -> str:
|
||||||
@ -515,7 +516,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
|
|
||||||
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
|
||||||
triton_attention)
|
triton_attention)
|
||||||
self.attn_func = triton_attention
|
self.triton_attn_func = triton_attention
|
||||||
logger.debug("Using Triton FA in ROCmBackend")
|
logger.debug("Using Triton FA in ROCmBackend")
|
||||||
if self.sliding_window != (-1, -1):
|
if self.sliding_window != (-1, -1):
|
||||||
logger.warning("ROCm Triton FA does not currently support "
|
logger.warning("ROCm Triton FA does not currently support "
|
||||||
@ -531,7 +532,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
from flash_attn import flash_attn_varlen_func # noqa: F401
|
from flash_attn import flash_attn_varlen_func # noqa: F401
|
||||||
self.attn_func = flash_attn_varlen_func
|
self.fa_attn_func = flash_attn_varlen_func
|
||||||
logger.debug("Using CK FA in ROCmBackend")
|
logger.debug("Using CK FA in ROCmBackend")
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
self.use_naive_attn = True
|
self.use_naive_attn = True
|
||||||
@ -542,7 +543,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
"ROCm Naive FlashAttention does not support "
|
"ROCm Naive FlashAttention does not support "
|
||||||
"attention logits soft capping.")
|
"attention logits soft capping.")
|
||||||
|
|
||||||
self.attn_func = _sdpa_attention
|
self.sdpa_attn_func = _sdpa_attention
|
||||||
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
||||||
|
|
||||||
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
@ -613,6 +614,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
Returns:
|
Returns:
|
||||||
shape = [num_tokens, num_heads * head_size]
|
shape = [num_tokens, num_heads * head_size]
|
||||||
"""
|
"""
|
||||||
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
|
||||||
query = query.view(-1, self.num_heads, self.head_size)
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
if key is not None:
|
if key is not None:
|
||||||
assert value is not None
|
assert value is not None
|
||||||
@ -656,7 +659,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
assert attn_metadata.num_encoder_tokens is not None
|
assert attn_metadata.num_encoder_tokens is not None
|
||||||
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
num_prefill_tokens = attn_metadata.num_encoder_tokens
|
||||||
|
|
||||||
output = torch.empty_like(query)
|
|
||||||
# Query for decode. KV is not needed because it is already cached.
|
# Query for decode. KV is not needed because it is already cached.
|
||||||
decode_query = query[num_prefill_tokens:]
|
decode_query = query[num_prefill_tokens:]
|
||||||
# QKV for prefill.
|
# QKV for prefill.
|
||||||
@ -704,11 +706,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
query.dtype,
|
query.dtype,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
make_attn_mask=causal_mask) # type: ignore
|
make_attn_mask=causal_mask) # type: ignore
|
||||||
out, _ = self.attn_func(
|
self.triton_attn_func(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
None,
|
output[:num_prefill_tokens],
|
||||||
query_seq_start_loc,
|
query_seq_start_loc,
|
||||||
key_seq_start_loc,
|
key_seq_start_loc,
|
||||||
query_max_seq_len,
|
query_max_seq_len,
|
||||||
@ -733,10 +735,11 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
key = key.movedim(0, key.dim() - 2)
|
key = key.movedim(0, key.dim() - 2)
|
||||||
value = value.movedim(0, value.dim() - 2)
|
value = value.movedim(0, value.dim() - 2)
|
||||||
# sdpa math backend attention
|
# sdpa math backend attention
|
||||||
out = self.attn_func(
|
self.sdpa_attn_func(
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
value,
|
value,
|
||||||
|
output[:num_prefill_tokens],
|
||||||
query_seq_start_loc,
|
query_seq_start_loc,
|
||||||
num_prefill_tokens,
|
num_prefill_tokens,
|
||||||
self.num_heads,
|
self.num_heads,
|
||||||
@ -745,7 +748,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
attn_masks,
|
attn_masks,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
out = self.attn_func(
|
# upstream FA does not support an output arg, copy
|
||||||
|
output[:num_prefill_tokens] = self.fa_attn_func(
|
||||||
q=query,
|
q=query,
|
||||||
k=key,
|
k=key,
|
||||||
v=value,
|
v=value,
|
||||||
@ -760,12 +764,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
softcap=self.logits_soft_cap,
|
softcap=self.logits_soft_cap,
|
||||||
)
|
)
|
||||||
|
|
||||||
# common code for prefill
|
|
||||||
assert output[:num_prefill_tokens].shape == out.shape
|
|
||||||
if output.shape[0] > num_prefill_tokens:
|
|
||||||
output[:num_prefill_tokens] = out
|
|
||||||
else:
|
|
||||||
output = out
|
|
||||||
else:
|
else:
|
||||||
# prefix-enabled attention -
|
# prefix-enabled attention -
|
||||||
# not applicable for encoder-only models
|
# not applicable for encoder-only models
|
||||||
@ -818,14 +816,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
device=output.device,
|
device=output.device,
|
||||||
)
|
)
|
||||||
max_logits = torch.empty_like(exp_sums)
|
max_logits = torch.empty_like(exp_sums)
|
||||||
if num_prefill_tokens > 0:
|
|
||||||
out = output[num_prefill_tokens:]
|
|
||||||
else:
|
|
||||||
out = output
|
|
||||||
|
|
||||||
query_start_loc = None
|
query_start_loc = None
|
||||||
ops.paged_attention_rocm(
|
ops.paged_attention_rocm(
|
||||||
out,
|
output[num_prefill_tokens:],
|
||||||
exp_sums,
|
exp_sums,
|
||||||
max_logits,
|
max_logits,
|
||||||
tmp_output,
|
tmp_output,
|
||||||
@ -878,7 +872,8 @@ def _sdpa_attention(
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
seq_lens: List[int],
|
output: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
head_size: int,
|
head_size: int,
|
||||||
@ -886,9 +881,9 @@ def _sdpa_attention(
|
|||||||
attn_masks: Optional[List[torch.Tensor]] = None,
|
attn_masks: Optional[List[torch.Tensor]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
start = 0
|
start = 0
|
||||||
output = torch.empty((num_tokens, num_heads, head_size),
|
assert output.shape == (num_tokens, num_heads, head_size)
|
||||||
dtype=query.dtype,
|
assert output.dtype == query.dtype
|
||||||
device=query.device)
|
assert output.device == query.device
|
||||||
|
|
||||||
for i, seq_len in enumerate(seq_lens):
|
for i, seq_len in enumerate(seq_lens):
|
||||||
end = start + seq_len
|
end = start + seq_len
|
||||||
|
@ -283,12 +283,13 @@ def get_vllm_version():
|
|||||||
if __version__ == "dev":
|
if __version__ == "dev":
|
||||||
return "N/A (dev)"
|
return "N/A (dev)"
|
||||||
|
|
||||||
if len(__version_tuple__) == 4: # dev build
|
if len(__version_tuple__) == 4: # dev build
|
||||||
git_sha = __version_tuple__[-1][1:] # type: ignore
|
git_sha = __version_tuple__[-1][1:] # type: ignore
|
||||||
return f"{__version__} (git sha: {git_sha}"
|
return f"{__version__} (git sha: {git_sha}"
|
||||||
|
|
||||||
return __version__
|
return __version__
|
||||||
|
|
||||||
|
|
||||||
def summarize_vllm_build_flags():
|
def summarize_vllm_build_flags():
|
||||||
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
|
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
|
||||||
return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format(
|
return 'CUDA Archs: {}; ROCm: {}; Neuron: {}'.format(
|
||||||
@ -502,7 +503,9 @@ def get_pip_packages(run_lambda, patterns=None):
|
|||||||
print("uv is set")
|
print("uv is set")
|
||||||
cmd = ["uv", "pip", "list", "--format=freeze"]
|
cmd = ["uv", "pip", "list", "--format=freeze"]
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Could not collect pip list output (pip or uv module not available)")
|
raise RuntimeError(
|
||||||
|
"Could not collect pip list output (pip or uv module not available)"
|
||||||
|
)
|
||||||
|
|
||||||
out = run_and_read_all(run_lambda, cmd)
|
out = run_and_read_all(run_lambda, cmd)
|
||||||
return "\n".join(line for line in out.splitlines()
|
return "\n".join(line for line in out.splitlines()
|
||||||
@ -535,13 +538,12 @@ def is_xnnpack_available():
|
|||||||
else:
|
else:
|
||||||
return "N/A"
|
return "N/A"
|
||||||
|
|
||||||
|
|
||||||
def get_env_vars():
|
def get_env_vars():
|
||||||
env_vars = ''
|
env_vars = ''
|
||||||
secret_terms=('secret', 'token', 'api', 'access', 'password')
|
secret_terms = ('secret', 'token', 'api', 'access', 'password')
|
||||||
report_prefix = ("TORCH", "NCCL", "PYTORCH",
|
report_prefix = ("TORCH", "NCCL", "PYTORCH", "CUDA", "CUBLAS", "CUDNN",
|
||||||
"CUDA", "CUBLAS", "CUDNN",
|
"OMP_", "MKL_", "NVIDIA")
|
||||||
"OMP_", "MKL_",
|
|
||||||
"NVIDIA")
|
|
||||||
for k, v in os.environ.items():
|
for k, v in os.environ.items():
|
||||||
if any(term in k.lower() for term in secret_terms):
|
if any(term in k.lower() for term in secret_terms):
|
||||||
continue
|
continue
|
||||||
@ -552,6 +554,7 @@ def get_env_vars():
|
|||||||
|
|
||||||
return env_vars
|
return env_vars
|
||||||
|
|
||||||
|
|
||||||
def get_env_info():
|
def get_env_info():
|
||||||
run_lambda = run
|
run_lambda = run
|
||||||
pip_version, pip_list_output = get_pip_packages(run_lambda)
|
pip_version, pip_list_output = get_pip_packages(run_lambda)
|
@ -17,7 +17,7 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
|||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||||
Optional, Protocol, TypeVar, Union)
|
Optional, Protocol, TypeVar, Union, get_args)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
@ -2725,6 +2725,7 @@ class PromptAdapterConfig:
|
|||||||
self.prompt_adapter_dtype)
|
self.prompt_adapter_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultiModalConfig:
|
class MultiModalConfig:
|
||||||
"""Controls the behavior of multimodal models."""
|
"""Controls the behavior of multimodal models."""
|
||||||
@ -2732,6 +2733,8 @@ class MultiModalConfig:
|
|||||||
limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
|
limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
|
||||||
"""
|
"""
|
||||||
The maximum number of input items allowed per prompt for each modality.
|
The maximum number of input items allowed per prompt for each modality.
|
||||||
|
This should be a JSON string that will be parsed into a dictionary.
|
||||||
|
Defaults to 1 (V0) or 999 (V1) for each modality.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
@ -2753,24 +2756,20 @@ class MultiModalConfig:
|
|||||||
usedforsecurity=False).hexdigest()
|
usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def get_default_limit_per_prompt(self) -> int:
|
|
||||||
"""
|
|
||||||
Return the default number of input items allowed per prompt
|
|
||||||
for any modality if not specified by the user.
|
|
||||||
"""
|
|
||||||
return 999 if envs.VLLM_USE_V1 else 1
|
|
||||||
|
|
||||||
def get_limit_per_prompt(self, modality: str) -> int:
|
def get_limit_per_prompt(self, modality: str) -> int:
|
||||||
"""
|
"""
|
||||||
Get the maximum number of input items allowed per prompt
|
Get the maximum number of input items allowed per prompt
|
||||||
for the given modality.
|
for the given modality.
|
||||||
"""
|
"""
|
||||||
default = self.get_default_limit_per_prompt()
|
return self.limit_per_prompt.get(
|
||||||
return self.limit_per_prompt.get(modality, default)
|
modality,
|
||||||
|
999 if envs.VLLM_USE_V1 else 1,
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: Add configs to init vision tower or not.
|
# TODO: Add configs to init vision tower or not.
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class PoolerConfig:
|
class PoolerConfig:
|
||||||
"""Controls the behavior of output pooling in pooling models."""
|
"""Controls the behavior of output pooling in pooling models."""
|
||||||
@ -3095,15 +3094,28 @@ def get_served_model_name(model: str,
|
|||||||
return served_model_name
|
return served_model_name
|
||||||
|
|
||||||
|
|
||||||
|
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
|
||||||
|
"xgrammar"]
|
||||||
|
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
|
||||||
|
|
||||||
|
|
||||||
|
@config
|
||||||
@dataclass
|
@dataclass
|
||||||
class DecodingConfig:
|
class DecodingConfig:
|
||||||
"""Dataclass which contains the decoding strategy of the engine"""
|
"""Dataclass which contains the decoding strategy of the engine."""
|
||||||
|
|
||||||
# Which guided decoding algo to use.
|
guided_decoding_backend: Union[
|
||||||
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
|
GuidedDecodingBackendV0,
|
||||||
guided_decoding_backend: str = "auto" if envs.VLLM_USE_V1 else "xgrammar"
|
GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar"
|
||||||
|
"""Which engine will be used for guided decoding (JSON schema / regex etc)
|
||||||
|
by default. With "auto", we will make opinionated choices based on request
|
||||||
|
contents and what the backend libraries currently support, so the behavior
|
||||||
|
is subject to change in each release."""
|
||||||
|
|
||||||
reasoning_backend: Optional[str] = None
|
reasoning_backend: Optional[str] = None
|
||||||
|
"""Select the reasoning parser depending on the model that you're using.
|
||||||
|
This is used to parse the reasoning content into OpenAI API format.
|
||||||
|
Required for `--enable-reasoning`."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -3125,17 +3137,12 @@ class DecodingConfig:
|
|||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
v0_valid_guided_backends = [
|
|
||||||
'outlines', 'lm-format-enforcer', 'xgrammar', 'auto'
|
|
||||||
]
|
|
||||||
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']
|
|
||||||
|
|
||||||
backend = GuidedDecodingParams(
|
backend = GuidedDecodingParams(
|
||||||
backend=self.guided_decoding_backend).backend_name
|
backend=self.guided_decoding_backend).backend_name
|
||||||
if envs.VLLM_USE_V1:
|
if envs.VLLM_USE_V1:
|
||||||
valid_guided_backends = v1_valid_guided_backends
|
valid_guided_backends = get_args(GuidedDecodingBackendV1)
|
||||||
else:
|
else:
|
||||||
valid_guided_backends = v0_valid_guided_backends
|
valid_guided_backends = get_args(GuidedDecodingBackendV0)
|
||||||
if backend not in valid_guided_backends:
|
if backend not in valid_guided_backends:
|
||||||
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
|
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
|
||||||
f" must be one of {valid_guided_backends}")
|
f" must be one of {valid_guided_backends}")
|
||||||
|
@ -20,11 +20,12 @@ from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat,
|
|||||||
DecodingConfig, Device, DeviceConfig,
|
DecodingConfig, Device, DeviceConfig,
|
||||||
DistributedExecutorBackend, HfOverrides,
|
DistributedExecutorBackend, HfOverrides,
|
||||||
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||||
ModelConfig, ModelImpl, ObservabilityConfig,
|
ModelConfig, ModelImpl, MultiModalConfig,
|
||||||
ParallelConfig, PoolerConfig, PoolType,
|
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||||
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
|
PoolType, PromptAdapterConfig, SchedulerConfig,
|
||||||
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
|
SchedulerPolicy, SpeculativeConfig, TaskOption,
|
||||||
VllmConfig, get_attr_docs, get_field)
|
TokenizerPoolConfig, VllmConfig, get_attr_docs,
|
||||||
|
get_field)
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||||
@ -190,7 +191,8 @@ class EngineArgs:
|
|||||||
TokenizerPoolConfig.pool_type
|
TokenizerPoolConfig.pool_type
|
||||||
tokenizer_pool_extra_config: dict[str, Any] = \
|
tokenizer_pool_extra_config: dict[str, Any] = \
|
||||||
get_field(TokenizerPoolConfig, "extra_config")
|
get_field(TokenizerPoolConfig, "extra_config")
|
||||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
|
limit_mm_per_prompt: Mapping[str, int] = \
|
||||||
|
get_field(MultiModalConfig, "limit_per_prompt")
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||||
disable_mm_preprocessor_cache: bool = False
|
disable_mm_preprocessor_cache: bool = False
|
||||||
enable_lora: bool = False
|
enable_lora: bool = False
|
||||||
@ -252,7 +254,7 @@ class EngineArgs:
|
|||||||
|
|
||||||
additional_config: Optional[Dict[str, Any]] = None
|
additional_config: Optional[Dict[str, Any]] = None
|
||||||
enable_reasoning: Optional[bool] = None
|
enable_reasoning: Optional[bool] = None
|
||||||
reasoning_parser: Optional[str] = None
|
reasoning_parser: Optional[str] = DecodingConfig.reasoning_backend
|
||||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -478,18 +480,22 @@ class EngineArgs:
|
|||||||
'Examples:\n'
|
'Examples:\n'
|
||||||
'- 1k → 1000\n'
|
'- 1k → 1000\n'
|
||||||
'- 1K → 1024\n')
|
'- 1K → 1024\n')
|
||||||
parser.add_argument(
|
|
||||||
|
# Guided decoding arguments
|
||||||
|
guided_decoding_kwargs = get_kwargs(DecodingConfig)
|
||||||
|
guided_decoding_group = parser.add_argument_group(
|
||||||
|
title="DecodingConfig",
|
||||||
|
description=DecodingConfig.__doc__,
|
||||||
|
)
|
||||||
|
guided_decoding_group.add_argument(
|
||||||
'--guided-decoding-backend',
|
'--guided-decoding-backend',
|
||||||
type=str,
|
**guided_decoding_kwargs["guided_decoding_backend"])
|
||||||
default=DecodingConfig.guided_decoding_backend,
|
guided_decoding_group.add_argument(
|
||||||
help='Which engine will be used for guided decoding'
|
"--reasoning-parser",
|
||||||
' (JSON schema / regex etc) by default. Currently support '
|
# This choices is a special case because it's not static
|
||||||
'https://github.com/mlc-ai/xgrammar and '
|
choices=list(ReasoningParserManager.reasoning_parsers),
|
||||||
'https://github.com/guidance-ai/llguidance.'
|
**guided_decoding_kwargs["reasoning_backend"])
|
||||||
'Valid backend values are "xgrammar", "guidance", and "auto". '
|
|
||||||
'With "auto", we will make opinionated choices based on request '
|
|
||||||
'contents and what the backend libraries currently support, so '
|
|
||||||
'the behavior is subject to change in each release.')
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--logits-processor-pattern',
|
'--logits-processor-pattern',
|
||||||
type=optional_str,
|
type=optional_str,
|
||||||
@ -697,18 +703,14 @@ class EngineArgs:
|
|||||||
**tokenizer_kwargs["extra_config"])
|
**tokenizer_kwargs["extra_config"])
|
||||||
|
|
||||||
# Multimodal related configs
|
# Multimodal related configs
|
||||||
parser.add_argument(
|
multimodal_kwargs = get_kwargs(MultiModalConfig)
|
||||||
'--limit-mm-per-prompt',
|
multimodal_group = parser.add_argument_group(
|
||||||
type=nullable_kvs,
|
title="MultiModalConfig",
|
||||||
default=EngineArgs.limit_mm_per_prompt,
|
description=MultiModalConfig.__doc__,
|
||||||
# The default value is given in
|
)
|
||||||
# MultiModalConfig.get_default_limit_per_prompt
|
multimodal_group.add_argument('--limit-mm-per-prompt',
|
||||||
help=('For each multimodal plugin, limit how many '
|
**multimodal_kwargs["limit_per_prompt"])
|
||||||
'input instances to allow for each prompt. '
|
|
||||||
'Expects a comma-separated list of items, '
|
|
||||||
'e.g.: `image=16,video=2` allows a maximum of 16 '
|
|
||||||
'images and 2 videos per prompt. Defaults to '
|
|
||||||
'1 (V0) or 999 (V1) for each modality.'))
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--mm-processor-kwargs',
|
'--mm-processor-kwargs',
|
||||||
default=None,
|
default=None,
|
||||||
@ -1018,16 +1020,6 @@ class EngineArgs:
|
|||||||
"If enabled, the model will be able to generate reasoning content."
|
"If enabled, the model will be able to generate reasoning content."
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--reasoning-parser",
|
|
||||||
type=str,
|
|
||||||
choices=list(ReasoningParserManager.reasoning_parsers),
|
|
||||||
default=None,
|
|
||||||
help=
|
|
||||||
"Select the reasoning parser depending on the model that you're "
|
|
||||||
"using. This is used to parse the reasoning content into OpenAI "
|
|
||||||
"API format. Required for ``--enable-reasoning``.")
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--disable-cascade-attn",
|
"--disable-cascade-attn",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
|
35
vllm/entrypoints/cli/collect_env.py
Normal file
35
vllm/entrypoints/cli/collect_env.py
Normal file
@ -0,0 +1,35 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
from vllm.collect_env import main as collect_env_main
|
||||||
|
from vllm.entrypoints.cli.types import CLISubcommand
|
||||||
|
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||||
|
from vllm.utils import FlexibleArgumentParser
|
||||||
|
|
||||||
|
|
||||||
|
class CollectEnvSubcommand(CLISubcommand):
|
||||||
|
"""The `serve` subcommand for the vLLM CLI. """
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.name = "collect-env"
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cmd(args: argparse.Namespace) -> None:
|
||||||
|
"""Collect information about the environment."""
|
||||||
|
collect_env_main()
|
||||||
|
|
||||||
|
def subparser_init(
|
||||||
|
self,
|
||||||
|
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
|
||||||
|
serve_parser = subparsers.add_parser(
|
||||||
|
"collect-env",
|
||||||
|
help="Start collecting environment information.",
|
||||||
|
description="Start collecting environment information.",
|
||||||
|
usage="vllm collect-env")
|
||||||
|
return make_arg_parser(serve_parser)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_init() -> list[CLISubcommand]:
|
||||||
|
return [CollectEnvSubcommand()]
|
@ -5,6 +5,7 @@ import signal
|
|||||||
import sys
|
import sys
|
||||||
|
|
||||||
import vllm.entrypoints.cli.benchmark.main
|
import vllm.entrypoints.cli.benchmark.main
|
||||||
|
import vllm.entrypoints.cli.collect_env
|
||||||
import vllm.entrypoints.cli.openai
|
import vllm.entrypoints.cli.openai
|
||||||
import vllm.entrypoints.cli.serve
|
import vllm.entrypoints.cli.serve
|
||||||
import vllm.version
|
import vllm.version
|
||||||
@ -15,6 +16,7 @@ CMD_MODULES = [
|
|||||||
vllm.entrypoints.cli.openai,
|
vllm.entrypoints.cli.openai,
|
||||||
vllm.entrypoints.cli.serve,
|
vllm.entrypoints.cli.serve,
|
||||||
vllm.entrypoints.cli.benchmark.main,
|
vllm.entrypoints.cli.benchmark.main,
|
||||||
|
vllm.entrypoints.cli.collect_env,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -422,6 +422,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
|
|
||||||
if params_dtype is None:
|
if params_dtype is None:
|
||||||
params_dtype = torch.get_default_dtype()
|
params_dtype = torch.get_default_dtype()
|
||||||
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
# Note: here we guard against accessing the TP and DP groups when
|
# Note: here we guard against accessing the TP and DP groups when
|
||||||
# uninitialized (this happens when testing)
|
# uninitialized (this happens when testing)
|
||||||
|
@ -51,8 +51,8 @@ class Llama4MoE(nn.Module):
|
|||||||
renormalize: bool,
|
renormalize: bool,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
|
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
|
||||||
router_scores = torch.sigmoid(router_scores.float()).to(
|
# psuedo-standard is that the router scores are floats
|
||||||
hidden_states.dtype)
|
router_scores = torch.sigmoid(router_scores.float())
|
||||||
return (router_scores, router_indices.to(torch.int32))
|
return (router_scores, router_indices.to(torch.int32))
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
@ -672,9 +672,9 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.config,
|
self.config,
|
||||||
None,
|
None,
|
||||||
prefix=maybe_prefix(prefix, "multi_modal_projector"))
|
prefix=maybe_prefix(prefix, "multi_modal_projector"))
|
||||||
|
|
||||||
self.language_model = _initialize_model(
|
self.language_model = _initialize_model(
|
||||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
vllm_config=vllm_config.with_hf_config(config.text_config,
|
||||||
|
["LlamaForCausalLM"]),
|
||||||
prefix=maybe_prefix(prefix, "language_model"),
|
prefix=maybe_prefix(prefix, "language_model"),
|
||||||
model_class=Llama4ForCausalLM,
|
model_class=Llama4ForCausalLM,
|
||||||
)
|
)
|
||||||
@ -824,7 +824,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
# language_model is an Llama4ForCausalLM instance. We load it's
|
# language_model is an Llama4ForCausalLM instance. We load it's
|
||||||
# using llama4's load_weights routine.
|
# using llama4's load_weights routine.
|
||||||
language_model_weights, other_weights = self.separate_weights(
|
language_model_weights, other_weights = self.separate_weights(
|
||||||
weights, prefix="language_model.model.")
|
weights, prefix="language_model.")
|
||||||
loader = AutoWeightsLoader(self)
|
loader = AutoWeightsLoader(self)
|
||||||
loaded_language_model_params = loader.load_weights(
|
loaded_language_model_params = loader.load_weights(
|
||||||
language_model_weights)
|
language_model_weights)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import json
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -1117,8 +1118,9 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
|||||||
|
|
||||||
if num_items > allowed_limit:
|
if num_items > allowed_limit:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You set or defaulted to {modality}={allowed_limit} "
|
"You set or defaulted to "
|
||||||
f"in --limit-mm-per-prompt`, but passed {num_items} "
|
f"'{json.dumps({modality: allowed_limit})}' in "
|
||||||
|
f"`--limit-mm-per-prompt`, but passed {num_items} "
|
||||||
f"{modality} items in the same prompt.")
|
f"{modality} items in the same prompt.")
|
||||||
|
|
||||||
return mm_items
|
return mm_items
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
import json
|
||||||
from collections import UserDict
|
from collections import UserDict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -194,9 +195,9 @@ class MultiModalRegistry:
|
|||||||
max_items = self._limits_by_model[model_config][data_key]
|
max_items = self._limits_by_model[model_config][data_key]
|
||||||
if num_items > max_items:
|
if num_items > max_items:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You set {data_key}={max_items} (or defaulted to 1) in "
|
f"You set '{json.dumps({data_key: max_items})}' (or "
|
||||||
f"`--limit-mm-per-prompt`, but found {num_items} items "
|
"defaulted to 1) in `--limit-mm-per-prompt`, but found "
|
||||||
"in the same prompt.")
|
f"{num_items} items in the same prompt.")
|
||||||
|
|
||||||
input_dict = plugin.map_input(model_config, data_value,
|
input_dict = plugin.map_input(model_config, data_value,
|
||||||
mm_processor_kwargs)
|
mm_processor_kwargs)
|
||||||
|
@ -149,6 +149,7 @@ class Processor:
|
|||||||
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
|
"xgrammar", "xgrammar:disable-any-whitespace", "guidance",
|
||||||
"guidance:disable-any-whitespace", "auto"
|
"guidance:disable-any-whitespace", "auto"
|
||||||
]
|
]
|
||||||
|
|
||||||
engine_level_backend = self.decoding_config.guided_decoding_backend
|
engine_level_backend = self.decoding_config.guided_decoding_backend
|
||||||
if engine_level_backend not in supported_backends:
|
if engine_level_backend not in supported_backends:
|
||||||
raise ValueError(f"Only {supported_backends} structured output is "
|
raise ValueError(f"Only {supported_backends} structured output is "
|
||||||
@ -169,8 +170,15 @@ class Processor:
|
|||||||
if engine_level_backend.startswith("xgrammar"):
|
if engine_level_backend.startswith("xgrammar"):
|
||||||
# xgrammar with no fallback
|
# xgrammar with no fallback
|
||||||
validate_xgrammar_grammar(params)
|
validate_xgrammar_grammar(params)
|
||||||
params.guided_decoding.backend = engine_level_backend
|
elif engine_level_backend.startswith("guidance"):
|
||||||
elif engine_level_backend == "auto":
|
# TODO: ideally we would have the LLTokenizer here as Lark syntax
|
||||||
|
# allows <|special_token|> and similar, see
|
||||||
|
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
||||||
|
# Without tokenizer these are disallowed in grammars.
|
||||||
|
validate_guidance_grammar(params, tokenizer=None)
|
||||||
|
else:
|
||||||
|
# NOTE: engine_level_backend must be "auto" here, because we have
|
||||||
|
# checked supported_backends above.
|
||||||
# "auto" is an opt-in to opinionated behavior where we try to
|
# "auto" is an opt-in to opinionated behavior where we try to
|
||||||
# choose a backend based on request contents. This is not the
|
# choose a backend based on request contents. This is not the
|
||||||
# default as it is less predictable and subject to change
|
# default as it is less predictable and subject to change
|
||||||
@ -183,14 +191,6 @@ class Processor:
|
|||||||
# are not supported in xgrammar. Fall back to guidance.
|
# are not supported in xgrammar. Fall back to guidance.
|
||||||
params.guided_decoding.backend = "guidance"
|
params.guided_decoding.backend = "guidance"
|
||||||
|
|
||||||
if engine_level_backend.startswith("guidance"):
|
|
||||||
# TODO ideally we would have the LLTokenizer here as Lark syntax
|
|
||||||
# allows <|special_token|> and similar, see
|
|
||||||
# https://github.com/guidance-ai/llguidance/blob/main/docs/syntax.md#special-tokens
|
|
||||||
# Without tokenizer these are disallowed in grammars.
|
|
||||||
validate_guidance_grammar(params, tokenizer=None)
|
|
||||||
params.guided_decoding.backend = engine_level_backend
|
|
||||||
|
|
||||||
def process_inputs(
|
def process_inputs(
|
||||||
self,
|
self,
|
||||||
request_id: str,
|
request_id: str,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user