[Core] Change LoRA embedding sharding to support loading methods (#5038)
This commit is contained in:
parent
a31cab7556
commit
ccdc490dda
@ -46,6 +46,7 @@ steps:
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
|
||||
- label: Distributed Tests (Multiple Groups)
|
||||
#mirror_hardwares: [amd]
|
||||
@ -138,14 +139,7 @@ steps:
|
||||
num_gpus: 4
|
||||
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||
commands:
|
||||
# Temporarily run this way because we cannot clean up GPU mem usage
|
||||
# for multi GPU tests.
|
||||
# TODO(sang): Fix it.
|
||||
- pytest -v -s lora/test_long_context.py::test_rotary_emb_replaced
|
||||
- pytest -v -s lora/test_long_context.py::test_batched_rope_kernel
|
||||
- pytest -v -s lora/test_long_context.py::test_self_consistency
|
||||
- pytest -v -s lora/test_long_context.py::test_quality
|
||||
- pytest -v -s lora/test_long_context.py::test_max_len
|
||||
- pytest -v -s -x lora/test_long_context.py
|
||||
|
||||
- label: Tensorizer Test
|
||||
#mirror_hardwares: [amd]
|
||||
|
@ -1,6 +1,8 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional, Tuple, TypeVar
|
||||
|
||||
import pytest
|
||||
@ -522,3 +524,22 @@ def caplog_vllm(temporary_enable_log_propagate, caplog):
|
||||
# To capture vllm log, we should enable propagate=True temporarily
|
||||
# because caplog depends on logs propagated to the root logger.
|
||||
yield caplog
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def num_gpus_available():
|
||||
"""Get number of GPUs without initializing the CUDA context
|
||||
in current process."""
|
||||
|
||||
try:
|
||||
out = subprocess.run([
|
||||
sys.executable, "-c",
|
||||
"import torch; print(torch.cuda.device_count())"
|
||||
],
|
||||
capture_output=True,
|
||||
check=True,
|
||||
text=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning("Failed to get number of GPUs.", exc_info=e)
|
||||
return 0
|
||||
return int(out.stdout.strip())
|
||||
|
@ -42,10 +42,24 @@ def cleanup():
|
||||
ray.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def should_do_global_cleanup_after_test(request) -> bool:
|
||||
"""Allow subdirectories to skip global cleanup by overriding this fixture.
|
||||
This can provide a ~10x speedup for non-GPU unit tests since they don't need
|
||||
to initialize torch.
|
||||
"""
|
||||
|
||||
if request.node.get_closest_marker("skip_global_cleanup"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_fixture():
|
||||
def cleanup_fixture(should_do_global_cleanup_after_test: bool):
|
||||
yield
|
||||
cleanup()
|
||||
if should_do_global_cleanup_after_test:
|
||||
cleanup()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -2,6 +2,7 @@ import random
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -32,7 +33,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
ParallelLMHead, VocabParallelEmbedding, get_masked_input_and_mask)
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
from .utils import DummyLoRAManager
|
||||
@ -427,7 +428,8 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
logits_processor = LogitsProcessor(
|
||||
vocab_size + lora_config.lora_extra_vocab_size, vocab_size)
|
||||
lora_logits_processor = LogitsProcessorWithLoRA(
|
||||
logits_processor, 1024, linear.weight.dtype, linear.weight.device)
|
||||
logits_processor, 1024, linear.weight.dtype, linear.weight.device,
|
||||
None)
|
||||
lora_logits_processor.create_lora_weights(max_loras, lora_config)
|
||||
|
||||
return linear, logits_processor, lora_logits_processor
|
||||
@ -867,3 +869,216 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
||||
|
||||
torch.allclose(ref_q, actual_q)
|
||||
torch.allclose(ref_k, actual_k)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("seed", list(range(256)))
|
||||
def test_vocab_parallel_embedding_indices(tp_size, seed):
|
||||
random.seed(seed)
|
||||
vocab_size = random.randint(4000, 64000)
|
||||
added_vocab_size = random.randint(0, 1024)
|
||||
org_vocab_size = vocab_size - added_vocab_size
|
||||
last_org_vocab_end_index = 0
|
||||
last_added_vocab_end_index = org_vocab_size
|
||||
computed_vocab_size = 0
|
||||
computed_org_vocab_size = 0
|
||||
computed_added_vocab_size = 0
|
||||
vocab_size_padded = -1
|
||||
|
||||
all_org_tokens = []
|
||||
all_added_tokens = []
|
||||
token_ids = []
|
||||
|
||||
for tp_rank in range(tp_size):
|
||||
with patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_rank",
|
||||
return_value=tp_rank
|
||||
), patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.get_tensor_model_parallel_world_size",
|
||||
return_value=tp_size):
|
||||
vocab_embedding = VocabParallelEmbedding(
|
||||
vocab_size, 1, org_num_embeddings=org_vocab_size)
|
||||
vocab_size_padded = vocab_embedding.num_embeddings_padded
|
||||
shard_indices = vocab_embedding.shard_indices
|
||||
# Assert that the ranges are contiguous
|
||||
assert shard_indices.org_vocab_start_index == last_org_vocab_end_index
|
||||
assert (shard_indices.added_vocab_start_index ==
|
||||
last_added_vocab_end_index)
|
||||
|
||||
# Ensure that we are not exceeding the vocab size
|
||||
computed_vocab_size += shard_indices.num_elements_padded
|
||||
computed_org_vocab_size += shard_indices.num_org_elements
|
||||
computed_added_vocab_size += shard_indices.num_added_elements
|
||||
|
||||
# Ensure that the ranges are not overlapping
|
||||
all_org_tokens.extend(
|
||||
range(shard_indices.org_vocab_start_index,
|
||||
shard_indices.org_vocab_end_index))
|
||||
all_added_tokens.extend(
|
||||
range(shard_indices.added_vocab_start_index,
|
||||
shard_indices.added_vocab_end_index))
|
||||
|
||||
token_ids.extend(
|
||||
range(shard_indices.org_vocab_start_index,
|
||||
shard_indices.org_vocab_end_index))
|
||||
token_ids.extend([-1] * (shard_indices.num_org_elements_padded -
|
||||
shard_indices.num_org_elements))
|
||||
token_ids.extend(
|
||||
range(shard_indices.added_vocab_start_index,
|
||||
shard_indices.added_vocab_end_index))
|
||||
token_ids.extend([-1] * (shard_indices.num_added_elements_padded -
|
||||
shard_indices.num_added_elements))
|
||||
|
||||
last_org_vocab_end_index = shard_indices.org_vocab_end_index
|
||||
last_added_vocab_end_index = shard_indices.added_vocab_end_index
|
||||
|
||||
assert computed_vocab_size == vocab_size_padded
|
||||
assert computed_org_vocab_size == org_vocab_size
|
||||
assert computed_added_vocab_size == added_vocab_size
|
||||
|
||||
# Ensure that the ranges are not overlapping
|
||||
assert len(all_org_tokens) == len(set(all_org_tokens))
|
||||
assert len(all_added_tokens) == len(set(all_added_tokens))
|
||||
assert not set(all_org_tokens).intersection(set(all_added_tokens))
|
||||
|
||||
token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
|
||||
reindex_mapping = vocab_embedding.get_sharded_to_full_mapping()
|
||||
assert reindex_mapping is not None or tp_size == 1
|
||||
if reindex_mapping is not None:
|
||||
reindexed_token_ids = token_ids_tensor[reindex_mapping]
|
||||
expected = torch.tensor(list(range(0, vocab_size)))
|
||||
assert reindexed_token_ids[:vocab_size].equal(expected)
|
||||
assert torch.all(reindexed_token_ids[vocab_size:] == -1)
|
||||
|
||||
|
||||
def test_get_masked_input_and_mask():
|
||||
x = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
|
||||
|
||||
# base tp 1 case, no padding
|
||||
modified_x, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=0,
|
||||
org_vocab_end_index=8,
|
||||
added_vocab_start_index=8,
|
||||
added_vocab_end_index=12,
|
||||
num_org_vocab_padding=0)
|
||||
assert torch.equal(x, modified_x)
|
||||
|
||||
# tp 2 case, no padding
|
||||
modified_x_rank_0, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=0,
|
||||
org_vocab_end_index=4,
|
||||
added_vocab_start_index=8,
|
||||
added_vocab_end_index=10,
|
||||
num_org_vocab_padding=0)
|
||||
modified_x_rank_1, _ = get_masked_input_and_mask(
|
||||
x,
|
||||
org_vocab_start_index=4,
|
||||
org_vocab_end_index=8,
|
||||
added_vocab_start_index=10,
|
||||
added_vocab_end_index=12,
|
||||
num_org_vocab_padding=0)
|
||||
assert torch.equal(modified_x_rank_0,
|
||||
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 0, 0]))
|
||||
assert torch.equal(modified_x_rank_1,
|
||||
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 4, 5]))
|
||||
|
||||
# tp 4 case, no padding
|
||||
modified_x_rank_0, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=0,
|
||||
org_vocab_end_index=2,
|
||||
added_vocab_start_index=8,
|
||||
added_vocab_end_index=9,
|
||||
num_org_vocab_padding=0)
|
||||
modified_x_rank_1, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=2,
|
||||
org_vocab_end_index=4,
|
||||
added_vocab_start_index=9,
|
||||
added_vocab_end_index=10,
|
||||
num_org_vocab_padding=0)
|
||||
modified_x_rank_2, _ = get_masked_input_and_mask(
|
||||
x,
|
||||
org_vocab_start_index=4,
|
||||
org_vocab_end_index=6,
|
||||
added_vocab_start_index=10,
|
||||
added_vocab_end_index=11,
|
||||
num_org_vocab_padding=0)
|
||||
modified_x_rank_3, _ = get_masked_input_and_mask(
|
||||
x,
|
||||
org_vocab_start_index=6,
|
||||
org_vocab_end_index=8,
|
||||
added_vocab_start_index=11,
|
||||
added_vocab_end_index=12,
|
||||
num_org_vocab_padding=0)
|
||||
assert torch.equal(modified_x_rank_0,
|
||||
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0]))
|
||||
assert torch.equal(modified_x_rank_1,
|
||||
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 2, 0, 0]))
|
||||
assert torch.equal(modified_x_rank_2,
|
||||
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0]))
|
||||
assert torch.equal(modified_x_rank_3,
|
||||
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 2]))
|
||||
|
||||
# base tp 1 case, with padding
|
||||
modified_x, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=0,
|
||||
org_vocab_end_index=8,
|
||||
added_vocab_start_index=8,
|
||||
added_vocab_end_index=12,
|
||||
num_org_vocab_padding=2)
|
||||
assert torch.equal(modified_x,
|
||||
torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 10, 11, 12, 13]))
|
||||
|
||||
# tp 2 case, with padding
|
||||
modified_x_rank_0, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=0,
|
||||
org_vocab_end_index=4,
|
||||
added_vocab_start_index=8,
|
||||
added_vocab_end_index=10,
|
||||
num_org_vocab_padding=2)
|
||||
modified_x_rank_1, _ = get_masked_input_and_mask(
|
||||
x,
|
||||
org_vocab_start_index=4,
|
||||
org_vocab_end_index=8,
|
||||
added_vocab_start_index=10,
|
||||
added_vocab_end_index=12,
|
||||
num_org_vocab_padding=2)
|
||||
assert torch.equal(modified_x_rank_0,
|
||||
torch.tensor([0, 1, 2, 3, 0, 0, 0, 0, 6, 7, 0, 0]))
|
||||
assert torch.equal(modified_x_rank_1,
|
||||
torch.tensor([0, 0, 0, 0, 0, 1, 2, 3, 0, 0, 6, 7]))
|
||||
|
||||
# tp 4 case, with padding
|
||||
modified_x_rank_0, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=0,
|
||||
org_vocab_end_index=2,
|
||||
added_vocab_start_index=8,
|
||||
added_vocab_end_index=9,
|
||||
num_org_vocab_padding=2)
|
||||
modified_x_rank_1, _ = get_masked_input_and_mask(x,
|
||||
org_vocab_start_index=2,
|
||||
org_vocab_end_index=4,
|
||||
added_vocab_start_index=9,
|
||||
added_vocab_end_index=10,
|
||||
num_org_vocab_padding=2)
|
||||
modified_x_rank_2, _ = get_masked_input_and_mask(
|
||||
x,
|
||||
org_vocab_start_index=4,
|
||||
org_vocab_end_index=6,
|
||||
added_vocab_start_index=10,
|
||||
added_vocab_end_index=11,
|
||||
num_org_vocab_padding=2)
|
||||
modified_x_rank_3, _ = get_masked_input_and_mask(
|
||||
x,
|
||||
org_vocab_start_index=6,
|
||||
org_vocab_end_index=8,
|
||||
added_vocab_start_index=11,
|
||||
added_vocab_end_index=12,
|
||||
num_org_vocab_padding=2)
|
||||
assert torch.equal(modified_x_rank_0,
|
||||
torch.tensor([0, 1, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0]))
|
||||
assert torch.equal(modified_x_rank_1,
|
||||
torch.tensor([0, 0, 0, 1, 0, 0, 0, 0, 0, 4, 0, 0]))
|
||||
assert torch.equal(modified_x_rank_2,
|
||||
torch.tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 4, 0]))
|
||||
assert torch.equal(modified_x_rank_3,
|
||||
torch.tensor([0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 4]))
|
||||
|
@ -36,11 +36,10 @@ def do_sample(llm, lora_path: str, lora_id: int):
|
||||
return generated_texts
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [1])
|
||||
def test_llama_lora(sql_lora_files, tp_size):
|
||||
# Cannot use as it will initialize torch.cuda too early...
|
||||
# if torch.cuda.device_count() < tp_size:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
@pytest.mark.parametrize("tp_size", [1, 2, 4])
|
||||
def test_llama_lora(sql_lora_files, tp_size, num_gpus_available):
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
enable_lora=True,
|
||||
@ -80,11 +79,9 @@ def test_llama_lora(sql_lora_files, tp_size):
|
||||
print("removing lora")
|
||||
|
||||
|
||||
@pytest.mark.skip("Requires multiple GPUs")
|
||||
def test_llama_tensor_parallel_equality(sql_lora_files):
|
||||
# Cannot use as it will initialize torch.cuda too early...
|
||||
# if torch.cuda.device_count() < 4:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {4}")
|
||||
def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
|
||||
if num_gpus_available < 4:
|
||||
pytest.skip("Not enough GPUs for tensor parallelism 4")
|
||||
|
||||
llm_tp1 = vllm.LLM(MODEL_PATH,
|
||||
enable_lora=True,
|
||||
|
@ -102,22 +102,21 @@ def batched_generate(
|
||||
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="module")
|
||||
def lora_llm(long_context_infos):
|
||||
scaling_factors = [
|
||||
context_len_to_scaling_factor[info["context_length"]]
|
||||
for info in long_context_infos.values()
|
||||
]
|
||||
|
||||
llm = vllm.LLM(
|
||||
"meta-llama/Llama-2-13b-chat-hf",
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=2,
|
||||
long_lora_scaling_factors=tuple(scaling_factors),
|
||||
max_num_batched_tokens=4096 * 8,
|
||||
tensor_parallel_size=4,
|
||||
)
|
||||
llm = vllm.LLM("meta-llama/Llama-2-13b-chat-hf",
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=2,
|
||||
long_lora_scaling_factors=tuple(scaling_factors),
|
||||
max_num_batched_tokens=4096 * 8,
|
||||
tensor_parallel_size=4,
|
||||
distributed_executor_backend="mp")
|
||||
yield llm
|
||||
del llm
|
||||
|
||||
@ -154,6 +153,7 @@ def test_rotary_emb_replaced(dist_init):
|
||||
assert rotary_emb_count == 32
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_batched_rope_kernel(lora_llm, long_context_infos):
|
||||
"""We test the batched kernel by comparing the results of batched an
|
||||
non-batched generation.
|
||||
@ -188,6 +188,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
|
||||
f"same:\n{batched}\n{non_batched}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_self_consistency(lora_llm, long_context_infos):
|
||||
"""We test consistency of the batched kernel by permuting batched
|
||||
inputs and comparing the results to the non-permuted batched results.
|
||||
@ -227,6 +228,7 @@ def test_self_consistency(lora_llm, long_context_infos):
|
||||
f"\n{permutated_batched_results[permutation[i]]}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_quality(lora_llm, long_context_infos):
|
||||
"""We test the quality of the answers given by the LoRA model by
|
||||
comparing the generated text to the merged model's outputs.
|
||||
@ -257,6 +259,7 @@ def test_quality(lora_llm, long_context_infos):
|
||||
assert np.mean(scores) > 0.5
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_max_len(lora_llm, long_context_infos):
|
||||
"""Test that we raise an ValueError when the input of a given LoRA
|
||||
model exceeds the maximum length."""
|
||||
|
@ -1,3 +1,4 @@
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import shutil
|
||||
from tempfile import TemporaryDirectory
|
||||
@ -18,9 +19,7 @@ prompts = [
|
||||
|
||||
# Create a sampling params object.
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.8,
|
||||
top_p=0.95,
|
||||
seed=0,
|
||||
temperature=0,
|
||||
max_tokens=256,
|
||||
ignore_eos=True,
|
||||
)
|
||||
@ -43,48 +42,85 @@ def test_filter_subtensors():
|
||||
assert tensor.equal(state_dict[key])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_lora", [False, True])
|
||||
def test_sharded_state_loader(enable_lora):
|
||||
weights_patterns = ("*.bin", "*.pt", "*.safetensors")
|
||||
|
||||
with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir:
|
||||
@pytest.fixture(scope="module")
|
||||
def llama_2_7b_files():
|
||||
with TemporaryDirectory() as cache_dir:
|
||||
input_dir = snapshot_download("meta-llama/Llama-2-7b-hf",
|
||||
cache_dir=cache_dir)
|
||||
cache_dir=cache_dir,
|
||||
ignore_patterns="*.bin*")
|
||||
yield input_dir
|
||||
|
||||
llm = LLM(
|
||||
model=input_dir,
|
||||
worker_use_ray=True,
|
||||
gpu_memory_utilization=0.3,
|
||||
)
|
||||
|
||||
# Dump worker states to output directory
|
||||
model_executor = llm.llm_engine.model_executor
|
||||
model_executor.save_sharded_state(path=output_dir)
|
||||
# Copy metadata files to output directory
|
||||
for file in os.listdir(input_dir):
|
||||
if not any(file.endswith(ext) for ext in weights_patterns):
|
||||
shutil.copy(f"{input_dir}/{file}", output_dir)
|
||||
del llm.llm_engine.model_executor
|
||||
def _run_writer(input_dir, output_dir, weights_patterns, **kwargs):
|
||||
llm_sharded_writer = LLM(model=input_dir, **kwargs)
|
||||
|
||||
llm_before = LLM(
|
||||
model=input_dir,
|
||||
worker_use_ray=True,
|
||||
enable_lora=enable_lora,
|
||||
gpu_memory_utilization=0.3,
|
||||
)
|
||||
gen_before = llm_before.generate(prompts, sampling_params)
|
||||
out_before = [gen.outputs[0].__dict__ for gen in gen_before]
|
||||
del llm_before.llm_engine.model_executor
|
||||
# Dump worker states to output directory
|
||||
llm_sharded_writer.llm_engine.model_executor.save_sharded_state(
|
||||
path=output_dir)
|
||||
# Copy metadata files to output directory
|
||||
for file in os.listdir(input_dir):
|
||||
if not any(file.endswith(ext) for ext in weights_patterns):
|
||||
shutil.copy(f"{input_dir}/{file}", output_dir)
|
||||
|
||||
llm_after = LLM(
|
||||
model=output_dir,
|
||||
worker_use_ray=True,
|
||||
enable_lora=enable_lora,
|
||||
gpu_memory_utilization=0.3,
|
||||
load_format="sharded_state",
|
||||
)
|
||||
gen_after = llm_after.generate(prompts, sampling_params)
|
||||
out_after = [gen.outputs[0].__dict__ for gen in gen_after]
|
||||
del llm_after.llm_engine.model_executor
|
||||
|
||||
def _run_generate(input_dir, queue: mp.Queue, **kwargs):
|
||||
llm = LLM(model=input_dir, **kwargs)
|
||||
gen = llm.generate(prompts, sampling_params)
|
||||
queue.put([g.outputs[0].__dict__ for g in gen])
|
||||
queue.close()
|
||||
queue.join_thread()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_lora", [False, True])
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
def test_sharded_state_loader(enable_lora, tp_size, num_gpus_available,
|
||||
llama_2_7b_files):
|
||||
if num_gpus_available < tp_size:
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
weights_patterns = ("*.safetensors", )
|
||||
gpu_memory_utilization = 0.8
|
||||
input_dir = llama_2_7b_files
|
||||
ctx = mp.get_context("spawn")
|
||||
|
||||
# Run in separate processes for memory & CUDA isolation
|
||||
with TemporaryDirectory() as output_dir:
|
||||
p = ctx.Process(target=_run_writer,
|
||||
args=(input_dir, output_dir, weights_patterns),
|
||||
kwargs=dict(
|
||||
tensor_parallel_size=tp_size,
|
||||
distributed_executor_backend="mp",
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
enforce_eager=True,
|
||||
))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
queue = ctx.Queue()
|
||||
|
||||
p = ctx.Process(target=_run_generate,
|
||||
args=(input_dir, queue),
|
||||
kwargs=dict(
|
||||
distributed_executor_backend="mp",
|
||||
enable_lora=enable_lora,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
tensor_parallel_size=tp_size,
|
||||
))
|
||||
p.start()
|
||||
p.join()
|
||||
out_before = queue.get()
|
||||
|
||||
p = ctx.Process(target=_run_generate,
|
||||
args=(output_dir, queue),
|
||||
kwargs=dict(
|
||||
distributed_executor_backend="mp",
|
||||
enable_lora=enable_lora,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
tensor_parallel_size=tp_size,
|
||||
load_format="sharded_state",
|
||||
))
|
||||
p.start()
|
||||
p.join()
|
||||
out_after = queue.get()
|
||||
|
||||
assert out_before == out_after
|
||||
|
@ -215,19 +215,19 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
|
||||
lora_vocab_start_idx = self.base_layer.org_vocab_size
|
||||
weights_idx = None
|
||||
if self.base_layer.vocab_end_index > lora_vocab_start_idx:
|
||||
if self.base_layer.num_added_embeddings_per_partition > 0:
|
||||
# We can start adding lora weights
|
||||
weights_idx = max(
|
||||
lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
|
||||
self.embeddings_slice = (self.base_layer.vocab_start_index -
|
||||
self.base_layer.org_vocab_size +
|
||||
weights_idx,
|
||||
self.base_layer.vocab_end_index -
|
||||
self.base_layer.org_vocab_size)
|
||||
self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
|
||||
self.embeddings_weights.fill_(0)
|
||||
self.embeddings_weights = self.base_layer.weight.data[
|
||||
self.base_layer.num_org_embeddings_per_partition:self.
|
||||
base_layer.num_org_embeddings_per_partition +
|
||||
self.base_layer.num_added_embeddings_per_partition]
|
||||
self.embeddings_slice = (
|
||||
self.base_layer.shard_indices.added_vocab_start_index -
|
||||
self.base_layer.org_vocab_size,
|
||||
self.base_layer.shard_indices.added_vocab_end_index -
|
||||
self.base_layer.org_vocab_size)
|
||||
self.base_layer.weight.data[
|
||||
self.base_layer.num_org_embeddings_per_partition:].fill_(0)
|
||||
else:
|
||||
self.embeddings_slice = None
|
||||
self.embeddings_weights = None
|
||||
@ -1025,19 +1025,31 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
|
||||
class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
"""
|
||||
LoRA wrapper for LogitsProcessor, with extra logic to handle the
|
||||
application of the LoRA adapter and added LoRA vocabulary.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: LogitsProcessor,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
Args:
|
||||
base_layer: LogitsProcessor layer
|
||||
hidden_size: hidden size of the model
|
||||
dtype: data type of the model
|
||||
device: device of the model
|
||||
sharded_to_full_mapping: index mapping from sharded vocab to full vocab
|
||||
received from base_layer.get_sharded_to_full_mapping(). If None,
|
||||
no reindexing will be done.
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: LogitsProcessor, hidden_size: int,
|
||||
dtype: torch.dtype, device: torch.device,
|
||||
sharded_to_full_mapping: Optional[List[int]]) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.sharded_to_full_mapping = sharded_to_full_mapping
|
||||
|
||||
@property
|
||||
def logits_as_input(self):
|
||||
@ -1098,6 +1110,13 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
if self.sharded_to_full_mapping is not None:
|
||||
self.sharded_to_full_mapping_gpu = torch.tensor(
|
||||
self.sharded_to_full_mapping,
|
||||
device=self.device,
|
||||
dtype=torch.long)
|
||||
else:
|
||||
self.sharded_to_full_mapping_gpu = None
|
||||
# Lazily initialized.
|
||||
self.indices: torch.Tensor
|
||||
self.indices_len: List[int]
|
||||
@ -1154,6 +1173,25 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
||||
if logits is None:
|
||||
return None
|
||||
|
||||
if self.sharded_to_full_mapping_gpu is not None:
|
||||
# Reindex full logits tensor to ensure 1:1 mapping between
|
||||
# index and token_id
|
||||
# Example for:
|
||||
# org_vocab_size = 4
|
||||
# added_vocab_size = 2
|
||||
# pad_to_size = 8
|
||||
# tp_size = 2
|
||||
|
||||
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
# token_id: [0, 1, 4, -1, 2, 3, 5, -1]
|
||||
|
||||
# Therefore, the mapping is expected to be:
|
||||
# [0, 1, 4, 6, 2, 3, 5, 7] so that when we reindex,
|
||||
# we get:
|
||||
# indices: [0, 1, 2, 3, 4, 5, 6, 7]
|
||||
# token_id: [0, 1, 2, 3, 4, 5, -1, -1]
|
||||
logits = logits[:, self.sharded_to_full_mapping_gpu]
|
||||
|
||||
lora_logits = torch.empty(
|
||||
self.embeddings_tensors.shape[0] + 1,
|
||||
self.embeddings_tensors.shape[1],
|
||||
|
@ -67,7 +67,8 @@ def from_layer_logits_processor(
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> LogitsProcessorWithLoRA:
|
||||
ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim,
|
||||
lm_head.weight.dtype, lm_head.weight.device)
|
||||
lm_head.weight.dtype, lm_head.weight.device,
|
||||
lm_head.get_sharded_to_full_mapping())
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Optional, Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -18,18 +19,107 @@ def pad_vocab_size(vocab_size: int,
|
||||
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
|
||||
|
||||
|
||||
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int,
|
||||
rank: int) -> Sequence[int]:
|
||||
def vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size: int,
|
||||
rank: int,
|
||||
offset: int = 0) -> Sequence[int]:
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
return index_f + offset, index_l + offset
|
||||
|
||||
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
|
||||
world_size: int) -> Sequence[int]:
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
offset: int = 0) -> Sequence[int]:
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size,
|
||||
rank)
|
||||
rank,
|
||||
offset=offset)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VocabParallelEmbeddingShardIndices:
|
||||
"""Indices for a shard of a vocab parallel embedding."""
|
||||
padded_org_vocab_start_index: int
|
||||
padded_org_vocab_end_index: int
|
||||
padded_added_vocab_start_index: int
|
||||
padded_added_vocab_end_index: int
|
||||
|
||||
org_vocab_start_index: int
|
||||
org_vocab_end_index: int
|
||||
added_vocab_start_index: int
|
||||
added_vocab_end_index: int
|
||||
|
||||
@property
|
||||
def num_org_elements(self) -> int:
|
||||
return self.org_vocab_end_index - self.org_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_added_elements(self) -> int:
|
||||
return self.added_vocab_end_index - self.added_vocab_start_index
|
||||
|
||||
@property
|
||||
def num_org_elements_padded(self) -> int:
|
||||
return (self.padded_org_vocab_end_index -
|
||||
self.padded_org_vocab_start_index)
|
||||
|
||||
@property
|
||||
def num_added_elements_padded(self) -> int:
|
||||
return (self.padded_added_vocab_end_index -
|
||||
self.padded_added_vocab_start_index)
|
||||
|
||||
@property
|
||||
def num_org_vocab_padding(self) -> int:
|
||||
return self.num_org_elements_padded - self.num_org_elements
|
||||
|
||||
@property
|
||||
def num_added_vocab_padding(self) -> int:
|
||||
return self.num_added_elements_padded - self.num_added_elements
|
||||
|
||||
@property
|
||||
def num_elements_padded(self) -> int:
|
||||
return self.num_org_elements_padded + self.num_added_elements_padded
|
||||
|
||||
def __post_init__(self):
|
||||
# sanity checks
|
||||
assert (self.padded_org_vocab_start_index <=
|
||||
self.padded_org_vocab_end_index)
|
||||
assert (self.padded_added_vocab_start_index <=
|
||||
self.padded_added_vocab_end_index)
|
||||
|
||||
assert self.org_vocab_start_index <= self.org_vocab_end_index
|
||||
assert self.added_vocab_start_index <= self.added_vocab_end_index
|
||||
|
||||
assert self.org_vocab_start_index <= self.padded_org_vocab_start_index
|
||||
assert (self.added_vocab_start_index <=
|
||||
self.padded_added_vocab_start_index)
|
||||
assert self.org_vocab_end_index <= self.padded_org_vocab_end_index
|
||||
assert self.added_vocab_end_index <= self.padded_added_vocab_end_index
|
||||
|
||||
assert self.num_org_elements <= self.num_org_elements_padded
|
||||
assert self.num_added_elements <= self.num_added_elements_padded
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def get_masked_input_and_mask(
|
||||
input_: torch.Tensor, org_vocab_start_index: int,
|
||||
org_vocab_end_index: int, num_org_vocab_padding: int,
|
||||
added_vocab_start_index: int,
|
||||
added_vocab_end_index: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# torch.jit.script will fuse all of the pointwise ops below
|
||||
# into a single kernel, making it very fast
|
||||
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ <
|
||||
org_vocab_end_index)
|
||||
added_vocab_mask = (input_ >= added_vocab_start_index) & (
|
||||
input_ < added_vocab_end_index)
|
||||
added_offset = added_vocab_start_index - (
|
||||
org_vocab_end_index - org_vocab_start_index) - num_org_vocab_padding
|
||||
valid_offset = (org_vocab_start_index *
|
||||
org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
input_ = vocab_mask * (input_ - valid_offset)
|
||||
return input_, ~vocab_mask
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
@ -38,13 +128,36 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
|
||||
make sure it is divisible by the number of model parallel GPUs.
|
||||
|
||||
In order to support various loading methods, we ensure that LoRA-added
|
||||
embeddings are always at the end of TP-sharded tensors. In other words,
|
||||
we shard base embeddings and LoRA embeddings separately (both padded),
|
||||
and place them in the same tensor.
|
||||
In this example, we will have the original vocab size = 1010,
|
||||
added vocab size = 16 and padding to 64. Therefore, the total
|
||||
vocab size with padding will be 1088 (because we first pad 1010 to
|
||||
1024, add 16, and then pad to 1088).
|
||||
Therefore, the tensor format looks like the following:
|
||||
TP1, rank 0 (no sharding):
|
||||
|< --------BASE-------- >|< -BASE PADDING-- >|< -----LORA------ >|< -LORA PADDING-- >|
|
||||
corresponding token_id: | 0 | 1 | ... | 1009 | -1 | ... | -1 | 1010 | ... | 1015 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | ... | 1009 | 1010 | ... | 1023 | 1024 | ... | 1039 | 1040 | ... | 1087 |
|
||||
|
||||
TP2, rank 0:
|
||||
|< --------------------BASE--------------------- >|< -----LORA------ >|< -LORA PADDING- >|
|
||||
corresponding token_id: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 1000 | ... | 1015 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 527 | 520 | ... | 543 |
|
||||
TP2, rank 1:
|
||||
|< -----------BASE----------- >|< -BASE PADDING- >|< -----------LORA PADDING----------- >|
|
||||
corresponding token_id: | 512 | 513 | 514 | ... | 1009 | -1 | ... | -1 | -1 | ... | -1 | -1 | ... | -1 |
|
||||
index: | 0 | 1 | 2 | ... | 497 | 498 | ... | 511 | 512 | ... | 519 | 520 | ... | 543 |
|
||||
|
||||
Args:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
params_dtype: type of the parameters.
|
||||
org_num_embeddings: original vocabulary size (without LoRA).
|
||||
padding_size: padding size for the vocabulary.
|
||||
"""
|
||||
""" # noqa: E501
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
@ -55,21 +168,39 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
super().__init__()
|
||||
|
||||
# Keep the input dimensions.
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.padding_size = padding_size
|
||||
self.org_vocab_size = org_num_embeddings or num_embeddings
|
||||
self.num_embeddings_padded = pad_vocab_size(num_embeddings,
|
||||
padding_size)
|
||||
num_added_embeddings = num_embeddings - self.org_vocab_size
|
||||
self.org_vocab_size_padded = pad_vocab_size(self.org_vocab_size,
|
||||
self.padding_size)
|
||||
self.num_embeddings_padded = pad_vocab_size(
|
||||
self.org_vocab_size_padded + num_added_embeddings,
|
||||
self.padding_size)
|
||||
assert self.org_vocab_size_padded <= self.num_embeddings_padded
|
||||
|
||||
self.shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
self.embedding_dim = embedding_dim
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings_padded, get_tensor_model_parallel_rank(),
|
||||
self.tp_size))
|
||||
self.num_embeddings_per_partition = (self.vocab_end_index -
|
||||
self.vocab_start_index)
|
||||
self.num_added_embeddings = self.num_embeddings - self.org_vocab_size
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings_padded,
|
||||
self.tp_size)
|
||||
assert (self.shard_indices.num_elements_padded ==
|
||||
self.num_embeddings_per_partition)
|
||||
self.num_org_embeddings_per_partition = (
|
||||
self.shard_indices.org_vocab_end_index -
|
||||
self.shard_indices.org_vocab_start_index)
|
||||
self.num_added_embeddings_per_partition = (
|
||||
self.shard_indices.added_vocab_end_index -
|
||||
self.shard_indices.added_vocab_start_index)
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
self.embedding_dim,
|
||||
@ -79,28 +210,107 @@ class VocabParallelEmbedding(torch.nn.Module):
|
||||
"weight_loader": self.weight_loader
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def _get_indices(cls, vocab_size_padded: int, org_vocab_size_padded: int,
|
||||
vocab_size: int, org_vocab_size: int, tp_rank: int,
|
||||
tp_size: int) -> VocabParallelEmbeddingShardIndices:
|
||||
"""Get start and end indices for vocab parallel embedding, following the
|
||||
layout outlined in the class docstring, based on the given tp_rank and
|
||||
tp_size."""
|
||||
num_added_embeddings_padded = vocab_size_padded - org_vocab_size_padded
|
||||
padded_org_vocab_start_index, padded_org_vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(org_vocab_size_padded, tp_rank,
|
||||
tp_size))
|
||||
padded_added_vocab_start_index, padded_added_vocab_end_index = (
|
||||
vocab_range_from_global_vocab_size(num_added_embeddings_padded,
|
||||
tp_rank,
|
||||
tp_size,
|
||||
offset=org_vocab_size))
|
||||
# remove padding
|
||||
org_vocab_start_index = min(padded_org_vocab_start_index,
|
||||
org_vocab_size)
|
||||
org_vocab_end_index = min(padded_org_vocab_end_index, org_vocab_size)
|
||||
added_vocab_start_index = min(padded_added_vocab_start_index,
|
||||
vocab_size)
|
||||
added_vocab_end_index = min(padded_added_vocab_end_index, vocab_size)
|
||||
return VocabParallelEmbeddingShardIndices(
|
||||
padded_org_vocab_start_index, padded_org_vocab_end_index,
|
||||
padded_added_vocab_start_index, padded_added_vocab_end_index,
|
||||
org_vocab_start_index, org_vocab_end_index,
|
||||
added_vocab_start_index, added_vocab_end_index)
|
||||
|
||||
def get_sharded_to_full_mapping(self) -> Optional[List[int]]:
|
||||
"""Get a mapping that can be used to reindex the gathered
|
||||
logits for sampling.
|
||||
|
||||
During sampling, we gather logits from all ranks. The relationship
|
||||
of index->token_id will follow the same format as outlined in the class
|
||||
docstring. However, after the gather, we want to reindex the final
|
||||
logits tensor to map index->token_id one-to-one (the index is always
|
||||
equal the token_id it corresponds to). The indices returned by this
|
||||
method allow us to do that.
|
||||
"""
|
||||
if self.tp_size < 2:
|
||||
return None
|
||||
|
||||
base_embeddings: List[int] = []
|
||||
added_embeddings: List[int] = []
|
||||
padding: List[int] = []
|
||||
for tp_rank in range(self.tp_size):
|
||||
shard_indices = self._get_indices(self.num_embeddings_padded,
|
||||
self.org_vocab_size_padded,
|
||||
self.num_embeddings,
|
||||
self.org_vocab_size, tp_rank,
|
||||
self.tp_size)
|
||||
range_start = self.num_embeddings_per_partition * tp_rank
|
||||
range_end = self.num_embeddings_per_partition * (tp_rank + 1)
|
||||
base_embeddings.extend(
|
||||
range(range_start,
|
||||
range_start + shard_indices.num_org_elements))
|
||||
padding.extend(
|
||||
range(range_start + shard_indices.num_org_elements,
|
||||
range_start + shard_indices.num_org_elements_padded))
|
||||
added_embeddings.extend(
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements_padded,
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements))
|
||||
padding.extend(
|
||||
range(
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements,
|
||||
range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements_padded))
|
||||
assert (range_start + shard_indices.num_org_elements_padded +
|
||||
shard_indices.num_added_elements_padded == range_end)
|
||||
ret = base_embeddings + added_embeddings + padding
|
||||
assert len(ret) == self.num_embeddings_padded
|
||||
return ret
|
||||
|
||||
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
|
||||
parallel_dim = param.parallel_dim
|
||||
assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
|
||||
loaded_weight = loaded_weight[self.vocab_start_index:self.
|
||||
vocab_end_index]
|
||||
loaded_weight = loaded_weight[self.shard_indices.org_vocab_start_index:
|
||||
self.shard_indices.org_vocab_end_index]
|
||||
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
|
||||
param[loaded_weight.shape[0]:].data.fill_(0)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
input_mask = ((input_ < self.vocab_start_index) |
|
||||
(input_ >= self.vocab_end_index))
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
masked_input, input_mask = get_masked_input_and_mask(
|
||||
input_, self.shard_indices.org_vocab_start_index,
|
||||
self.shard_indices.org_vocab_end_index,
|
||||
self.shard_indices.num_org_vocab_padding,
|
||||
self.shard_indices.added_vocab_start_index,
|
||||
self.shard_indices.added_vocab_end_index)
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(masked_input, self.weight)
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
output_parallel.masked_fill_(input_mask.unsqueeze(1), 0)
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
@ -35,6 +35,7 @@ _BATCH_SIZE_ALIGNMENT = 8
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
|
||||
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
|
||||
]
|
||||
_NUM_WARMUP_ITERS = 2
|
||||
|
||||
|
||||
class ModelInput(NamedTuple):
|
||||
@ -975,16 +976,18 @@ class CUDAGraphRunner:
|
||||
**kwargs,
|
||||
) -> None:
|
||||
assert self._graph is None
|
||||
# Run the model once without capturing the graph.
|
||||
# Run the model a few times without capturing the graph.
|
||||
# This is to make sure that the captured graph does not include the
|
||||
# kernel launches for initial benchmarking (e.g., Triton autotune).
|
||||
self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
# Note one iteration is not enough for torch.jit.script
|
||||
for _ in range(_NUM_WARMUP_ITERS):
|
||||
self.model(
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
**kwargs,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture the graph.
|
||||
|
Loading…
x
Reference in New Issue
Block a user