Remove hardcoded device="cuda"
to support more devices (#2503)
Co-authored-by: Jiang Li <jiang1.li@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
parent
c410f5d020
commit
96b6f475dd
@ -25,6 +25,7 @@ def main(args: argparse.Namespace):
|
||||
dtype=args.dtype,
|
||||
enforce_eager=args.enforce_eager,
|
||||
kv_cache_dtype=args.kv_cache_dtype,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
@ -135,5 +136,11 @@ if __name__ == '__main__':
|
||||
default=None,
|
||||
help=('path to save the pytorch profiler output. Can be visualized '
|
||||
'with ui.perfetto.dev or Tensorboard.'))
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda"],
|
||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -72,6 +72,7 @@ def run_vllm(
|
||||
max_model_len: Optional[int],
|
||||
enforce_eager: bool,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> float:
|
||||
from vllm import LLM, SamplingParams
|
||||
llm = LLM(
|
||||
@ -85,6 +86,7 @@ def run_vllm(
|
||||
max_model_len=max_model_len,
|
||||
enforce_eager=enforce_eager,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Add the requests to the engine.
|
||||
@ -209,7 +211,7 @@ def main(args: argparse.Namespace):
|
||||
args.seed, args.n, args.use_beam_search,
|
||||
args.trust_remote_code, args.dtype,
|
||||
args.max_model_len, args.enforce_eager,
|
||||
args.kv_cache_dtype)
|
||||
args.kv_cache_dtype, args.device)
|
||||
elif args.backend == "hf":
|
||||
assert args.tensor_parallel_size == 1
|
||||
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
|
||||
@ -294,6 +296,12 @@ if __name__ == "__main__":
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
choices=["cuda"],
|
||||
help='device type for vLLM execution, supporting CUDA only currently.')
|
||||
args = parser.parse_args()
|
||||
if args.tokenizer is None:
|
||||
args.tokenizer = args.model
|
||||
|
@ -25,18 +25,20 @@ def main(
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
do_profile: bool,
|
||||
device: str = "cuda",
|
||||
kv_cache_dtype: Optional[str] = None,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
query = torch.empty(num_seqs,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device="cuda")
|
||||
device=device)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
@ -44,11 +46,11 @@ def main(
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
dtype=torch.float,
|
||||
device="cuda")
|
||||
device=device)
|
||||
|
||||
context_lens = [context_len for _ in range(num_seqs)]
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
@ -59,12 +61,17 @@ def main(
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
|
||||
|
||||
# Create the KV cache.
|
||||
key_caches, value_caches = create_kv_caches_with_random(
|
||||
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
|
||||
dtype)
|
||||
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
|
||||
block_size,
|
||||
1,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
kv_cache_dtype,
|
||||
dtype,
|
||||
device=device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Prepare for the paged attention kernel.
|
||||
@ -84,7 +91,7 @@ def main(
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
|
||||
def run_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
|
||||
torch.cuda.synchronize()
|
||||
if profile:
|
||||
torch.cuda.cudart().cudaProfilerStart()
|
||||
@ -135,6 +142,7 @@ def main(
|
||||
|
||||
# Warmup.
|
||||
print("Warming up...")
|
||||
run_benchmark = run_cuda_benchmark
|
||||
run_benchmark(num_iters=3, profile=False)
|
||||
|
||||
# Benchmark.
|
||||
@ -175,6 +183,7 @@ if __name__ == '__main__':
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type.')
|
||||
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
|
@ -7,26 +7,29 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
|
||||
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_silu_and_mul(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
gpu_id = f"cuda:{device}"
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
|
||||
layer = SiluAndMul()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
@ -37,19 +40,20 @@ def test_silu_and_mul(
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_gelu_new(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
gpu_id = f"cuda:{device}"
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype)
|
||||
layer = NewGELU()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
@ -60,18 +64,19 @@ def test_gelu_new(
|
||||
@pytest.mark.parametrize("d", D)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_gelu_fast(
|
||||
num_tokens: int,
|
||||
d: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
gpu_id = f"cuda:{device}"
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype)
|
||||
layer = FastGELU()
|
||||
out = layer(x)
|
||||
ref_out = layer._forward(x)
|
||||
|
@ -27,7 +27,9 @@ BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
def ref_masked_attention(
|
||||
@ -91,7 +93,7 @@ def ref_single_query_cached_kv_attention(
|
||||
alibi_bias = None
|
||||
if alibi_slopes is not None:
|
||||
# Create the ALiBi bias used in the paged attention kernel.
|
||||
position_ids = torch.arange(context_len, device=query.device).int()
|
||||
position_ids = torch.arange(context_len).int()
|
||||
alibi_bias = (position_ids - context_len + 1).float()
|
||||
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
|
||||
1, 1, -1)
|
||||
@ -110,7 +112,7 @@ def ref_single_query_cached_kv_attention(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_paged_attention(
|
||||
kv_cache_factory,
|
||||
version: str,
|
||||
@ -122,33 +124,28 @@ def test_paged_attention(
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str,
|
||||
seed: int,
|
||||
device: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
gpu_id = f"cuda:{device}"
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
scale = float(1.0 / (head_size**0.5))
|
||||
num_query_heads, num_kv_heads = num_heads
|
||||
query = torch.empty(num_seqs,
|
||||
num_query_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=gpu_id)
|
||||
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
dtype=torch.float,
|
||||
device=gpu_id)
|
||||
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
|
||||
|
||||
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
context_lens[-1] = MAX_SEQ_LEN
|
||||
max_context_len = max(context_lens)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int)
|
||||
|
||||
# Create the block tables.
|
||||
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
|
||||
@ -159,13 +156,13 @@ def test_paged_attention(
|
||||
for _ in range(max_num_blocks_per_seq)
|
||||
]
|
||||
block_tables.append(block_table)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
|
||||
num_kv_heads, head_size,
|
||||
kv_cache_dtype, dtype, seed,
|
||||
gpu_id)
|
||||
device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Call the paged attention kernel.
|
||||
@ -193,12 +190,10 @@ def test_paged_attention(
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
@ -229,14 +224,14 @@ def test_paged_attention(
|
||||
block_size, x)
|
||||
dequantized_key_cache = torch.empty(size=key_cache_shape,
|
||||
dtype=dtype,
|
||||
device=gpu_id)
|
||||
device=device)
|
||||
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
|
||||
key_cache = dequantized_key_cache
|
||||
|
||||
value_cache_shape = value_cache.shape
|
||||
dequantized_value_cache = torch.empty(size=value_cache_shape,
|
||||
dtype=dtype,
|
||||
device=gpu_id)
|
||||
device=device)
|
||||
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
|
||||
value_cache = dequantized_value_cache
|
||||
|
||||
@ -283,7 +278,7 @@ def ref_multi_query_kv_attention(
|
||||
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(dtype).min
|
||||
attn_mask = attn_mask.to(dtype=dtype, device=query.device)
|
||||
attn_mask = attn_mask.to(dtype=dtype)
|
||||
|
||||
ref_output = ref_masked_attention(
|
||||
query[start_idx:end_idx],
|
||||
@ -303,7 +298,7 @@ def ref_multi_query_kv_attention(
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention(
|
||||
num_seqs: int,
|
||||
@ -311,12 +306,13 @@ def test_multi_query_kv_attention(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
gpu_id = f"cuda:{device}"
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
|
||||
# As the xformers library is already tested with its own tests, we can use
|
||||
# a smaller MAX_SEQ_LEN here.
|
||||
@ -329,8 +325,7 @@ def test_multi_query_kv_attention(
|
||||
qkv = torch.empty(num_tokens,
|
||||
num_query_heads + 2 * num_kv_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=gpu_id)
|
||||
dtype=dtype)
|
||||
qkv.uniform_(-scale, scale)
|
||||
query, key, value = qkv.split(
|
||||
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)
|
||||
|
@ -17,7 +17,9 @@ BLOCK_SIZES = [8, 16, 32]
|
||||
NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
|
||||
NUM_MAPPINGS = [256] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||
|
||||
|
||||
@ -29,7 +31,7 @@ KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
|
||||
@torch.inference_mode()
|
||||
def test_copy_blocks(
|
||||
@ -42,13 +44,14 @@ def test_copy_blocks(
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
kv_cache_dtype: str,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
gpu_id = f"cuda:{device}"
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
# Generate random block mappings where each source block is mapped to two
|
||||
# destination blocks.
|
||||
assert 2 * num_mappings <= num_blocks
|
||||
@ -66,7 +69,7 @@ def test_copy_blocks(
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
|
||||
num_layers, num_heads,
|
||||
head_size, kv_cache_dtype,
|
||||
dtype, seed, gpu_id)
|
||||
dtype, seed, device)
|
||||
|
||||
# Clone the KV caches.
|
||||
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
|
||||
@ -98,7 +101,7 @@ def test_copy_blocks(
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_reshape_and_cache(
|
||||
kv_cache_factory,
|
||||
@ -109,29 +112,25 @@ def test_reshape_and_cache(
|
||||
num_blocks: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
gpu_id = f"cuda:{device}"
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long)
|
||||
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device=gpu_id)
|
||||
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
# Create the KV caches.
|
||||
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
|
||||
num_heads, head_size, dtype,
|
||||
None, seed, gpu_id)
|
||||
None, seed, device)
|
||||
key_cache, value_cache = key_caches[0], value_caches[0]
|
||||
|
||||
# Clone the KV caches.
|
||||
@ -166,7 +165,7 @@ def test_reshape_and_cache(
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_swap_blocks(
|
||||
kv_cache_factory,
|
||||
@ -182,7 +181,8 @@ def test_swap_blocks(
|
||||
) -> None:
|
||||
random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
src_device = f"{direction[0]}:{device}" if direction[
|
||||
0] == "cuda" else direction[0]
|
||||
dst_device = f"{direction[1]}:{device}" if direction[
|
||||
|
@ -8,7 +8,9 @@ NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
|
||||
HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing
|
||||
ADD_RESIDUAL = [False, True]
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||
@ -16,7 +18,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rms_norm(
|
||||
num_tokens: int,
|
||||
@ -24,15 +26,16 @@ def test_rms_norm(
|
||||
add_residual: bool,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
gpu_id = f"cuda:{device}"
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
layer = RMSNorm(hidden_size).to(dtype=dtype)
|
||||
layer.weight.data.normal_(mean=1.0, std=0.1)
|
||||
scale = 1 / (2 * hidden_size)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=gpu_id)
|
||||
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
|
||||
x *= scale
|
||||
residual = torch.randn_like(x) * scale if add_residual else None
|
||||
|
||||
|
@ -13,7 +13,9 @@ NUM_HEADS = [7, 17] # Arbitrary values for testing
|
||||
BATCH_SIZES = [1, 5] # Arbitrary values for testing
|
||||
SEQ_LENS = [11, 8192] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@ -24,7 +26,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding(
|
||||
is_neox_style: bool,
|
||||
@ -35,28 +37,26 @@ def test_rotary_embedding(
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: int,
|
||||
device: str,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
gpu_id = f"cuda:{device}"
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
|
||||
rope = rope.to(dtype=dtype, device=gpu_id)
|
||||
rope = rope.to(dtype=dtype)
|
||||
|
||||
positions = torch.randint(0,
|
||||
max_position, (batch_size, seq_len),
|
||||
device=gpu_id)
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query = torch.randn(batch_size,
|
||||
seq_len,
|
||||
num_heads * head_size,
|
||||
dtype=dtype,
|
||||
device=gpu_id)
|
||||
dtype=dtype)
|
||||
key = torch.randn_like(query)
|
||||
|
||||
# NOTE(woosuk): The reference implementation should be executed first
|
||||
|
@ -11,19 +11,27 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
|
||||
NUM_HEADS = [12]
|
||||
HEAD_SIZES = [128]
|
||||
DTYPES = [torch.float16]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_contexted_kv_attention(
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: str,
|
||||
) -> None:
|
||||
random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(0)
|
||||
torch.set_default_device(device)
|
||||
MAX_SEQ_LEN = 1024
|
||||
MAX_CTX_LEN = 1024
|
||||
BS = 10
|
||||
@ -35,24 +43,11 @@ def test_contexted_kv_attention(
|
||||
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
|
||||
|
||||
num_tokens = sum(subquery_lens)
|
||||
query = torch.empty(num_tokens,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
query.uniform_(-1e-3, 1e-3)
|
||||
output = torch.empty(num_tokens,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
|
||||
|
||||
kv = torch.empty(sum(seq_lens),
|
||||
2,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
kv = torch.empty(sum(seq_lens), 2, num_heads, head_size, dtype=dtype)
|
||||
kv.uniform_(-1e-3, 1e-3)
|
||||
key, value = kv.unbind(dim=1)
|
||||
|
||||
@ -60,39 +55,27 @@ def test_contexted_kv_attention(
|
||||
block_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
dtype=dtype)
|
||||
v_cache = torch.zeros(cache_size,
|
||||
block_size,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
k = torch.zeros(sum(subquery_lens),
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
v = torch.zeros(sum(subquery_lens),
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
values = torch.arange(0, cache_size, dtype=torch.long, device='cuda')
|
||||
dtype=dtype)
|
||||
k = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype)
|
||||
v = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype)
|
||||
values = torch.arange(0, cache_size, dtype=torch.long)
|
||||
values = values[torch.randperm(cache_size)]
|
||||
block_table = values[:BS * max_block_per_request].view(
|
||||
BS, max_block_per_request)
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda')
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda')
|
||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||
b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1],
|
||||
dtype=torch.long,
|
||||
device='cuda'),
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
max_input_len = MAX_SEQ_LEN
|
||||
# copy kv to cache
|
||||
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
|
||||
dtype=torch.long,
|
||||
device='cuda'),
|
||||
dtype=torch.long),
|
||||
dim=0)
|
||||
for i in range(BS):
|
||||
for j in range(subquery_lens[i]):
|
||||
|
@ -126,8 +126,8 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
|
||||
cleanup()
|
||||
get_model_old = get_model
|
||||
|
||||
def get_model_patched(model_config, lora_config=None):
|
||||
return get_model_old(model_config,
|
||||
def get_model_patched(model_config, device_config, lora_config=None):
|
||||
return get_model_old(model_config, device_config,
|
||||
LoRAConfig(max_loras=4, max_lora_rank=8))
|
||||
|
||||
with patch("vllm.worker.model_runner.get_model", get_model_patched):
|
||||
|
@ -34,6 +34,9 @@ TOLERANCES = {
|
||||
torch.float32: (5e-3, 5e-3),
|
||||
torch.bfloat16: (3e-2, 2e-2),
|
||||
}
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
def get_random_id_to_index(num_loras: int,
|
||||
@ -151,14 +154,10 @@ def create_random_inputs(
|
||||
for _ in range(num_inputs):
|
||||
if input_type == torch.int:
|
||||
inputs.append(
|
||||
torch.randint(low=int(low),
|
||||
high=int(high),
|
||||
size=input_size,
|
||||
device="cuda"))
|
||||
torch.randint(low=int(low), high=int(high), size=input_size))
|
||||
else:
|
||||
inputs.append(
|
||||
torch.rand(size=input_size, dtype=input_type, device="cuda") *
|
||||
high + low)
|
||||
torch.rand(size=input_size, dtype=input_type) * high + low)
|
||||
|
||||
lora_id = random.choice(active_lora_ids)
|
||||
index_mapping += [lora_id] * input_size[0]
|
||||
@ -169,8 +168,10 @@ def create_random_inputs(
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
def test_embeddings(dist_init, num_loras) -> None:
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_embeddings(dist_init, num_loras, device) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
@ -259,8 +260,10 @@ def test_embeddings(dist_init, num_loras) -> None:
|
||||
@torch.inference_mode()
|
||||
# @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.")
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None:
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
@ -305,8 +308,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None:
|
||||
|
||||
# Add empty embeddings_tensors for unoccupied lora slots.
|
||||
for _ in range(max_loras - len(embeddings_tensors)):
|
||||
embeddings_tensors.append(
|
||||
torch.zeros(embeddings_tensors[0].shape, device="cuda"))
|
||||
embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
|
||||
|
||||
inputs, index_mapping, prompt_mapping = create_random_inputs(
|
||||
active_lora_ids=list(lora_dict.keys()),
|
||||
@ -388,8 +390,10 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None:
|
||||
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
def test_lm_head_sampler(dist_init, num_loras) -> None:
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_lm_head_sampler(dist_init, num_loras, device) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
@ -432,7 +436,7 @@ def test_lm_head_sampler(dist_init, num_loras) -> None:
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
input_ = torch.rand(20, 1024, device="cuda")
|
||||
input_ = torch.rand(20, 1024)
|
||||
mapping_info = convert_mapping(
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
@ -500,8 +504,10 @@ def test_lm_head_sampler(dist_init, num_loras) -> None:
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("orientation", ["row", "column"])
|
||||
def test_linear_parallel(dist_init, num_loras, orientation) -> None:
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
@ -597,8 +603,10 @@ def test_linear_parallel(dist_init, num_loras, orientation) -> None:
|
||||
@torch.inference_mode()
|
||||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("repeats", [2, 3])
|
||||
def test_column_parallel_packed(dist_init, num_loras, repeats) -> None:
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
|
@ -5,7 +5,8 @@ from unittest.mock import patch
|
||||
|
||||
from vllm.lora.models import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig
|
||||
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
DeviceConfig, LoRAConfig)
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
|
||||
@ -25,6 +26,7 @@ def test_worker_apply_lora(sql_lora_files):
|
||||
),
|
||||
parallel_config=ParallelConfig(1, 1, False),
|
||||
scheduler_config=SchedulerConfig(32, 32, 32, 256),
|
||||
device_config=DeviceConfig("cuda"),
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,
|
||||
|
@ -9,6 +9,10 @@ from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
|
||||
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
def mock_causal_accepted_tensor(
|
||||
k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor:
|
||||
@ -39,11 +43,14 @@ def mock_causal_accepted_tensor(
|
||||
@pytest.mark.parametrize(
|
||||
"which_tokens_accepted",
|
||||
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_correct_output_format(which_tokens_accepted: str, seed: int):
|
||||
def test_correct_output_format(which_tokens_accepted: str, seed: int,
|
||||
device: str):
|
||||
"""Verify the output has correct format given predetermined accepted matrix.
|
||||
"""
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
batch_size = 10
|
||||
k = 5
|
||||
@ -66,18 +73,15 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int):
|
||||
recovered_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device="cuda")
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device="cuda")
|
||||
dtype=torch.int64)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64,
|
||||
device="cuda")
|
||||
dtype=torch.int64)
|
||||
|
||||
rejection_sampler = RejectionSampler()
|
||||
rejection_sampler.init_gpu_tensors(rank=0)
|
||||
@ -120,31 +124,24 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int):
|
||||
@pytest.mark.parametrize("k", list(range(1, 6)))
|
||||
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int):
|
||||
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
|
||||
device: str):
|
||||
torch.set_default_device(device)
|
||||
rejection_sampler = RejectionSampler()
|
||||
rejection_sampler.init_gpu_tensors(rank=0)
|
||||
|
||||
draft_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
target_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64,
|
||||
device="cuda")
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device="cuda")
|
||||
dtype=torch.int64)
|
||||
|
||||
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
|
||||
draft_token_ids)
|
||||
@ -153,36 +150,28 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int):
|
||||
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
|
||||
@pytest.mark.parametrize("which_token_ids",
|
||||
["bonus_token_ids", "draft_token_ids"])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
|
||||
which_token_ids: str):
|
||||
which_token_ids: str, device: str):
|
||||
k = 3
|
||||
batch_size = 5
|
||||
vocab_size = 30_000
|
||||
torch.set_default_device(device)
|
||||
|
||||
rejection_sampler = RejectionSampler(strict_mode=True)
|
||||
rejection_sampler.init_gpu_tensors(rank=0)
|
||||
|
||||
draft_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
target_probs = torch.rand(batch_size,
|
||||
k,
|
||||
vocab_size,
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
|
||||
bonus_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, 1),
|
||||
dtype=torch.int64,
|
||||
device="cuda")
|
||||
dtype=torch.int64)
|
||||
draft_token_ids = torch.randint(low=0,
|
||||
high=vocab_size,
|
||||
size=(batch_size, k),
|
||||
dtype=torch.int64,
|
||||
device="cuda")
|
||||
dtype=torch.int64)
|
||||
|
||||
oob_token_ids = None
|
||||
if which_token_ids == "bonus_token_ids":
|
||||
@ -237,6 +226,7 @@ def test_rejection_sampling_approximates_target_distribution(
|
||||
probabilities are exactly equal. Rejection sampling should
|
||||
still work without any NaNs or exceptions.
|
||||
"""
|
||||
torch.set_default_device("cpu")
|
||||
set_random_seed(seed)
|
||||
|
||||
helper = _CorrectnessTestHelper(
|
||||
|
@ -31,24 +31,26 @@ def _prepare_test(
|
||||
batch_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024),
|
||||
device="cuda",
|
||||
dtype=torch.float16)
|
||||
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
|
||||
fake_logits = torch.full((batch_size, vocab_size),
|
||||
1e-2,
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(32000, fake_logits)
|
||||
model_runner = ModelRunner(None, None, None, None)
|
||||
model_runner = ModelRunner(None, None, None, None, None)
|
||||
return input_tensor, fake_logits, sampler, model_runner
|
||||
|
||||
|
||||
RANDOM_SEEDS = list(range(128))
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_all_greedy(seed: int):
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_greedy(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
@ -81,8 +83,10 @@ def test_sampler_all_greedy(seed: int):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_all_random(seed: int):
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_random(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
@ -120,8 +124,10 @@ def test_sampler_all_random(seed: int):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_all_beam(seed: int):
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_all_beam(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
@ -156,8 +162,10 @@ def test_sampler_all_beam(seed: int):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_mixed(seed: int):
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_mixed(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
|
||||
batch_size)
|
||||
@ -212,8 +220,10 @@ def test_sampler_mixed(seed: int):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_logits_processors(seed: int):
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_logits_processors(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
@ -252,14 +262,15 @@ def test_sampler_logits_processors(seed: int):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
|
||||
def test_sampler_top_k_top_p(seed: int):
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_sampler_top_k_top_p(seed: int, device: str):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
top_k = random.randint(100, 500)
|
||||
top_p = random.random() * 0.1
|
||||
vocab_size = 32000
|
||||
input_tensor = torch.rand((batch_size, 1024),
|
||||
device="cuda",
|
||||
device=device,
|
||||
dtype=torch.float16)
|
||||
fake_logits = torch.normal(0,
|
||||
5,
|
||||
@ -267,7 +278,7 @@ def test_sampler_top_k_top_p(seed: int):
|
||||
device=input_tensor.device,
|
||||
dtype=input_tensor.dtype)
|
||||
sampler = MockLogitsSampler(32000, fake_logits)
|
||||
model_runner = ModelRunner(None, None, None, None)
|
||||
model_runner = ModelRunner(None, None, None, None, None)
|
||||
|
||||
generation_model = GenerationMixin()
|
||||
generation_config = GenerationConfig(top_k=top_k,
|
||||
|
@ -84,7 +84,7 @@ def create_worker(cls: type,
|
||||
)
|
||||
|
||||
(model_config, cache_config, parallel_config, scheduler_config,
|
||||
_) = engine_args.create_engine_configs()
|
||||
device_config, _) = engine_args.create_engine_configs()
|
||||
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
@ -93,6 +93,7 @@ def create_worker(cls: type,
|
||||
model_config=model_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=device_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
|
@ -6,7 +6,7 @@ from vllm.worker.model_runner import ModelRunner
|
||||
|
||||
|
||||
def test_prepare_prompt():
|
||||
model_runner = ModelRunner(None, None, None, None)
|
||||
model_runner = ModelRunner(None, None, None, None, None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
batch_size = random.randint(1, 256)
|
||||
|
@ -444,6 +444,12 @@ class SchedulerConfig:
|
||||
f"({self.max_num_seqs}).")
|
||||
|
||||
|
||||
class DeviceConfig:
|
||||
|
||||
def __init__(self, device: str = "cuda") -> None:
|
||||
self.device = torch.device(device)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
max_lora_rank: int
|
||||
|
@ -3,8 +3,8 @@ import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, LoRAConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -43,6 +43,7 @@ class EngineArgs:
|
||||
lora_extra_vocab_size: int = 256
|
||||
lora_dtype = 'auto'
|
||||
max_cpu_loras: Optional[int] = None
|
||||
device: str = 'cuda'
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
@ -127,13 +128,13 @@ class EngineArgs:
|
||||
'--kv-cache-dtype',
|
||||
type=str,
|
||||
choices=['auto', 'fp8_e5m2'],
|
||||
default='auto',
|
||||
default=EngineArgs.kv_cache_dtype,
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. Note FP8 is not supported when cuda version is '
|
||||
'lower than 11.8.')
|
||||
parser.add_argument('--max-model-len',
|
||||
type=int,
|
||||
default=None,
|
||||
default=EngineArgs.max_model_len,
|
||||
help='model context length. If unspecified, '
|
||||
'will be automatically derived from the model.')
|
||||
# Parallel arguments
|
||||
@ -154,6 +155,7 @@ class EngineArgs:
|
||||
parser.add_argument(
|
||||
'--max-parallel-loading-workers',
|
||||
type=int,
|
||||
default=EngineArgs.max_parallel_loading_workers,
|
||||
help='load model sequentially in multiple batches, '
|
||||
'to avoid RAM OOM when using tensor '
|
||||
'parallel and large models')
|
||||
@ -200,7 +202,7 @@ class EngineArgs:
|
||||
'-q',
|
||||
type=str,
|
||||
choices=['awq', 'gptq', 'squeezellm', None],
|
||||
default=None,
|
||||
default=EngineArgs.quantization,
|
||||
help='Method used to quantize the weights. If '
|
||||
'None, we first check the `quantization_config` '
|
||||
'attribute in the model config file. If that is '
|
||||
@ -255,6 +257,13 @@ class EngineArgs:
|
||||
help=('Maximum number of LoRAs to store in CPU memory. '
|
||||
'Must be >= than max_num_seqs. '
|
||||
'Defaults to max_num_seqs.'))
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default=EngineArgs.device,
|
||||
choices=["cuda"],
|
||||
help=('Device type for vLLM execution. '
|
||||
'Currently, only CUDA-compatible devices are supported.'))
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -268,7 +277,8 @@ class EngineArgs:
|
||||
def create_engine_configs(
|
||||
self,
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
|
||||
Optional[LoRAConfig]]:
|
||||
DeviceConfig, Optional[LoRAConfig]]:
|
||||
device_config = DeviceConfig(self.device)
|
||||
model_config = ModelConfig(self.model, self.tokenizer,
|
||||
self.tokenizer_mode, self.trust_remote_code,
|
||||
self.download_dir, self.load_format,
|
||||
@ -296,7 +306,8 @@ class EngineArgs:
|
||||
lora_dtype=self.lora_dtype,
|
||||
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
||||
return model_config, cache_config, parallel_config, scheduler_config, lora_config
|
||||
return (model_config, cache_config, parallel_config, scheduler_config,
|
||||
device_config, lora_config)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -6,8 +6,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
|
||||
Union)
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, LoRAConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||
from vllm.core.scheduler import Scheduler, SchedulerOutputs
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.metrics import StatLogger, Stats
|
||||
@ -53,6 +53,7 @@ class LLMEngine:
|
||||
management.
|
||||
parallel_config: The configuration related to distributed execution.
|
||||
scheduler_config: The configuration related to the request scheduler.
|
||||
device_config: The configuration related to the device.
|
||||
placement_group: Ray placement group for distributed execution.
|
||||
Required for distributed execution.
|
||||
log_stats: Whether to log statistics.
|
||||
@ -64,6 +65,7 @@ class LLMEngine:
|
||||
cache_config: CacheConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
placement_group: Optional["PlacementGroup"],
|
||||
log_stats: bool,
|
||||
@ -85,6 +87,7 @@ class LLMEngine:
|
||||
f"quantization={model_config.quantization}, "
|
||||
f"enforce_eager={model_config.enforce_eager}, "
|
||||
f"kv_cache_dtype={cache_config.cache_dtype}, "
|
||||
f"device_config={device_config.device}, "
|
||||
f"seed={model_config.seed})")
|
||||
# TODO(woosuk): Print more configs in debug mode.
|
||||
|
||||
@ -93,6 +96,7 @@ class LLMEngine:
|
||||
self.lora_config = lora_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.log_stats = log_stats
|
||||
self._verify_args()
|
||||
|
||||
@ -138,6 +142,7 @@ class LLMEngine:
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
self.device_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
@ -233,6 +238,7 @@ class LLMEngine:
|
||||
model_config = copy.deepcopy(self.model_config)
|
||||
parallel_config = copy.deepcopy(self.parallel_config)
|
||||
scheduler_config = copy.deepcopy(self.scheduler_config)
|
||||
device_config = copy.deepcopy(self.device_config)
|
||||
|
||||
for rank, (worker, (node_id,
|
||||
_)) in enumerate(zip(self.workers,
|
||||
@ -244,6 +250,7 @@ class LLMEngine:
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
local_rank,
|
||||
rank,
|
||||
distributed_init_method,
|
||||
@ -257,6 +264,7 @@ class LLMEngine:
|
||||
model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
driver_local_rank,
|
||||
driver_rank,
|
||||
distributed_init_method,
|
||||
|
@ -89,9 +89,7 @@ class ScaledActivation(nn.Module):
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.scales = nn.Parameter(
|
||||
torch.empty(intermediate_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
device="cuda"))
|
||||
torch.empty(intermediate_size_per_partition, dtype=params_dtype))
|
||||
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -200,7 +200,7 @@ def _make_alibi_bias(
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
) -> LowerTriangularMaskWithTensorBias:
|
||||
bias = torch.arange(seq_len, dtype=dtype, device="cuda")
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
|
@ -54,7 +54,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
@ -113,9 +112,7 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
self.register_parameter(name, weight)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=self.params_dtype))
|
||||
torch.empty(self.output_size, dtype=self.params_dtype))
|
||||
set_weight_attrs(self.bias, {"output_dim": 0})
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
@ -183,7 +180,6 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
@ -509,9 +505,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
torch.empty(self.output_size, dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"output_dim": 0,
|
||||
"weight_loader": self.weight_loader,
|
||||
|
@ -96,7 +96,6 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
torch.empty(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@ -112,7 +111,6 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@ -128,7 +126,6 @@ class AWQLinearMethod(LinearMethodBase):
|
||||
torch.empty(
|
||||
input_size_per_partition // self.quant_config.group_size,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
|
@ -127,7 +127,6 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@ -145,7 +144,6 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@ -156,7 +154,6 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@ -172,7 +169,6 @@ class GPTQLinearMethod(LinearMethodBase):
|
||||
torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
|
@ -80,7 +80,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
requires_grad=False,
|
||||
@ -96,7 +95,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
torch.empty(
|
||||
output_size,
|
||||
self.quant_config.weight_bits**2,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@ -118,12 +116,12 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
if is_hip():
|
||||
out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float)
|
||||
out_f = torch.zeros(out_shape, dtype=torch.float)
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
|
||||
out = out_f.to(dtype=torch.float16)
|
||||
else:
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||
out = torch.zeros(out_shape, dtype=torch.float16)
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
|
||||
|
||||
if bias is not None:
|
||||
|
@ -77,16 +77,13 @@ class RotaryEmbedding(nn.Module):
|
||||
# create the cache on GPU for faster initialization. This may cause
|
||||
# a slight numerical difference between the HF implementation and ours.
|
||||
inv_freq = 1.0 / (base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
|
||||
self.rotary_dim))
|
||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings,
|
||||
dtype=torch.float,
|
||||
device="cuda")
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
@ -174,7 +171,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
|
||||
# Thus, the maximum length after applying the rope scaling is
|
||||
# self.max_position_embeddings * self.scaling_factor.
|
||||
max_len = self.max_position_embeddings * self.scaling_factor
|
||||
t = torch.arange(max_len, dtype=torch.float, device="cuda")
|
||||
t = torch.arange(max_len, dtype=torch.float)
|
||||
t = t / self.scaling_factor
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
@ -214,7 +211,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
|
||||
(self.scaling_factor - 1))**(self.rotary_dim /
|
||||
(self.rotary_dim - 2))
|
||||
inv_freq = self._compute_inv_freq(base)
|
||||
t = torch.arange(max_len, dtype=torch.float, device="cuda")
|
||||
t = torch.arange(max_len, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
@ -297,9 +294,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style)
|
||||
|
||||
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
||||
pos_freqs = self.base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
|
||||
self.rotary_dim)
|
||||
pos_freqs = self.base**(
|
||||
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
|
||||
self.rotary_dim)
|
||||
inv_freq_extrapolation = 1.0 / pos_freqs
|
||||
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
|
||||
|
||||
@ -308,8 +305,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
self.max_position_embeddings)
|
||||
# Get n-d rotational scaling corrected for extrapolation
|
||||
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
|
||||
low, high, self.rotary_dim // 2, dtype=torch.float,
|
||||
device="cuda")) * self.extrapolation_factor
|
||||
low, high, self.rotary_dim // 2,
|
||||
dtype=torch.float)) * self.extrapolation_factor
|
||||
inv_freq = inv_freq_interpolation * (
|
||||
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
|
||||
return inv_freq
|
||||
@ -317,7 +314,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
inv_freq = self._compute_inv_freq(self.scaling_factor)
|
||||
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = (freqs.cos() * self.mscale)
|
||||
|
@ -77,7 +77,6 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
self.embedding_dim,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.weight, {
|
||||
"parallel_dim": 0,
|
||||
@ -139,7 +138,6 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_weight_attrs(self.bias, {
|
||||
"parallel_dim": 0,
|
||||
|
@ -5,7 +5,7 @@ from typing import Optional, Type
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig, LoRAConfig
|
||||
from vllm.config import DeviceConfig, ModelConfig, LoRAConfig
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||
initialize_dummy_weights)
|
||||
@ -38,6 +38,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
|
||||
|
||||
|
||||
def get_model(model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
|
||||
model_class = _get_model_architecture(model_config)
|
||||
|
||||
@ -64,7 +65,7 @@ def get_model(model_config: ModelConfig,
|
||||
with _set_default_torch_dtype(model_config.dtype):
|
||||
# Create a model instance.
|
||||
# The weights will be initialized as empty tensors.
|
||||
with torch.device("cuda"):
|
||||
with torch.device(device_config.device):
|
||||
if getattr(model_class, "supports_lora", False):
|
||||
model = model_class(model_config.hf_config, linear_method,
|
||||
lora_config)
|
||||
|
@ -228,7 +228,8 @@ def create_kv_caches_with_random(
|
||||
device: Optional[str] = "cuda",
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
if isinstance(cache_dtype, str):
|
||||
if cache_dtype == "auto":
|
||||
|
@ -104,11 +104,13 @@ class CacheEngine:
|
||||
size=(self.num_cpu_blocks, *key_block_shape),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device="cpu",
|
||||
)
|
||||
value_blocks = torch.empty(
|
||||
size=(self.num_cpu_blocks, *value_block_shape),
|
||||
dtype=self.dtype,
|
||||
pin_memory=pin_memory,
|
||||
device="cpu",
|
||||
)
|
||||
cpu_cache.append((key_blocks, value_blocks))
|
||||
return cpu_cache
|
||||
|
@ -5,7 +5,7 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
|
||||
from vllm.config import DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
@ -35,6 +35,7 @@ class ModelRunner:
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
@ -49,7 +50,10 @@ class ModelRunner:
|
||||
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
|
||||
self.sliding_window = (model_config.get_sliding_window()
|
||||
if model_config is not None else None)
|
||||
self.device = torch.device(torch.cuda.current_device())
|
||||
self.device_config = (device_config
|
||||
if device_config is not None else DeviceConfig())
|
||||
self.device = self.device_config.device
|
||||
|
||||
self.model = None
|
||||
self.block_size = None # Set after initial profiling.
|
||||
self.lora_manager = None
|
||||
@ -72,7 +76,8 @@ class ModelRunner:
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(self.model_config, self.lora_config)
|
||||
self.model = get_model(self.model_config, self.device_config,
|
||||
self.lora_config)
|
||||
|
||||
vocab_size = self.model.config.vocab_size
|
||||
|
||||
@ -182,22 +187,25 @@ class ModelRunner:
|
||||
input_tokens = _make_tensor_with_pad(input_tokens,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
dtype=torch.long)
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = _make_tensor_with_pad(input_positions,
|
||||
max_prompt_len,
|
||||
pad=0,
|
||||
dtype=torch.long)
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
||||
max_prompt_len,
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=torch.long)
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
lora_index_mapping = [
|
||||
_pad_to_max(mapping, max_prompt_len, pad=0)
|
||||
for mapping in lora_index_mapping
|
||||
]
|
||||
context_lens_tensor = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device='cuda')
|
||||
device=self.device)
|
||||
# Prepare prefix block tables
|
||||
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
|
||||
block_tables = _make_tensor_with_pad(
|
||||
@ -205,15 +213,16 @@ class ModelRunner:
|
||||
max_len=max_prompt_block_table_len,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
)
|
||||
start_loc_tensor = torch.arange(0,
|
||||
len(prompt_lens) * max_prompt_len,
|
||||
max_prompt_len,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
device=self.device)
|
||||
prompt_lens_tensor = torch.tensor(prompt_lens,
|
||||
dtype=torch.long,
|
||||
device='cuda')
|
||||
device=self.device)
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
is_prompt=True,
|
||||
@ -305,20 +314,20 @@ class ModelRunner:
|
||||
max_len=1,
|
||||
pad=0,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
device=self.device)
|
||||
input_positions = _make_tensor_with_pad(input_positions,
|
||||
max_len=1,
|
||||
pad=0,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
device=self.device)
|
||||
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
||||
max_len=1,
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
device=self.device)
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device="cuda")
|
||||
device=self.device)
|
||||
|
||||
if use_captured_graph:
|
||||
# The shape of graph_block_tables is
|
||||
@ -327,7 +336,7 @@ class ModelRunner:
|
||||
for i, block_table in enumerate(block_tables):
|
||||
if block_table:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
block_tables = torch.tensor(input_block_tables, device="cuda")
|
||||
block_tables = torch.tensor(input_block_tables, device=self.device)
|
||||
else:
|
||||
max_block_table_len = max(
|
||||
len(block_table) for block_table in block_tables)
|
||||
@ -336,7 +345,7 @@ class ModelRunner:
|
||||
max_len=max_block_table_len,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device="cuda",
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
lora_index_mapping = [
|
||||
@ -355,7 +364,8 @@ class ModelRunner:
|
||||
use_cuda_graph=use_captured_graph,
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
)
|
||||
return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests
|
||||
return (input_tokens, input_positions, input_metadata,
|
||||
lora_index_mapping, lora_prompt_mapping, lora_requests)
|
||||
|
||||
def _prepare_sample(
|
||||
self,
|
||||
@ -410,9 +420,13 @@ class ModelRunner:
|
||||
|
||||
selected_token_indices = _async_h2d(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
target_device=self.device,
|
||||
pin_memory=not self.in_wsl)
|
||||
categorized_sample_indices = {
|
||||
t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
|
||||
t: _async_h2d(seq_ids,
|
||||
dtype=torch.int,
|
||||
target_device=self.device,
|
||||
pin_memory=not self.in_wsl)
|
||||
for t, seq_ids in categorized_sample_indices.items()
|
||||
}
|
||||
|
||||
@ -511,7 +525,8 @@ class ModelRunner:
|
||||
perform_sampling=False,
|
||||
)
|
||||
|
||||
return input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping
|
||||
return (input_tokens, input_positions, input_metadata,
|
||||
sampling_metadata, lora_requests, lora_mapping)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
@ -519,8 +534,9 @@ class ModelRunner:
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> Optional[SamplerOutput]:
|
||||
input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping = (
|
||||
self.prepare_input_tensors(seq_group_metadata_list))
|
||||
(input_tokens, input_positions, input_metadata, sampling_metadata,
|
||||
lora_requests,
|
||||
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
|
||||
|
||||
if self.lora_config:
|
||||
self.set_active_loras(lora_requests, lora_mapping)
|
||||
@ -789,14 +805,10 @@ def _make_tensor_with_pad(
|
||||
max_len: int,
|
||||
pad: int,
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
pin_memory: bool = False,
|
||||
device: Optional[Union[str, torch.device]],
|
||||
) -> torch.Tensor:
|
||||
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
||||
return torch.tensor(padded_x,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory and str(device) == "cpu")
|
||||
return torch.tensor(padded_x, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def _get_graph_batch_size(batch_size: int) -> int:
|
||||
@ -808,6 +820,11 @@ def _get_graph_batch_size(batch_size: int) -> int:
|
||||
return (batch_size + 7) // 8 * 8
|
||||
|
||||
|
||||
def _async_h2d(data: list, dtype, pin_memory):
|
||||
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
|
||||
return t.to(device="cuda", non_blocking=True)
|
||||
def _async_h2d(
|
||||
data: list,
|
||||
dtype: torch.dtype,
|
||||
target_device: Union[str, torch.device],
|
||||
pin_memory: bool,
|
||||
) -> torch.Tensor:
|
||||
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
|
||||
return t.to(device=target_device, non_blocking=True)
|
||||
|
@ -6,8 +6,8 @@ from typing import Dict, List, Tuple, Set, Optional
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, LoRAConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
broadcast_tensor_dict)
|
||||
@ -33,6 +33,7 @@ class Worker:
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
@ -43,6 +44,7 @@ class Worker:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
@ -54,6 +56,7 @@ class Worker:
|
||||
self.model_runner = ModelRunner(model_config,
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker)
|
||||
@ -65,21 +68,24 @@ class Worker:
|
||||
self.gpu_cache = None
|
||||
|
||||
def init_model(self) -> None:
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
# as the number of all_reduce calls increases. This env var disables
|
||||
# this behavior.
|
||||
# Related issue:
|
||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
if self.device_config.device.type == "cuda":
|
||||
# torch.distributed.all_reduce does not free the input tensor until
|
||||
# the synchronization point. This causes the memory usage to grow
|
||||
# as the number of all_reduce calls increases. This env var disables
|
||||
# this behavior.
|
||||
# Related issue:
|
||||
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
|
||||
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
||||
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
torch.cuda.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
# This env var set by Ray causes exceptions with graph building.
|
||||
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
|
||||
self.device = torch.device(f"cuda:{self.local_rank}")
|
||||
torch.cuda.set_device(self.device)
|
||||
|
||||
_check_if_gpu_supports_dtype(self.model_config.dtype)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
init_distributed_environment(self.parallel_config, self.rank,
|
||||
self.distributed_init_method)
|
||||
|
Loading…
x
Reference in New Issue
Block a user