[Model] Adds Phi-3 support (#4298)
This commit is contained in:
parent
a395a638c2
commit
96e90fdeb3
@ -78,6 +78,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
|
|||||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||||
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
|
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
|
||||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
||||||
|
- Phi3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.)
|
||||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||||
- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
|
- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
|
||||||
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
|
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
|
||||||
|
@ -115,6 +115,10 @@ Alongside each architecture, we include some popular models that use it.
|
|||||||
- Phi
|
- Phi
|
||||||
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
|
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
|
||||||
-
|
-
|
||||||
|
* - :code:`Phi3ForCausalLM`
|
||||||
|
- Phi-3
|
||||||
|
- :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc.
|
||||||
|
-
|
||||||
* - :code:`QWenLMHeadModel`
|
* - :code:`QWenLMHeadModel`
|
||||||
- Qwen
|
- Qwen
|
||||||
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||||
|
@ -1056,7 +1056,7 @@ def _get_and_verify_max_len(
|
|||||||
derived_max_model_len = default_max_len
|
derived_max_model_len = default_max_len
|
||||||
|
|
||||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||||
if rope_scaling is not None:
|
if rope_scaling is not None and rope_scaling["type"] != "su":
|
||||||
assert "factor" in rope_scaling
|
assert "factor" in rope_scaling
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
if rope_scaling["type"] == "yarn":
|
if rope_scaling["type"] == "yarn":
|
||||||
|
@ -338,6 +338,114 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
|||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
class Phi3SuScaledRotaryEmbedding(nn.Module):
|
||||||
|
"""Phi3 family of models scaled rotary embedding.
|
||||||
|
|
||||||
|
Based on the original RotaryEmbedding implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
head_size: int,
|
||||||
|
rotary_dim: int,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
original_max_position_embeddings: int,
|
||||||
|
base: int,
|
||||||
|
is_neox_style: bool,
|
||||||
|
short_factor: List[float],
|
||||||
|
long_factor: List[float],
|
||||||
|
short_mscale: float = 1.1,
|
||||||
|
long_mscale: float = 1.225,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if rotary_dim != head_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
|
||||||
|
head_size ({rotary_dim}!={head_size}).")
|
||||||
|
if is_neox_style is False:
|
||||||
|
raise ValueError(
|
||||||
|
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.")
|
||||||
|
|
||||||
|
self.head_size = head_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.original_max_position_embeddings = original_max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
self.short_factor = short_factor
|
||||||
|
self.long_factor = long_factor
|
||||||
|
self.short_mscale = short_mscale
|
||||||
|
self.long_mscale = long_mscale
|
||||||
|
|
||||||
|
short_cache = self._compute_cos_sin_cache(
|
||||||
|
original_max_position_embeddings, short_factor, short_mscale)
|
||||||
|
short_cache = short_cache.to(torch.get_default_dtype())
|
||||||
|
self.register_buffer("short_cos_sin_cache",
|
||||||
|
short_cache,
|
||||||
|
persistent=False)
|
||||||
|
|
||||||
|
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
|
||||||
|
long_factor, long_mscale)
|
||||||
|
long_cache = long_cache.to(torch.get_default_dtype())
|
||||||
|
self.register_buffer("long_cos_sin_cache",
|
||||||
|
long_cache,
|
||||||
|
persistent=False)
|
||||||
|
|
||||||
|
long_short_cache = torch.cat(
|
||||||
|
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
|
||||||
|
self.register_buffer("long_short_cos_sin_cache",
|
||||||
|
long_short_cache,
|
||||||
|
persistent=False)
|
||||||
|
|
||||||
|
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
|
||||||
|
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
|
||||||
|
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
|
||||||
|
0, self.head_size, 2, dtype=torch.float) / self.head_size)))
|
||||||
|
return inv_freq
|
||||||
|
|
||||||
|
def _compute_cos_sin_cache(
|
||||||
|
self,
|
||||||
|
max_position_embeddings: int,
|
||||||
|
rescale_factors: List[float],
|
||||||
|
mscale: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
inv_freq = self._compute_inv_freq(rescale_factors)
|
||||||
|
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||||
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
|
cos = freqs.cos() * mscale
|
||||||
|
sin = freqs.sin() * mscale
|
||||||
|
cache = torch.cat((cos, sin), dim=-1)
|
||||||
|
return cache
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
offsets: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||||
|
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||||
|
|
||||||
|
k = self.original_max_position_embeddings
|
||||||
|
long_prompt_offset = (torch.any(positions > k).float() *
|
||||||
|
torch.full_like(positions, k)).long()
|
||||||
|
idx = (torch.add(positions, long_prompt_offset)
|
||||||
|
if long_prompt_offset is not None else positions)
|
||||||
|
self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to(
|
||||||
|
idx.device)
|
||||||
|
idx = torch.add(idx, offsets) if offsets is not None else idx
|
||||||
|
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
||||||
|
|
||||||
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
|
cos = cos.repeat(1, 2).unsqueeze(-2)
|
||||||
|
sin = sin.repeat(1, 2).unsqueeze(-2)
|
||||||
|
|
||||||
|
query = query * cos + _rotate_neox(query) * sin
|
||||||
|
key = key * cos + _rotate_neox(key) * sin
|
||||||
|
|
||||||
|
return query.flatten(-2), key.flatten(-2)
|
||||||
|
|
||||||
|
|
||||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||||
|
|
||||||
|
|
||||||
@ -349,16 +457,25 @@ def get_rope(
|
|||||||
is_neox_style: bool = True,
|
is_neox_style: bool = True,
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
) -> RotaryEmbedding:
|
) -> RotaryEmbedding:
|
||||||
|
if rope_scaling is not None:
|
||||||
|
# Transforms every value that is a list into a tuple for caching calls
|
||||||
|
rope_scaling_tuple = {
|
||||||
|
k: tuple(v) if isinstance(v, list) else v
|
||||||
|
for k, v in rope_scaling.items()
|
||||||
|
}
|
||||||
|
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||||
|
else:
|
||||||
|
rope_scaling_args = None
|
||||||
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||||
tuple(rope_scaling.items()) if rope_scaling is not None else None)
|
rope_scaling_args)
|
||||||
if key in _ROPE_DICT:
|
if key in _ROPE_DICT:
|
||||||
return _ROPE_DICT[key]
|
return _ROPE_DICT[key]
|
||||||
|
|
||||||
if rope_scaling is None:
|
if rope_scaling is None:
|
||||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||||
is_neox_style)
|
is_neox_style)
|
||||||
else:
|
else:
|
||||||
scaling_type = rope_scaling["type"]
|
scaling_type = rope_scaling["type"]
|
||||||
|
if scaling_type != "su":
|
||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
if scaling_type == "linear":
|
if scaling_type == "linear":
|
||||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||||
@ -383,6 +500,19 @@ def get_rope(
|
|||||||
base, is_neox_style,
|
base, is_neox_style,
|
||||||
scaling_factor,
|
scaling_factor,
|
||||||
**extra_kwargs)
|
**extra_kwargs)
|
||||||
|
elif scaling_type == "su":
|
||||||
|
short_factor = rope_scaling["short_factor"]
|
||||||
|
long_factor = rope_scaling["long_factor"]
|
||||||
|
original_max_position = rope_scaling[
|
||||||
|
"original_max_position_embeddings"]
|
||||||
|
extra_kwargs = {
|
||||||
|
k: v
|
||||||
|
for k, v in rope_scaling.items()
|
||||||
|
if k in ("short_mscale", "long_mscale")
|
||||||
|
}
|
||||||
|
rotary_emb = Phi3SuScaledRotaryEmbedding(
|
||||||
|
head_size, rotary_dim, max_position, original_max_position,
|
||||||
|
base, is_neox_style, short_factor, long_factor, **extra_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||||
_ROPE_DICT[key] = rotary_emb
|
_ROPE_DICT[key] = rotary_emb
|
||||||
|
@ -46,6 +46,7 @@ _MODELS = {
|
|||||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||||
|
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||||
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
||||||
|
@ -180,6 +180,10 @@ class LlamaDecoderLayer(nn.Module):
|
|||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
rope_scaling = getattr(config, "rope_scaling", None)
|
rope_scaling = getattr(config, "rope_scaling", None)
|
||||||
|
if rope_scaling is not None and getattr(
|
||||||
|
config, "original_max_position_embeddings", None):
|
||||||
|
rope_scaling["original_max_position_embeddings"] = (
|
||||||
|
config.original_max_position_embeddings)
|
||||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||||
8192)
|
8192)
|
||||||
sliding_window = getattr(config, "sliding_window", None)
|
sliding_window = getattr(config, "sliding_window", None)
|
||||||
@ -378,11 +382,11 @@ class LlamaForCausalLM(nn.Module):
|
|||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
# (param_name, shard_name, shard_id)
|
# (param_name, shard_name, shard_id)
|
||||||
("qkv_proj", "q_proj", "q"),
|
(".qkv_proj", ".q_proj", "q"),
|
||||||
("qkv_proj", "k_proj", "k"),
|
(".qkv_proj", ".k_proj", "k"),
|
||||||
("qkv_proj", "v_proj", "v"),
|
(".qkv_proj", ".v_proj", "v"),
|
||||||
("gate_up_proj", "gate_proj", 0),
|
(".gate_up_proj", ".gate_proj", 0),
|
||||||
("gate_up_proj", "up_proj", 1),
|
(".gate_up_proj", ".up_proj", 1),
|
||||||
]
|
]
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user