[Core] Change LoRA embedding sharding to support loading methods (#5038)

This commit is contained in:
Antoni Baum 2024-06-06 19:07:57 -07:00 committed by GitHub
parent a31cab7556
commit ccdc490dda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 658 additions and 126 deletions

View File

@ -46,6 +46,7 @@ steps:
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
- pytest -v -s spec_decode/e2e/test_integration_dist.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
- label: Distributed Tests (Multiple Groups)
#mirror_hardwares: [amd]
@ -138,14 +139,7 @@ steps:
num_gpus: 4
# This test runs llama 13B, so it is required to run on 4 GPUs.
commands:
# Temporarily run this way because we cannot clean up GPU mem usage
# for multi GPU tests.
# TODO(sang): Fix it.
- pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
- pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
- pytest -v -s lora/test_long_context.py::test_self_consistency
- pytest -v -s lora/test_long_context.py::test_quality
- pytest -v -s lora/test_long_context.py::test_max_len
- pytest -v -s -x lora/test_long_context.py
- label: Tensorizer Test
#mirror_hardwares: [amd]

View File

@ -1,6 +1,8 @@
import contextlib
import gc
import os
import subprocess
import sys
from typing import Any, Dict, List, Optional, Tuple, TypeVar
import pytest
@ -522,3 +524,22 @@ def caplog_vllm(temporary_enable_log_propagate, caplog):
# To capture vllm log, we should enable propagate=True temporarily
# because caplog depends on logs propagated to the root logger.
yield caplog
@pytest.fixture(scope="session")
def num_gpus_available():
"""Get number of GPUs without initializing the CUDA context
in current process."""
try:
out = subprocess.run([
sys.executable, "-c",
"import torch; print(torch.cuda.device_count())"
],
capture_output=True,
check=True,
text=True)
except subprocess.CalledProcessError as e:
logger.warning("Failed to get number of GPUs.", exc_info=e)
return 0
return int(out.stdout.strip())

View File

@ -42,10 +42,24 @@ def cleanup():
ray.shutdown()
@pytest.fixture()
def should_do_global_cleanup_after_test(request) -> bool:
"""Allow subdirectories to skip global cleanup by overriding this fixture.
This can provide a ~10x speedup for non-GPU unit tests since they don't need
to initialize torch.
"""
if request.node.get_closest_marker("skip_global_cleanup"):
return False
return True
@pytest.fixture(autouse=True)
def cleanup_fixture():
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
yield
cleanup()
if should_do_global_cleanup_after_test:
cleanup()
@pytest.fixture

View File

@ -2,6 +2,7 @@ import random
from copy import deepcopy
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from unittest.mock import patch
import pytest
import torch
@ -32,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
from vllm.model_executor.utils import set_random_seed
from .utils import DummyLoRAManager
@ -427,7 +428,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
logits_processor = LogitsProcessor(
vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
lora_logits_processor = LogitsProcessorWithLoRA(
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
logits_processor, 1024, linear.weight.dtype, linear.weight.device,
None)
lora_logits_processor.create_lora_weights(max_loras, lora_config)
return linear, logits_processor, lora_logits_processor
@ -867,3 +869,216 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
torch.allclose(ref_q, actual_q)
torch.allclose(ref_k, actual_k)
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
@pytest.mark.parametrize("seed", list(range(256)))
def test_vocab_parallel_embedding_indices(tp_size, seed):
random.seed(seed)
vocab_size = random.randint(4000, 64000)
added_vocab_size = random.randint(0, 1024)
org_vocab_size = vocab_size - added_vocab_size
last_org_vocab_end_index = 0
last_added_vocab_end_index = org_vocab_size
computed_vocab_size = 0
computed_org_vocab_size = 0
computed_added_vocab_size = 0
vocab_size_padded = -1
all_org_tokens = []
all_added_tokens = []
token_ids = []
for tp_rank in range(tp_size):
with patch(
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
return_value=tp_rank
), patch(
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
return_value=tp_size):
vocab_embedding = VocabParallelEmbedding(
vocab_size, 1, org_num_embeddings=org_vocab_size)
vocab_size_padded = vocab_embedding.num_embeddings_padded
shard_indices = vocab_embedding.shard_indices
# Assert that the ranges are contiguous
assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
assert (shard_indices.added_vocab_start_index ==
last_added_vocab_end_index)
# Ensure that we are not exceeding the vocab size
computed_vocab_size += shard_indices.num_elements_padded
computed_org_vocab_size += shard_indices.num_org_elements
computed_added_vocab_size += shard_indices.num_added_elements
# Ensure that the ranges are not overlapping
all_org_tokens.extend(
range(shard_indices.org_vocab_start_index,
shard_indices.org_vocab_end_index))
all_added_tokens.extend(
range(shard_indices.added_vocab_start_index,
shard_indices.added_vocab_end_index))
token_ids.extend(
range(shard_indices.org_vocab_start_index,
shard_indices.org_vocab_end_index))
token_ids.extend([-1] * (shard_indices.num_org_elements_padded -
shard_indices.num_org_elements))
token_ids.extend(
range(shard_indices.added_vocab_start_index,
shard_indices.added_vocab_end_index))
token_ids.extend([-1] * (shard_indices.num_added_elements_padded -
shard_indices.num_added_elements))
last_org_vocab_end_index = shard_indices.org_vocab_end_index
last_added_vocab_end_index = shard_indices.added_vocab_end_index
assert computed_vocab_size == vocab_size_padded
assert computed_org_vocab_size == org_vocab_size
assert computed_added_vocab_size == added_vocab_size
# Ensure that the ranges are not overlapping
assert len(all_org_tokens) == len(set(all_org_tokens))
assert len(all_added_tokens) == len(set(all_added_tokens))
assert not set(all_org_tokens).intersection(set(all_added_tokens))
token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
assert reindex_mapping is not None or tp_size == 1
if reindex_mapping is not None:
reindexed_token_ids = token_ids_tensor[reindex_mapping]
expected = torch.tensor(list(range(0, vocab_size)))
assert reindexed_token_ids[:vocab_size].equal(expected)
assert torch.all(reindexed_token_ids[vocab_size:] == -1)
def test_get_masked_input_and_mask():
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
# base tp 1 case, no padding
modified_x, _ = get_masked_input_and_mask(x,
org_vocab_start_index=0,
org_vocab_end_index=8,
added_vocab_start_index=8,
added_vocab_end_index=12,
num_org_vocab_padding=0)
assert torch.equal(x, modified_x)
# tp 2 case, no padding
modified_x_rank_0, _ = get_masked_input_and_mask(x,
org_vocab_start_index=0,
org_vocab_end_index=4,
added_vocab_start_index=8,
added_vocab_end_index=10,
num_org_vocab_padding=0)
modified_x_rank_1, _ = get_masked_input_and_mask(
x,
org_vocab_start_index=4,
org_vocab_end_index=8,
added_vocab_start_index=10,
added_vocab_end_index=12,
num_org_vocab_padding=0)
assert torch.equal(modified_x_rank_0,
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
assert torch.equal(modified_x_rank_1,
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))
# tp 4 case, no padding
modified_x_rank_0, _ = get_masked_input_and_mask(x,
org_vocab_start_index=0,
org_vocab_end_index=2,
added_vocab_start_index=8,
added_vocab_end_index=9,
num_org_vocab_padding=0)
modified_x_rank_1, _ = get_masked_input_and_mask(x,
org_vocab_start_index=2,
org_vocab_end_index=4,
added_vocab_start_index=9,
added_vocab_end_index=10,
num_org_vocab_padding=0)
modified_x_rank_2, _ = get_masked_input_and_mask(
x,
org_vocab_start_index=4,
org_vocab_end_index=6,
added_vocab_start_index=10,
added_vocab_end_index=11,
num_org_vocab_padding=0)
modified_x_rank_3, _ = get_masked_input_and_mask(
x,
org_vocab_start_index=6,
org_vocab_end_index=8,
added_vocab_start_index=11,
added_vocab_end_index=12,
num_org_vocab_padding=0)
assert torch.equal(modified_x_rank_0,
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
assert torch.equal(modified_x_rank_1,
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
assert torch.equal(modified_x_rank_2,
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
assert torch.equal(modified_x_rank_3,
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))
# base tp 1 case, with padding
modified_x, _ = get_masked_input_and_mask(x,
org_vocab_start_index=0,
org_vocab_end_index=8,
added_vocab_start_index=8,
added_vocab_end_index=12,
num_org_vocab_padding=2)
assert torch.equal(modified_x,
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))
# tp 2 case, with padding
modified_x_rank_0, _ = get_masked_input_and_mask(x,
org_vocab_start_index=0,
org_vocab_end_index=4,
added_vocab_start_index=8,
added_vocab_end_index=10,
num_org_vocab_padding=2)
modified_x_rank_1, _ = get_masked_input_and_mask(
x,
org_vocab_start_index=4,
org_vocab_end_index=8,
added_vocab_start_index=10,
added_vocab_end_index=12,
num_org_vocab_padding=2)
assert torch.equal(modified_x_rank_0,
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
assert torch.equal(modified_x_rank_1,
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))
# tp 4 case, with padding
modified_x_rank_0, _ = get_masked_input_and_mask(x,
org_vocab_start_index=0,
org_vocab_end_index=2,
added_vocab_start_index=8,
added_vocab_end_index=9,
num_org_vocab_padding=2)
modified_x_rank_1, _ = get_masked_input_and_mask(x,
org_vocab_start_index=2,
org_vocab_end_index=4,
added_vocab_start_index=9,
added_vocab_end_index=10,
num_org_vocab_padding=2)
modified_x_rank_2, _ = get_masked_input_and_mask(
x,
org_vocab_start_index=4,
org_vocab_end_index=6,
added_vocab_start_index=10,
added_vocab_end_index=11,
num_org_vocab_padding=2)
modified_x_rank_3, _ = get_masked_input_and_mask(
x,
org_vocab_start_index=6,
org_vocab_end_index=8,
added_vocab_start_index=11,
added_vocab_end_index=12,
num_org_vocab_padding=2)
assert torch.equal(modified_x_rank_0,
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
assert torch.equal(modified_x_rank_1,
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
assert torch.equal(modified_x_rank_2,
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
assert torch.equal(modified_x_rank_3,
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))

View File

@ -36,11 +36,10 @@ def do_sample(llm, lora_path: str, lora_id: int):
return generated_texts
@pytest.mark.parametrize("tp_size", [1])
def test_llama_lora(sql_lora_files, tp_size):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < tp_size:
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
@pytest.mark.parametrize("tp_size", [1, 2, 4])
def test_llama_lora(sql_lora_files, tp_size, num_gpus_available):
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
@ -80,11 +79,9 @@ def test_llama_lora(sql_lora_files, tp_size):
print("removing lora")
@pytest.mark.skip("Requires multiple GPUs")
def test_llama_tensor_parallel_equality(sql_lora_files):
# Cannot use as it will initialize torch.cuda too early...
# if torch.cuda.device_count() < 4:
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
if num_gpus_available < 4:
pytest.skip("Not enough GPUs for tensor parallelism 4")
llm_tp1 = vllm.LLM(MODEL_PATH,
enable_lora=True,

View File

@ -102,22 +102,21 @@ def batched_generate(
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
@pytest.fixture
@pytest.fixture(scope="module")
def lora_llm(long_context_infos):
scaling_factors = [
context_len_to_scaling_factor[info["context_length"]]
for info in long_context_infos.values()
]
llm = vllm.LLM(
"meta-llama/Llama-2-13b-chat-hf",
enable_lora=True,
max_num_seqs=16,
max_loras=2,
long_lora_scaling_factors=tuple(scaling_factors),
max_num_batched_tokens=4096 * 8,
tensor_parallel_size=4,
)
llm = vllm.LLM("meta-llama/Llama-2-13b-chat-hf",
enable_lora=True,
max_num_seqs=16,
max_loras=2,
long_lora_scaling_factors=tuple(scaling_factors),
max_num_batched_tokens=4096 * 8,
tensor_parallel_size=4,
distributed_executor_backend="mp")
yield llm
del llm
@ -154,6 +153,7 @@ def test_rotary_emb_replaced(dist_init):
assert rotary_emb_count == 32
@pytest.mark.skip_global_cleanup
def test_batched_rope_kernel(lora_llm, long_context_infos):
"""We test the batched kernel by comparing the results of batched an
non-batched generation.
@ -188,6 +188,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
f"same:\n{batched}\n{non_batched}")
@pytest.mark.skip_global_cleanup
def test_self_consistency(lora_llm, long_context_infos):
"""We test consistency of the batched kernel by permuting batched
inputs and comparing the results to the non-permuted batched results.
@ -227,6 +228,7 @@ def test_self_consistency(lora_llm, long_context_infos):
f"\n{permutated_batched_results[permutation[i]]}")
@pytest.mark.skip_global_cleanup
def test_quality(lora_llm, long_context_infos):
"""We test the quality of the answers given by the LoRA model by
comparing the generated text to the merged model's outputs.
@ -257,6 +259,7 @@ def test_quality(lora_llm, long_context_infos):
assert np.mean(scores) > 0.5
@pytest.mark.skip_global_cleanup
def test_max_len(lora_llm, long_context_infos):
"""Test that we raise an ValueError when the input of a given LoRA
model exceeds the maximum length."""

View File

@ -1,3 +1,4 @@
import multiprocessing as mp
import os
import shutil
from tempfile import TemporaryDirectory
@ -18,9 +19,7 @@ prompts = [
# Create a sampling params object.
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
seed=0,
temperature=0,
max_tokens=256,
ignore_eos=True,
)
@ -43,48 +42,85 @@ def test_filter_subtensors():
assert tensor.equal(state_dict[key])
@pytest.mark.parametrize("enable_lora", [False, True])
def test_sharded_state_loader(enable_lora):
weights_patterns = ("*.bin", "*.pt", "*.safetensors")
with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir:
@pytest.fixture(scope="module")
def llama_2_7b_files():
with TemporaryDirectory() as cache_dir:
input_dir = snapshot_download("meta-llama/Llama-2-7b-hf",
cache_dir=cache_dir)
cache_dir=cache_dir,
ignore_patterns="*.bin*")
yield input_dir
llm = LLM(
model=input_dir,
worker_use_ray=True,
gpu_memory_utilization=0.3,
)
# Dump worker states to output directory
model_executor = llm.llm_engine.model_executor
model_executor.save_sharded_state(path=output_dir)
# Copy metadata files to output directory
for file in os.listdir(input_dir):
if not any(file.endswith(ext) for ext in weights_patterns):
shutil.copy(f"{input_dir}/{file}", output_dir)
del llm.llm_engine.model_executor
def _run_writer(input_dir, output_dir, weights_patterns, **kwargs):
llm_sharded_writer = LLM(model=input_dir, **kwargs)
llm_before = LLM(
model=input_dir,
worker_use_ray=True,
enable_lora=enable_lora,
gpu_memory_utilization=0.3,
)
gen_before = llm_before.generate(prompts, sampling_params)
out_before = [gen.outputs[0].__dict__ for gen in gen_before]
del llm_before.llm_engine.model_executor
# Dump worker states to output directory
llm_sharded_writer.llm_engine.model_executor.save_sharded_state(
path=output_dir)
# Copy metadata files to output directory
for file in os.listdir(input_dir):
if not any(file.endswith(ext) for ext in weights_patterns):
shutil.copy(f"{input_dir}/{file}", output_dir)
llm_after = LLM(
model=output_dir,
worker_use_ray=True,
enable_lora=enable_lora,
gpu_memory_utilization=0.3,
load_format="sharded_state",
)
gen_after = llm_after.generate(prompts, sampling_params)
out_after = [gen.outputs[0].__dict__ for gen in gen_after]
del llm_after.llm_engine.model_executor
def _run_generate(input_dir, queue: mp.Queue, **kwargs):
llm = LLM(model=input_dir, **kwargs)
gen = llm.generate(prompts, sampling_params)
queue.put([g.outputs[0].__dict__ for g in gen])
queue.close()
queue.join_thread()
@pytest.mark.parametrize("enable_lora", [False, True])
@pytest.mark.parametrize("tp_size", [1, 2])
def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available,
llama_2_7b_files):
if num_gpus_available < tp_size:
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
weights_patterns = ("*.safetensors", )
gpu_memory_utilization = 0.8
input_dir = llama_2_7b_files
ctx = mp.get_context("spawn")
# Run in separate processes for memory & CUDA isolation
with TemporaryDirectory() as output_dir:
p = ctx.Process(target=_run_writer,
args=(input_dir, output_dir, weights_patterns),
kwargs=dict(
tensor_parallel_size=tp_size,
distributed_executor_backend="mp",
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=True,
))
p.start()
p.join()
queue = ctx.Queue()
p = ctx.Process(target=_run_generate,
args=(input_dir, queue),
kwargs=dict(
distributed_executor_backend="mp",
enable_lora=enable_lora,
gpu_memory_utilization=gpu_memory_utilization,
tensor_parallel_size=tp_size,
))
p.start()
p.join()
out_before = queue.get()
p = ctx.Process(target=_run_generate,
args=(output_dir, queue),
kwargs=dict(
distributed_executor_backend="mp",
enable_lora=enable_lora,
gpu_memory_utilization=gpu_memory_utilization,
tensor_parallel_size=tp_size,
load_format="sharded_state",
))
p.start()
p.join()
out_after = queue.get()
assert out_before == out_after

View File

@ -215,19 +215,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
lora_config: LoRAConfig,
model_config: Optional[PretrainedConfig] = None) -> None:
lora_vocab_start_idx = self.base_layer.org_vocab_size
weights_idx = None
if self.base_layer.vocab_end_index > lora_vocab_start_idx:
if self.base_layer.num_added_embeddings_per_partition > 0:
# We can start adding lora weights
weights_idx = max(
lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
self.embeddings_slice = (self.base_layer.vocab_start_index -
self.base_layer.org_vocab_size +
weights_idx,
self.base_layer.vocab_end_index -
self.base_layer.org_vocab_size)
self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
self.embeddings_weights.fill_(0)
self.embeddings_weights = self.base_layer.weight.data[
self.base_layer.num_org_embeddings_per_partition:self.
base_layer.num_org_embeddings_per_partition +
self.base_layer.num_added_embeddings_per_partition]
self.embeddings_slice = (
self.base_layer.shard_indices.added_vocab_start_index -
self.base_layer.org_vocab_size,
self.base_layer.shard_indices.added_vocab_end_index -
self.base_layer.org_vocab_size)
self.base_layer.weight.data[
self.base_layer.num_org_embeddings_per_partition:].fill_(0)
else:
self.embeddings_slice = None
self.embeddings_weights = None
@ -1025,19 +1025,31 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
"""
LoRA wrapper for LogitsProcessor, with extra logic to handle the
application of the LoRA adapter and added LoRA vocabulary.
def __init__(
self,
base_layer: LogitsProcessor,
hidden_size: int,
dtype: torch.dtype,
device: torch.device,
) -> None:
Args:
base_layer: LogitsProcessor layer
hidden_size: hidden size of the model
dtype: data type of the model
device: device of the model
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
received from base_layer.get_sharded_to_full_mapping(). If None,
no reindexing will be done.
"""
def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
dtype: torch.dtype, device: torch.device,
sharded_to_full_mapping: Optional[List[int]]) -> None:
super().__init__()
self.base_layer = base_layer
self.hidden_size = hidden_size
self.dtype = dtype
self.device = device
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.sharded_to_full_mapping = sharded_to_full_mapping
@property
def logits_as_input(self):
@ -1098,6 +1110,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
dtype=self.dtype,
device=self.device,
)
if self.sharded_to_full_mapping is not None:
self.sharded_to_full_mapping_gpu = torch.tensor(
self.sharded_to_full_mapping,
device=self.device,
dtype=torch.long)
else:
self.sharded_to_full_mapping_gpu = None
# Lazily initialized.
self.indices: torch.Tensor
self.indices_len: List[int]
@ -1154,6 +1173,25 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
if logits is None:
return None
if self.sharded_to_full_mapping_gpu is not None:
# Reindex full logits tensor to ensure 1:1 mapping between
# index and token_id
# Example for:
# org_vocab_size = 4
# added_vocab_size = 2
# pad_to_size = 8
# tp_size = 2
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
# Therefore, the mapping is expected to be:
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
# we get:
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
logits = logits[:, self.sharded_to_full_mapping_gpu]
lora_logits = torch.empty(
self.embeddings_tensors.shape[0] + 1,
self.embeddings_tensors.shape[1],

View File

@ -67,7 +67,8 @@ def from_layer_logits_processor(
model_config: Optional[PretrainedConfig] = None,
) -> LogitsProcessorWithLoRA:
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
lm_head.weight.dtype, lm_head.weight.device)
lm_head.weight.dtype, lm_head.weight.device,
lm_head.get_sharded_to_full_mapping())
ret.create_lora_weights(max_loras, lora_config, model_config)
return ret

View File

@ -1,4 +1,5 @@
from typing import Optional, Sequence
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple
import torch
import torch.nn.functional as F
@ -18,18 +19,107 @@ def pad_vocab_size(vocab_size: int,
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
rank: int) -> Sequence[int]:
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int,
rank: int,
offset: int = 0) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
return index_f + offset, index_l + offset
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
world_size: int) -> Sequence[int]:
def vocab_range_from_global_vocab_size(global_vocab_size: int,
rank: int,
world_size: int,
offset: int = 0) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
rank)
rank,
offset=offset)
@dataclass
class VocabParallelEmbeddingShardIndices:
"""Indices for a shard of a vocab parallel embedding."""
padded_org_vocab_start_index: int
padded_org_vocab_end_index: int
padded_added_vocab_start_index: int
padded_added_vocab_end_index: int
org_vocab_start_index: int
org_vocab_end_index: int
added_vocab_start_index: int
added_vocab_end_index: int
@property
def num_org_elements(self) -> int:
return self.org_vocab_end_index - self.org_vocab_start_index
@property
def num_added_elements(self) -> int:
return self.added_vocab_end_index - self.added_vocab_start_index
@property
def num_org_elements_padded(self) -> int:
return (self.padded_org_vocab_end_index -
self.padded_org_vocab_start_index)
@property
def num_added_elements_padded(self) -> int:
return (self.padded_added_vocab_end_index -
self.padded_added_vocab_start_index)
@property
def num_org_vocab_padding(self) -> int:
return self.num_org_elements_padded - self.num_org_elements
@property
def num_added_vocab_padding(self) -> int:
return self.num_added_elements_padded - self.num_added_elements
@property
def num_elements_padded(self) -> int:
return self.num_org_elements_padded + self.num_added_elements_padded
def __post_init__(self):
# sanity checks
assert (self.padded_org_vocab_start_index <=
self.padded_org_vocab_end_index)
assert (self.padded_added_vocab_start_index <=
self.padded_added_vocab_end_index)
assert self.org_vocab_start_index <= self.org_vocab_end_index
assert self.added_vocab_start_index <= self.added_vocab_end_index
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
assert (self.added_vocab_start_index <=
self.padded_added_vocab_start_index)
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
assert self.num_org_elements <= self.num_org_elements_padded
assert self.num_added_elements <= self.num_added_elements_padded
@torch.jit.script
def get_masked_input_and_mask(
input_: torch.Tensor, org_vocab_start_index: int,
org_vocab_end_index: int, num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
# torch.jit.script will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index)
added_offset = added_vocab_start_index - (
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
valid_offset = (org_vocab_start_index *
org_vocab_mask) + (added_offset * added_vocab_mask)
vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
class VocabParallelEmbedding(torch.nn.Module):
@ -38,13 +128,36 @@ class VocabParallelEmbedding(torch.nn.Module):
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel GPUs.
In order to support various loading methods, we ensure that LoRA-added
embeddings are always at the end of TP-sharded tensors. In other words,
we shard base embeddings and LoRA embeddings separately (both padded),
and place them in the same tensor.
In this example, we will have the original vocab size = 1010,
added vocab size = 16 and padding to 64. Therefore, the total
vocab size with padding will be 1088 (because we first pad 1010 to
1024, add 16, and then pad to 1088).
Therefore, the tensor format looks like the following:
TP1, rank 0 (no sharding):
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
TP2, rank 0:
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
TP2, rank 1:
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
Args:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
""" # noqa: E501
def __init__(self,
num_embeddings: int,
@ -55,21 +168,39 @@ class VocabParallelEmbedding(torch.nn.Module):
super().__init__()
# Keep the input dimensions.
tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_embeddings = num_embeddings
self.padding_size = padding_size
self.org_vocab_size = org_num_embeddings or num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings,
padding_size)
num_added_embeddings = num_embeddings - self.org_vocab_size
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
self.padding_size)
self.num_embeddings_padded = pad_vocab_size(
self.org_vocab_size_padded + num_added_embeddings,
self.padding_size)
assert self.org_vocab_size_padded <= self.num_embeddings_padded
self.shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size, tp_rank,
self.tp_size)
self.embedding_dim = embedding_dim
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = (
vocab_range_from_global_vocab_size(
self.num_embeddings_padded, get_tensor_model_parallel_rank(),
self.tp_size))
self.num_embeddings_per_partition = (self.vocab_end_index -
self.vocab_start_index)
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
self.tp_size)
assert (self.shard_indices.num_elements_padded ==
self.num_embeddings_per_partition)
self.num_org_embeddings_per_partition = (
self.shard_indices.org_vocab_end_index -
self.shard_indices.org_vocab_start_index)
self.num_added_embeddings_per_partition = (
self.shard_indices.added_vocab_end_index -
self.shard_indices.added_vocab_start_index)
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition,
self.embedding_dim,
@ -79,28 +210,107 @@ class VocabParallelEmbedding(torch.nn.Module):
"weight_loader": self.weight_loader
})
@classmethod
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
vocab_size: int, org_vocab_size: int, tp_rank: int,
tp_size: int) -> VocabParallelEmbeddingShardIndices:
"""Get start and end indices for vocab parallel embedding, following the
layout outlined in the class docstring, based on the given tp_rank and
tp_size."""
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
padded_org_vocab_start_index, padded_org_vocab_end_index = (
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
tp_size))
padded_added_vocab_start_index, padded_added_vocab_end_index = (
vocab_range_from_global_vocab_size(num_added_embeddings_padded,
tp_rank,
tp_size,
offset=org_vocab_size))
# remove padding
org_vocab_start_index = min(padded_org_vocab_start_index,
org_vocab_size)
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
added_vocab_start_index = min(padded_added_vocab_start_index,
vocab_size)
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
return VocabParallelEmbeddingShardIndices(
padded_org_vocab_start_index, padded_org_vocab_end_index,
padded_added_vocab_start_index, padded_added_vocab_end_index,
org_vocab_start_index, org_vocab_end_index,
added_vocab_start_index, added_vocab_end_index)
def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
"""Get a mapping that can be used to reindex the gathered
logits for sampling.
During sampling, we gather logits from all ranks. The relationship
of index->token_id will follow the same format as outlined in the class
docstring. However, after the gather, we want to reindex the final
logits tensor to map index->token_id one-to-one (the index is always
equal the token_id it corresponds to). The indices returned by this
method allow us to do that.
"""
if self.tp_size < 2:
return None
base_embeddings: List[int] = []
added_embeddings: List[int] = []
padding: List[int] = []
for tp_rank in range(self.tp_size):
shard_indices = self._get_indices(self.num_embeddings_padded,
self.org_vocab_size_padded,
self.num_embeddings,
self.org_vocab_size, tp_rank,
self.tp_size)
range_start = self.num_embeddings_per_partition * tp_rank
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
base_embeddings.extend(
range(range_start,
range_start + shard_indices.num_org_elements))
padding.extend(
range(range_start + shard_indices.num_org_elements,
range_start + shard_indices.num_org_elements_padded))
added_embeddings.extend(
range(
range_start + shard_indices.num_org_elements_padded,
range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements))
padding.extend(
range(
range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements,
range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements_padded))
assert (range_start + shard_indices.num_org_elements_padded +
shard_indices.num_added_elements_padded == range_end)
ret = base_embeddings + added_embeddings + padding
assert len(ret) == self.num_embeddings_padded
return ret
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
parallel_dim = param.parallel_dim
assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
loaded_weight = loaded_weight[self.vocab_start_index:self.
vocab_end_index]
loaded_weight = loaded_weight[self.shard_indices.org_vocab_start_index:
self.shard_indices.org_vocab_end_index]
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
param[loaded_weight.shape[0]:].data.fill_(0)
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
input_mask = ((input_ < self.vocab_start_index) |
(input_ >= self.vocab_end_index))
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
masked_input, input_mask = get_masked_input_and_mask(
input_, self.shard_indices.org_vocab_start_index,
self.shard_indices.org_vocab_end_index,
self.shard_indices.num_org_vocab_padding,
self.shard_indices.added_vocab_start_index,
self.shard_indices.added_vocab_end_index)
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel[input_mask, :] = 0.0
output_parallel.masked_fill_(input_mask.unsqueeze(1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output

View File

@ -35,6 +35,7 @@ _BATCH_SIZE_ALIGNMENT = 8
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
]
_NUM_WARMUP_ITERS = 2
class ModelInput(NamedTuple):
@ -975,16 +976,18 @@ class CUDAGraphRunner:
**kwargs,
) -> None:
assert self._graph is None
# Run the model once without capturing the graph.
# Run the model a few times without capturing the graph.
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
**kwargs,
)
# Note one iteration is not enough for torch.jit.script
for _ in range(_NUM_WARMUP_ITERS):
self.model(
input_ids,
positions,
kv_caches,
attn_metadata,
**kwargs,
)
torch.cuda.synchronize()
# Capture the graph.