[Model] H2O Danube3-4b (#6451)

This commit is contained in:
Joe 2024-07-26 20:47:50 -07:00 committed by GitHub
parent ed94e4f427
commit 14dbd5a767
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 79 additions and 7 deletions

View File

@ -23,7 +23,7 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
# Run basic model test
docker exec cpu-test bash -c "
pip install pytest Pillow protobuf
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
# online inference
docker exec cpu-test bash -c "

View File

@ -175,7 +175,7 @@ if __name__ == '__main__':
parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size",
type=int,
choices=[64, 80, 96, 112, 128, 192, 256],
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true")

View File

@ -94,7 +94,7 @@ if __name__ == '__main__':
parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size",
type=int,
choices=[64, 80, 96, 112, 128, 192, 256],
choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype",

View File

@ -751,6 +751,9 @@ void paged_attention_v1_launcher(
case 112:
LAUNCH_PAGED_ATTENTION_V1(112);
break;
case 120:
LAUNCH_PAGED_ATTENTION_V1(120);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V1(128);
break;
@ -912,6 +915,9 @@ void paged_attention_v2_launcher(
case 112:
LAUNCH_PAGED_ATTENTION_V2(112);
break;
case 120:
LAUNCH_PAGED_ATTENTION_V2(120);
break;
case 128:
LAUNCH_PAGED_ATTENTION_V2(128);
break;

View File

@ -28,7 +28,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256
] if not is_hip() else [64, 80, 96, 112, 128]
BLOCK_SIZES = [16, 32]
@ -134,6 +134,8 @@ def test_paged_attention(
seed: int,
device: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():

View File

@ -11,7 +11,7 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
BLOCK_SIZES = [8, 16, 32]
# Arbitrary values for testing
@ -52,6 +52,8 @@ def test_copy_blocks(
kv_cache_dtype: str,
device: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
@ -124,6 +126,8 @@ def test_reshape_and_cache(
device: str,
kv_cache_dtype: str,
) -> None:
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
@ -325,6 +329,8 @@ def test_swap_blocks(
) -> None:
if kv_cache_dtype == "fp8" and "cpu" in direction:
pytest.skip()
if kv_cache_dtype == "fp8" and head_size % 16:
pytest.skip()
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():

View File

@ -10,7 +10,7 @@ from .allclose_default import get_default_atol, get_default_rtol
IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 192, 256]
HEAD_SIZES = [64, 80, 96, 112, 120, 128, 192, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [7, 17] # Arbitrary values for testing
BATCH_SIZES = [1, 5] # Arbitrary values for testing

View File

@ -0,0 +1,52 @@
"""Compare the outputs of HF and vLLM when using greedy sampling.
This tests danube3 separately because its head size isn't supported on CPU yet.
Run `pytest tests/models/test_danube3_4b.py`.
"""
import pytest
from .utils import check_outputs_equal
MODELS = ["h2oai/h2o-danube3-4b-base"]
target_dtype = "half"
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [32])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
with vllm_runner(model, dtype=dtype) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", [target_dtype])
def test_model_print(
vllm_runner,
model: str,
dtype: str,
) -> None:
with vllm_runner(model, dtype=dtype) as vllm_model:
# This test is for verifying whether the model's extra_repr
# can be printed correctly.
print(vllm_model.model.llm_engine.model_executor.driver_worker.
model_runner.model)

View File

@ -31,7 +31,7 @@ class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 192, 256]
return [64, 80, 96, 112, 120, 128, 192, 256]
@staticmethod
def get_kv_cache_shape(

View File

@ -508,6 +508,12 @@ def create_kv_caches_with_random(
seed: int = 0,
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
if cache_dtype == "fp8" and head_size % 16:
raise ValueError(
f"Does not support key cache of type fp8 with head_size {head_size}"
)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)