Remove hardcoded device="cuda" to support more devices (#2503)

Co-authored-by: Jiang Li <jiang1.li@intel.com>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Kunshang Ji 2024-02-02 07:46:39 +08:00 committed by GitHub
parent c410f5d020
commit 96b6f475dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 343 additions and 292 deletions

View File

@ -25,6 +25,7 @@ def main(args: argparse.Namespace):
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
device=args.device,
)
sampling_params = SamplingParams(
@ -135,5 +136,11 @@ if __name__ == '__main__':
default=None,
help=('path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'))
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
args = parser.parse_args()
main(args)

View File

@ -72,6 +72,7 @@ def run_vllm(
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
device: str,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
@ -85,6 +86,7 @@ def run_vllm(
max_model_len=max_model_len,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
device=device,
)
# Add the requests to the engine.
@ -209,7 +211,7 @@ def main(args: argparse.Namespace):
args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype,
args.max_model_len, args.enforce_eager,
args.kv_cache_dtype)
args.kv_cache_dtype, args.device)
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
@ -294,6 +296,12 @@ if __name__ == "__main__":
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model

View File

@ -25,10 +25,12 @@ def main(
dtype: torch.dtype,
seed: int,
do_profile: bool,
device: str = "cuda",
kv_cache_dtype: Optional[str] = None,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
scale = float(1.0 / (head_size**0.5))
@ -36,7 +38,7 @@ def main(
num_query_heads,
head_size,
dtype=dtype,
device="cuda")
device=device)
query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0
@ -44,11 +46,11 @@ def main(
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
dtype=torch.float,
device="cuda")
device=device)
context_lens = [context_len for _ in range(num_seqs)]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
@ -59,12 +61,17 @@ def main(
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
# Create the KV cache.
key_caches, value_caches = create_kv_caches_with_random(
NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype,
dtype)
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
block_size,
1,
num_kv_heads,
head_size,
kv_cache_dtype,
dtype,
device=device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Prepare for the paged attention kernel.
@ -84,7 +91,7 @@ def main(
)
max_logits = torch.empty_like(exp_sums)
def run_benchmark(num_iters: int, profile: bool = False) -> float:
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize()
if profile:
torch.cuda.cudart().cudaProfilerStart()
@ -135,6 +142,7 @@ def main(
# Warmup.
print("Warming up...")
run_benchmark = run_cuda_benchmark
run_benchmark(num_iters=3, profile=False)
# Benchmark.
@ -175,6 +183,7 @@ if __name__ == '__main__':
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type.')
parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
args = parser.parse_args()
print(args)

View File

@ -7,26 +7,29 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
D = [512, 4096, 5120, 13824] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_silu_and_mul(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id)
torch.set_default_device(device)
x = torch.randn(num_tokens, 2 * d, dtype=dtype)
layer = SiluAndMul()
out = layer(x)
ref_out = layer._forward(x)
@ -37,19 +40,20 @@ def test_silu_and_mul(
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_gelu_new(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = NewGELU()
out = layer(x)
ref_out = layer._forward(x)
@ -60,18 +64,19 @@ def test_gelu_new(
@pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_gelu_fast(
num_tokens: int,
d: int,
dtype: torch.dtype,
seed: int,
device: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id)
torch.set_default_device(device)
x = torch.randn(num_tokens, d, dtype=dtype)
layer = FastGELU()
out = layer(x)
ref_out = layer._forward(x)

View File

@ -27,7 +27,9 @@ BLOCK_SIZES = [16, 32]
USE_ALIBI = [False, True]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
def ref_masked_attention(
@ -91,7 +93,7 @@ def ref_single_query_cached_kv_attention(
alibi_bias = None
if alibi_slopes is not None:
# Create the ALiBi bias used in the paged attention kernel.
position_ids = torch.arange(context_len, device=query.device).int()
position_ids = torch.arange(context_len).int()
alibi_bias = (position_ids - context_len + 1).float()
alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view(
1, 1, -1)
@ -110,7 +112,7 @@ def ref_single_query_cached_kv_attention(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_paged_attention(
kv_cache_factory,
version: str,
@ -122,33 +124,28 @@ def test_paged_attention(
dtype: torch.dtype,
kv_cache_dtype: str,
seed: int,
device: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
torch.set_default_device(device)
scale = float(1.0 / (head_size**0.5))
num_query_heads, num_kv_heads = num_heads
query = torch.empty(num_seqs,
num_query_heads,
head_size,
dtype=dtype,
device=gpu_id)
query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype)
query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0
num_queries_per_kv = num_query_heads // num_kv_heads
alibi_slopes = None
if use_alibi:
alibi_slopes = torch.randn(num_query_heads,
dtype=torch.float,
device=gpu_id)
alibi_slopes = torch.randn(num_query_heads, dtype=torch.float)
context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
context_lens[-1] = MAX_SEQ_LEN
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id)
context_lens = torch.tensor(context_lens, dtype=torch.int)
# Create the block tables.
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
@ -159,13 +156,13 @@ def test_paged_attention(
for _ in range(max_num_blocks_per_seq)
]
block_tables.append(block_table)
block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id)
block_tables = torch.tensor(block_tables, dtype=torch.int)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1,
num_kv_heads, head_size,
kv_cache_dtype, dtype, seed,
gpu_id)
device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Call the paged attention kernel.
@ -193,12 +190,10 @@ def test_paged_attention(
tmp_output = torch.empty(
size=(num_seqs, num_heads, num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
@ -229,14 +224,14 @@ def test_paged_attention(
block_size, x)
dequantized_key_cache = torch.empty(size=key_cache_shape,
dtype=dtype,
device=gpu_id)
device=device)
cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache)
key_cache = dequantized_key_cache
value_cache_shape = value_cache.shape
dequantized_value_cache = torch.empty(size=value_cache_shape,
dtype=dtype,
device=gpu_id)
device=device)
cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache)
value_cache = dequantized_value_cache
@ -283,7 +278,7 @@ def ref_multi_query_kv_attention(
attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype),
diagonal=1)
attn_mask = attn_mask * torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device=query.device)
attn_mask = attn_mask.to(dtype=dtype)
ref_output = ref_masked_attention(
query[start_idx:end_idx],
@ -303,7 +298,7 @@ def ref_multi_query_kv_attention(
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_multi_query_kv_attention(
num_seqs: int,
@ -311,12 +306,13 @@ def test_multi_query_kv_attention(
head_size: int,
dtype: torch.dtype,
seed: int,
device: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
torch.set_default_device(device)
# MAX_SEQ_LEN sometimes causes OOM in the reference implementation.
# As the xformers library is already tested with its own tests, we can use
# a smaller MAX_SEQ_LEN here.
@ -329,8 +325,7 @@ def test_multi_query_kv_attention(
qkv = torch.empty(num_tokens,
num_query_heads + 2 * num_kv_heads,
head_size,
dtype=dtype,
device=gpu_id)
dtype=dtype)
qkv.uniform_(-scale, scale)
query, key, value = qkv.split(
[num_query_heads, num_kv_heads, num_kv_heads], dim=1)

View File

@ -17,7 +17,9 @@ BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
@ -29,7 +31,7 @@ KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE)
@torch.inference_mode()
def test_copy_blocks(
@ -42,13 +44,14 @@ def test_copy_blocks(
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: int,
kv_cache_dtype: str,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
torch.set_default_device(device)
# Generate random block mappings where each source block is mapped to two
# destination blocks.
assert 2 * num_mappings <= num_blocks
@ -66,7 +69,7 @@ def test_copy_blocks(
key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
num_layers, num_heads,
head_size, kv_cache_dtype,
dtype, seed, gpu_id)
dtype, seed, device)
# Clone the KV caches.
cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
@ -98,7 +101,7 @@ def test_copy_blocks(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_reshape_and_cache(
kv_cache_factory,
@ -109,29 +112,25 @@ def test_reshape_and_cache(
num_blocks: int,
dtype: torch.dtype,
seed: int,
device: int,
device: str,
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
torch.set_default_device(device)
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long)
qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device=gpu_id)
qkv = torch.randn(num_tokens, 3, num_heads, head_size, dtype=dtype)
_, key, value = qkv.unbind(dim=1)
# Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
num_heads, head_size, dtype,
None, seed, gpu_id)
None, seed, device)
key_cache, value_cache = key_caches[0], value_caches[0]
# Clone the KV caches.
@ -166,7 +165,7 @@ def test_reshape_and_cache(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_swap_blocks(
kv_cache_factory,
@ -182,6 +181,7 @@ def test_swap_blocks(
) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
src_device = f"{direction[0]}:{device}" if direction[
0] == "cuda" else direction[0]

View File

@ -8,7 +8,9 @@ NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing
ADD_RESIDUAL = [False, True]
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@ -16,7 +18,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_rms_norm(
num_tokens: int,
@ -24,15 +26,16 @@ def test_rms_norm(
add_residual: bool,
dtype: torch.dtype,
seed: int,
device: int,
device: str,
) -> None:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id)
torch.set_default_device(device)
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=gpu_id)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None

View File

@ -13,7 +13,9 @@ NUM_HEADS = [7, 17] # Arbitrary values for testing
BATCH_SIZES = [1, 5] # Arbitrary values for testing
SEQ_LENS = [11, 8192] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@ -24,7 +26,7 @@ DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_rotary_embedding(
is_neox_style: bool,
@ -35,28 +37,26 @@ def test_rotary_embedding(
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
if rotary_dim is None:
rotary_dim = head_size
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
rope = rope.to(dtype=dtype, device=gpu_id)
rope = rope.to(dtype=dtype)
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device=gpu_id)
positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size,
seq_len,
num_heads * head_size,
dtype=dtype,
device=gpu_id)
dtype=dtype)
key = torch.randn_like(query)
# NOTE(woosuk): The reference implementation should be executed first

View File

@ -11,19 +11,27 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask
NUM_HEADS = [12]
HEAD_SIZES = [128]
DTYPES = [torch.float16]
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_contexted_kv_attention(
num_heads: int,
head_size: int,
dtype: torch.dtype,
device: str,
) -> None:
random.seed(0)
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
torch.set_default_device(device)
MAX_SEQ_LEN = 1024
MAX_CTX_LEN = 1024
BS = 10
@ -35,24 +43,11 @@ def test_contexted_kv_attention(
seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)]
num_tokens = sum(subquery_lens)
query = torch.empty(num_tokens,
num_heads,
head_size,
dtype=dtype,
device='cuda')
query = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
query.uniform_(-1e-3, 1e-3)
output = torch.empty(num_tokens,
num_heads,
head_size,
dtype=dtype,
device='cuda')
output = torch.empty(num_tokens, num_heads, head_size, dtype=dtype)
kv = torch.empty(sum(seq_lens),
2,
num_heads,
head_size,
dtype=dtype,
device='cuda')
kv = torch.empty(sum(seq_lens), 2, num_heads, head_size, dtype=dtype)
kv.uniform_(-1e-3, 1e-3)
key, value = kv.unbind(dim=1)
@ -60,39 +55,27 @@ def test_contexted_kv_attention(
block_size,
num_heads,
head_size,
dtype=dtype,
device='cuda')
dtype=dtype)
v_cache = torch.zeros(cache_size,
block_size,
num_heads,
head_size,
dtype=dtype,
device='cuda')
k = torch.zeros(sum(subquery_lens),
num_heads,
head_size,
dtype=dtype,
device='cuda')
v = torch.zeros(sum(subquery_lens),
num_heads,
head_size,
dtype=dtype,
device='cuda')
values = torch.arange(0, cache_size, dtype=torch.long, device='cuda')
dtype=dtype)
k = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype)
v = torch.zeros(sum(subquery_lens), num_heads, head_size, dtype=dtype)
values = torch.arange(0, cache_size, dtype=torch.long)
values = values[torch.randperm(cache_size)]
block_table = values[:BS * max_block_per_request].view(
BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda')
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda')
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1],
dtype=torch.long,
device='cuda'),
dtype=torch.long),
dim=0)
max_input_len = MAX_SEQ_LEN
# copy kv to cache
b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1],
dtype=torch.long,
device='cuda'),
dtype=torch.long),
dim=0)
for i in range(BS):
for j in range(subquery_lens[i]):

View File

@ -126,8 +126,8 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module:
cleanup()
get_model_old = get_model
def get_model_patched(model_config, lora_config=None):
return get_model_old(model_config,
def get_model_patched(model_config, device_config, lora_config=None):
return get_model_old(model_config, device_config,
LoRAConfig(max_loras=4, max_lora_rank=8))
with patch("vllm.worker.model_runner.get_model", get_model_patched):

View File

@ -34,6 +34,9 @@ TOLERANCES = {
torch.float32: (5e-3, 5e-3),
torch.bfloat16: (3e-2, 2e-2),
}
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
def get_random_id_to_index(num_loras: int,
@ -151,14 +154,10 @@ def create_random_inputs(
for _ in range(num_inputs):
if input_type == torch.int:
inputs.append(
torch.randint(low=int(low),
high=int(high),
size=input_size,
device="cuda"))
torch.randint(low=int(low), high=int(high), size=input_size))
else:
inputs.append(
torch.rand(size=input_size, dtype=input_type, device="cuda") *
high + low)
torch.rand(size=input_size, dtype=input_type) * high + low)
lora_id = random.choice(active_lora_ids)
index_mapping += [lora_id] * input_size[0]
@ -169,8 +168,10 @@ def create_random_inputs(
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
def test_embeddings(dist_init, num_loras) -> None:
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_embeddings(dist_init, num_loras, device) -> None:
torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@ -259,8 +260,10 @@ def test_embeddings(dist_init, num_loras) -> None:
@torch.inference_mode()
# @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.")
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None:
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@ -305,8 +308,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None:
# Add empty embeddings_tensors for unoccupied lora slots.
for _ in range(max_loras - len(embeddings_tensors)):
embeddings_tensors.append(
torch.zeros(embeddings_tensors[0].shape, device="cuda"))
embeddings_tensors.append(torch.zeros(embeddings_tensors[0].shape))
inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=list(lora_dict.keys()),
@ -388,8 +390,10 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None:
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
def test_lm_head_sampler(dist_init, num_loras) -> None:
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_lm_head_sampler(dist_init, num_loras, device) -> None:
torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@ -432,7 +436,7 @@ def test_lm_head_sampler(dist_init, num_loras) -> None:
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
input_ = torch.rand(20, 1024, device="cuda")
input_ = torch.rand(20, 1024)
mapping_info = convert_mapping(
lora_mapping,
id_to_index,
@ -500,8 +504,10 @@ def test_lm_head_sampler(dist_init, num_loras) -> None:
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("orientation", ["row", "column"])
def test_linear_parallel(dist_init, num_loras, orientation) -> None:
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_linear_parallel(dist_init, num_loras, orientation, device) -> None:
torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,
@ -597,8 +603,10 @@ def test_linear_parallel(dist_init, num_loras, orientation) -> None:
@torch.inference_mode()
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
@pytest.mark.parametrize("repeats", [2, 3])
def test_column_parallel_packed(dist_init, num_loras, repeats) -> None:
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None:
torch.set_default_device(device)
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras,
max_lora_rank=8,

View File

@ -5,7 +5,8 @@ from unittest.mock import patch
from vllm.lora.models import LoRAMapping
from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
DeviceConfig, LoRAConfig)
from vllm.worker.worker import Worker
@ -25,6 +26,7 @@ def test_worker_apply_lora(sql_lora_files):
),
parallel_config=ParallelConfig(1, 1, False),
scheduler_config=SchedulerConfig(32, 32, 32, 256),
device_config=DeviceConfig("cuda"),
local_rank=0,
rank=0,
lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32,

View File

@ -9,6 +9,10 @@ from vllm.model_executor.utils import set_random_seed
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
def mock_causal_accepted_tensor(
k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor:
@ -39,11 +43,14 @@ def mock_causal_accepted_tensor(
@pytest.mark.parametrize(
"which_tokens_accepted",
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str, seed: int):
def test_correct_output_format(which_tokens_accepted: str, seed: int,
device: str):
"""Verify the output has correct format given predetermined accepted matrix.
"""
set_random_seed(seed)
torch.set_default_device(device)
batch_size = 10
k = 5
@ -66,18 +73,15 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int):
recovered_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device="cuda")
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device="cuda")
dtype=torch.int64)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64,
device="cuda")
dtype=torch.int64)
rejection_sampler = RejectionSampler()
rejection_sampler.init_gpu_tensors(rank=0)
@ -120,31 +124,24 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int):
@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int):
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str):
torch.set_default_device(device)
rejection_sampler = RejectionSampler()
rejection_sampler.init_gpu_tensors(rank=0)
draft_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device="cuda")
target_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device="cuda")
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64,
device="cuda")
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device="cuda")
dtype=torch.int64)
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids)
@ -153,36 +150,28 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int):
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids",
["bonus_token_ids", "draft_token_ids"])
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
which_token_ids: str):
which_token_ids: str, device: str):
k = 3
batch_size = 5
vocab_size = 30_000
torch.set_default_device(device)
rejection_sampler = RejectionSampler(strict_mode=True)
rejection_sampler.init_gpu_tensors(rank=0)
draft_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device="cuda")
target_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device="cuda")
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64,
device="cuda")
dtype=torch.int64)
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device="cuda")
dtype=torch.int64)
oob_token_ids = None
if which_token_ids == "bonus_token_ids":
@ -237,6 +226,7 @@ def test_rejection_sampling_approximates_target_distribution(
probabilities are exactly equal. Rejection sampling should
still work without any NaNs or exceptions.
"""
torch.set_default_device("cpu")
set_random_seed(seed)
helper = _CorrectnessTestHelper(

View File

@ -31,24 +31,26 @@ def _prepare_test(
batch_size: int
) -> Tuple[torch.Tensor, torch.Tensor, MockLogitsSampler, ModelRunner]:
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024),
device="cuda",
dtype=torch.float16)
input_tensor = torch.rand((batch_size, 1024), dtype=torch.float16)
fake_logits = torch.full((batch_size, vocab_size),
1e-2,
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
model_runner = ModelRunner(None, None, None, None)
model_runner = ModelRunner(None, None, None, None, None)
return input_tensor, fake_logits, sampler, model_runner
RANDOM_SEEDS = list(range(128))
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_greedy(seed: int):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_greedy(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
@ -81,8 +83,10 @@ def test_sampler_all_greedy(seed: int):
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_random(seed: int):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_random(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
@ -120,8 +124,10 @@ def test_sampler_all_random(seed: int):
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_beam(seed: int):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_all_beam(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
@ -156,8 +162,10 @@ def test_sampler_all_beam(seed: int):
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_mixed(seed: int):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_mixed(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, model_runner = _prepare_test(
batch_size)
@ -212,8 +220,10 @@ def test_sampler_mixed(seed: int):
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_logits_processors(seed: int):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_logits_processors(seed: int, device: str):
set_random_seed(seed)
torch.set_default_device(device)
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
@ -252,14 +262,15 @@ def test_sampler_logits_processors(seed: int):
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_top_k_top_p(seed: int):
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_sampler_top_k_top_p(seed: int, device: str):
set_random_seed(seed)
batch_size = random.randint(1, 256)
top_k = random.randint(100, 500)
top_p = random.random() * 0.1
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024),
device="cuda",
device=device,
dtype=torch.float16)
fake_logits = torch.normal(0,
5,
@ -267,7 +278,7 @@ def test_sampler_top_k_top_p(seed: int):
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
model_runner = ModelRunner(None, None, None, None)
model_runner = ModelRunner(None, None, None, None, None)
generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k,

View File

@ -84,7 +84,7 @@ def create_worker(cls: type,
)
(model_config, cache_config, parallel_config, scheduler_config,
_) = engine_args.create_engine_configs()
device_config, _) = engine_args.create_engine_configs()
distributed_init_method = get_distributed_init_method(
get_ip(), get_open_port())
@ -93,6 +93,7 @@ def create_worker(cls: type,
model_config=model_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,

View File

@ -6,7 +6,7 @@ from vllm.worker.model_runner import ModelRunner
def test_prepare_prompt():
model_runner = ModelRunner(None, None, None, None)
model_runner = ModelRunner(None, None, None, None, None)
model_runner.set_block_size(16)
batch_size = random.randint(1, 256)

View File

@ -444,6 +444,12 @@ class SchedulerConfig:
f"({self.max_num_seqs}).")
class DeviceConfig:
def __init__(self, device: str = "cuda") -> None:
self.device = torch.device(device)
@dataclass
class LoRAConfig:
max_lora_rank: int

View File

@ -3,8 +3,8 @@ import dataclasses
from dataclasses import dataclass
from typing import Optional, Tuple
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, LoRAConfig)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
@dataclass
@ -43,6 +43,7 @@ class EngineArgs:
lora_extra_vocab_size: int = 256
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'cuda'
def __post_init__(self):
if self.tokenizer is None:
@ -127,13 +128,13 @@ class EngineArgs:
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8_e5m2'],
default='auto',
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. Note FP8 is not supported when cuda version is '
'lower than 11.8.')
parser.add_argument('--max-model-len',
type=int,
default=None,
default=EngineArgs.max_model_len,
help='model context length. If unspecified, '
'will be automatically derived from the model.')
# Parallel arguments
@ -154,6 +155,7 @@ class EngineArgs:
parser.add_argument(
'--max-parallel-loading-workers',
type=int,
default=EngineArgs.max_parallel_loading_workers,
help='load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor '
'parallel and large models')
@ -200,7 +202,7 @@ class EngineArgs:
'-q',
type=str,
choices=['awq', 'gptq', 'squeezellm', None],
default=None,
default=EngineArgs.quantization,
help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` '
'attribute in the model config file. If that is '
@ -255,6 +257,13 @@ class EngineArgs:
help=('Maximum number of LoRAs to store in CPU memory. '
'Must be >= than max_num_seqs. '
'Defaults to max_num_seqs.'))
parser.add_argument(
"--device",
type=str,
default=EngineArgs.device,
choices=["cuda"],
help=('Device type for vLLM execution. '
'Currently, only CUDA-compatible devices are supported.'))
return parser
@classmethod
@ -268,7 +277,8 @@ class EngineArgs:
def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
Optional[LoRAConfig]]:
DeviceConfig, Optional[LoRAConfig]]:
device_config = DeviceConfig(self.device)
model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format,
@ -296,7 +306,8 @@ class EngineArgs:
lora_dtype=self.lora_dtype,
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
return model_config, cache_config, parallel_config, scheduler_config, lora_config
return (model_config, cache_config, parallel_config, scheduler_config,
device_config, lora_config)
@dataclass

View File

@ -6,8 +6,8 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union)
from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, LoRAConfig)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.metrics import StatLogger, Stats
@ -53,6 +53,7 @@ class LLMEngine:
management.
parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler.
device_config: The configuration related to the device.
placement_group: Ray placement group for distributed execution.
Required for distributed execution.
log_stats: Whether to log statistics.
@ -64,6 +65,7 @@ class LLMEngine:
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
placement_group: Optional["PlacementGroup"],
log_stats: bool,
@ -85,6 +87,7 @@ class LLMEngine:
f"quantization={model_config.quantization}, "
f"enforce_eager={model_config.enforce_eager}, "
f"kv_cache_dtype={cache_config.cache_dtype}, "
f"device_config={device_config.device}, "
f"seed={model_config.seed})")
# TODO(woosuk): Print more configs in debug mode.
@ -93,6 +96,7 @@ class LLMEngine:
self.lora_config = lora_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.log_stats = log_stats
self._verify_args()
@ -138,6 +142,7 @@ class LLMEngine:
self.model_config,
self.parallel_config,
self.scheduler_config,
self.device_config,
local_rank=0,
rank=0,
distributed_init_method=distributed_init_method,
@ -233,6 +238,7 @@ class LLMEngine:
model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config)
device_config = copy.deepcopy(self.device_config)
for rank, (worker, (node_id,
_)) in enumerate(zip(self.workers,
@ -244,6 +250,7 @@ class LLMEngine:
model_config,
parallel_config,
scheduler_config,
device_config,
local_rank,
rank,
distributed_init_method,
@ -257,6 +264,7 @@ class LLMEngine:
model_config,
parallel_config,
scheduler_config,
device_config,
driver_local_rank,
driver_rank,
distributed_init_method,

View File

@ -89,9 +89,7 @@ class ScaledActivation(nn.Module):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.scales = nn.Parameter(
torch.empty(intermediate_size_per_partition,
dtype=params_dtype,
device="cuda"))
torch.empty(intermediate_size_per_partition, dtype=params_dtype))
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
def forward(self, x: torch.Tensor) -> torch.Tensor:

View File

@ -200,7 +200,7 @@ def _make_alibi_bias(
seq_len: int,
dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias:
bias = torch.arange(seq_len, dtype=dtype, device="cuda")
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but

View File

@ -54,7 +54,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
params_dtype: torch.dtype) -> Dict[str, Any]:
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
@ -113,9 +112,7 @@ class ReplicatedLinear(torch.nn.Module):
self.register_parameter(name, weight)
if bias:
self.bias = Parameter(
torch.empty(self.output_size,
device=torch.cuda.current_device(),
dtype=self.params_dtype))
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {"output_dim": 0})
else:
self.register_parameter("bias", None)
@ -183,7 +180,6 @@ class ColumnParallelLinear(torch.nn.Module):
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
@ -509,9 +505,7 @@ class RowParallelLinear(torch.nn.Module):
if bias:
self.bias = Parameter(
torch.empty(self.output_size,
device=torch.cuda.current_device(),
dtype=params_dtype))
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,

View File

@ -96,7 +96,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
@ -112,7 +111,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
@ -128,7 +126,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,

View File

@ -127,7 +127,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
@ -145,7 +144,6 @@ class GPTQLinearMethod(LinearMethodBase):
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
@ -156,7 +154,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
@ -172,7 +169,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty(
scale_and_zero_size,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,

View File

@ -80,7 +80,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
@ -96,7 +95,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
torch.empty(
output_size,
self.quant_config.weight_bits**2,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
@ -118,12 +116,12 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
if is_hip():
out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float)
out_f = torch.zeros(out_shape, dtype=torch.float)
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
out = out_f.to(dtype=torch.float16)
else:
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
out = torch.zeros(out_shape, dtype=torch.float16)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None:

View File

@ -77,16 +77,13 @@ class RotaryEmbedding(nn.Module):
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim))
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings,
dtype=torch.float,
device="cuda")
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
@ -174,7 +171,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor
t = torch.arange(max_len, dtype=torch.float, device="cuda")
t = torch.arange(max_len, dtype=torch.float)
t = t / self.scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq)
@ -214,7 +211,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
(self.scaling_factor - 1))**(self.rotary_dim /
(self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float, device="cuda")
t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
@ -297,8 +294,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
pos_freqs = self.base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
@ -308,8 +305,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2, dtype=torch.float,
device="cuda")) * self.extrapolation_factor
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
@ -317,7 +314,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)

View File

@ -77,7 +77,6 @@ class VocabParallelEmbedding(torch.nn.Module):
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.weight, {
"parallel_dim": 0,
@ -139,7 +138,6 @@ class ParallelLMHead(VocabParallelEmbedding):
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.bias, {
"parallel_dim": 0,

View File

@ -5,7 +5,7 @@ from typing import Optional, Type
import torch
import torch.nn as nn
from vllm.config import ModelConfig, LoRAConfig
from vllm.config import DeviceConfig, ModelConfig, LoRAConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
@ -38,6 +38,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
model_class = _get_model_architecture(model_config)
@ -64,7 +65,7 @@ def get_model(model_config: ModelConfig,
with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# The weights will be initialized as empty tensors.
with torch.device("cuda"):
with torch.device(device_config.device):
if getattr(model_class, "supports_lora", False):
model = model_class(model_config.hf_config, linear_method,
lora_config)

View File

@ -228,6 +228,7 @@ def create_kv_caches_with_random(
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
if isinstance(cache_dtype, str):

View File

@ -104,11 +104,13 @@ class CacheEngine:
size=(self.num_cpu_blocks, *key_block_shape),
dtype=self.dtype,
pin_memory=pin_memory,
device="cpu",
)
value_blocks = torch.empty(
size=(self.num_cpu_blocks, *value_block_shape),
dtype=self.dtype,
pin_memory=pin_memory,
device="cpu",
)
cpu_cache.append((key_blocks, value_blocks))
return cpu_cache

View File

@ -5,7 +5,7 @@ import numpy as np
import torch
import torch.nn as nn
from vllm.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
from vllm.config import DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import (
@ -35,6 +35,7 @@ class ModelRunner:
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
@ -49,7 +50,10 @@ class ModelRunner:
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self.sliding_window = (model_config.get_sliding_window()
if model_config is not None else None)
self.device = torch.device(torch.cuda.current_device())
self.device_config = (device_config
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.model = None
self.block_size = None # Set after initial profiling.
self.lora_manager = None
@ -72,7 +76,8 @@ class ModelRunner:
self.kv_cache_dtype = kv_cache_dtype
def load_model(self) -> None:
self.model = get_model(self.model_config, self.lora_config)
self.model = get_model(self.model_config, self.device_config,
self.lora_config)
vocab_size = self.model.config.vocab_size
@ -182,22 +187,25 @@ class ModelRunner:
input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
dtype=torch.long)
dtype=torch.long,
device=self.device)
input_positions = _make_tensor_with_pad(input_positions,
max_prompt_len,
pad=0,
dtype=torch.long)
dtype=torch.long,
device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.long)
dtype=torch.long,
device=self.device)
lora_index_mapping = [
_pad_to_max(mapping, max_prompt_len, pad=0)
for mapping in lora_index_mapping
]
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device='cuda')
device=self.device)
# Prepare prefix block tables
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
block_tables = _make_tensor_with_pad(
@ -205,15 +213,16 @@ class ModelRunner:
max_len=max_prompt_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
)
start_loc_tensor = torch.arange(0,
len(prompt_lens) * max_prompt_len,
max_prompt_len,
dtype=torch.long,
device='cuda')
device=self.device)
prompt_lens_tensor = torch.tensor(prompt_lens,
dtype=torch.long,
device='cuda')
device=self.device)
input_metadata = InputMetadata(
is_prompt=True,
@ -305,20 +314,20 @@ class ModelRunner:
max_len=1,
pad=0,
dtype=torch.long,
device="cuda")
device=self.device)
input_positions = _make_tensor_with_pad(input_positions,
max_len=1,
pad=0,
dtype=torch.long,
device="cuda")
device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1,
pad=_PAD_SLOT_ID,
dtype=torch.long,
device="cuda")
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device="cuda")
device=self.device)
if use_captured_graph:
# The shape of graph_block_tables is
@ -327,7 +336,7 @@ class ModelRunner:
for i, block_table in enumerate(block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device="cuda")
block_tables = torch.tensor(input_block_tables, device=self.device)
else:
max_block_table_len = max(
len(block_table) for block_table in block_tables)
@ -336,7 +345,7 @@ class ModelRunner:
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device="cuda",
device=self.device,
)
lora_index_mapping = [
@ -355,7 +364,8 @@ class ModelRunner:
use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype,
)
return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests
return (input_tokens, input_positions, input_metadata,
lora_index_mapping, lora_prompt_mapping, lora_requests)
def _prepare_sample(
self,
@ -410,9 +420,13 @@ class ModelRunner:
selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=not self.in_wsl)
categorized_sample_indices = {
t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
t: _async_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=not self.in_wsl)
for t, seq_ids in categorized_sample_indices.items()
}
@ -511,7 +525,8 @@ class ModelRunner:
perform_sampling=False,
)
return input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping
return (input_tokens, input_positions, input_metadata,
sampling_metadata, lora_requests, lora_mapping)
@torch.inference_mode()
def execute_model(
@ -519,8 +534,9 @@ class ModelRunner:
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Optional[SamplerOutput]:
input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping = (
self.prepare_input_tensors(seq_group_metadata_list))
(input_tokens, input_positions, input_metadata, sampling_metadata,
lora_requests,
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
@ -789,14 +805,10 @@ def _make_tensor_with_pad(
max_len: int,
pad: int,
dtype: torch.dtype,
device: Union[str, torch.device] = "cuda",
pin_memory: bool = False,
device: Optional[Union[str, torch.device]],
) -> torch.Tensor:
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x,
dtype=dtype,
device=device,
pin_memory=pin_memory and str(device) == "cpu")
return torch.tensor(padded_x, dtype=dtype, device=device)
def _get_graph_batch_size(batch_size: int) -> int:
@ -808,6 +820,11 @@ def _get_graph_batch_size(batch_size: int) -> int:
return (batch_size + 7) // 8 * 8
def _async_h2d(data: list, dtype, pin_memory):
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
return t.to(device="cuda", non_blocking=True)
def _async_h2d(
data: list,
dtype: torch.dtype,
target_device: Union[str, torch.device],
pin_memory: bool,
) -> torch.Tensor:
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
return t.to(device=target_device, non_blocking=True)

View File

@ -6,8 +6,8 @@ from typing import Dict, List, Tuple, Set, Optional
import torch
import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, LoRAConfig)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
@ -33,6 +33,7 @@ class Worker:
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
@ -43,6 +44,7 @@ class Worker:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
@ -54,6 +56,7 @@ class Worker:
self.model_runner = ModelRunner(model_config,
parallel_config,
scheduler_config,
device_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
@ -65,6 +68,7 @@ class Worker:
self.gpu_cache = None
def init_model(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
@ -79,7 +83,9 @@ class Worker:
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method)