diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d8030ab2..b48ef31b 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -46,6 +46,7 @@ steps: - TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py - pytest -v -s spec_decode/e2e/test_integration_dist.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - label: Distributed Tests (Multiple Groups) #mirror_hardwares: [amd] @@ -138,14 +139,7 @@ steps: num_gpus: 4 # This test runs llama 13B, so it is required to run on 4 GPUs. commands: - # Temporarily run this way because we cannot clean up GPU mem usage - # for multi GPU tests. - # TODO(sang): Fix it. - - pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced - - pytest -v -s lora/test_long_context.py::test_batched_rope_kernel - - pytest -v -s lora/test_long_context.py::test_self_consistency - - pytest -v -s lora/test_long_context.py::test_quality - - pytest -v -s lora/test_long_context.py::test_max_len + - pytest -v -s -x lora/test_long_context.py - label: Tensorizer Test #mirror_hardwares: [amd] diff --git a/tests/conftest.py b/tests/conftest.py index a481daa3..1a7037eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,8 @@ import contextlib import gc import os +import subprocess +import sys from typing import Any, Dict, List, Optional, Tuple, TypeVar import pytest @@ -522,3 +524,22 @@ def caplog_vllm(temporary_enable_log_propagate, caplog): # To capture vllm log, we should enable propagate=True temporarily # because caplog depends on logs propagated to the root logger. yield caplog + + +@pytest.fixture(scope="session") +def num_gpus_available(): + """Get number of GPUs without initializing the CUDA context + in current process.""" + + try: + out = subprocess.run([ + sys.executable, "-c", + "import torch; print(torch.cuda.device_count())" + ], + capture_output=True, + check=True, + text=True) + except subprocess.CalledProcessError as e: + logger.warning("Failed to get number of GPUs.", exc_info=e) + return 0 + return int(out.stdout.strip()) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index e5cf9cd4..40033306 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -42,10 +42,24 @@ def cleanup(): ray.shutdown() +@pytest.fixture() +def should_do_global_cleanup_after_test(request) -> bool: + """Allow subdirectories to skip global cleanup by overriding this fixture. + This can provide a ~10x speedup for non-GPU unit tests since they don't need + to initialize torch. + """ + + if request.node.get_closest_marker("skip_global_cleanup"): + return False + + return True + + @pytest.fixture(autouse=True) -def cleanup_fixture(): +def cleanup_fixture(should_do_global_cleanup_after_test: bool): yield - cleanup() + if should_do_global_cleanup_after_test: + cleanup() @pytest.fixture diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 9a2c8b04..fc4445c6 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -2,6 +2,7 @@ import random from copy import deepcopy from dataclasses import dataclass from typing import Dict, List, Optional, Tuple +from unittest.mock import patch import pytest import torch @@ -32,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask) from vllm.model_executor.utils import set_random_seed from .utils import DummyLoRAManager @@ -427,7 +428,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, logits_processor = LogitsProcessor( vocab_size + lora_config.lora_extra_vocab_size, vocab_size) lora_logits_processor = LogitsProcessorWithLoRA( - logits_processor, 1024, linear.weight.dtype, linear.weight.device) + logits_processor, 1024, linear.weight.dtype, linear.weight.device, + None) lora_logits_processor.create_lora_weights(max_loras, lora_config) return linear, logits_processor, lora_logits_processor @@ -867,3 +869,216 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device, torch.allclose(ref_q, actual_q) torch.allclose(ref_k, actual_k) + + +@pytest.mark.parametrize("tp_size", [1, 2, 4, 8]) +@pytest.mark.parametrize("seed", list(range(256))) +def test_vocab_parallel_embedding_indices(tp_size, seed): + random.seed(seed) + vocab_size = random.randint(4000, 64000) + added_vocab_size = random.randint(0, 1024) + org_vocab_size = vocab_size - added_vocab_size + last_org_vocab_end_index = 0 + last_added_vocab_end_index = org_vocab_size + computed_vocab_size = 0 + computed_org_vocab_size = 0 + computed_added_vocab_size = 0 + vocab_size_padded = -1 + + all_org_tokens = [] + all_added_tokens = [] + token_ids = [] + + for tp_rank in range(tp_size): + with patch( + "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank", + return_value=tp_rank + ), patch( + "vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size", + return_value=tp_size): + vocab_embedding = VocabParallelEmbedding( + vocab_size, 1, org_num_embeddings=org_vocab_size) + vocab_size_padded = vocab_embedding.num_embeddings_padded + shard_indices = vocab_embedding.shard_indices + # Assert that the ranges are contiguous + assert shard_indices.org_vocab_start_index == last_org_vocab_end_index + assert (shard_indices.added_vocab_start_index == + last_added_vocab_end_index) + + # Ensure that we are not exceeding the vocab size + computed_vocab_size += shard_indices.num_elements_padded + computed_org_vocab_size += shard_indices.num_org_elements + computed_added_vocab_size += shard_indices.num_added_elements + + # Ensure that the ranges are not overlapping + all_org_tokens.extend( + range(shard_indices.org_vocab_start_index, + shard_indices.org_vocab_end_index)) + all_added_tokens.extend( + range(shard_indices.added_vocab_start_index, + shard_indices.added_vocab_end_index)) + + token_ids.extend( + range(shard_indices.org_vocab_start_index, + shard_indices.org_vocab_end_index)) + token_ids.extend([-1] * (shard_indices.num_org_elements_padded - + shard_indices.num_org_elements)) + token_ids.extend( + range(shard_indices.added_vocab_start_index, + shard_indices.added_vocab_end_index)) + token_ids.extend([-1] * (shard_indices.num_added_elements_padded - + shard_indices.num_added_elements)) + + last_org_vocab_end_index = shard_indices.org_vocab_end_index + last_added_vocab_end_index = shard_indices.added_vocab_end_index + + assert computed_vocab_size == vocab_size_padded + assert computed_org_vocab_size == org_vocab_size + assert computed_added_vocab_size == added_vocab_size + + # Ensure that the ranges are not overlapping + assert len(all_org_tokens) == len(set(all_org_tokens)) + assert len(all_added_tokens) == len(set(all_added_tokens)) + assert not set(all_org_tokens).intersection(set(all_added_tokens)) + + token_ids_tensor = torch.tensor(token_ids, dtype=torch.long) + reindex_mapping = vocab_embedding.get_sharded_to_full_mapping() + assert reindex_mapping is not None or tp_size == 1 + if reindex_mapping is not None: + reindexed_token_ids = token_ids_tensor[reindex_mapping] + expected = torch.tensor(list(range(0, vocab_size))) + assert reindexed_token_ids[:vocab_size].equal(expected) + assert torch.all(reindexed_token_ids[vocab_size:] == -1) + + +def test_get_masked_input_and_mask(): + x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + + # base tp 1 case, no padding + modified_x, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=0) + assert torch.equal(x, modified_x) + + # tp 2 case, no padding + modified_x_rank_0, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=0) + modified_x_rank_1, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=4, + org_vocab_end_index=8, + added_vocab_start_index=10, + added_vocab_end_index=12, + num_org_vocab_padding=0) + assert torch.equal(modified_x_rank_0, + torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0])) + assert torch.equal(modified_x_rank_1, + torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5])) + + # tp 4 case, no padding + modified_x_rank_0, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=0) + modified_x_rank_1, _ = get_masked_input_and_mask(x, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=0) + modified_x_rank_2, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=4, + org_vocab_end_index=6, + added_vocab_start_index=10, + added_vocab_end_index=11, + num_org_vocab_padding=0) + modified_x_rank_3, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=6, + org_vocab_end_index=8, + added_vocab_start_index=11, + added_vocab_end_index=12, + num_org_vocab_padding=0) + assert torch.equal(modified_x_rank_0, + torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0])) + assert torch.equal(modified_x_rank_1, + torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0])) + assert torch.equal(modified_x_rank_2, + torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0])) + assert torch.equal(modified_x_rank_3, + torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2])) + + # base tp 1 case, with padding + modified_x, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=8, + added_vocab_start_index=8, + added_vocab_end_index=12, + num_org_vocab_padding=2) + assert torch.equal(modified_x, + torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13])) + + # tp 2 case, with padding + modified_x_rank_0, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=4, + added_vocab_start_index=8, + added_vocab_end_index=10, + num_org_vocab_padding=2) + modified_x_rank_1, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=4, + org_vocab_end_index=8, + added_vocab_start_index=10, + added_vocab_end_index=12, + num_org_vocab_padding=2) + assert torch.equal(modified_x_rank_0, + torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0])) + assert torch.equal(modified_x_rank_1, + torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7])) + + # tp 4 case, with padding + modified_x_rank_0, _ = get_masked_input_and_mask(x, + org_vocab_start_index=0, + org_vocab_end_index=2, + added_vocab_start_index=8, + added_vocab_end_index=9, + num_org_vocab_padding=2) + modified_x_rank_1, _ = get_masked_input_and_mask(x, + org_vocab_start_index=2, + org_vocab_end_index=4, + added_vocab_start_index=9, + added_vocab_end_index=10, + num_org_vocab_padding=2) + modified_x_rank_2, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=4, + org_vocab_end_index=6, + added_vocab_start_index=10, + added_vocab_end_index=11, + num_org_vocab_padding=2) + modified_x_rank_3, _ = get_masked_input_and_mask( + x, + org_vocab_start_index=6, + org_vocab_end_index=8, + added_vocab_start_index=11, + added_vocab_end_index=12, + num_org_vocab_padding=2) + assert torch.equal(modified_x_rank_0, + torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0])) + assert torch.equal(modified_x_rank_1, + torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0])) + assert torch.equal(modified_x_rank_2, + torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0])) + assert torch.equal(modified_x_rank_3, + torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4])) diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py index f5a571e8..7143a99b 100644 --- a/tests/lora/test_llama.py +++ b/tests/lora/test_llama.py @@ -36,11 +36,10 @@ def do_sample(llm, lora_path: str, lora_id: int): return generated_texts -@pytest.mark.parametrize("tp_size", [1]) -def test_llama_lora(sql_lora_files, tp_size): - # Cannot use as it will initialize torch.cuda too early... - # if torch.cuda.device_count() < tp_size: - # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") +@pytest.mark.parametrize("tp_size", [1, 2, 4]) +def test_llama_lora(sql_lora_files, tp_size, num_gpus_available): + if num_gpus_available < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") llm = vllm.LLM(MODEL_PATH, enable_lora=True, @@ -80,11 +79,9 @@ def test_llama_lora(sql_lora_files, tp_size): print("removing lora") -@pytest.mark.skip("Requires multiple GPUs") -def test_llama_tensor_parallel_equality(sql_lora_files): - # Cannot use as it will initialize torch.cuda too early... - # if torch.cuda.device_count() < 4: - # pytest.skip(f"Not enough GPUs for tensor parallelism {4}") +def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available): + if num_gpus_available < 4: + pytest.skip("Not enough GPUs for tensor parallelism 4") llm_tp1 = vllm.LLM(MODEL_PATH, enable_lora=True, diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index 4361e545..b58145ed 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -102,22 +102,21 @@ def batched_generate( return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] -@pytest.fixture +@pytest.fixture(scope="module") def lora_llm(long_context_infos): scaling_factors = [ context_len_to_scaling_factor[info["context_length"]] for info in long_context_infos.values() ] - llm = vllm.LLM( - "meta-llama/Llama-2-13b-chat-hf", - enable_lora=True, - max_num_seqs=16, - max_loras=2, - long_lora_scaling_factors=tuple(scaling_factors), - max_num_batched_tokens=4096 * 8, - tensor_parallel_size=4, - ) + llm = vllm.LLM("meta-llama/Llama-2-13b-chat-hf", + enable_lora=True, + max_num_seqs=16, + max_loras=2, + long_lora_scaling_factors=tuple(scaling_factors), + max_num_batched_tokens=4096 * 8, + tensor_parallel_size=4, + distributed_executor_backend="mp") yield llm del llm @@ -154,6 +153,7 @@ def test_rotary_emb_replaced(dist_init): assert rotary_emb_count == 32 +@pytest.mark.skip_global_cleanup def test_batched_rope_kernel(lora_llm, long_context_infos): """We test the batched kernel by comparing the results of batched an non-batched generation. @@ -188,6 +188,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos): f"same:\n{batched}\n{non_batched}") +@pytest.mark.skip_global_cleanup def test_self_consistency(lora_llm, long_context_infos): """We test consistency of the batched kernel by permuting batched inputs and comparing the results to the non-permuted batched results. @@ -227,6 +228,7 @@ def test_self_consistency(lora_llm, long_context_infos): f"\n{permutated_batched_results[permutation[i]]}") +@pytest.mark.skip_global_cleanup def test_quality(lora_llm, long_context_infos): """We test the quality of the answers given by the LoRA model by comparing the generated text to the merged model's outputs. @@ -257,6 +259,7 @@ def test_quality(lora_llm, long_context_infos): assert np.mean(scores) > 0.5 +@pytest.mark.skip_global_cleanup def test_max_len(lora_llm, long_context_infos): """Test that we raise an ValueError when the input of a given LoRA model exceeds the maximum length.""" diff --git a/tests/test_sharded_state_loader.py b/tests/test_sharded_state_loader.py index 8540e98d..022fb36b 100644 --- a/tests/test_sharded_state_loader.py +++ b/tests/test_sharded_state_loader.py @@ -1,3 +1,4 @@ +import multiprocessing as mp import os import shutil from tempfile import TemporaryDirectory @@ -18,9 +19,7 @@ prompts = [ # Create a sampling params object. sampling_params = SamplingParams( - temperature=0.8, - top_p=0.95, - seed=0, + temperature=0, max_tokens=256, ignore_eos=True, ) @@ -43,48 +42,85 @@ def test_filter_subtensors(): assert tensor.equal(state_dict[key]) -@pytest.mark.parametrize("enable_lora", [False, True]) -def test_sharded_state_loader(enable_lora): - weights_patterns = ("*.bin", "*.pt", "*.safetensors") - - with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir: +@pytest.fixture(scope="module") +def llama_2_7b_files(): + with TemporaryDirectory() as cache_dir: input_dir = snapshot_download("meta-llama/Llama-2-7b-hf", - cache_dir=cache_dir) + cache_dir=cache_dir, + ignore_patterns="*.bin*") + yield input_dir - llm = LLM( - model=input_dir, - worker_use_ray=True, - gpu_memory_utilization=0.3, - ) - # Dump worker states to output directory - model_executor = llm.llm_engine.model_executor - model_executor.save_sharded_state(path=output_dir) - # Copy metadata files to output directory - for file in os.listdir(input_dir): - if not any(file.endswith(ext) for ext in weights_patterns): - shutil.copy(f"{input_dir}/{file}", output_dir) - del llm.llm_engine.model_executor +def _run_writer(input_dir, output_dir, weights_patterns, **kwargs): + llm_sharded_writer = LLM(model=input_dir, **kwargs) - llm_before = LLM( - model=input_dir, - worker_use_ray=True, - enable_lora=enable_lora, - gpu_memory_utilization=0.3, - ) - gen_before = llm_before.generate(prompts, sampling_params) - out_before = [gen.outputs[0].__dict__ for gen in gen_before] - del llm_before.llm_engine.model_executor + # Dump worker states to output directory + llm_sharded_writer.llm_engine.model_executor.save_sharded_state( + path=output_dir) + # Copy metadata files to output directory + for file in os.listdir(input_dir): + if not any(file.endswith(ext) for ext in weights_patterns): + shutil.copy(f"{input_dir}/{file}", output_dir) - llm_after = LLM( - model=output_dir, - worker_use_ray=True, - enable_lora=enable_lora, - gpu_memory_utilization=0.3, - load_format="sharded_state", - ) - gen_after = llm_after.generate(prompts, sampling_params) - out_after = [gen.outputs[0].__dict__ for gen in gen_after] - del llm_after.llm_engine.model_executor + +def _run_generate(input_dir, queue: mp.Queue, **kwargs): + llm = LLM(model=input_dir, **kwargs) + gen = llm.generate(prompts, sampling_params) + queue.put([g.outputs[0].__dict__ for g in gen]) + queue.close() + queue.join_thread() + + +@pytest.mark.parametrize("enable_lora", [False, True]) +@pytest.mark.parametrize("tp_size", [1, 2]) +def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available, + llama_2_7b_files): + if num_gpus_available < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + weights_patterns = ("*.safetensors", ) + gpu_memory_utilization = 0.8 + input_dir = llama_2_7b_files + ctx = mp.get_context("spawn") + + # Run in separate processes for memory & CUDA isolation + with TemporaryDirectory() as output_dir: + p = ctx.Process(target=_run_writer, + args=(input_dir, output_dir, weights_patterns), + kwargs=dict( + tensor_parallel_size=tp_size, + distributed_executor_backend="mp", + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=True, + )) + p.start() + p.join() + + queue = ctx.Queue() + + p = ctx.Process(target=_run_generate, + args=(input_dir, queue), + kwargs=dict( + distributed_executor_backend="mp", + enable_lora=enable_lora, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tp_size, + )) + p.start() + p.join() + out_before = queue.get() + + p = ctx.Process(target=_run_generate, + args=(output_dir, queue), + kwargs=dict( + distributed_executor_backend="mp", + enable_lora=enable_lora, + gpu_memory_utilization=gpu_memory_utilization, + tensor_parallel_size=tp_size, + load_format="sharded_state", + )) + p.start() + p.join() + out_after = queue.get() assert out_before == out_after diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 24b74476..e3ab1708 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -215,19 +215,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: - lora_vocab_start_idx = self.base_layer.org_vocab_size - weights_idx = None - if self.base_layer.vocab_end_index > lora_vocab_start_idx: + if self.base_layer.num_added_embeddings_per_partition > 0: # We can start adding lora weights - weights_idx = max( - lora_vocab_start_idx - self.base_layer.vocab_start_index, 0) - self.embeddings_slice = (self.base_layer.vocab_start_index - - self.base_layer.org_vocab_size + - weights_idx, - self.base_layer.vocab_end_index - - self.base_layer.org_vocab_size) - self.embeddings_weights = self.base_layer.weight.data[weights_idx:] - self.embeddings_weights.fill_(0) + self.embeddings_weights = self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:self. + base_layer.num_org_embeddings_per_partition + + self.base_layer.num_added_embeddings_per_partition] + self.embeddings_slice = ( + self.base_layer.shard_indices.added_vocab_start_index - + self.base_layer.org_vocab_size, + self.base_layer.shard_indices.added_vocab_end_index - + self.base_layer.org_vocab_size) + self.base_layer.weight.data[ + self.base_layer.num_org_embeddings_per_partition:].fill_(0) else: self.embeddings_slice = None self.embeddings_weights = None @@ -1025,19 +1025,31 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA): class LogitsProcessorWithLoRA(BaseLayerWithLoRA): + """ + LoRA wrapper for LogitsProcessor, with extra logic to handle the + application of the LoRA adapter and added LoRA vocabulary. - def __init__( - self, - base_layer: LogitsProcessor, - hidden_size: int, - dtype: torch.dtype, - device: torch.device, - ) -> None: + Args: + base_layer: LogitsProcessor layer + hidden_size: hidden size of the model + dtype: data type of the model + device: device of the model + sharded_to_full_mapping: index mapping from sharded vocab to full vocab + received from base_layer.get_sharded_to_full_mapping(). If None, + no reindexing will be done. + """ + + def __init__(self, base_layer: LogitsProcessor, hidden_size: int, + dtype: torch.dtype, device: torch.device, + sharded_to_full_mapping: Optional[List[int]]) -> None: super().__init__() self.base_layer = base_layer self.hidden_size = hidden_size self.dtype = dtype self.device = device + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + self.sharded_to_full_mapping = sharded_to_full_mapping @property def logits_as_input(self): @@ -1098,6 +1110,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): dtype=self.dtype, device=self.device, ) + if self.sharded_to_full_mapping is not None: + self.sharded_to_full_mapping_gpu = torch.tensor( + self.sharded_to_full_mapping, + device=self.device, + dtype=torch.long) + else: + self.sharded_to_full_mapping_gpu = None # Lazily initialized. self.indices: torch.Tensor self.indices_len: List[int] @@ -1154,6 +1173,25 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA): if logits is None: return None + if self.sharded_to_full_mapping_gpu is not None: + # Reindex full logits tensor to ensure 1:1 mapping between + # index and token_id + # Example for: + # org_vocab_size = 4 + # added_vocab_size = 2 + # pad_to_size = 8 + # tp_size = 2 + + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 4, -1, 2, 3, 5, -1] + + # Therefore, the mapping is expected to be: + # [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex, + # we get: + # indices: [0, 1, 2, 3, 4, 5, 6, 7] + # token_id: [0, 1, 2, 3, 4, 5, -1, -1] + logits = logits[:, self.sharded_to_full_mapping_gpu] + lora_logits = torch.empty( self.embeddings_tensors.shape[0] + 1, self.embeddings_tensors.shape[1], diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index fcc7f247..b0198a50 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -67,7 +67,8 @@ def from_layer_logits_processor( model_config: Optional[PretrainedConfig] = None, ) -> LogitsProcessorWithLoRA: ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, - lm_head.weight.dtype, lm_head.weight.device) + lm_head.weight.dtype, lm_head.weight.device, + lm_head.get_sharded_to_full_mapping()) ret.create_lora_weights(max_loras, lora_config, model_config) return ret diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 4585b167..60eb5b40 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -1,4 +1,5 @@ -from typing import Optional, Sequence +from dataclasses import dataclass +from typing import List, Optional, Sequence, Tuple import torch import torch.nn.functional as F @@ -18,18 +19,107 @@ def pad_vocab_size(vocab_size: int, return ((vocab_size + pad_to - 1) // pad_to) * pad_to -def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, - rank: int) -> Sequence[int]: +def vocab_range_from_per_partition_vocab_size( + per_partition_vocab_size: int, + rank: int, + offset: int = 0) -> Sequence[int]: index_f = rank * per_partition_vocab_size index_l = index_f + per_partition_vocab_size - return index_f, index_l + return index_f + offset, index_l + offset -def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, - world_size: int) -> Sequence[int]: +def vocab_range_from_global_vocab_size(global_vocab_size: int, + rank: int, + world_size: int, + offset: int = 0) -> Sequence[int]: per_partition_vocab_size = divide(global_vocab_size, world_size) return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, - rank) + rank, + offset=offset) + + +@dataclass +class VocabParallelEmbeddingShardIndices: + """Indices for a shard of a vocab parallel embedding.""" + padded_org_vocab_start_index: int + padded_org_vocab_end_index: int + padded_added_vocab_start_index: int + padded_added_vocab_end_index: int + + org_vocab_start_index: int + org_vocab_end_index: int + added_vocab_start_index: int + added_vocab_end_index: int + + @property + def num_org_elements(self) -> int: + return self.org_vocab_end_index - self.org_vocab_start_index + + @property + def num_added_elements(self) -> int: + return self.added_vocab_end_index - self.added_vocab_start_index + + @property + def num_org_elements_padded(self) -> int: + return (self.padded_org_vocab_end_index - + self.padded_org_vocab_start_index) + + @property + def num_added_elements_padded(self) -> int: + return (self.padded_added_vocab_end_index - + self.padded_added_vocab_start_index) + + @property + def num_org_vocab_padding(self) -> int: + return self.num_org_elements_padded - self.num_org_elements + + @property + def num_added_vocab_padding(self) -> int: + return self.num_added_elements_padded - self.num_added_elements + + @property + def num_elements_padded(self) -> int: + return self.num_org_elements_padded + self.num_added_elements_padded + + def __post_init__(self): + # sanity checks + assert (self.padded_org_vocab_start_index <= + self.padded_org_vocab_end_index) + assert (self.padded_added_vocab_start_index <= + self.padded_added_vocab_end_index) + + assert self.org_vocab_start_index <= self.org_vocab_end_index + assert self.added_vocab_start_index <= self.added_vocab_end_index + + assert self.org_vocab_start_index <= self.padded_org_vocab_start_index + assert (self.added_vocab_start_index <= + self.padded_added_vocab_start_index) + assert self.org_vocab_end_index <= self.padded_org_vocab_end_index + assert self.added_vocab_end_index <= self.padded_added_vocab_end_index + + assert self.num_org_elements <= self.num_org_elements_padded + assert self.num_added_elements <= self.num_added_elements_padded + + +@torch.jit.script +def get_masked_input_and_mask( + input_: torch.Tensor, org_vocab_start_index: int, + org_vocab_end_index: int, num_org_vocab_padding: int, + added_vocab_start_index: int, + added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]: + # torch.jit.script will fuse all of the pointwise ops below + # into a single kernel, making it very fast + org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < + org_vocab_end_index) + added_vocab_mask = (input_ >= added_vocab_start_index) & ( + input_ < added_vocab_end_index) + added_offset = added_vocab_start_index - ( + org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding + valid_offset = (org_vocab_start_index * + org_vocab_mask) + (added_offset * added_vocab_mask) + vocab_mask = org_vocab_mask | added_vocab_mask + input_ = vocab_mask * (input_ - valid_offset) + return input_, ~vocab_mask class VocabParallelEmbedding(torch.nn.Module): @@ -38,13 +128,36 @@ class VocabParallelEmbedding(torch.nn.Module): Adapted from torch.nn.Embedding, note that we pad the vocabulary size to make sure it is divisible by the number of model parallel GPUs. + In order to support various loading methods, we ensure that LoRA-added + embeddings are always at the end of TP-sharded tensors. In other words, + we shard base embeddings and LoRA embeddings separately (both padded), + and place them in the same tensor. + In this example, we will have the original vocab size = 1010, + added vocab size = 16 and padding to 64. Therefore, the total + vocab size with padding will be 1088 (because we first pad 1010 to + 1024, add 16, and then pad to 1088). + Therefore, the tensor format looks like the following: + TP1, rank 0 (no sharding): + |< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >| + corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 | + index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 | + + TP2, rank 0: + |< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >| + corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 | + TP2, rank 1: + |< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >| + corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 | + index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 | + Args: num_embeddings: vocabulary size. embedding_dim: size of hidden state. params_dtype: type of the parameters. org_num_embeddings: original vocabulary size (without LoRA). padding_size: padding size for the vocabulary. - """ + """ # noqa: E501 def __init__(self, num_embeddings: int, @@ -55,21 +168,39 @@ class VocabParallelEmbedding(torch.nn.Module): super().__init__() # Keep the input dimensions. + tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() self.num_embeddings = num_embeddings + self.padding_size = padding_size self.org_vocab_size = org_num_embeddings or num_embeddings - self.num_embeddings_padded = pad_vocab_size(num_embeddings, - padding_size) + num_added_embeddings = num_embeddings - self.org_vocab_size + self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size, + self.padding_size) + self.num_embeddings_padded = pad_vocab_size( + self.org_vocab_size_padded + num_added_embeddings, + self.padding_size) + assert self.org_vocab_size_padded <= self.num_embeddings_padded + + self.shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, tp_rank, + self.tp_size) self.embedding_dim = embedding_dim if params_dtype is None: params_dtype = torch.get_default_dtype() - self.tp_size = get_tensor_model_parallel_world_size() # Divide the weight matrix along the vocaburaly dimension. - self.vocab_start_index, self.vocab_end_index = ( - vocab_range_from_global_vocab_size( - self.num_embeddings_padded, get_tensor_model_parallel_rank(), - self.tp_size)) - self.num_embeddings_per_partition = (self.vocab_end_index - - self.vocab_start_index) + self.num_added_embeddings = self.num_embeddings - self.org_vocab_size + self.num_embeddings_per_partition = divide(self.num_embeddings_padded, + self.tp_size) + assert (self.shard_indices.num_elements_padded == + self.num_embeddings_per_partition) + self.num_org_embeddings_per_partition = ( + self.shard_indices.org_vocab_end_index - + self.shard_indices.org_vocab_start_index) + self.num_added_embeddings_per_partition = ( + self.shard_indices.added_vocab_end_index - + self.shard_indices.added_vocab_start_index) self.weight = Parameter( torch.empty(self.num_embeddings_per_partition, self.embedding_dim, @@ -79,28 +210,107 @@ class VocabParallelEmbedding(torch.nn.Module): "weight_loader": self.weight_loader }) + @classmethod + def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int, + vocab_size: int, org_vocab_size: int, tp_rank: int, + tp_size: int) -> VocabParallelEmbeddingShardIndices: + """Get start and end indices for vocab parallel embedding, following the + layout outlined in the class docstring, based on the given tp_rank and + tp_size.""" + num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded + padded_org_vocab_start_index, padded_org_vocab_end_index = ( + vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank, + tp_size)) + padded_added_vocab_start_index, padded_added_vocab_end_index = ( + vocab_range_from_global_vocab_size(num_added_embeddings_padded, + tp_rank, + tp_size, + offset=org_vocab_size)) + # remove padding + org_vocab_start_index = min(padded_org_vocab_start_index, + org_vocab_size) + org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size) + added_vocab_start_index = min(padded_added_vocab_start_index, + vocab_size) + added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size) + return VocabParallelEmbeddingShardIndices( + padded_org_vocab_start_index, padded_org_vocab_end_index, + padded_added_vocab_start_index, padded_added_vocab_end_index, + org_vocab_start_index, org_vocab_end_index, + added_vocab_start_index, added_vocab_end_index) + + def get_sharded_to_full_mapping(self) -> Optional[List[int]]: + """Get a mapping that can be used to reindex the gathered + logits for sampling. + + During sampling, we gather logits from all ranks. The relationship + of index->token_id will follow the same format as outlined in the class + docstring. However, after the gather, we want to reindex the final + logits tensor to map index->token_id one-to-one (the index is always + equal the token_id it corresponds to). The indices returned by this + method allow us to do that. + """ + if self.tp_size < 2: + return None + + base_embeddings: List[int] = [] + added_embeddings: List[int] = [] + padding: List[int] = [] + for tp_rank in range(self.tp_size): + shard_indices = self._get_indices(self.num_embeddings_padded, + self.org_vocab_size_padded, + self.num_embeddings, + self.org_vocab_size, tp_rank, + self.tp_size) + range_start = self.num_embeddings_per_partition * tp_rank + range_end = self.num_embeddings_per_partition * (tp_rank + 1) + base_embeddings.extend( + range(range_start, + range_start + shard_indices.num_org_elements)) + padding.extend( + range(range_start + shard_indices.num_org_elements, + range_start + shard_indices.num_org_elements_padded)) + added_embeddings.extend( + range( + range_start + shard_indices.num_org_elements_padded, + range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements)) + padding.extend( + range( + range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements, + range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded)) + assert (range_start + shard_indices.num_org_elements_padded + + shard_indices.num_added_elements_padded == range_end) + ret = base_embeddings + added_embeddings + padding + assert len(ret) == self.num_embeddings_padded + return ret + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): parallel_dim = param.parallel_dim assert loaded_weight.shape[parallel_dim] == self.org_vocab_size - loaded_weight = loaded_weight[self.vocab_start_index:self. - vocab_end_index] + loaded_weight = loaded_weight[self.shard_indices.org_vocab_start_index: + self.shard_indices.org_vocab_end_index] param[:loaded_weight.shape[0]].data.copy_(loaded_weight) + param[loaded_weight.shape[0]:].data.fill_(0) def forward(self, input_): if self.tp_size > 1: # Build the mask. - input_mask = ((input_ < self.vocab_start_index) | - (input_ >= self.vocab_end_index)) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 + masked_input, input_mask = get_masked_input_and_mask( + input_, self.shard_indices.org_vocab_start_index, + self.shard_indices.org_vocab_end_index, + self.shard_indices.num_org_vocab_padding, + self.shard_indices.added_vocab_start_index, + self.shard_indices.added_vocab_end_index) else: masked_input = input_ # Get the embeddings. output_parallel = F.embedding(masked_input, self.weight) # Mask the output embedding. if self.tp_size > 1: - output_parallel[input_mask, :] = 0.0 + output_parallel.masked_fill_(input_mask.unsqueeze(1), 0) # Reduce across all the model parallel GPUs. output = tensor_model_parallel_all_reduce(output_parallel) return output diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 67c03ad6..c59288b4 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -35,6 +35,7 @@ _BATCH_SIZE_ALIGNMENT = 8 _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) ] +_NUM_WARMUP_ITERS = 2 class ModelInput(NamedTuple): @@ -975,16 +976,18 @@ class CUDAGraphRunner: **kwargs, ) -> None: assert self._graph is None - # Run the model once without capturing the graph. + # Run the model a few times without capturing the graph. # This is to make sure that the captured graph does not include the # kernel launches for initial benchmarking (e.g., Triton autotune). - self.model( - input_ids, - positions, - kv_caches, - attn_metadata, - **kwargs, - ) + # Note one iteration is not enough for torch.jit.script + for _ in range(_NUM_WARMUP_ITERS): + self.model( + input_ids, + positions, + kv_caches, + attn_metadata, + **kwargs, + ) torch.cuda.synchronize() # Capture the graph.