[Model] Support Cohere2ForCausalLM (Cohere R7B) (#11203)
This commit is contained in:
parent
b3b1526f03
commit
bddbbcb132
@ -118,9 +118,9 @@ Text Generation (``--task generate``)
|
|||||||
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
|
- :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`CohereForCausalLM`
|
* - :code:`CohereForCausalLM`,:code:`Cohere2ForCausalLM`
|
||||||
- Command-R
|
- Command-R
|
||||||
- :code:`CohereForAI/c4ai-command-r-v01`, etc.
|
- :code:`CohereForAI/c4ai-command-r-v01`, :code:`CohereForAI/c4ai-command-r7b-12-2024`, etc.
|
||||||
- ✅︎
|
- ✅︎
|
||||||
- ✅︎
|
- ✅︎
|
||||||
* - :code:`DbrxForCausalLM`
|
* - :code:`DbrxForCausalLM`
|
||||||
|
@ -53,6 +53,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
|||||||
# ChatGLMModel supports multimodal
|
# ChatGLMModel supports multimodal
|
||||||
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
|
"CohereForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r-v01",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
"Cohere2ForCausalLM": _HfExamplesInfo("CohereForAI/c4ai-command-r7b-12-2024", # noqa: E501
|
||||||
|
trust_remote_code=True),
|
||||||
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
|
"DbrxForCausalLM": _HfExamplesInfo("databricks/dbrx-instruct"),
|
||||||
"DeciLMForCausalLM": _HfExamplesInfo("Deci/DeciLM-7B-instruct",
|
"DeciLMForCausalLM": _HfExamplesInfo("Deci/DeciLM-7B-instruct",
|
||||||
trust_remote_code=True),
|
trust_remote_code=True),
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import transformers
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
@ -11,6 +12,9 @@ from .registry import HF_EXAMPLE_MODELS
|
|||||||
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
|
@pytest.mark.parametrize("model_arch", HF_EXAMPLE_MODELS.get_supported_archs())
|
||||||
def test_can_initialize(model_arch):
|
def test_can_initialize(model_arch):
|
||||||
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
|
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
|
||||||
|
if (model_arch == "Cohere2ForCausalLM"
|
||||||
|
and transformers.__version__ < "4.48.0"):
|
||||||
|
pytest.skip(reason="Model introduced in HF >= 4.48.0")
|
||||||
if not model_info.is_available_online:
|
if not model_info.is_available_online:
|
||||||
pytest.skip("Model is not available online")
|
pytest.skip("Model is not available online")
|
||||||
|
|
||||||
|
@ -48,7 +48,7 @@ from vllm.model_executor.utils import set_weight_attrs
|
|||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import SupportsLoRA, SupportsPP
|
from .interfaces import SupportsLoRA, SupportsPP
|
||||||
from .utils import (is_pp_missing_parameter,
|
from .utils import (extract_layer_index, is_pp_missing_parameter,
|
||||||
make_empty_intermediate_tensors_factory, make_layers,
|
make_empty_intermediate_tensors_factory, make_layers,
|
||||||
maybe_prefix)
|
maybe_prefix)
|
||||||
|
|
||||||
@ -171,12 +171,26 @@ class CohereAttention(nn.Module):
|
|||||||
rope_scaling=self.rope_scaling,
|
rope_scaling=self.rope_scaling,
|
||||||
is_neox_style=False,
|
is_neox_style=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sliding_window = getattr(config, "sliding_window", None)
|
||||||
|
# Model v2 has sliding windows, v1 does not
|
||||||
|
self.v1 = sliding_window is None
|
||||||
|
|
||||||
|
layer_idx = extract_layer_index(prefix)
|
||||||
|
layer_has_sliding_window = (
|
||||||
|
getattr(config, "sliding_window_pattern", False)
|
||||||
|
and (layer_idx + 1) % self.config.sliding_window_pattern != 0)
|
||||||
|
|
||||||
|
self.sliding_window = (sliding_window
|
||||||
|
if layer_has_sliding_window else None)
|
||||||
|
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(self.num_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
self.scaling,
|
self.scaling,
|
||||||
num_kv_heads=self.num_kv_heads,
|
num_kv_heads=self.num_kv_heads,
|
||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
|
per_layer_sliding_window=self.sliding_window,
|
||||||
prefix=f"{prefix}.attn")
|
prefix=f"{prefix}.attn")
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
||||||
@ -206,6 +220,7 @@ class CohereAttention(nn.Module):
|
|||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
q, k = self._apply_qk_norm(q, k)
|
q, k = self._apply_qk_norm(q, k)
|
||||||
|
if self.v1 or self.sliding_window:
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
|
@ -41,6 +41,7 @@ _TEXT_GENERATION_MODELS = {
|
|||||||
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
||||||
# ChatGLMModel supports multimodal
|
# ChatGLMModel supports multimodal
|
||||||
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
|
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
|
||||||
|
"Cohere2ForCausalLM": ("commandr", "CohereForCausalLM"),
|
||||||
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
|
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
|
||||||
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
|
||||||
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user