diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 369417b2..453ae7b6 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -76,6 +76,7 @@ def main(): max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, speculative_config={ + "method": "eagle", "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, "draft_tensor_parallel_size": args.draft_tp, diff --git a/tests/models/registry.py b/tests/models/registry.py index b43bdb9c..896b6c3b 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -374,6 +374,10 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { "DeepSeekMTPModel": _HfExamplesInfo("luccafong/deepseek_mtp_main_random", speculative_model="luccafong/deepseek_mtp_draft_random", # noqa: E501 trust_remote_code=True), + "EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B", + trust_remote_code=True, + speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", + tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501 } _TRANSFORMERS_MODELS = { diff --git a/tests/v1/e2e/test_ngram_spec_decode.py b/tests/v1/e2e/test_spec_decode.py similarity index 65% rename from tests/v1/e2e/test_ngram_spec_decode.py rename to tests/v1/e2e/test_spec_decode.py index 7c7c2f02..67371498 100644 --- a/tests/v1/e2e/test_ngram_spec_decode.py +++ b/tests/v1/e2e/test_spec_decode.py @@ -53,6 +53,11 @@ def model_name(): return "meta-llama/Meta-Llama-3-8B-Instruct" +@pytest.fixture +def eagle_model_name(): + return "yuhuili/EAGLE-LLaMA3-Instruct-8B" + + def test_ngram_correctness( monkeypatch: pytest.MonkeyPatch, test_prompts: list[list[dict[str, Any]]], @@ -95,3 +100,47 @@ def test_ngram_correctness( # Upon failure, inspect the outputs to check for inaccuracy. assert matches > int(0.7 * len(ref_outputs)) del spec_llm + + +def test_eagle_correctness( + monkeypatch: pytest.MonkeyPatch, + test_prompts: list[list[dict[str, Any]]], + sampling_config: SamplingParams, + model_name: str, + eagle_model_name: str, +): + ''' + Compare the outputs of a original LLM and a speculative LLM + should be the same when using eagle speculative decoding. + ''' + with monkeypatch.context() as m: + m.setenv("VLLM_USE_V1", "1") + + ref_llm = LLM(model=model_name, max_model_len=1024) + ref_outputs = ref_llm.chat(test_prompts, sampling_config) + del ref_llm + + spec_llm = LLM( + model=model_name, + speculative_config={ + "method": "eagle", + "model": eagle_model_name, + "num_speculative_tokens": 3, + }, + max_model_len=1024, + ) + spec_outputs = spec_llm.chat(test_prompts, sampling_config) + matches = 0 + misses = 0 + for ref_output, spec_output in zip(ref_outputs, spec_outputs): + if ref_output.outputs[0].text == spec_output.outputs[0].text: + matches += 1 + else: + misses += 1 + print(f"ref_output: {ref_output.outputs[0].text}") + print(f"spec_output: {spec_output.outputs[0].text}") + + # Heuristic: expect at least 70% of the prompts to match exactly + # Upon failure, inspect the outputs to check for inaccuracy. + assert matches > int(0.7 * len(ref_outputs)) + del spec_llm diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 03934ba0..b0a0a20a 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -414,7 +414,7 @@ class DefaultModelLoader(BaseModelLoader): return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) - def _get_all_weights( + def get_all_weights( self, model_config: ModelConfig, model: nn.Module, @@ -453,7 +453,7 @@ class DefaultModelLoader(BaseModelLoader): weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights( - self._get_all_weights(model_config, model)) + self.get_all_weights(model_config, model)) self.counter_after_loading_weights = time.perf_counter() logger.info( "Loading weights took %.2f seconds", diff --git a/vllm/model_executor/models/llama_eagle.py b/vllm/model_executor/models/llama_eagle.py new file mode 100644 index 00000000..28ad6128 --- /dev/null +++ b/vllm/model_executor/models/llama_eagle.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Iterable, Set, Tuple + +import torch +import torch.nn as nn +from transformers import LlamaConfig + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.llama import (LlamaDecoderLayer, + LlamaForCausalLM) + +from .utils import AutoWeightsLoader, maybe_prefix + +logger = init_logger(__name__) + + +class LlamaDecoderLayer(LlamaDecoderLayer): + + def __init__( + self, + config: LlamaConfig, + disable_input_layernorm: bool, + prefix: str = "", + ) -> None: + super().__init__(config, prefix=prefix) + + # Skip the input_layernorm + # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427 + if disable_input_layernorm: + del self.input_layernorm + self.input_layernorm = nn.Identity() + + +class LlamaModel(nn.Module): + + def __init__( + self, + *, + model_config: ModelConfig, + start_layer_id: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + self.config = model_config.hf_config + self.vocab_size = self.config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + self.layers = nn.ModuleList([ + LlamaDecoderLayer( + self.config, + i == 0, + prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + ) for i in range(self.config.num_hidden_layers) + ]) + self.fc = torch.nn.Linear(self.config.hidden_size * 2, + self.config.hidden_size, + bias=False) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + input_embeds = self.embed_tokens(input_ids) + hidden_states = self.fc( + torch.cat((input_embeds, hidden_states), dim=-1)) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + residual, + ) + return hidden_states + residual + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class EagleLlamaForCausalLM(LlamaForCausalLM): + + def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0): + nn.Module.__init__(self) + self.config = model_config.hf_config + self.model = LlamaModel(model_config=model_config, + start_layer_id=start_layer_id, + prefix="model") + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.config.vocab_size, + scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + return self.model(input_ids, positions, hidden_states) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader( + self, + skip_prefixes=(["lm_head."] + if self.config.tie_word_embeddings else None), + ) + + model_weights = {} + for name, loaded_weight in weights: + if "lm_head" not in name: + name = "model." + name + model_weights[name] = loaded_weight + + loader.load_weights(model_weights.items()) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 6a70f6bb..0d13d699 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -206,6 +206,7 @@ _MULTIMODAL_MODELS = { _SPECULATIVE_DECODING_MODELS = { "EAGLEModel": ("eagle", "EAGLE"), + "EagleLlamaForCausalLM": ("llama_eagle", "EagleLlamaForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), diff --git a/vllm/transformers_utils/configs/eagle.py b/vllm/transformers_utils/configs/eagle.py index dd806061..3a9ad3e0 100644 --- a/vllm/transformers_utils/configs/eagle.py +++ b/vllm/transformers_utils/configs/eagle.py @@ -5,6 +5,7 @@ from typing import Optional, Union from transformers import AutoConfig, PretrainedConfig +import vllm.envs as envs from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config @@ -41,8 +42,10 @@ class EAGLEConfig(PretrainedConfig): self.truncated_vocab_size = self.model.vocab_size if \ truncated_vocab_size is None else truncated_vocab_size - if "architectures" not in kwargs: + if not envs.VLLM_USE_V1: kwargs["architectures"] = ["EAGLEModel"] + else: + kwargs["architectures"] = ["EagleLlamaForCausalLM"] super().__init__(**kwargs) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 3aaaf34b..2322463c 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -4,8 +4,11 @@ import torch.nn as nn import triton import triton.language as tl -from vllm.config import VllmConfig +from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context +from vllm.model_executor.model_loader.loader import get_model_loader +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models.llama_eagle import EagleLlamaForCausalLM from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.sample.metadata import SamplingMetadata @@ -21,8 +24,12 @@ class EagleProposer: self.num_speculative_tokens = ( vllm_config.speculative_config.num_speculative_tokens) self.block_size = vllm_config.cache_config.block_size - self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs, - device=device) + # We need +1 here because the arange is used to set query_start_loc, + # which has one more element than batch_size. + self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs + + 1, + device=device, + dtype=torch.int32) def propose( self, @@ -54,7 +61,9 @@ class EagleProposer: # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] input_ids[last_token_indices] = next_token_ids - seq_lens = target_positions[last_token_indices] + 1 + # FA requires seq_len to have dtype int32. + seq_lens = (target_positions[last_token_indices] + 1).int() + # FIXME(woosuk): The below two ops cause synchronization. Optimize. max_seq_len = seq_lens.max().item() max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item() @@ -98,7 +107,7 @@ class EagleProposer: hidden_states = sample_hidden_states attn_metadata.num_actual_tokens = batch_size attn_metadata.max_query_len = 1 - attn_metadata.query_start_loc = self.arange[:batch_size] + attn_metadata.query_start_loc = self.arange[:batch_size + 1] for _ in range(self.num_speculative_tokens - 1): # Update the inputs. input_ids = draft_token_ids_list[-1] @@ -176,26 +185,28 @@ class EagleProposer: return cu_num_tokens, token_indices def load_model(self, target_model: nn.Module) -> None: - self.model = DummyEagleModel() - self.model.get_input_embeddings = target_model.get_input_embeddings - self.model.compute_logits = target_model.compute_logits + loader = get_model_loader(self.vllm_config.load_config) + target_layer_num = self.vllm_config.model_config.get_num_layers( + self.vllm_config.parallel_config) + draft_model_config = \ + self.vllm_config.speculative_config.draft_model_config + # FIXME(lily): This does not handle with distributed inference. + target_device = self.vllm_config.device_config.device + # We need to set the vllm_config here to register attention + # layers in the forward context. + with set_default_torch_dtype( + draft_model_config.dtype), set_current_vllm_config( + self.vllm_config): + self.model = EagleLlamaForCausalLM( + model_config=draft_model_config, + start_layer_id=target_layer_num).to(target_device) -# FIXME(woosuk): This is a dummy model for testing. -# Remove this once we have a real model. -class DummyEagleModel(nn.Module): - - def __init__(self): - super().__init__() - - def forward( - self, - input_ids: torch.Tensor, - hidden_states: torch.Tensor, - positions: torch.Tensor, - ) -> torch.Tensor: - input_embeddings = self.get_input_embeddings(input_ids) - return hidden_states + input_embeddings # Dummy return. + self.model.load_weights( + loader.get_all_weights( + self.vllm_config.speculative_config.draft_model_config, + self.model)) + self.model.lm_head = target_model.lm_head # FIXME(woosuk): The logic here is duplicated with the main sampling code. diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index debb7072..0e70d77e 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1191,9 +1191,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): if spec_decode_metadata is None: # input_ids can be None for multimodal models. + # We need to slice token_ids, positions, and hidden_states + # because the eagle head does not use cuda graph and should + # not include padding. target_token_ids = self.input_ids[:num_scheduled_tokens] - target_positions = positions - target_hidden_states = hidden_states + target_positions = positions[:num_scheduled_tokens] + target_hidden_states = hidden_states[:num_scheduled_tokens] target_slot_mapping = attn_metadata.slot_mapping cu_num_tokens = attn_metadata.query_start_loc else: