[Model] H2O Danube3-4b (#6451)
This commit is contained in:
parent
ed94e4f427
commit
14dbd5a767
@ -23,7 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
|
|||||||
# Run basic model test
|
# Run basic model test
|
||||||
docker exec cpu-test bash -c "
|
docker exec cpu-test bash -c "
|
||||||
pip install pytest Pillow protobuf
|
pip install pytest Pillow protobuf
|
||||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported
|
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||||
|
|
||||||
# online inference
|
# online inference
|
||||||
docker exec cpu-test bash -c "
|
docker exec cpu-test bash -c "
|
||||||
|
@ -175,7 +175,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||||
parser.add_argument("--head-size",
|
parser.add_argument("--head-size",
|
||||||
type=int,
|
type=int,
|
||||||
choices=[64, 80, 96, 112, 128, 192, 256],
|
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||||
default=128)
|
default=128)
|
||||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||||
parser.add_argument("--use-alibi", action="store_true")
|
parser.add_argument("--use-alibi", action="store_true")
|
||||||
|
@ -94,7 +94,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument("--num-heads", type=int, default=8)
|
parser.add_argument("--num-heads", type=int, default=8)
|
||||||
parser.add_argument("--head-size",
|
parser.add_argument("--head-size",
|
||||||
type=int,
|
type=int,
|
||||||
choices=[64, 80, 96, 112, 128, 192, 256],
|
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||||
default=128)
|
default=128)
|
||||||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
||||||
parser.add_argument("--dtype",
|
parser.add_argument("--dtype",
|
||||||
|
@ -751,6 +751,9 @@ void paged_attention_v1_launcher(
|
|||||||
case 112:
|
case 112:
|
||||||
LAUNCH_PAGED_ATTENTION_V1(112);
|
LAUNCH_PAGED_ATTENTION_V1(112);
|
||||||
break;
|
break;
|
||||||
|
case 120:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V1(120);
|
||||||
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
LAUNCH_PAGED_ATTENTION_V1(128);
|
LAUNCH_PAGED_ATTENTION_V1(128);
|
||||||
break;
|
break;
|
||||||
@ -912,6 +915,9 @@ void paged_attention_v2_launcher(
|
|||||||
case 112:
|
case 112:
|
||||||
LAUNCH_PAGED_ATTENTION_V2(112);
|
LAUNCH_PAGED_ATTENTION_V2(112);
|
||||||
break;
|
break;
|
||||||
|
case 120:
|
||||||
|
LAUNCH_PAGED_ATTENTION_V2(120);
|
||||||
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
LAUNCH_PAGED_ATTENTION_V2(128);
|
LAUNCH_PAGED_ATTENTION_V2(128);
|
||||||
break;
|
break;
|
||||||
|
@ -28,7 +28,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
|||||||
|
|
||||||
# FlashAttention forward only supports head dimension at most 128
|
# FlashAttention forward only supports head dimension at most 128
|
||||||
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
|
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
|
||||||
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256
|
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
|
||||||
] if not is_hip() else [64, 80, 96, 112, 128]
|
] if not is_hip() else [64, 80, 96, 112, 128]
|
||||||
|
|
||||||
BLOCK_SIZES = [16, 32]
|
BLOCK_SIZES = [16, 32]
|
||||||
@ -134,6 +134,8 @@ def test_paged_attention(
|
|||||||
seed: int,
|
seed: int,
|
||||||
device: str,
|
device: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||||
|
pytest.skip()
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -11,7 +11,7 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
|
|||||||
NUM_TOKENS = [42] # Arbitrary values for testing
|
NUM_TOKENS = [42] # Arbitrary values for testing
|
||||||
NUM_LAYERS = [1] # Arbitrary values for testing
|
NUM_LAYERS = [1] # Arbitrary values for testing
|
||||||
NUM_HEADS = [8] # Arbitrary values for testing
|
NUM_HEADS = [8] # Arbitrary values for testing
|
||||||
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
|
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
|
||||||
BLOCK_SIZES = [8, 16, 32]
|
BLOCK_SIZES = [8, 16, 32]
|
||||||
|
|
||||||
# Arbitrary values for testing
|
# Arbitrary values for testing
|
||||||
@ -52,6 +52,8 @@ def test_copy_blocks(
|
|||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
device: str,
|
device: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||||
|
pytest.skip()
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -124,6 +126,8 @@ def test_reshape_and_cache(
|
|||||||
device: str,
|
device: str,
|
||||||
kv_cache_dtype: str,
|
kv_cache_dtype: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||||
|
pytest.skip()
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
@ -325,6 +329,8 @@ def test_swap_blocks(
|
|||||||
) -> None:
|
) -> None:
|
||||||
if kv_cache_dtype == "fp8" and "cpu" in direction:
|
if kv_cache_dtype == "fp8" and "cpu" in direction:
|
||||||
pytest.skip()
|
pytest.skip()
|
||||||
|
if kv_cache_dtype == "fp8" and head_size % 16:
|
||||||
|
pytest.skip()
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
@ -10,7 +10,7 @@ from .allclose_default import get_default_atol, get_default_rtol
|
|||||||
|
|
||||||
IS_NEOX_STYLE = [True, False]
|
IS_NEOX_STYLE = [True, False]
|
||||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||||
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
|
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
|
||||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||||
NUM_HEADS = [7, 17] # Arbitrary values for testing
|
NUM_HEADS = [7, 17] # Arbitrary values for testing
|
||||||
BATCH_SIZES = [1, 5] # Arbitrary values for testing
|
BATCH_SIZES = [1, 5] # Arbitrary values for testing
|
||||||
|
52
tests/models/test_danube3_4b.py
Normal file
52
tests/models/test_danube3_4b.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
"""Compare the outputs of HF and vLLM when using greedy sampling.
|
||||||
|
|
||||||
|
This tests danube3 separately because its head size isn't supported on CPU yet.
|
||||||
|
|
||||||
|
Run `pytest tests/models/test_danube3_4b.py`.
|
||||||
|
"""
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .utils import check_outputs_equal
|
||||||
|
|
||||||
|
MODELS = ["h2oai/h2o-danube3-4b-base"]
|
||||||
|
|
||||||
|
target_dtype = "half"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
|
def test_models(
|
||||||
|
hf_runner,
|
||||||
|
vllm_runner,
|
||||||
|
example_prompts,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
max_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
with hf_runner(model, dtype=dtype) as hf_model:
|
||||||
|
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
|
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|
||||||
|
check_outputs_equal(
|
||||||
|
outputs_0_lst=hf_outputs,
|
||||||
|
outputs_1_lst=vllm_outputs,
|
||||||
|
name_0="hf",
|
||||||
|
name_1="vllm",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||||
|
def test_model_print(
|
||||||
|
vllm_runner,
|
||||||
|
model: str,
|
||||||
|
dtype: str,
|
||||||
|
) -> None:
|
||||||
|
with vllm_runner(model, dtype=dtype) as vllm_model:
|
||||||
|
# This test is for verifying whether the model's extra_repr
|
||||||
|
# can be printed correctly.
|
||||||
|
print(vllm_model.model.llm_engine.model_executor.driver_worker.
|
||||||
|
model_runner.model)
|
@ -31,7 +31,7 @@ class PagedAttention:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_supported_head_sizes() -> List[int]:
|
def get_supported_head_sizes() -> List[int]:
|
||||||
return [64, 80, 96, 112, 128, 192, 256]
|
return [64, 80, 96, 112, 120, 128, 192, 256]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_kv_cache_shape(
|
def get_kv_cache_shape(
|
||||||
|
@ -508,6 +508,12 @@ def create_kv_caches_with_random(
|
|||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
device: Optional[str] = "cuda",
|
device: Optional[str] = "cuda",
|
||||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||||
|
|
||||||
|
if cache_dtype == "fp8" and head_size % 16:
|
||||||
|
raise ValueError(
|
||||||
|
f"Does not support key cache of type fp8 with head_size {head_size}"
|
||||||
|
)
|
||||||
|
|
||||||
torch.random.manual_seed(seed)
|
torch.random.manual_seed(seed)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user