From 96b6f475dda40a0c7d557f73c36fe09c07be2e9c Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Fri, 2 Feb 2024 07:46:39 +0800 Subject: [PATCH] Remove hardcoded `device="cuda" ` to support more devices (#2503) Co-authored-by: Jiang Li Co-authored-by: Kunshang Ji --- benchmarks/benchmark_latency.py | 7 ++ benchmarks/benchmark_throughput.py | 10 ++- .../kernels/benchmark_paged_attention.py | 27 ++++--- tests/kernels/test_activation.py | 37 +++++---- tests/kernels/test_attention.py | 51 ++++++------- tests/kernels/test_cache.py | 40 +++++----- tests/kernels/test_layernorm.py | 17 +++-- tests/kernels/test_pos_encoding.py | 22 +++--- tests/kernels/test_prefix_prefill.py | 57 +++++--------- tests/lora/conftest.py | 4 +- tests/lora/test_layers.py | 36 +++++---- tests/lora/test_worker.py | 4 +- tests/samplers/test_rejection_sampler.py | 64 +++++++--------- tests/samplers/test_sampler.py | 37 +++++---- tests/worker/spec_decode/utils.py | 3 +- tests/worker/test_model_runner.py | 2 +- vllm/config.py | 6 ++ vllm/engine/arg_utils.py | 25 +++++-- vllm/engine/llm_engine.py | 12 ++- vllm/model_executor/layers/activation.py | 4 +- vllm/model_executor/layers/attention.py | 2 +- vllm/model_executor/layers/linear.py | 10 +-- .../model_executor/layers/quantization/awq.py | 3 - .../layers/quantization/gptq.py | 4 - .../layers/quantization/squeezellm.py | 6 +- .../model_executor/layers/rotary_embedding.py | 22 +++--- .../layers/vocab_parallel_embedding.py | 2 - vllm/model_executor/model_loader.py | 5 +- vllm/utils.py | 3 +- vllm/worker/cache_engine.py | 2 + vllm/worker/model_runner.py | 75 ++++++++++++------- vllm/worker/worker.py | 36 +++++---- 32 files changed, 343 insertions(+), 292 deletions(-) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 71731343..2eb9e2cb 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -25,6 +25,7 @@ def main(args: argparse.Namespace): dtype=args.dtype, enforce_eager=args.enforce_eager, kv_cache_dtype=args.kv_cache_dtype, + device=args.device, ) sampling_params = SamplingParams( @@ -135,5 +136,11 @@ if __name__ == '__main__': default=None, help=('path to save the pytorch profiler output. Can be visualized ' 'with ui.perfetto.dev or Tensorboard.')) + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda"], + help='device type for vLLM execution, supporting CUDA only currently.') args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index d45d3330..1ad50252 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -72,6 +72,7 @@ def run_vllm( max_model_len: Optional[int], enforce_eager: bool, kv_cache_dtype: str, + device: str, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -85,6 +86,7 @@ def run_vllm( max_model_len=max_model_len, enforce_eager=enforce_eager, kv_cache_dtype=kv_cache_dtype, + device=device, ) # Add the requests to the engine. @@ -209,7 +211,7 @@ def main(args: argparse.Namespace): args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, args.max_model_len, args.enforce_eager, - args.kv_cache_dtype) + args.kv_cache_dtype, args.device) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -294,6 +296,12 @@ if __name__ == "__main__": default="auto", help= 'Data type for kv cache storage. If "auto", will use model data type.') + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda"], + help='device type for vLLM execution, supporting CUDA only currently.') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 56fe1b92..d921dea1 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -25,18 +25,20 @@ def main( dtype: torch.dtype, seed: int, do_profile: bool, + device: str = "cuda", kv_cache_dtype: Optional[str] = None, ) -> None: random.seed(seed) torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) scale = float(1.0 / (head_size**0.5)) query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype, - device="cuda") + device=device) query.uniform_(-scale, scale) assert num_query_heads % num_kv_heads == 0 @@ -44,11 +46,11 @@ def main( if use_alibi: alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, - device="cuda") + device=device) context_lens = [context_len for _ in range(num_seqs)] max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") + context_lens = torch.tensor(context_lens, dtype=torch.int, device=device) # Create the block tables. max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size @@ -59,12 +61,17 @@ def main( for _ in range(max_num_blocks_per_seq) ] block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") + block_tables = torch.tensor(block_tables, dtype=torch.int, device=device) # Create the KV cache. - key_caches, value_caches = create_kv_caches_with_random( - NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype, - dtype) + key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + device=device) key_cache, value_cache = key_caches[0], value_caches[0] # Prepare for the paged attention kernel. @@ -84,7 +91,7 @@ def main( ) max_logits = torch.empty_like(exp_sums) - def run_benchmark(num_iters: int, profile: bool = False) -> float: + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: torch.cuda.synchronize() if profile: torch.cuda.cudart().cudaProfilerStart() @@ -135,6 +142,7 @@ def main( # Warmup. print("Warming up...") + run_benchmark = run_cuda_benchmark run_benchmark(num_iters=3, profile=False) # Benchmark. @@ -175,6 +183,7 @@ if __name__ == '__main__': default="auto", help= 'Data type for kv cache storage. If "auto", will use model data type.') + parser.add_argument("--device", type=str, choices=["cuda"], default="cuda") args = parser.parse_args() print(args) diff --git a/tests/kernels/test_activation.py b/tests/kernels/test_activation.py index 826bf835..de0b4970 100644 --- a/tests/kernels/test_activation.py +++ b/tests/kernels/test_activation.py @@ -7,26 +7,29 @@ DTYPES = [torch.half, torch.bfloat16, torch.float] NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing D = [512, 4096, 5120, 13824] # Arbitrary values for testing SEEDS = [0] -DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_silu_and_mul( num_tokens: int, d: int, dtype: torch.dtype, seed: int, - device: int, + device: str, ) -> None: torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" - x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + x = torch.randn(num_tokens, 2 * d, dtype=dtype) layer = SiluAndMul() out = layer(x) ref_out = layer._forward(x) @@ -37,19 +40,20 @@ def test_silu_and_mul( @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_gelu_new( num_tokens: int, d: int, dtype: torch.dtype, seed: int, - device: int, + device: str, ) -> None: torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" - x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + x = torch.randn(num_tokens, d, dtype=dtype) layer = NewGELU() out = layer(x) ref_out = layer._forward(x) @@ -60,18 +64,19 @@ def test_gelu_new( @pytest.mark.parametrize("d", D) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("device", CUDA_DEVICES) def test_gelu_fast( num_tokens: int, d: int, dtype: torch.dtype, seed: int, - device: int, + device: str, ) -> None: torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" - x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + x = torch.randn(num_tokens, d, dtype=dtype) layer = FastGELU() out = layer(x) ref_out = layer._forward(x) diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index cbb1d406..92d63eb6 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -27,7 +27,9 @@ BLOCK_SIZES = [16, 32] USE_ALIBI = [False, True] KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] SEEDS = [0] -DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] def ref_masked_attention( @@ -91,7 +93,7 @@ def ref_single_query_cached_kv_attention( alibi_bias = None if alibi_slopes is not None: # Create the ALiBi bias used in the paged attention kernel. - position_ids = torch.arange(context_len, device=query.device).int() + position_ids = torch.arange(context_len).int() alibi_bias = (position_ids - context_len + 1).float() alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( 1, 1, -1) @@ -110,7 +112,7 @@ def ref_single_query_cached_kv_attention( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("device", CUDA_DEVICES) def test_paged_attention( kv_cache_factory, version: str, @@ -122,33 +124,28 @@ def test_paged_attention( dtype: torch.dtype, kv_cache_dtype: str, seed: int, - device: int, + device: str, ) -> None: random.seed(seed) torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) scale = float(1.0 / (head_size**0.5)) num_query_heads, num_kv_heads = num_heads - query = torch.empty(num_seqs, - num_query_heads, - head_size, - dtype=dtype, - device=gpu_id) + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) query.uniform_(-scale, scale) assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads alibi_slopes = None if use_alibi: - alibi_slopes = torch.randn(num_query_heads, - dtype=torch.float, - device=gpu_id) + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] context_lens[-1] = MAX_SEQ_LEN max_context_len = max(context_lens) - context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id) + context_lens = torch.tensor(context_lens, dtype=torch.int) # Create the block tables. max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size @@ -159,13 +156,13 @@ def test_paged_attention( for _ in range(max_num_blocks_per_seq) ] block_tables.append(block_table) - block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id) + block_tables = torch.tensor(block_tables, dtype=torch.int) # Create the KV caches. key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype, dtype, seed, - gpu_id) + device) key_cache, value_cache = key_caches[0], value_caches[0] # Call the paged attention kernel. @@ -193,12 +190,10 @@ def test_paged_attention( tmp_output = torch.empty( size=(num_seqs, num_heads, num_partitions, head_size), dtype=output.dtype, - device=output.device, ) exp_sums = torch.empty( size=(num_seqs, num_heads, num_partitions), dtype=torch.float32, - device=output.device, ) max_logits = torch.empty_like(exp_sums) ops.paged_attention_v2( @@ -229,14 +224,14 @@ def test_paged_attention( block_size, x) dequantized_key_cache = torch.empty(size=key_cache_shape, dtype=dtype, - device=gpu_id) + device=device) cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache) key_cache = dequantized_key_cache value_cache_shape = value_cache.shape dequantized_value_cache = torch.empty(size=value_cache_shape, dtype=dtype, - device=gpu_id) + device=device) cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache) value_cache = dequantized_value_cache @@ -283,7 +278,7 @@ def ref_multi_query_kv_attention( attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), diagonal=1) attn_mask = attn_mask * torch.finfo(dtype).min - attn_mask = attn_mask.to(dtype=dtype, device=query.device) + attn_mask = attn_mask.to(dtype=dtype) ref_output = ref_masked_attention( query[start_idx:end_idx], @@ -303,7 +298,7 @@ def ref_multi_query_kv_attention( @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_multi_query_kv_attention( num_seqs: int, @@ -311,12 +306,13 @@ def test_multi_query_kv_attention( head_size: int, dtype: torch.dtype, seed: int, - device: int, + device: str, ) -> None: random.seed(seed) torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. # As the xformers library is already tested with its own tests, we can use # a smaller MAX_SEQ_LEN here. @@ -329,8 +325,7 @@ def test_multi_query_kv_attention( qkv = torch.empty(num_tokens, num_query_heads + 2 * num_kv_heads, head_size, - dtype=dtype, - device=gpu_id) + dtype=dtype) qkv.uniform_(-scale, scale) query, key, value = qkv.split( [num_query_heads, num_kv_heads, num_kv_heads], dim=1) diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 275ef819..a90492f5 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -17,7 +17,9 @@ BLOCK_SIZES = [8, 16, 32] NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing NUM_MAPPINGS = [256] # Arbitrary values for testing SEEDS = [0] -DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] @@ -29,7 +31,7 @@ KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("device", CUDA_DEVICES) @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) @torch.inference_mode() def test_copy_blocks( @@ -42,13 +44,14 @@ def test_copy_blocks( num_blocks: int, dtype: torch.dtype, seed: int, - device: int, kv_cache_dtype: str, + device: str, ) -> None: random.seed(seed) torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) # Generate random block mappings where each source block is mapped to two # destination blocks. assert 2 * num_mappings <= num_blocks @@ -66,7 +69,7 @@ def test_copy_blocks( key_caches, value_caches = kv_cache_factory(num_blocks, block_size, num_layers, num_heads, head_size, kv_cache_dtype, - dtype, seed, gpu_id) + dtype, seed, device) # Clone the KV caches. cloned_key_caches = [key_cache.clone() for key_cache in key_caches] @@ -98,7 +101,7 @@ def test_copy_blocks( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_reshape_and_cache( kv_cache_factory, @@ -109,29 +112,25 @@ def test_reshape_and_cache( num_blocks: int, dtype: torch.dtype, seed: int, - device: int, + device: str, ) -> None: random.seed(seed) torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) # Create a random slot mapping. num_slots = block_size * num_blocks slot_mapping = random.sample(range(num_slots), num_tokens) - slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long) - qkv = torch.randn(num_tokens, - 3, - num_heads, - head_size, - dtype=dtype, - device=gpu_id) + qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype) _, key, value = qkv.unbind(dim=1) # Create the KV caches. key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, num_heads, head_size, dtype, - None, seed, gpu_id) + None, seed, device) key_cache, value_cache = key_caches[0], value_caches[0] # Clone the KV caches. @@ -166,7 +165,7 @@ def test_reshape_and_cache( @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_swap_blocks( kv_cache_factory, @@ -182,7 +181,8 @@ def test_swap_blocks( ) -> None: random.seed(seed) torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) src_device = f"{direction[0]}:{device}" if direction[ 0] == "cuda" else direction[0] dst_device = f"{direction[1]}:{device}" if direction[ diff --git a/tests/kernels/test_layernorm.py b/tests/kernels/test_layernorm.py index 8a06b3aa..b1e3c1a7 100644 --- a/tests/kernels/test_layernorm.py +++ b/tests/kernels/test_layernorm.py @@ -8,7 +8,9 @@ NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing ADD_RESIDUAL = [False, True] SEEDS = [0] -DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] @pytest.mark.parametrize("num_tokens", NUM_TOKENS) @@ -16,7 +18,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_rms_norm( num_tokens: int, @@ -24,15 +26,16 @@ def test_rms_norm( add_residual: bool, dtype: torch.dtype, seed: int, - device: int, + device: str, ) -> None: torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" - layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + layer = RMSNorm(hidden_size).to(dtype=dtype) layer.weight.data.normal_(mean=1.0, std=0.1) scale = 1 / (2 * hidden_size) - x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=gpu_id) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) x *= scale residual = torch.randn_like(x) * scale if add_residual else None diff --git a/tests/kernels/test_pos_encoding.py b/tests/kernels/test_pos_encoding.py index aad310e2..19cbd600 100644 --- a/tests/kernels/test_pos_encoding.py +++ b/tests/kernels/test_pos_encoding.py @@ -13,7 +13,9 @@ NUM_HEADS = [7, 17] # Arbitrary values for testing BATCH_SIZES = [1, 5] # Arbitrary values for testing SEQ_LENS = [11, 8192] # Arbitrary values for testing SEEDS = [0] -DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -24,7 +26,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_rotary_embedding( is_neox_style: bool, @@ -35,28 +37,26 @@ def test_rotary_embedding( rotary_dim: Optional[int], dtype: torch.dtype, seed: int, - device: int, + device: str, max_position: int = 8192, base: int = 10000, ) -> None: if rotary_dim is None: rotary_dim = head_size torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) - gpu_id = f"cuda:{device}" + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) - rope = rope.to(dtype=dtype, device=gpu_id) + rope = rope.to(dtype=dtype) - positions = torch.randint(0, - max_position, (batch_size, seq_len), - device=gpu_id) + positions = torch.randint(0, max_position, (batch_size, seq_len)) query = torch.randn(batch_size, seq_len, num_heads * head_size, - dtype=dtype, - device=gpu_id) + dtype=dtype) key = torch.randn_like(query) # NOTE(woosuk): The reference implementation should be executed first diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 0531b051..ac93b325 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -11,19 +11,27 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask NUM_HEADS = [12] HEAD_SIZES = [128] DTYPES = [torch.float16] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] @pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_contexted_kv_attention( num_heads: int, head_size: int, dtype: torch.dtype, + device: str, ) -> None: random.seed(0) torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed(0) + torch.set_default_device(device) MAX_SEQ_LEN = 1024 MAX_CTX_LEN = 1024 BS = 10 @@ -35,24 +43,11 @@ def test_contexted_kv_attention( seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] num_tokens = sum(subquery_lens) - query = torch.empty(num_tokens, - num_heads, - head_size, - dtype=dtype, - device='cuda') + query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) query.uniform_(-1e-3, 1e-3) - output = torch.empty(num_tokens, - num_heads, - head_size, - dtype=dtype, - device='cuda') + output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype) - kv = torch.empty(sum(seq_lens), - 2, - num_heads, - head_size, - dtype=dtype, - device='cuda') + kv = torch.empty(sum(seq_lens), 2, num_heads, head_size, dtype=dtype) kv.uniform_(-1e-3, 1e-3) key, value = kv.unbind(dim=1) @@ -60,39 +55,27 @@ def test_contexted_kv_attention( block_size, num_heads, head_size, - dtype=dtype, - device='cuda') + dtype=dtype) v_cache = torch.zeros(cache_size, block_size, num_heads, head_size, - dtype=dtype, - device='cuda') - k = torch.zeros(sum(subquery_lens), - num_heads, - head_size, - dtype=dtype, - device='cuda') - v = torch.zeros(sum(subquery_lens), - num_heads, - head_size, - dtype=dtype, - device='cuda') - values = torch.arange(0, cache_size, dtype=torch.long, device='cuda') + dtype=dtype) + k = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype) + v = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype) + values = torch.arange(0, cache_size, dtype=torch.long) values = values[torch.randperm(cache_size)] block_table = values[:BS * max_block_per_request].view( BS, max_block_per_request) - b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda') - b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda') + b_seq_len = torch.tensor(seq_lens, dtype=torch.long) + b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], - dtype=torch.long, - device='cuda'), + dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN # copy kv to cache b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], - dtype=torch.long, - device='cuda'), + dtype=torch.long), dim=0) for i in range(BS): for j in range(subquery_lens[i]): diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index c1b3d04c..163c3c70 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -126,8 +126,8 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() get_model_old = get_model - def get_model_patched(model_config, lora_config=None): - return get_model_old(model_config, + def get_model_patched(model_config, device_config, lora_config=None): + return get_model_old(model_config, device_config, LoRAConfig(max_loras=4, max_lora_rank=8)) with patch("vllm.worker.model_runner.get_model", get_model_patched): diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 71c67113..f739bbea 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -34,6 +34,9 @@ TOLERANCES = { torch.float32: (5e-3, 5e-3), torch.bfloat16: (3e-2, 2e-2), } +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] def get_random_id_to_index(num_loras: int, @@ -151,14 +154,10 @@ def create_random_inputs( for _ in range(num_inputs): if input_type == torch.int: inputs.append( - torch.randint(low=int(low), - high=int(high), - size=input_size, - device="cuda")) + torch.randint(low=int(low), high=int(high), size=input_size)) else: inputs.append( - torch.rand(size=input_size, dtype=input_type, device="cuda") * - high + low) + torch.rand(size=input_size, dtype=input_type) * high + low) lora_id = random.choice(active_lora_ids) index_mapping += [lora_id] * input_size[0] @@ -169,8 +168,10 @@ def create_random_inputs( @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -def test_embeddings(dist_init, num_loras) -> None: +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_embeddings(dist_init, num_loras, device) -> None: + torch.set_default_device(device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -259,8 +260,10 @@ def test_embeddings(dist_init, num_loras) -> None: @torch.inference_mode() # @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.") @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None: +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None: + torch.set_default_device(device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -305,8 +308,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None: # Add empty embeddings_tensors for unoccupied lora slots. for _ in range(max_loras - len(embeddings_tensors)): - embeddings_tensors.append( - torch.zeros(embeddings_tensors[0].shape, device="cuda")) + embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape)) inputs, index_mapping, prompt_mapping = create_random_inputs( active_lora_ids=list(lora_dict.keys()), @@ -388,8 +390,10 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None: @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) -def test_lm_head_sampler(dist_init, num_loras) -> None: +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_lm_head_sampler(dist_init, num_loras, device) -> None: + torch.set_default_device(device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -432,7 +436,7 @@ def test_lm_head_sampler(dist_init, num_loras) -> None: ) lora_mapping = LoRAMapping(index_mapping, prompt_mapping) - input_ = torch.rand(20, 1024, device="cuda") + input_ = torch.rand(20, 1024) mapping_info = convert_mapping( lora_mapping, id_to_index, @@ -500,8 +504,10 @@ def test_lm_head_sampler(dist_init, num_loras) -> None: @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) -def test_linear_parallel(dist_init, num_loras, orientation) -> None: +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_linear_parallel(dist_init, num_loras, orientation, device) -> None: + torch.set_default_device(device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, @@ -597,8 +603,10 @@ def test_linear_parallel(dist_init, num_loras, orientation) -> None: @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("repeats", [2, 3]) -def test_column_parallel_packed(dist_init, num_loras, repeats) -> None: +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: + torch.set_default_device(device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 68c2c0b5..31a7c716 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -5,7 +5,8 @@ from unittest.mock import patch from vllm.lora.models import LoRAMapping from vllm.lora.request import LoRARequest -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig +from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig, + DeviceConfig, LoRAConfig) from vllm.worker.worker import Worker @@ -25,6 +26,7 @@ def test_worker_apply_lora(sql_lora_files): ), parallel_config=ParallelConfig(1, 1, False), scheduler_config=SchedulerConfig(32, 32, 32, 256), + device_config=DeviceConfig("cuda"), local_rank=0, rank=0, lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 9d3ef3c6..99ee78ce 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -9,6 +9,10 @@ from vllm.model_executor.utils import set_random_seed from vllm.model_executor.layers.rejection_sampler import RejectionSampler +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + def mock_causal_accepted_tensor( k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor: @@ -39,11 +43,14 @@ def mock_causal_accepted_tensor( @pytest.mark.parametrize( "which_tokens_accepted", ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) +@pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_correct_output_format(which_tokens_accepted: str, seed: int): +def test_correct_output_format(which_tokens_accepted: str, seed: int, + device: str): """Verify the output has correct format given predetermined accepted matrix. """ set_random_seed(seed) + torch.set_default_device(device) batch_size = 10 k = 5 @@ -66,18 +73,15 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int): recovered_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), - dtype=torch.int64, - device="cuda") + dtype=torch.int64) draft_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), - dtype=torch.int64, - device="cuda") + dtype=torch.int64) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), - dtype=torch.int64, - device="cuda") + dtype=torch.int64) rejection_sampler = RejectionSampler() rejection_sampler.init_gpu_tensors(rank=0) @@ -120,31 +124,24 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int): @pytest.mark.parametrize("k", list(range(1, 6))) @pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("batch_size", list(range(1, 32))) +@pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() -def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int): +def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, + device: str): + torch.set_default_device(device) rejection_sampler = RejectionSampler() rejection_sampler.init_gpu_tensors(rank=0) - draft_probs = torch.rand(batch_size, - k, - vocab_size, - dtype=torch.float32, - device="cuda") - target_probs = torch.rand(batch_size, - k, - vocab_size, - dtype=torch.float32, - device="cuda") + draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), - dtype=torch.int64, - device="cuda") + dtype=torch.int64) draft_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), - dtype=torch.int64, - device="cuda") + dtype=torch.int64) rejection_sampler(target_probs, bonus_token_ids, draft_probs, draft_token_ids) @@ -153,36 +150,28 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int): @pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) @pytest.mark.parametrize("which_token_ids", ["bonus_token_ids", "draft_token_ids"]) +@pytest.mark.parametrize("device", CUDA_DEVICES) @torch.inference_mode() def test_raises_when_vocab_oob(above_or_below_vocab_range: str, - which_token_ids: str): + which_token_ids: str, device: str): k = 3 batch_size = 5 vocab_size = 30_000 + torch.set_default_device(device) rejection_sampler = RejectionSampler(strict_mode=True) rejection_sampler.init_gpu_tensors(rank=0) - draft_probs = torch.rand(batch_size, - k, - vocab_size, - dtype=torch.float32, - device="cuda") - target_probs = torch.rand(batch_size, - k, - vocab_size, - dtype=torch.float32, - device="cuda") + draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), - dtype=torch.int64, - device="cuda") + dtype=torch.int64) draft_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), - dtype=torch.int64, - device="cuda") + dtype=torch.int64) oob_token_ids = None if which_token_ids == "bonus_token_ids": @@ -237,6 +226,7 @@ def test_rejection_sampling_approximates_target_distribution( probabilities are exactly equal. Rejection sampling should still work without any NaNs or exceptions. """ + torch.set_default_device("cpu") set_random_seed(seed) helper = _CorrectnessTestHelper( diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 962183a2..d34f32d0 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -31,24 +31,26 @@ def _prepare_test( batch_size: int ) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]: vocab_size = 32000 - input_tensor = torch.rand((batch_size, 1024), - device="cuda", - dtype=torch.float16) + input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16) fake_logits = torch.full((batch_size, vocab_size), 1e-2, - device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(32000, fake_logits) - model_runner = ModelRunner(None, None, None, None) + model_runner = ModelRunner(None, None, None, None, None) return input_tensor, fake_logits, sampler, model_runner RANDOM_SEEDS = list(range(128)) +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] @pytest.mark.parametrize("seed", RANDOM_SEEDS) -def test_sampler_all_greedy(seed: int): +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_all_greedy(seed: int, device: str): set_random_seed(seed) + torch.set_default_device(device) batch_size = random.randint(1, 256) input_tensor, fake_logits, sampler, model_runner = _prepare_test( batch_size) @@ -81,8 +83,10 @@ def test_sampler_all_greedy(seed: int): @pytest.mark.parametrize("seed", RANDOM_SEEDS) -def test_sampler_all_random(seed: int): +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_all_random(seed: int, device: str): set_random_seed(seed) + torch.set_default_device(device) batch_size = random.randint(1, 256) input_tensor, fake_logits, sampler, model_runner = _prepare_test( batch_size) @@ -120,8 +124,10 @@ def test_sampler_all_random(seed: int): @pytest.mark.parametrize("seed", RANDOM_SEEDS) -def test_sampler_all_beam(seed: int): +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_all_beam(seed: int, device: str): set_random_seed(seed) + torch.set_default_device(device) batch_size = random.randint(1, 256) input_tensor, _, sampler, model_runner = _prepare_test(batch_size) @@ -156,8 +162,10 @@ def test_sampler_all_beam(seed: int): @pytest.mark.parametrize("seed", RANDOM_SEEDS) -def test_sampler_mixed(seed: int): +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_mixed(seed: int, device: str): set_random_seed(seed) + torch.set_default_device(device) batch_size = random.randint(1, 256) input_tensor, fake_logits, sampler, model_runner = _prepare_test( batch_size) @@ -212,8 +220,10 @@ def test_sampler_mixed(seed: int): @pytest.mark.parametrize("seed", RANDOM_SEEDS) -def test_sampler_logits_processors(seed: int): +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_logits_processors(seed: int, device: str): set_random_seed(seed) + torch.set_default_device(device) batch_size = random.randint(1, 256) input_tensor, _, sampler, model_runner = _prepare_test(batch_size) @@ -252,14 +262,15 @@ def test_sampler_logits_processors(seed: int): @pytest.mark.parametrize("seed", RANDOM_SEEDS) -def test_sampler_top_k_top_p(seed: int): +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_sampler_top_k_top_p(seed: int, device: str): set_random_seed(seed) batch_size = random.randint(1, 256) top_k = random.randint(100, 500) top_p = random.random() * 0.1 vocab_size = 32000 input_tensor = torch.rand((batch_size, 1024), - device="cuda", + device=device, dtype=torch.float16) fake_logits = torch.normal(0, 5, @@ -267,7 +278,7 @@ def test_sampler_top_k_top_p(seed: int): device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(32000, fake_logits) - model_runner = ModelRunner(None, None, None, None) + model_runner = ModelRunner(None, None, None, None, None) generation_model = GenerationMixin() generation_config = GenerationConfig(top_k=top_k, diff --git a/tests/worker/spec_decode/utils.py b/tests/worker/spec_decode/utils.py index e0db7700..8d74509f 100644 --- a/tests/worker/spec_decode/utils.py +++ b/tests/worker/spec_decode/utils.py @@ -84,7 +84,7 @@ def create_worker(cls: type, ) (model_config, cache_config, parallel_config, scheduler_config, - _) = engine_args.create_engine_configs() + device_config, _) = engine_args.create_engine_configs() distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) @@ -93,6 +93,7 @@ def create_worker(cls: type, model_config=model_config, parallel_config=parallel_config, scheduler_config=scheduler_config, + device_config=device_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 5d9ad052..f44895a7 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -6,7 +6,7 @@ from vllm.worker.model_runner import ModelRunner def test_prepare_prompt(): - model_runner = ModelRunner(None, None, None, None) + model_runner = ModelRunner(None, None, None, None, None) model_runner.set_block_size(16) batch_size = random.randint(1, 256) diff --git a/vllm/config.py b/vllm/config.py index 4fb7357a..1dfc0d63 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -444,6 +444,12 @@ class SchedulerConfig: f"({self.max_num_seqs}).") +class DeviceConfig: + + def __init__(self, device: str = "cuda") -> None: + self.device = torch.device(device) + + @dataclass class LoRAConfig: max_lora_rank: int diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 231ce332..d5e63e25 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -3,8 +3,8 @@ import dataclasses from dataclasses import dataclass from typing import Optional, Tuple -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig, LoRAConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, LoRAConfig) @dataclass @@ -43,6 +43,7 @@ class EngineArgs: lora_extra_vocab_size: int = 256 lora_dtype = 'auto' max_cpu_loras: Optional[int] = None + device: str = 'cuda' def __post_init__(self): if self.tokenizer is None: @@ -127,13 +128,13 @@ class EngineArgs: '--kv-cache-dtype', type=str, choices=['auto', 'fp8_e5m2'], - default='auto', + default=EngineArgs.kv_cache_dtype, help='Data type for kv cache storage. If "auto", will use model ' 'data type. Note FP8 is not supported when cuda version is ' 'lower than 11.8.') parser.add_argument('--max-model-len', type=int, - default=None, + default=EngineArgs.max_model_len, help='model context length. If unspecified, ' 'will be automatically derived from the model.') # Parallel arguments @@ -154,6 +155,7 @@ class EngineArgs: parser.add_argument( '--max-parallel-loading-workers', type=int, + default=EngineArgs.max_parallel_loading_workers, help='load model sequentially in multiple batches, ' 'to avoid RAM OOM when using tensor ' 'parallel and large models') @@ -200,7 +202,7 @@ class EngineArgs: '-q', type=str, choices=['awq', 'gptq', 'squeezellm', None], - default=None, + default=EngineArgs.quantization, help='Method used to quantize the weights. If ' 'None, we first check the `quantization_config` ' 'attribute in the model config file. If that is ' @@ -255,6 +257,13 @@ class EngineArgs: help=('Maximum number of LoRAs to store in CPU memory. ' 'Must be >= than max_num_seqs. ' 'Defaults to max_num_seqs.')) + parser.add_argument( + "--device", + type=str, + default=EngineArgs.device, + choices=["cuda"], + help=('Device type for vLLM execution. ' + 'Currently, only CUDA-compatible devices are supported.')) return parser @classmethod @@ -268,7 +277,8 @@ class EngineArgs: def create_engine_configs( self, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, - Optional[LoRAConfig]]: + DeviceConfig, Optional[LoRAConfig]]: + device_config = DeviceConfig(self.device) model_config = ModelConfig(self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.load_format, @@ -296,7 +306,8 @@ class EngineArgs: lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None) if self.enable_lora else None - return model_config, cache_config, parallel_config, scheduler_config, lora_config + return (model_config, cache_config, parallel_config, scheduler_config, + device_config, lora_config) @dataclass diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e60efc5e..92568450 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -6,8 +6,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union) from vllm.lora.request import LoRARequest -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig, LoRAConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, LoRAConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import StatLogger, Stats @@ -53,6 +53,7 @@ class LLMEngine: management. parallel_config: The configuration related to distributed execution. scheduler_config: The configuration related to the request scheduler. + device_config: The configuration related to the device. placement_group: Ray placement group for distributed execution. Required for distributed execution. log_stats: Whether to log statistics. @@ -64,6 +65,7 @@ class LLMEngine: cache_config: CacheConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, + device_config: DeviceConfig, lora_config: Optional[LoRAConfig], placement_group: Optional["PlacementGroup"], log_stats: bool, @@ -85,6 +87,7 @@ class LLMEngine: f"quantization={model_config.quantization}, " f"enforce_eager={model_config.enforce_eager}, " f"kv_cache_dtype={cache_config.cache_dtype}, " + f"device_config={device_config.device}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. @@ -93,6 +96,7 @@ class LLMEngine: self.lora_config = lora_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.device_config = device_config self.log_stats = log_stats self._verify_args() @@ -138,6 +142,7 @@ class LLMEngine: self.model_config, self.parallel_config, self.scheduler_config, + self.device_config, local_rank=0, rank=0, distributed_init_method=distributed_init_method, @@ -233,6 +238,7 @@ class LLMEngine: model_config = copy.deepcopy(self.model_config) parallel_config = copy.deepcopy(self.parallel_config) scheduler_config = copy.deepcopy(self.scheduler_config) + device_config = copy.deepcopy(self.device_config) for rank, (worker, (node_id, _)) in enumerate(zip(self.workers, @@ -244,6 +250,7 @@ class LLMEngine: model_config, parallel_config, scheduler_config, + device_config, local_rank, rank, distributed_init_method, @@ -257,6 +264,7 @@ class LLMEngine: model_config, parallel_config, scheduler_config, + device_config, driver_local_rank, driver_rank, distributed_init_method, diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 1af120d1..95902ae3 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -89,9 +89,7 @@ class ScaledActivation(nn.Module): if params_dtype is None: params_dtype = torch.get_default_dtype() self.scales = nn.Parameter( - torch.empty(intermediate_size_per_partition, - dtype=params_dtype, - device="cuda")) + torch.empty(intermediate_size_per_partition, dtype=params_dtype)) set_weight_attrs(self.scales, {"weight_loader": self.weight_loader}) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 91ed43f0..2ce9d60f 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -200,7 +200,7 @@ def _make_alibi_bias( seq_len: int, dtype: torch.dtype, ) -> LowerTriangularMaskWithTensorBias: - bias = torch.arange(seq_len, dtype=dtype, device="cuda") + bias = torch.arange(seq_len, dtype=dtype) # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(prompt_len, 1)` # here. We find that both biases give the same results, but diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 5e1d63a6..55d38b76 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -54,7 +54,6 @@ class UnquantizedLinearMethod(LinearMethodBase): params_dtype: torch.dtype) -> Dict[str, Any]: weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, - device=torch.cuda.current_device(), dtype=params_dtype), requires_grad=False) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) @@ -113,9 +112,7 @@ class ReplicatedLinear(torch.nn.Module): self.register_parameter(name, weight) if bias: self.bias = Parameter( - torch.empty(self.output_size, - device=torch.cuda.current_device(), - dtype=self.params_dtype)) + torch.empty(self.output_size, dtype=self.params_dtype)) set_weight_attrs(self.bias, {"output_dim": 0}) else: self.register_parameter("bias", None) @@ -183,7 +180,6 @@ class ColumnParallelLinear(torch.nn.Module): if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, - device=torch.cuda.current_device(), dtype=params_dtype)) set_weight_attrs(self.bias, { "output_dim": 0, @@ -509,9 +505,7 @@ class RowParallelLinear(torch.nn.Module): if bias: self.bias = Parameter( - torch.empty(self.output_size, - device=torch.cuda.current_device(), - dtype=params_dtype)) + torch.empty(self.output_size, dtype=params_dtype)) set_weight_attrs(self.bias, { "output_dim": 0, "weight_loader": self.weight_loader, diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 4d3fd3ec..681f9582 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -96,7 +96,6 @@ class AWQLinearMethod(LinearMethodBase): torch.empty( input_size_per_partition, output_size_per_partition // self.quant_config.pack_factor, - device="cuda", dtype=torch.int32, ), requires_grad=False, @@ -112,7 +111,6 @@ class AWQLinearMethod(LinearMethodBase): torch.empty( input_size_per_partition // self.quant_config.group_size, output_size_per_partition // self.quant_config.pack_factor, - device="cuda", dtype=torch.int32, ), requires_grad=False, @@ -128,7 +126,6 @@ class AWQLinearMethod(LinearMethodBase): torch.empty( input_size_per_partition // self.quant_config.group_size, output_size_per_partition, - device="cuda", dtype=params_dtype, ), requires_grad=False, diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 8fe96e7d..7218760f 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -127,7 +127,6 @@ class GPTQLinearMethod(LinearMethodBase): torch.empty( input_size_per_partition // self.quant_config.pack_factor, output_size_per_partition, - device="cuda", dtype=torch.int32, ), requires_grad=False, @@ -145,7 +144,6 @@ class GPTQLinearMethod(LinearMethodBase): i // self.quant_config.group_size for i in range(input_size_per_partition) ], - device="cuda", dtype=torch.int32, ), requires_grad=False, @@ -156,7 +154,6 @@ class GPTQLinearMethod(LinearMethodBase): torch.empty( scale_and_zero_size, output_size_per_partition // self.quant_config.pack_factor, - device="cuda", dtype=torch.int32, ), requires_grad=False, @@ -172,7 +169,6 @@ class GPTQLinearMethod(LinearMethodBase): torch.empty( scale_and_zero_size, output_size_per_partition, - device="cuda", dtype=params_dtype, ), requires_grad=False, diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 1932bd14..9244e885 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -80,7 +80,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase): torch.empty( input_size_per_partition // self.quant_config.pack_factor, output_size_per_partition, - device="cuda", dtype=torch.int32, ), requires_grad=False, @@ -96,7 +95,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase): torch.empty( output_size, self.quant_config.weight_bits**2, - device="cuda", dtype=params_dtype, ), requires_grad=False, @@ -118,12 +116,12 @@ class SqueezeLLMLinearMethod(LinearMethodBase): out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) if is_hip(): - out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float) + out_f = torch.zeros(out_shape, dtype=torch.float) ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table) out = out_f.to(dtype=torch.float16) else: # NOTE: The output tensor should be zero-initialized. - out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) + out = torch.zeros(out_shape, dtype=torch.float16) ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) if bias is not None: diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 91c093e3..93ec5c12 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -77,16 +77,13 @@ class RotaryEmbedding(nn.Module): # create the cache on GPU for faster initialization. This may cause # a slight numerical difference between the HF implementation and ours. inv_freq = 1.0 / (base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / - self.rotary_dim)) + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) return inv_freq def _compute_cos_sin_cache(self) -> torch.Tensor: """Compute the cos and sin cache.""" inv_freq = self._compute_inv_freq(self.base) - t = torch.arange(self.max_position_embeddings, - dtype=torch.float, - device="cuda") + t = torch.arange(self.max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() @@ -174,7 +171,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): # Thus, the maximum length after applying the rope scaling is # self.max_position_embeddings * self.scaling_factor. max_len = self.max_position_embeddings * self.scaling_factor - t = torch.arange(max_len, dtype=torch.float, device="cuda") + t = torch.arange(max_len, dtype=torch.float) t = t / self.scaling_factor freqs = torch.einsum("i,j -> ij", t, inv_freq) @@ -214,7 +211,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): (self.scaling_factor - 1))**(self.rotary_dim / (self.rotary_dim - 2)) inv_freq = self._compute_inv_freq(base) - t = torch.arange(max_len, dtype=torch.float, device="cuda") + t = torch.arange(max_len, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() @@ -297,9 +294,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): is_neox_style) def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: - pos_freqs = self.base**(torch.arange( - 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / - self.rotary_dim) + pos_freqs = self.base**( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / + self.rotary_dim) inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) @@ -308,8 +305,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): self.max_position_embeddings) # Get n-d rotational scaling corrected for extrapolation inv_freq_mask = (1 - _yarn_linear_ramp_mask( - low, high, self.rotary_dim // 2, dtype=torch.float, - device="cuda")) * self.extrapolation_factor + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor inv_freq = inv_freq_interpolation * ( 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask return inv_freq @@ -317,7 +314,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): def _compute_cos_sin_cache(self) -> torch.Tensor: inv_freq = self._compute_inv_freq(self.scaling_factor) t = torch.arange(self.max_position_embeddings * self.scaling_factor, - device="cuda", dtype=torch.float32) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = (freqs.cos() * self.mscale) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 9c5fb890..6d13cf81 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -77,7 +77,6 @@ class VocabParallelEmbedding(torch.nn.Module): self.weight = Parameter( torch.empty(self.num_embeddings_per_partition, self.embedding_dim, - device=torch.cuda.current_device(), dtype=params_dtype)) set_weight_attrs(self.weight, { "parallel_dim": 0, @@ -139,7 +138,6 @@ class ParallelLMHead(VocabParallelEmbedding): if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, - device=torch.cuda.current_device(), dtype=params_dtype)) set_weight_attrs(self.bias, { "parallel_dim": 0, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index cd21c778..4b1e13d9 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -5,7 +5,7 @@ from typing import Optional, Type import torch import torch.nn as nn -from vllm.config import ModelConfig, LoRAConfig +from vllm.config import DeviceConfig, ModelConfig, LoRAConfig from vllm.model_executor.models import ModelRegistry from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) @@ -38,6 +38,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: def get_model(model_config: ModelConfig, + device_config: DeviceConfig, lora_config: Optional[LoRAConfig] = None) -> nn.Module: model_class = _get_model_architecture(model_config) @@ -64,7 +65,7 @@ def get_model(model_config: ModelConfig, with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. - with torch.device("cuda"): + with torch.device(device_config.device): if getattr(model_class, "supports_lora", False): model = model_class(model_config.hf_config, linear_method, lora_config) diff --git a/vllm/utils.py b/vllm/utils.py index dc817414..9e9126a2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -228,7 +228,8 @@ def create_kv_caches_with_random( device: Optional[str] = "cuda", ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: torch.random.manual_seed(seed) - torch.cuda.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) if isinstance(cache_dtype, str): if cache_dtype == "auto": diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index f57e1ed7..bbe33989 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -104,11 +104,13 @@ class CacheEngine: size=(self.num_cpu_blocks, *key_block_shape), dtype=self.dtype, pin_memory=pin_memory, + device="cpu", ) value_blocks = torch.empty( size=(self.num_cpu_blocks, *value_block_shape), dtype=self.dtype, pin_memory=pin_memory, + device="cpu", ) cpu_cache.append((key_blocks, value_blocks)) return cpu_cache diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2df9fd52..fce0009e 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -5,7 +5,7 @@ import numpy as np import torch import torch.nn as nn -from vllm.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig +from vllm.config import DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig from vllm.logger import init_logger from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor.parallel_utils.communication_op import ( @@ -35,6 +35,7 @@ class ModelRunner: model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, + device_config: DeviceConfig, lora_config: Optional[LoRAConfig], kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, @@ -49,7 +50,10 @@ class ModelRunner: # FIXME(woosuk): This is a hack to make the tests work. Refactor this. self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None) - self.device = torch.device(torch.cuda.current_device()) + self.device_config = (device_config + if device_config is not None else DeviceConfig()) + self.device = self.device_config.device + self.model = None self.block_size = None # Set after initial profiling. self.lora_manager = None @@ -72,7 +76,8 @@ class ModelRunner: self.kv_cache_dtype = kv_cache_dtype def load_model(self) -> None: - self.model = get_model(self.model_config, self.lora_config) + self.model = get_model(self.model_config, self.device_config, + self.lora_config) vocab_size = self.model.config.vocab_size @@ -182,22 +187,25 @@ class ModelRunner: input_tokens = _make_tensor_with_pad(input_tokens, max_prompt_len, pad=0, - dtype=torch.long) + dtype=torch.long, + device=self.device) input_positions = _make_tensor_with_pad(input_positions, max_prompt_len, pad=0, - dtype=torch.long) + dtype=torch.long, + device=self.device) slot_mapping = _make_tensor_with_pad(slot_mapping, max_prompt_len, pad=_PAD_SLOT_ID, - dtype=torch.long) + dtype=torch.long, + device=self.device) lora_index_mapping = [ _pad_to_max(mapping, max_prompt_len, pad=0) for mapping in lora_index_mapping ] context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, - device='cuda') + device=self.device) # Prepare prefix block tables max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) block_tables = _make_tensor_with_pad( @@ -205,15 +213,16 @@ class ModelRunner: max_len=max_prompt_block_table_len, pad=0, dtype=torch.int, + device=self.device, ) start_loc_tensor = torch.arange(0, len(prompt_lens) * max_prompt_len, max_prompt_len, dtype=torch.long, - device='cuda') + device=self.device) prompt_lens_tensor = torch.tensor(prompt_lens, dtype=torch.long, - device='cuda') + device=self.device) input_metadata = InputMetadata( is_prompt=True, @@ -305,20 +314,20 @@ class ModelRunner: max_len=1, pad=0, dtype=torch.long, - device="cuda") + device=self.device) input_positions = _make_tensor_with_pad(input_positions, max_len=1, pad=0, dtype=torch.long, - device="cuda") + device=self.device) slot_mapping = _make_tensor_with_pad(slot_mapping, max_len=1, pad=_PAD_SLOT_ID, dtype=torch.long, - device="cuda") + device=self.device) context_lens = torch.tensor(context_lens, dtype=torch.int, - device="cuda") + device=self.device) if use_captured_graph: # The shape of graph_block_tables is @@ -327,7 +336,7 @@ class ModelRunner: for i, block_table in enumerate(block_tables): if block_table: input_block_tables[i, :len(block_table)] = block_table - block_tables = torch.tensor(input_block_tables, device="cuda") + block_tables = torch.tensor(input_block_tables, device=self.device) else: max_block_table_len = max( len(block_table) for block_table in block_tables) @@ -336,7 +345,7 @@ class ModelRunner: max_len=max_block_table_len, pad=0, dtype=torch.int, - device="cuda", + device=self.device, ) lora_index_mapping = [ @@ -355,7 +364,8 @@ class ModelRunner: use_cuda_graph=use_captured_graph, kv_cache_dtype=self.kv_cache_dtype, ) - return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests + return (input_tokens, input_positions, input_metadata, + lora_index_mapping, lora_prompt_mapping, lora_requests) def _prepare_sample( self, @@ -410,9 +420,13 @@ class ModelRunner: selected_token_indices = _async_h2d(selected_token_indices, dtype=torch.long, + target_device=self.device, pin_memory=not self.in_wsl) categorized_sample_indices = { - t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl) + t: _async_h2d(seq_ids, + dtype=torch.int, + target_device=self.device, + pin_memory=not self.in_wsl) for t, seq_ids in categorized_sample_indices.items() } @@ -511,7 +525,8 @@ class ModelRunner: perform_sampling=False, ) - return input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping + return (input_tokens, input_positions, input_metadata, + sampling_metadata, lora_requests, lora_mapping) @torch.inference_mode() def execute_model( @@ -519,8 +534,9 @@ class ModelRunner: seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> Optional[SamplerOutput]: - input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping = ( - self.prepare_input_tensors(seq_group_metadata_list)) + (input_tokens, input_positions, input_metadata, sampling_metadata, + lora_requests, + lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list) if self.lora_config: self.set_active_loras(lora_requests, lora_mapping) @@ -789,14 +805,10 @@ def _make_tensor_with_pad( max_len: int, pad: int, dtype: torch.dtype, - device: Union[str, torch.device] = "cuda", - pin_memory: bool = False, + device: Optional[Union[str, torch.device]], ) -> torch.Tensor: padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] - return torch.tensor(padded_x, - dtype=dtype, - device=device, - pin_memory=pin_memory and str(device) == "cpu") + return torch.tensor(padded_x, dtype=dtype, device=device) def _get_graph_batch_size(batch_size: int) -> int: @@ -808,6 +820,11 @@ def _get_graph_batch_size(batch_size: int) -> int: return (batch_size + 7) // 8 * 8 -def _async_h2d(data: list, dtype, pin_memory): - t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory) - return t.to(device="cuda", non_blocking=True) +def _async_h2d( + data: list, + dtype: torch.dtype, + target_device: Union[str, torch.device], + pin_memory: bool, +) -> torch.Tensor: + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") + return t.to(device=target_device, non_blocking=True) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index a74adfa5..c97e82a5 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -6,8 +6,8 @@ from typing import Dict, List, Tuple, Set, Optional import torch import torch.distributed -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig, LoRAConfig) +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, LoRAConfig) from vllm.model_executor import set_random_seed from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) @@ -33,6 +33,7 @@ class Worker: model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, + device_config: DeviceConfig, local_rank: int, rank: int, distributed_init_method: str, @@ -43,6 +44,7 @@ class Worker: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.device_config = device_config self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method @@ -54,6 +56,7 @@ class Worker: self.model_runner = ModelRunner(model_config, parallel_config, scheduler_config, + device_config, lora_config=self.lora_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker) @@ -65,21 +68,24 @@ class Worker: self.gpu_cache = None def init_model(self) -> None: - # torch.distributed.all_reduce does not free the input tensor until - # the synchronization point. This causes the memory usage to grow - # as the number of all_reduce calls increases. This env var disables - # this behavior. - # Related issue: - # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 - os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - # This env var set by Ray causes exceptions with graph building. - os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) - self.device = torch.device(f"cuda:{self.local_rank}") - torch.cuda.set_device(self.device) - - _check_if_gpu_supports_dtype(self.model_config.dtype) + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + _check_if_gpu_supports_dtype(self.model_config.dtype) + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. init_distributed_environment(self.parallel_config, self.rank, self.distributed_init_method)