[Model] Adds Phi-3 support (#4298)
This commit is contained in:
parent
a395a638c2
commit
96e90fdeb3
@ -78,6 +78,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
|
||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||
- Orion (`OrionStarAI/Orion-14B-Base`, `OrionStarAI/Orion-14B-Chat`, etc.)
|
||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
|
||||
- Phi3 (`microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, etc.)
|
||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
|
||||
- Qwen2 (`Qwen/Qwen1.5-7B`, `Qwen/Qwen1.5-7B-Chat`, etc.)
|
||||
- Qwen2MoE (`Qwen/Qwen1.5-MoE-A2.7B`, `Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc.)
|
||||
|
@ -115,6 +115,10 @@ Alongside each architecture, we include some popular models that use it.
|
||||
- Phi
|
||||
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
|
||||
-
|
||||
* - :code:`Phi3ForCausalLM`
|
||||
- Phi-3
|
||||
- :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc.
|
||||
-
|
||||
* - :code:`QWenLMHeadModel`
|
||||
- Qwen
|
||||
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
|
||||
|
@ -1056,7 +1056,7 @@ def _get_and_verify_max_len(
|
||||
derived_max_model_len = default_max_len
|
||||
|
||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||
if rope_scaling is not None:
|
||||
if rope_scaling is not None and rope_scaling["type"] != "su":
|
||||
assert "factor" in rope_scaling
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if rope_scaling["type"] == "yarn":
|
||||
|
@ -338,6 +338,114 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
return cache
|
||||
|
||||
|
||||
class Phi3SuScaledRotaryEmbedding(nn.Module):
|
||||
"""Phi3 family of models scaled rotary embedding.
|
||||
|
||||
Based on the original RotaryEmbedding implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
original_max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
short_factor: List[float],
|
||||
long_factor: List[float],
|
||||
short_mscale: float = 1.1,
|
||||
long_mscale: float = 1.225,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if rotary_dim != head_size:
|
||||
raise ValueError(
|
||||
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
|
||||
head_size ({rotary_dim}!={head_size}).")
|
||||
if is_neox_style is False:
|
||||
raise ValueError(
|
||||
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.")
|
||||
|
||||
self.head_size = head_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.original_max_position_embeddings = original_max_position_embeddings
|
||||
self.base = base
|
||||
self.short_factor = short_factor
|
||||
self.long_factor = long_factor
|
||||
self.short_mscale = short_mscale
|
||||
self.long_mscale = long_mscale
|
||||
|
||||
short_cache = self._compute_cos_sin_cache(
|
||||
original_max_position_embeddings, short_factor, short_mscale)
|
||||
short_cache = short_cache.to(torch.get_default_dtype())
|
||||
self.register_buffer("short_cos_sin_cache",
|
||||
short_cache,
|
||||
persistent=False)
|
||||
|
||||
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
|
||||
long_factor, long_mscale)
|
||||
long_cache = long_cache.to(torch.get_default_dtype())
|
||||
self.register_buffer("long_cos_sin_cache",
|
||||
long_cache,
|
||||
persistent=False)
|
||||
|
||||
long_short_cache = torch.cat(
|
||||
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
|
||||
self.register_buffer("long_short_cos_sin_cache",
|
||||
long_short_cache,
|
||||
persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
|
||||
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
|
||||
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
|
||||
0, self.head_size, 2, dtype=torch.float) / self.head_size)))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(
|
||||
self,
|
||||
max_position_embeddings: int,
|
||||
rescale_factors: List[float],
|
||||
mscale: float,
|
||||
) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(rescale_factors)
|
||||
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos() * mscale
|
||||
sin = freqs.sin() * mscale
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
query = query.view(*query.shape[:-1], -1, self.head_size)
|
||||
key = key.view(*key.shape[:-1], -1, self.head_size)
|
||||
|
||||
k = self.original_max_position_embeddings
|
||||
long_prompt_offset = (torch.any(positions > k).float() *
|
||||
torch.full_like(positions, k)).long()
|
||||
idx = (torch.add(positions, long_prompt_offset)
|
||||
if long_prompt_offset is not None else positions)
|
||||
self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to(
|
||||
idx.device)
|
||||
idx = torch.add(idx, offsets) if offsets is not None else idx
|
||||
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
|
||||
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
cos = cos.repeat(1, 2).unsqueeze(-2)
|
||||
sin = sin.repeat(1, 2).unsqueeze(-2)
|
||||
|
||||
query = query * cos + _rotate_neox(query) * sin
|
||||
key = key * cos + _rotate_neox(key) * sin
|
||||
|
||||
return query.flatten(-2), key.flatten(-2)
|
||||
|
||||
|
||||
_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}
|
||||
|
||||
|
||||
@ -349,17 +457,26 @@ def get_rope(
|
||||
is_neox_style: bool = True,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
) -> RotaryEmbedding:
|
||||
if rope_scaling is not None:
|
||||
# Transforms every value that is a list into a tuple for caching calls
|
||||
rope_scaling_tuple = {
|
||||
k: tuple(v) if isinstance(v, list) else v
|
||||
for k, v in rope_scaling.items()
|
||||
}
|
||||
rope_scaling_args = tuple(rope_scaling_tuple.items())
|
||||
else:
|
||||
rope_scaling_args = None
|
||||
key = (head_size, rotary_dim, max_position, base, is_neox_style,
|
||||
tuple(rope_scaling.items()) if rope_scaling is not None else None)
|
||||
rope_scaling_args)
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
if rope_scaling is None:
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style)
|
||||
else:
|
||||
scaling_type = rope_scaling["type"]
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if scaling_type != "su":
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if scaling_type == "linear":
|
||||
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
|
||||
max_position, base,
|
||||
@ -383,6 +500,19 @@ def get_rope(
|
||||
base, is_neox_style,
|
||||
scaling_factor,
|
||||
**extra_kwargs)
|
||||
elif scaling_type == "su":
|
||||
short_factor = rope_scaling["short_factor"]
|
||||
long_factor = rope_scaling["long_factor"]
|
||||
original_max_position = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
extra_kwargs = {
|
||||
k: v
|
||||
for k, v in rope_scaling.items()
|
||||
if k in ("short_mscale", "long_mscale")
|
||||
}
|
||||
rotary_emb = Phi3SuScaledRotaryEmbedding(
|
||||
head_size, rotary_dim, max_position, original_max_position,
|
||||
base, is_neox_style, short_factor, long_factor, **extra_kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
||||
_ROPE_DICT[key] = rotary_emb
|
||||
|
@ -46,6 +46,7 @@ _MODELS = {
|
||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
|
||||
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
|
||||
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
|
||||
|
@ -180,6 +180,10 @@ class LlamaDecoderLayer(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None and getattr(
|
||||
config, "original_max_position_embeddings", None):
|
||||
rope_scaling["original_max_position_embeddings"] = (
|
||||
config.original_max_position_embeddings)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings",
|
||||
8192)
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
@ -378,11 +382,11 @@ class LlamaForCausalLM(nn.Module):
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
|
Loading…
x
Reference in New Issue
Block a user