TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)

This commit is contained in:
Zhuohan Li 2023-10-02 15:36:09 -07:00 committed by GitHub
parent 84e4e37d14
commit ba0bfd40e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 819 additions and 1547 deletions

View File

@ -28,4 +28,4 @@ jobs:
pip install pylint==2.8.2
- name: Analysing the code with pylint
run: |
pylint vllm
pylint vllm tests

View File

@ -28,4 +28,4 @@ jobs:
pip install toml==0.10.2
- name: Running yapf
run: |
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'
yapf --diff --recursive vllm tests

View File

@ -8,7 +8,7 @@
[MASTER]
# Files or directories to be skipped. They should be base names, not paths.
ignore=docs,parallel_utils
ignore=docs
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.

View File

@ -44,7 +44,6 @@ YAPF_FLAGS=(
YAPF_EXCLUDES=(
'--exclude' 'build/**'
'--exclude' 'vllm/model_executor/parallel_utils/**'
)
# Format specified files
@ -72,7 +71,7 @@ format_changed() {
# Format all files
format_all() {
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm tests
}
## This flag formats individual files. --files *must* be the first command line
@ -96,7 +95,7 @@ echo 'vLLM yapf: Done'
# Run Pylint
echo 'vLLM Pylint:'
pylint vllm
pylint vllm tests
if ! git diff --quiet &>/dev/null; then
echo 'Reformatted files. Please review and stage the changes.'

View File

@ -14,6 +14,7 @@ app = vllm.entrypoints.api_server.app
class AsyncLLMEngineWithStats(AsyncLLMEngine):
# pylint: disable=redefined-outer-name
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._num_aborts = 0

View File

@ -24,6 +24,7 @@ def _query_server(prompt: str) -> dict:
def api_server():
script_path = Path(__file__).parent.joinpath(
"api_server_async_engine.py").absolute()
# pylint: disable=consider-using-with
uvicorn_process = subprocess.Popen([
sys.executable, "-u",
str(script_path), "--model", "facebook/opt-125m"
@ -32,6 +33,7 @@ def api_server():
uvicorn_process.terminate()
# pylint: disable=redefined-outer-name, unused-argument
def test_api_server(api_server):
"""
Run the API server and test it.
@ -47,6 +49,7 @@ def test_api_server(api_server):
prompts = ["Hello world"] * 1
result = None
while not result:
# pylint: disable=bare-except
try:
for result in pool.map(_query_server, prompts):
break

View File

@ -32,12 +32,12 @@ class MockEngine:
self.request_id = None
def add_request(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
return
def abort_request(self, request_id):
del request_id # Unused
self.abort_request_calls += 1
return
class MockAsyncLLMEngine(AsyncLLMEngine):

View File

@ -7,22 +7,22 @@ from vllm.outputs import RequestOutput
class DummyEvent:
def __init__(self):
self._flag = False
self.flag = False
def set(self):
self._flag = True
self.flag = True
def clear(self):
self._flag = False
self.flag = False
def test_request_tracker():
tracker = RequestTracker()
tracker.new_requests_event = DummyEvent()
stream_1 = tracker.add_request("1")
assert tracker.new_requests_event._flag
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == "1"
assert not finished
@ -30,9 +30,9 @@ def test_request_tracker():
stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3")
assert tracker.new_requests_event._flag
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert not tracker.new_requests_event.flag
assert len(new) == 2
assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3"
@ -43,7 +43,7 @@ def test_request_tracker():
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request("1")
assert not tracker.new_requests_event._flag
assert not tracker.new_requests_event.flag
tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests()
@ -54,7 +54,7 @@ def test_request_tracker():
stream_4 = tracker.add_request("4")
tracker.abort_request("4")
assert tracker.new_requests_event._flag
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "4" in finished
@ -62,11 +62,11 @@ def test_request_tracker():
assert stream_4.finished
stream_5 = tracker.add_request("5")
assert tracker.new_requests_event._flag
assert tracker.new_requests_event.flag
tracker.process_request_output(
RequestOutput("2", "output", [], [], finished=True))
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert not tracker.new_requests_event.flag
assert len(finished) == 1
assert "2" in finished
assert len(new) == 1

View File

@ -8,6 +8,7 @@ from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer
_TEST_PROMPTS = [
# pylint: disable=line-too-long
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",

View File

@ -0,0 +1,82 @@
"""Test the communication operators.
Run `pytest tests/distributed/test_comm_ops.py --forked`.
"""
from multiprocessing import Process
import pytest
import torch
from vllm.config import ParallelConfig
from vllm.engine.ray_utils import get_open_port
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce,
tensor_model_parallel_all_gather,
)
from vllm.worker.worker import _init_distributed_environment
def init_test_distributed_environment(pipeline_parallel_size: int,
tensor_parallel_size: int, rank: int,
distributed_init_port: str):
parallel_config = ParallelConfig(pipeline_parallel_size,
tensor_parallel_size,
worker_use_ray=True)
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
torch.cuda.set_device(rank)
_init_distributed_environment(parallel_config, rank,
distributed_init_method)
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
num_elements = 8
all_tensors = [
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
(r + 1) for r in range(tensor_parallel_size)
]
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank]
t = tensor_model_parallel_all_reduce(t)
assert torch.allclose(t, expected)
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
distributed_init_port: str):
init_test_distributed_environment(1, tensor_parallel_size, rank,
distributed_init_port)
num_dimensions = 3
tensor_size = list(range(2, num_dimensions + 2))
total_size = 1
for s in tensor_size:
total_size *= s
for all_gather_dimension in range(num_dimensions):
all_tensors = [
torch.arange(total_size, dtype=torch.float32,
device="cuda").reshape(tensor_size) * (r + 1)
for r in range(tensor_parallel_size)
]
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[rank]
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
assert torch.allclose(t, expected)
@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="Need at least 2 GPUs to run the test.")
@pytest.mark.parametrize("tensor_parallel_size", [2])
@pytest.mark.parametrize("test_target",
[all_reduce_test_worker, all_gather_test_worker])
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
distributed_init_port = get_open_port()
processes = []
for rank in range(tensor_parallel_size):
p = Process(target=test_target,
args=(tensor_parallel_size, rank, distributed_init_port))
p.start()
processes.append(p)
for p in processes:
p.join()
assert all(p.exitcode == 0 for p in processes)

View File

@ -5,6 +5,7 @@ from transformers import AutoTokenizer
from vllm.transformers_utils.tokenizer import detokenize_incrementally
TRUTH = [
# pylint: disable=line-too-long
"Hello here, this is a simple test",
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",
"我很感谢你的热情"

View File

@ -29,8 +29,8 @@ def test_silu_and_mul(
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.silu_and_mul(out, x)
ref_out = ref_silu_and_mul(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
@ -49,8 +49,8 @@ def test_gelu_new(
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.gelu_new(out, x)
ref_out = get_activation("gelu_new")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
@ -68,8 +68,8 @@ def test_gelu_fast(
) -> None:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
activation_ops.gelu_fast(out, x)
ref_out = get_activation("gelu_fast")(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)

View File

@ -106,14 +106,14 @@ def test_reshape_and_cache(
# Create a random slot mapping.
num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cuda")
qkv = torch.randn(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
device="cuda")
_, key, value = qkv.unbind(dim=1)
# Create the KV caches.
@ -132,7 +132,7 @@ def test_reshape_and_cache(
# Run the reference implementation.
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
block_indicies = block_indicies.cpu().tolist()
block_offsets = slot_mapping % block_size
block_offsets = block_offsets.cpu().tolist()

View File

@ -140,7 +140,7 @@ def test_rotary_embedding(
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone()

View File

@ -1,3 +1,4 @@
# pylint: disable=protected-access
import pytest
import random
from typing import Tuple
@ -108,7 +109,7 @@ def test_sampler_all_random(seed: int):
def test_sampler_all_beam(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
input_tensor, _, sampler, worker = _prepare_test(batch_size)
seq_group_metadata_list = []
for i in range(batch_size):

View File

@ -1,7 +1,7 @@
from vllm.model_executor.layers.quantized_linear.awq import (
AWQColumnParallelLinear, AWQRowParallelLinear)
from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
_QUANTIZED_LINEAR_REGISTRY = {
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),

View File

@ -4,8 +4,8 @@ import torch
from torch.nn.parameter import Parameter
from vllm import quantization_ops
from vllm.model_executor.parallel_utils.tensor_parallel.layers import (
ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear)
class AWQColumnParallelLinear(ColumnParallelLinear):

View File

@ -5,8 +5,8 @@ import torch
import torch.nn as nn
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.parallel_utils.tensor_parallel import (
gather_from_tensor_model_parallel_region)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather)
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
@ -92,7 +92,7 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = gather_from_tensor_model_parallel_region(logits)
logits = tensor_model_parallel_all_gather(logits)
# Remove paddings in vocab (if any).
logits = logits[:, :vocab_size]
return logits

View File

@ -39,8 +39,9 @@ from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.aquila import AquilaConfig
@ -56,16 +57,18 @@ class AquilaMLP(nn.Module):
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
self.gate_up_proj = ColumnParallelLinear(
hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@ -130,14 +133,12 @@ class AquilaAttention(nn.Module):
self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
@ -230,7 +231,7 @@ class AquilaModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
perform_initialization=False)
)
self.layers = nn.ModuleList([
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
@ -270,11 +271,12 @@ class AquilaForCausalLM(nn.Module):
self.config = config
self.model = AquilaModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.lm_head = ColumnParallelLinear(
config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size)
def forward(

View File

@ -39,8 +39,9 @@ from vllm.model_executor.weight_utils import (
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
@ -81,16 +82,18 @@ class BaiChuanMLP(nn.Module):
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
self.gate_up_proj = ColumnParallelLinear(
hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@ -133,14 +136,12 @@ class BaiChuanAttention(nn.Module):
3 * hidden_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
# Create the alibi slopes and slice them.
if self.postion_embedding == "ALIBI":
@ -249,7 +250,7 @@ class BaiChuanModel(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
perform_initialization=False)
)
self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding)
for _ in range(config.num_hidden_layers)
@ -288,11 +289,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
super().__init__()
self.config = config
self.model = BaiChuanModel(config, position_embedding)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.lm_head = ColumnParallelLinear(
config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size)
def forward(

View File

@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -85,14 +86,12 @@ class BloomAttention(nn.Module):
3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False,
)
# Create the alibi slopes and slice them.
@ -129,15 +128,17 @@ class BloomMLP(nn.Module):
def __init__(self, config: BloomConfig):
super().__init__()
hidden_size = config.hidden_size
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size,
gather_output=False,
perform_initialization=False)
self.dense_h_to_4h = ColumnParallelLinear(
hidden_size,
4 * hidden_size,
gather_output=False,
)
self.act = get_act_fn("gelu")
self.dense_4h_to_h = RowParallelLinear(4 * hidden_size,
hidden_size,
input_is_parallel=True,
perform_initialization=False)
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
input_is_parallel=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.dense_h_to_4h(x)
@ -208,7 +209,9 @@ class BloomModel(nn.Module):
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size, self.embed_dim, perform_initialization=False)
config.vocab_size,
self.embed_dim,
)
self.word_embeddings_layernorm = nn.LayerNorm(
self.embed_dim, eps=config.layer_norm_epsilon)

View File

@ -36,9 +36,11 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
reduce_from_tensor_model_parallel_region)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs import RWConfig
@ -109,7 +111,6 @@ class FalconAttention(nn.Module):
self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
elif self.multi_query:
@ -120,7 +121,6 @@ class FalconAttention(nn.Module):
self.total_num_heads * self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
self.key_value = FalconLinear(self.hidden_size,
@ -135,7 +135,6 @@ class FalconAttention(nn.Module):
self.head_dim,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True,
)
@ -151,7 +150,6 @@ class FalconAttention(nn.Module):
self.hidden_size,
bias=config.bias,
input_is_parallel=True,
perform_initialization=False,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results)
@ -231,7 +229,6 @@ class FalconMLP(nn.Module):
4 * hidden_size,
bias=config.bias,
gather_output=False,
perform_initialization=False,
skip_bias_add=True)
self.act = nn.GELU()
self.reduce_row_parallel_results = not (config.new_decoder_architecture
@ -241,7 +238,6 @@ class FalconMLP(nn.Module):
hidden_size,
bias=config.bias,
input_is_parallel=True,
perform_initialization=False,
skip_bias_add=True,
reduce_results=self.reduce_row_parallel_results)
@ -325,7 +321,7 @@ class FalconDecoderLayer(nn.Module):
# only one all-reduce operator to reduce the results from
# both MLP and Attention layers.
mlp_output += attention_output
mlp_output = reduce_from_tensor_model_parallel_region(mlp_output)
mlp_output = tensor_model_parallel_all_reduce(mlp_output)
if attention_bias is not None:
mlp_output += attention_bias
if mlp_bias is not None:
@ -347,7 +343,9 @@ class FalconModel(nn.Module):
# Embedding + LN Embedding
self.word_embeddings = VocabParallelEmbedding(
config.vocab_size, self.embed_dim, perform_initialization=False)
config.vocab_size,
self.embed_dim,
)
# Transformer blocks
self.h = nn.ModuleList([
@ -389,11 +387,12 @@ class FalconForCausalLM(nn.Module):
super().__init__()
self.config = config
self.transformer = FalconModel(config)
self.lm_head = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.lm_head = ColumnParallelLinear(
config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size)
def forward(

View File

@ -36,8 +36,9 @@ from vllm.model_executor.weight_utils import (
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -56,16 +57,18 @@ class GPT2Attention(nn.Module):
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5
self.c_attn = ColumnParallelLinear(self.hidden_size,
3 * self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.c_attn = ColumnParallelLinear(
self.hidden_size,
3 * self.hidden_size,
bias=True,
gather_output=False,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale)
@ -95,16 +98,18 @@ class GPT2MLP(nn.Module):
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.c_fc = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
gather_output=False,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@ -37,8 +37,9 @@ from vllm.model_executor.weight_utils import (
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -62,29 +63,31 @@ class GPTBigCodeAttention(nn.Module):
if self.multi_query:
self.num_kv_heads = 1
self.kv_dim = self.head_dim
self.c_attn_q = ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_attn_q = ColumnParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
gather_output=False,
)
self.c_attn_kv = nn.Linear(self.hidden_size,
2 * self.kv_dim,
bias=True)
else:
self.num_kv_heads = self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.c_attn = ColumnParallelLinear(self.hidden_size,
self.hidden_size +
2 * self.kv_dim,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_attn = ColumnParallelLinear(
self.hidden_size,
self.hidden_size + 2 * self.kv_dim,
bias=True,
gather_output=False,
)
self.c_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale,
@ -124,16 +127,18 @@ class GPTBigMLP(nn.Module):
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.c_fc = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
gather_output=False,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
input_is_parallel=True,
)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

View File

@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -49,16 +50,18 @@ class GPTJAttention(nn.Module):
self.hidden_size = config.hidden_size
self.head_size = self.hidden_size // self.total_num_heads
self.qkv_proj = ColumnParallelLinear(config.hidden_size,
3 * config.hidden_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(config.hidden_size,
config.hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
self.qkv_proj = ColumnParallelLinear(
config.hidden_size,
3 * config.hidden_size,
bias=False,
gather_output=False,
)
self.out_proj = RowParallelLinear(
config.hidden_size,
config.hidden_size,
bias=False,
input_is_parallel=True,
)
tp_world_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_world_size == 0
@ -102,14 +105,16 @@ class GPTJMLP(nn.Module):
def __init__(self, intermediate_size: int, config: GPTJConfig):
super().__init__()
hidden_size = config.n_embd
self.fc_in = ColumnParallelLinear(hidden_size,
intermediate_size,
gather_output=False,
perform_initialization=False)
self.fc_out = RowParallelLinear(intermediate_size,
hidden_size,
input_is_parallel=True,
perform_initialization=False)
self.fc_in = ColumnParallelLinear(
hidden_size,
intermediate_size,
gather_output=False,
)
self.fc_out = RowParallelLinear(
intermediate_size,
hidden_size,
input_is_parallel=True,
)
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@ -159,9 +164,10 @@ class GPTJModel(nn.Module):
super().__init__()
self.config = config
self.embed_dim = config.n_embd
self.wte = VocabParallelEmbedding(config.vocab_size,
self.embed_dim,
perform_initialization=False)
self.wte = VocabParallelEmbedding(
config.vocab_size,
self.embed_dim,
)
self.h = nn.ModuleList(
[GPTJBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@ -199,10 +205,11 @@ class GPTJForCausalLM(nn.Module):
self.config = config
assert not config.tie_word_embeddings
self.transformer = GPTJModel(config)
self.lm_head = ColumnParallelLinear(config.n_embd,
config.vocab_size,
gather_output=False,
perform_initialization=False)
self.lm_head = ColumnParallelLinear(
config.n_embd,
config.vocab_size,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size)
def forward(

View File

@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -59,11 +60,12 @@ class GPTNeoXAttention(nn.Module):
config.hidden_size,
3 * config.hidden_size,
gather_output=False,
perform_initialization=False)
self.dense = RowParallelLinear(config.hidden_size,
config.hidden_size,
input_is_parallel=True,
perform_initialization=False)
)
self.dense = RowParallelLinear(
config.hidden_size,
config.hidden_size,
input_is_parallel=True,
)
scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct)
@ -100,14 +102,16 @@ class GPTNeoXMLP(nn.Module):
def __init__(self, config: GPTNeoXConfig):
super().__init__()
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
config.intermediate_size,
gather_output=False,
perform_initialization=False)
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
config.hidden_size,
input_is_parallel=True,
perform_initialization=False)
self.dense_h_to_4h = ColumnParallelLinear(
config.hidden_size,
config.intermediate_size,
gather_output=False,
)
self.dense_4h_to_h = RowParallelLinear(
config.intermediate_size,
config.hidden_size,
input_is_parallel=True,
)
self.act = get_act_fn(config.hidden_act)
def forward(self, hidden_states):
@ -169,9 +173,10 @@ class GPTNeoXModel(nn.Module):
super().__init__()
self.config = config
self.embed_in = VocabParallelEmbedding(config.vocab_size,
config.hidden_size,
perform_initialization=False)
self.embed_in = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList(
[GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
@ -209,11 +214,12 @@ class GPTNeoXForCausalLM(nn.Module):
super().__init__()
self.config = config
self.gpt_neox = GPTNeoXModel(config)
self.embed_out = ColumnParallelLinear(config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.embed_out = ColumnParallelLinear(
config.hidden_size,
config.vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size)
def forward(

View File

@ -12,8 +12,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding)
from vllm.model_executor.weight_utils import (
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
load_tensor_parallel_weights)
@ -31,16 +32,18 @@ class InternLMMLP(nn.Module):
hidden_act: str,
):
super().__init__()
self.gate_up_proj = ColumnParallelLinear(hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False)
self.gate_up_proj = ColumnParallelLinear(
hidden_size,
2 * intermediate_size,
bias=False,
gather_output=False,
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
@ -80,14 +83,12 @@ class InternLMAttention(nn.Module):
3 * self.total_num_heads * self.head_dim,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False,
)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
@ -176,7 +177,9 @@ class InternLMModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False)
vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
InternLMDecoderLayer(config)
for _ in range(config.num_hidden_layers)
@ -216,11 +219,12 @@ class InternLMForCausalLM(nn.Module):
self.config = config
self.model = InternLMModel(config)
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ColumnParallelLinear(config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False)
self.lm_head = ColumnParallelLinear(
config.hidden_size,
vocab_size,
bias=False,
gather_output=False,
)
self.sampler = Sampler(config.vocab_size)
def forward(

View File

@ -39,8 +39,7 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.quantized_linear import ParallelLinear
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
@ -64,13 +63,11 @@ class LlamaMLP(nn.Module):
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config)
self.down_proj = ParallelLinear.row(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
@ -127,7 +124,6 @@ class LlamaAttention(nn.Module):
self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config,
)
self.o_proj = ParallelLinear.row(
@ -135,7 +131,6 @@ class LlamaAttention(nn.Module):
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config,
)
self.attn = PagedAttentionWithRoPE(
@ -241,7 +236,9 @@ class LlamaModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False)
vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
@ -291,7 +288,6 @@ class LlamaForCausalLM(nn.Module):
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size)

View File

@ -38,8 +38,7 @@ from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.quantized_linear import ParallelLinear
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding)
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.weight_utils import (
convert_pyslice_to_tensor, hf_model_weights_iterator,
@ -64,13 +63,11 @@ class MistralMLP(nn.Module):
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config)
self.down_proj = ParallelLinear.row(intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
@ -116,7 +113,6 @@ class MistralAttention(nn.Module):
self.head_dim,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=quant_config,
)
self.o_proj = ParallelLinear.row(
@ -124,7 +120,6 @@ class MistralAttention(nn.Module):
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
quant_config=quant_config,
)
self.attn = PagedAttentionWithRoPE(self.num_heads,
@ -225,7 +220,9 @@ class MistralModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.embed_tokens = VocabParallelEmbedding(
vocab_size, config.hidden_size, perform_initialization=False)
vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
MistralDecoderLayer(config, quant_config)
for _ in range(config.num_hidden_layers)
@ -275,7 +272,6 @@ class MistralForCausalLM(nn.Module):
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False,
quant_config=None)
self.sampler = Sampler(config.vocab_size)

View File

@ -15,8 +15,9 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
from vllm.transformers_utils.configs.mpt import MPTConfig
@ -53,7 +54,6 @@ class MPTAttention(nn.Module):
3 * self.d_model,
bias=not config.no_bias,
gather_output=False,
perform_initialization=False,
)
if self.qk_ln:
self.q_ln = nn.LayerNorm(self.d_model)
@ -63,7 +63,6 @@ class MPTAttention(nn.Module):
self.d_model,
bias=not config.no_bias,
input_is_parallel=True,
perform_initialization=False,
)
tp_world_size = get_tensor_model_parallel_world_size()
@ -113,17 +112,19 @@ class MPTMLP(nn.Module):
hidden_size = config.d_model
expansion_ratio = config.expansion_ratio
intermediate_size = expansion_ratio * hidden_size
self.up_proj = ColumnParallelLinear(hidden_size,
intermediate_size,
bias=not config.no_bias,
gather_output=False,
perform_initialization=False)
self.up_proj = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=not config.no_bias,
gather_output=False,
)
self.act = get_act_fn("gelu")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=not config.no_bias,
input_is_parallel=True,
perform_initialization=False)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=not config.no_bias,
input_is_parallel=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.up_proj(x)
@ -172,9 +173,10 @@ class MPTModel(nn.Module):
assert config.embedding_fraction == 1.0
assert config.norm_type == "low_precision_layernorm"
self.wte = VocabParallelEmbedding(config.vocab_size,
config.d_model,
perform_initialization=False)
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.d_model,
)
self.blocks = nn.ModuleList(
[MPTBlock(config) for _ in range(config.n_layers)])
self.norm_f = nn.LayerNorm(config.d_model)

View File

@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
load_tensor_parallel_weights)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.parallel_utils.tensor_parallel import (
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
@ -73,16 +74,18 @@ class OPTAttention(nn.Module):
self.head_dim = embed_dim // total_num_heads
self.scaling = self.head_dim**-0.5
self.qkv_proj = ColumnParallelLinear(embed_dim,
3 * embed_dim,
bias=bias,
gather_output=False,
perform_initialization=False)
self.out_proj = RowParallelLinear(embed_dim,
embed_dim,
bias=bias,
input_is_parallel=True,
perform_initialization=False)
self.qkv_proj = ColumnParallelLinear(
embed_dim,
3 * embed_dim,
bias=bias,
gather_output=False,
)
self.out_proj = RowParallelLinear(
embed_dim,
embed_dim,
bias=bias,
input_is_parallel=True,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scaling)
@ -120,16 +123,18 @@ class OPTDecoderLayer(nn.Module):
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
self.fc1 = ColumnParallelLinear(self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
gather_output=False,
perform_initialization=False)
self.fc2 = RowParallelLinear(config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
input_is_parallel=True,
perform_initialization=False)
self.fc1 = ColumnParallelLinear(
self.embed_dim,
config.ffn_dim,
bias=config.enable_bias,
gather_output=False,
)
self.fc2 = RowParallelLinear(
config.ffn_dim,
self.embed_dim,
bias=config.enable_bias,
input_is_parallel=True,
)
self.final_layer_norm = nn.LayerNorm(
self.embed_dim,
elementwise_affine=config.layer_norm_elementwise_affine)
@ -182,7 +187,7 @@ class OPTDecoder(nn.Module):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.word_embed_proj_dim,
perform_initialization=False)
)
# Positional embeddings are replicated (not sharded).
self.embed_positions = OPTLearnedPositionalEmbedding(
config.max_position_embeddings, config.hidden_size)

View File

@ -28,7 +28,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.parallel_utils.tensor_parallel import (
from vllm.model_executor.parallel_utils.layers import (
VocabParallelEmbedding,
ColumnParallelLinear,
RowParallelLinear,
@ -53,14 +53,12 @@ class QWenMLP(nn.Module):
2 * intermediate_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
@ -98,14 +96,12 @@ class QWenAttention(nn.Module):
3 * hidden_size,
bias=True,
gather_output=False,
perform_initialization=False,
)
self.c_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
input_is_parallel=True,
perform_initialization=False,
)
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(
@ -190,9 +186,10 @@ class QWenModel(nn.Module):
self.vocab_size = config.vocab_size
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.wte = VocabParallelEmbedding(vocab_size,
config.hidden_size,
perform_initialization=False)
self.wte = VocabParallelEmbedding(
vocab_size,
config.hidden_size,
)
self.h = nn.ModuleList(
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
@ -235,7 +232,6 @@ class QWenLMHeadModel(nn.Module):
vocab_size,
bias=False,
gather_output=False,
perform_initialization=False,
)
self.sampler = Sampler(config.vocab_size)

View File

@ -1,7 +0,0 @@
import vllm.model_executor.parallel_utils.parallel_state
import vllm.model_executor.parallel_utils.tensor_parallel
__all__ = [
"parallel_state",
"tensor_parallel",
]

View File

@ -0,0 +1,47 @@
import torch
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
def tensor_model_parallel_all_reduce(input_):
"""All-reduce the input tensor across model parallel group.
Note: This operation is applied in-place on the input tensor.
"""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
return input_
def tensor_model_parallel_all_gather(input_, dim=-1):
"""All-gather the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=get_tensor_model_parallel_group())
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size * input_size[dim], ) +
input_size[dim + 1:])
return output_tensor

View File

@ -0,0 +1,303 @@
# Copyright 2023 The vLLM team.
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
from typing import Optional
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.quantization_utils import QuantizationConfig
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
from vllm.model_executor.parallel_utils.utils import (
divide,
VocabUtility,
split_tensor_along_last_dim,
)
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None):
super().__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.tp_size = get_tensor_model_parallel_world_size()
# TODO: Handle vocab padding here.
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = (
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(),
self.tp_size))
self.num_embeddings_per_partition = (self.vocab_end_index -
self.vocab_start_index)
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype))
def forward(self, input_):
if self.tp_size > 1:
# Build the mask.
input_mask = ((input_ < self.vocab_start_index) |
(input_ >= self.vocab_end_index))
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight)
# Mask the output embedding.
if self.tp_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
return output
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
Keyword Arguments
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configuration.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
gather_output: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, self.tp_size)
self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.create_weights(params_dtype)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
else:
self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(
torch.empty(self.output_size_per_partition,
self.input_size,
device=torch.cuda.current_device(),
dtype=dtype))
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
return F.linear(x, self.weight, bias)
def forward(self, input_):
"""Forward of ColumnParallelLinear
Args:
input_: Tensor whose last dimension is `input_size`.
Returns:
- output
- bias
"""
bias = self.bias if not self.skip_bias_add else None
input_parallel = input_
# Matrix multiply.
output_parallel = self.apply_weights(input_parallel, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
Keyword Arguments:
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
skip_bias_add: This was added to enable performance optimization where
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configuration.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
self.create_weights(params_dtype)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError('When not reduce the results, adding bias to the '
'results can lead to incorrect results')
if bias:
self.bias = Parameter(
torch.empty(self.output_size,
device=torch.cuda.current_device(),
dtype=params_dtype))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(
torch.empty(self.output_size,
self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=dtype))
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight)
def forward(self, input_):
"""Forward of RowParallelLinear
Args:
input_: tensor whose last dimension is `input_size`. If
`input_is_parallel` is set, then the last dimension
is `input_size // tp_size`.
Returns:
- output
- bias
"""
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
# TODO: simplify code below
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
output_parallel = self.apply_weights(input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else:
output_ = output_parallel
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
return output, output_bias

View File

@ -1,78 +1,42 @@
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Model and data parallel groups."""
import torch
from typing import Optional
# Intra-layer model parallel group that the current rank belongs to.
# Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
# Pipeline model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Position embedding group.
_POSITION_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage.
# A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None
def initialize_model_parallel(
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
) -> None:
"""
Initialize model data parallel groups.
Initialize model parallel groups.
Arguments:
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
pipeline).
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
rank in pipeline with split point.
tensor_model_parallel_size: number of GPUs used for tensor model
parallelism.
pipeline_model_parallel_size: number of GPUs used for pipeline model
parallelism.
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
and 8 data-parallel groups as:
8 data_parallel groups:
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
8 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
4 pipeline model-parallel groups:
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
4 tensor model-parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 pipeline model-parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]
Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
@ -82,64 +46,23 @@ def initialize_model_parallel(
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
if (world_size !=
tensor_model_parallel_size * pipeline_model_parallel_size):
raise RuntimeError(
f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
)
data_parallel_size: int = world_size // (tensor_model_parallel_size *
pipeline_model_parallel_size)
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
num_data_parallel_groups: int = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size > 2:
raise RuntimeError("pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule")
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
if pipeline_model_parallel_split_rank is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
f"world_size ({world_size}) is not equal to "
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
num_tensor_model_parallel_groups: int = (world_size //
tensor_model_parallel_size)
num_pipeline_model_parallel_groups: int = (world_size //
pipeline_model_parallel_size)
rank = torch.distributed.get_rank()
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GLOBAL_RANKS
assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
all_data_parallel_group_ranks = []
for i in range(pipeline_model_parallel_size):
start_rank = i * num_pipeline_model_parallel_groups
end_rank = (i + 1) * num_pipeline_model_parallel_groups
for j in range(tensor_model_parallel_size):
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
all_data_parallel_group_ranks.append(list(ranks))
group = torch.distributed.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GLOBAL_RANKS = ranks
# Build the model-parallel groups.
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
for i in range(data_parallel_size):
ranks = [data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks]
group = torch.distributed.new_group(ranks)
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
'tensor model parallel group is already initialized'
assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
"tensor model parallel group is already initialized")
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)
@ -147,268 +70,60 @@ def initialize_model_parallel(
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
# Build the pipeline model-parallel groups and embedding groups
# (first and last rank in each pipeline model-parallel group).
# Build the pipeline model-parallel groups.
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
'pipeline model parallel group is already initialized'
global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS
assert _POSITION_EMBEDDING_GROUP is None, \
'position embedding group is already initialized'
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
"pipeline model parallel group is already initialized")
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_PIPELINE_MODEL_PARALLEL_GROUP = group
_PIPELINE_GLOBAL_RANKS = ranks
# Setup embedding group (to exchange gradients between
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank is not None:
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank],
ranks[-1]]
if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank]]
else:
embedding_ranks = ranks
position_embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
group = torch.distributed.new_group(position_embedding_ranks)
if rank in position_embedding_ranks:
_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
_PIPELINE_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None:
return False
return True
def get_model_parallel_group():
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, \
'model parallel group is not initialized'
return _MODEL_PARALLEL_GROUP
return (_TENSOR_MODEL_PARALLEL_GROUP is not None
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized'
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
"tenosr model parallel group is not initialized")
return _TENSOR_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
'pipeline_model parallel group is not initialized'
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
"pipeline model parallel group is not initialized")
return _PIPELINE_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_embedding_group():
"""Get the embedding group the caller rank belongs to."""
assert _EMBEDDING_GROUP is not None, \
'embedding group is not initialized'
return _EMBEDDING_GROUP
def get_position_embedding_group():
"""Get the position embedding group the caller rank belongs to."""
assert _POSITION_EMBEDDING_GROUP is not None, \
'position embedding group is not initialized'
return _POSITION_EMBEDDING_GROUP
def set_tensor_model_parallel_world_size(world_size):
"""Set the tensor model parallel size"""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
def set_pipeline_model_parallel_world_size(world_size):
"""Set the pipeline model parallel size"""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
return torch.distributed.get_world_size(
group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_world_size():
"""Return world size for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
def set_tensor_model_parallel_rank(rank):
"""Set tensor model parallel rank."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
def set_pipeline_model_parallel_rank(rank):
"""Set pipeline model parallel rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
def set_pipeline_model_parallel_split_rank(rank):
"""Set pipeline model parallel split rank."""
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
return torch.distributed.get_world_size(
group=get_pipeline_model_parallel_group())
def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
global _MPU_TENSOR_MODEL_PARALLEL_RANK
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
return _MPU_TENSOR_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
def get_pipeline_model_parallel_rank():
"""Return my rank for the pipeline model parallel group."""
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
def is_pipeline_first_stage(ignore_virtual=False):
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
if get_virtual_pipeline_model_parallel_world_size() is not None and \
get_virtual_pipeline_model_parallel_rank() != 0:
return False
return get_pipeline_model_parallel_rank() == 0
def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
virtual_pipeline_model_parallel_world_size = \
get_virtual_pipeline_model_parallel_world_size()
if virtual_pipeline_model_parallel_world_size is not None and \
get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size - 1):
return False
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1)
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _EMBEDDING_GLOBAL_RANKS
if ignore_virtual:
return rank in _EMBEDDING_GLOBAL_RANKS
if rank in _EMBEDDING_GLOBAL_RANKS:
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
return is_pipeline_first_stage(ignore_virtual=False)
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
return is_pipeline_last_stage(ignore_virtual=False)
else:
return True
return False
def is_rank_in_position_embedding_group():
"""Return true if current rank is in position embedding group, False otherwise."""
rank = torch.distributed.get_rank()
global _POSITION_EMBEDDING_GLOBAL_RANKS
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_after_split(rank=None):
"""Return True if pipeline stage executes decoder block for a model
with both encoder and decoder."""
if get_pipeline_model_parallel_world_size() == 1:
return True
if rank is None:
rank = get_pipeline_model_parallel_rank()
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
return True
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
return True
return False
def is_pipeline_stage_at_split():
"""Return true if pipeline stage executes decoder block and next
stage executes encoder block for a model with both encoder and
decoder."""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and \
is_pipeline_stage_after_split(rank+1)
def get_virtual_pipeline_model_parallel_rank():
"""Return the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
def set_virtual_pipeline_model_parallel_rank(rank):
"""Set the virtual pipeline-parallel rank."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
def get_virtual_pipeline_model_parallel_world_size():
"""Return the virtual pipeline-parallel world size."""
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
return torch.distributed.get_rank(
group=get_pipeline_model_parallel_group())
def get_tensor_model_parallel_src_rank():
@ -419,35 +134,27 @@ def get_tensor_model_parallel_src_rank():
return (global_rank // local_world_size) * local_world_size
def get_data_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the data parallel group."""
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
"Data parallel group is not initialized"
return _DATA_PARALLEL_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_first_rank():
"""Return the global rank of the first process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
"""Return the global rank of the last process in the pipeline for the
current tensor parallel group"""
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
"""Return the global rank that follows the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
@ -455,45 +162,18 @@ def get_pipeline_model_parallel_next_rank():
def get_pipeline_model_parallel_prev_rank():
"""Return the global rank that preceeds the caller in the pipeline"""
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
assert _PIPELINE_GLOBAL_RANKS is not None, (
"Pipeline parallel group is not initialized")
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=get_data_parallel_group())
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=get_data_parallel_group())
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
_MODEL_PARALLEL_GROUP = None
global _TENSOR_MODEL_PARALLEL_GROUP
_TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _DATA_PARALLEL_GROUP
_DATA_PARALLEL_GROUP = None
global _EMBEDDING_GROUP
_EMBEDDING_GROUP = None
global _POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
global _MPU_TENSOR_MODEL_PARALLEL_RANK
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None

View File

@ -1,50 +0,0 @@
from .layers import (
ColumnParallelLinear,
RowParallelLinear,
VocabParallelEmbedding,
set_tensor_model_parallel_attributes,
set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate,
)
from .mappings import (
copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
scatter_to_sequence_parallel_region,
)
from .random import (
get_cuda_rng_tracker,
model_parallel_cuda_manual_seed,
)
from .utils import (
split_tensor_along_last_dim,
)
__all__ = [
#layers.py
"ColumnParallelLinear",
"RowParallelLinear",
"VocabParallelEmbedding",
"set_tensor_model_parallel_attributes",
"set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes",
"param_is_not_tensor_parallel_duplicate",
# mappings.py
"copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region",
"gather_from_sequence_parallel_region",
"reduce_from_tensor_model_parallel_region",
"scatter_to_tensor_model_parallel_region",
"scatter_to_sequence_parallel_region",
# random.py
"get_cuda_rng_tracker",
"model_parallel_cuda_manual_seed",
# utils.py
"split_tensor_along_last_dim",
]

View File

@ -1,366 +0,0 @@
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
from typing import Optional
import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from .mappings import (
gather_from_tensor_model_parallel_region,
reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region,
)
from .utils import (
divide,
VocabUtility,
)
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1}
def param_is_not_tensor_parallel_duplicate(param):
return (hasattr(param, 'tensor_model_parallel') and
param.tensor_model_parallel) or (
get_tensor_model_parallel_rank() == 0)
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
# Make sure the attributes are not set.
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
assert not hasattr(tensor, attribute)
# Set the attributes.
setattr(tensor, 'tensor_model_parallel', is_parallel)
setattr(tensor, 'partition_dim', dim)
setattr(tensor, 'partition_stride', stride)
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
def maybe_set(attribute, value):
if not hasattr(tensor, attribute):
setattr(tensor, attribute, value)
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
def maybe_copy(attribute):
if hasattr(source_tensor, attribute):
setattr(destination_tensor, attribute,
getattr(source_tensor, attribute))
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
maybe_copy(attribute)
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
Keyword Arguments:
init_method: method to initialize weights.
params_dtype
use_cpu_initialization
perform_initialization
"""
def __init__(self, num_embeddings: int, embedding_dim: int, *,
init_method=init.xavier_normal_,
params_dtype: torch.dtype=None,
use_cpu_initialization: bool=False,
perform_initialization: bool=False):
super(VocabParallelEmbedding, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Set the defaults for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
# Divide the weight matrix along the vocaburaly dimension.
self.vocab_start_index, self.vocab_end_index = \
VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_tensor_model_parallel_rank(),
self.tensor_model_parallel_size)
self.num_embeddings_per_partition = self.vocab_end_index - \
self.vocab_start_index
self.weight = Parameter(torch.empty(
self.num_embeddings_per_partition, self.embedding_dim,
device=torch.cuda.current_device(), dtype=params_dtype))
def forward(self, input_):
if self.tensor_model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | \
(input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input, self.weight,
self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq,
self.sparse)
# Mask the output embedding.
if self.tensor_model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_tensor_model_parallel_region(output_parallel)
return output
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
Keyword Arguments
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
params_dtype:
use_cpu_initialization:
"""
def __init__(self, input_size, output_size, *,
bias=True, gather_output=True,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
params_dtype=None,
use_cpu_initialization=False,
perform_initialization=False,
quant_config=None,
):
super(ColumnParallelLinear, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
self.world_size = get_tensor_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, self.world_size)
self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
self.create_weights(params_dtype)
if bias:
self.bias = Parameter(torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty(
self.output_size_per_partition, self.input_size,
device=torch.cuda.current_device(), dtype=dtype))
def apply_weights(
self,
x: torch.Tensor,
bias: Optional[torch.Tensor],
) -> torch.Tensor:
return F.linear(x, self.weight, bias)
def forward(self, input_):
"""Forward of ColumnParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
bias = self.bias if not self.skip_bias_add else None
input_parallel = input_
# Matrix multiply.
output_parallel = self.apply_weights(input_parallel, bias)
if self.gather_output:
# All-gather across the partitions.
output = gather_from_tensor_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
Keyword Arguments:
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimization where bias
can be fused with other elementwise operations. We skip
adding bias but instead return it.
params_dtype:
use_cpu_initialization:
perform_initialization:
reduce_results:
"""
def __init__(self, input_size, output_size, *,
bias=True, input_is_parallel=False,
init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
params_dtype=None,
use_cpu_initialization=False,
perform_initialization=False,
reduce_results=True,
quant_config=None,
):
super(RowParallelLinear, self).__init__()
assert not perform_initialization
assert not use_cpu_initialization
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
if params_dtype is None:
params_dtype = torch.get_default_dtype()
# Divide the weight matrix along the last dimension.
self.world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.world_size)
self.skip_bias_add = skip_bias_add
self.quant_config = quant_config
self.create_weights(params_dtype)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(),
dtype=params_dtype))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def create_weights(self, dtype: torch.dtype) -> None:
self.weight = Parameter(torch.empty(
self.output_size, self.input_size_per_partition,
device=torch.cuda.current_device(), dtype=dtype))
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, self.weight)
def forward(self, input_):
"""Forward of RowParallelLinear
Args:
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
Returns:
- output
- bias
"""
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply.
output_parallel = self.apply_weights(input_parallel)
if self.reduce_results and self.world_size > 1:
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
else:
output_ = output_parallel
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
return output, output_bias

View File

@ -1,281 +0,0 @@
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/mappings.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import torch
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
from .utils import split_tensor_along_last_dim
def _reduce(input_):
"""All-reduce the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size()==1:
return input_
# All-reduce.
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
return input_
def _split_along_last_dim(input_):
"""Split the tensor along its last dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along last dimension.
input_list = split_tensor_along_last_dim(input_, world_size)
# Note: torch.split does not create contiguous tensors by default.
rank = get_tensor_model_parallel_rank()
output = input_list[rank].contiguous()
return output
def _split_along_first_dim(input_):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Split along first dimension.
dim_size = input_.size()[0]
assert dim_size % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size = dim_size // world_size
rank = get_tensor_model_parallel_rank()
dim_offset = rank * local_dim_size
output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
return output
def _gather_along_last_dim(input_):
"""Gather tensors and concatinate along the last dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
# Size and dimension.
last_dim = input_.dim() - 1
rank = get_tensor_model_parallel_rank()
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
tensor_list[rank] = input_
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=last_dim).contiguous()
return output
def _gather_along_first_dim(input_):
"""Gather tensors and concatinate along the first dimension."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._all_gather_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
dim_size = list(input_.size())
assert dim_size[0] % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype,
device=torch.cuda.current_device())
torch.distributed._reduce_scatter_base(output, input_.contiguous(),
group=get_tensor_model_parallel_group())
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return input_
@staticmethod
def forward(ctx, input_):
return input_
@staticmethod
def backward(ctx, grad_output):
return _reduce(grad_output)
class _ReduceFromModelParallelRegion(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce(input_)
@staticmethod
def forward(ctx, input_):
return _reduce(input_)
@staticmethod
def backward(ctx, grad_output):
return grad_output
class _ScatterToModelParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_last_dim(grad_output)
class _GatherFromModelParallelRegion(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_):
return _gather_along_last_dim(input_)
@staticmethod
def forward(ctx, input_):
return _gather_along_last_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _split_along_last_dim(grad_output)
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank."""
@staticmethod
def symbolic(graph, input_):
return _split_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _split_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
"""Gather the input from sequence parallel region and concatinate."""
@staticmethod
def symbolic(graph, input_, tensor_parallel_output_grad=True):
return _gather_along_first_dim(input_)
@staticmethod
def forward(ctx, input_, tensor_parallel_output_grad=True):
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
return _gather_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
# If the computation graph after the gather operation is
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
if tensor_parallel_output_grad:
return _reduce_scatter_along_first_dim(grad_output), None
else:
return _split_along_first_dim(grad_output), None
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
"""Reduce scatter the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def forward(ctx, input_):
return _reduce_scatter_along_first_dim(input_)
@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)
# -----------------
# Helper functions.
# -----------------
def copy_to_tensor_model_parallel_region(input_):
return _CopyToModelParallelRegion.apply(input_)
def reduce_from_tensor_model_parallel_region(input_):
return _ReduceFromModelParallelRegion.apply(input_)
def scatter_to_tensor_model_parallel_region(input_):
return _ScatterToModelParallelRegion.apply(input_)
def gather_from_tensor_model_parallel_region(input_):
return _GatherFromModelParallelRegion.apply(input_)
def scatter_to_sequence_parallel_region(input_):
return _ScatterToSequenceParallelRegion.apply(input_)
def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
def reduce_scatter_to_sequence_parallel_region(input_):
return _ReduceScatterToSequenceParallelRegion.apply(input_)

View File

@ -1,164 +0,0 @@
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import contextlib
import torch
from torch import _C
from torch.cuda import _lazy_call, device as device_ctx_manager
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
)
# Default name for the model parallel rng tracker.
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
def _set_cuda_rng_state(new_state, device=-1):
"""Sets the random number generator state of the current GPU.
Argumentss:
new_state (torch.ByteTensor): The desired state
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
with a single change: the input state is not cloned. Cloning caused
major performance issues for +4 GPU cases.
"""
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
# older PyTorch
def cb():
with device_ctx_manager(device):
_C._cuda_setRNGState(new_state)
else:
# newer PyTorch
if device == -1:
device = torch.device('cuda')
elif isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
device = torch.device('cuda', device)
def cb():
idx = device.index
if idx is None:
idx = torch.cuda.current_device()
default_generator = torch.cuda.default_generators[idx]
default_generator.set_state(new_state)
_lazy_call(cb)
class CudaRNGStatesTracker:
"""Tracker for the cuda RNG states.
Using the `add` method, a cuda rng state is initialized based on
the input `seed` and is assigned to `name`. Later, by forking the
rng state, we can perform operations and return to our starting
cuda state.
"""
def __init__(self):
# Map from a string name to the cuda rng state.
self.states_ = {}
# Seeds are just for book keeping and ensure no seed is set twice.
self.seeds_ = set()
def reset(self):
"""Set to the initial state (no tracker)."""
self.states_ = {}
self.seeds_ = set()
def get_states(self):
"""Get rng states. Copy the dictionary so we have direct
pointers to the states, not just a pointer to the dictionary."""
states = {}
for name in self.states_:
states[name] = self.states_[name]
return states
def set_states(self, states):
"""Set the rng states. For efficiency purposes, we do not check
the size of seed for compatibility."""
self.states_ = states
def add(self, name, seed):
"""Track the rng state."""
# Check seed is not already used.
if seed in self.seeds_:
raise Exception('seed {} already exists'.format(seed))
self.seeds_.add(seed)
# Check that state is not already defined.
if name in self.states_:
raise Exception('cuda rng state {} already exists'.format(name))
# Get the current rng state.
orig_rng_state = torch.cuda.get_rng_state()
# Set the new state and store it.
torch.cuda.manual_seed(seed)
self.states_[name] = torch.cuda.get_rng_state()
# Reset rng state to what it was.
_set_cuda_rng_state(orig_rng_state)
@contextlib.contextmanager
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
"""Fork the cuda rng state, perform operations, and exit with
the original state."""
# Check if we have added the state
if name not in self.states_:
raise Exception('cuda rng state {} is not added'.format(name))
# Store current rng state.
orig_cuda_rng_state = torch.cuda.get_rng_state()
# Set rng state to the desired one
_set_cuda_rng_state(self.states_[name])
# Do the stuff we wanted to do.
try:
yield
finally:
# Update the current rng state for later use.
self.states_[name] = torch.cuda.get_rng_state()
# And set the state to the original state we started with.
_set_cuda_rng_state(orig_cuda_rng_state)
# RNG tracker object.
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
def get_cuda_rng_tracker():
"""Get cuda rng tracker."""
return _CUDA_RNG_STATE_TRACKER
def model_parallel_cuda_manual_seed(seed):
"""Initialize model parallel cuda seed.
This function should be called after the model parallel is
initialized. Also, no torch.cuda.manual_seed should be called
after this function. Basically, this is replacement for that
function.
Two set of RNG states are tracked:
default state: This is for data parallelism and is the same among a
set of model parallel GPUs but different across
different model paralle groups. This is used for
example for dropout in the non-tensor-model-parallel regions.
tensor-model-parallel state: This state is different among a set of model
parallel GPUs, but the same across data parallel
groups. This is used for example for dropout in
model parallel regions.
"""
# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
# Data parallel gets the original seed.
data_parallel_seed = seed
_CUDA_RNG_STATE_TRACKER.reset()
# Set the default state.
torch.cuda.manual_seed(data_parallel_seed)
# and model parallel state.
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
tensor_model_parallel_seed)

View File

@ -1,15 +1,16 @@
# Copyright 2023 The vLLM team.
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from typing import List, Sequence
import torch
from typing import List, Sequence
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(
numerator, denominator
)
numerator, denominator)
def divide(numerator, denominator):
@ -56,15 +57,14 @@ class VocabUtility:
@staticmethod
def vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size: int, rank, world_size: int
) -> Sequence[int]:
per_partition_vocab_size: int, rank: int) -> Sequence[int]:
index_f = rank * per_partition_vocab_size
index_l = index_f + per_partition_vocab_size
return index_f, index_l
@staticmethod
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
world_size: int) -> Sequence[int]:
per_partition_vocab_size = divide(global_vocab_size, world_size)
return VocabUtility.vocab_range_from_per_partition_vocab_size(
per_partition_vocab_size, rank, world_size
)
per_partition_vocab_size, rank)

View File

@ -4,9 +4,6 @@ import random
import numpy as np
import torch
from vllm.model_executor.parallel_utils.parallel_state import model_parallel_is_initialized
from vllm.model_executor.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
def set_random_seed(seed: int) -> None:
random.seed(seed)
@ -14,6 +11,3 @@ def set_random_seed(seed: int) -> None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
if model_parallel_is_initialized():
model_parallel_cuda_manual_seed(seed)