TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)
This commit is contained in:
parent
84e4e37d14
commit
ba0bfd40e2
2
.github/workflows/pylint.yml
vendored
2
.github/workflows/pylint.yml
vendored
@ -28,4 +28,4 @@ jobs:
|
||||
pip install pylint==2.8.2
|
||||
- name: Analysing the code with pylint
|
||||
run: |
|
||||
pylint vllm
|
||||
pylint vllm tests
|
||||
|
2
.github/workflows/yapf.yml
vendored
2
.github/workflows/yapf.yml
vendored
@ -28,4 +28,4 @@ jobs:
|
||||
pip install toml==0.10.2
|
||||
- name: Running yapf
|
||||
run: |
|
||||
yapf --diff --recursive vllm --exclude 'vllm/model_executor/parallel_utils/**'
|
||||
yapf --diff --recursive vllm tests
|
||||
|
@ -8,7 +8,7 @@
|
||||
[MASTER]
|
||||
|
||||
# Files or directories to be skipped. They should be base names, not paths.
|
||||
ignore=docs,parallel_utils
|
||||
ignore=docs
|
||||
|
||||
# Files or directories matching the regex patterns are skipped. The regex
|
||||
# matches against base names, not paths.
|
||||
|
@ -44,7 +44,6 @@ YAPF_FLAGS=(
|
||||
|
||||
YAPF_EXCLUDES=(
|
||||
'--exclude' 'build/**'
|
||||
'--exclude' 'vllm/model_executor/parallel_utils/**'
|
||||
)
|
||||
|
||||
# Format specified files
|
||||
@ -72,7 +71,7 @@ format_changed() {
|
||||
|
||||
# Format all files
|
||||
format_all() {
|
||||
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm
|
||||
yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm tests
|
||||
}
|
||||
|
||||
## This flag formats individual files. --files *must* be the first command line
|
||||
@ -96,7 +95,7 @@ echo 'vLLM yapf: Done'
|
||||
|
||||
# Run Pylint
|
||||
echo 'vLLM Pylint:'
|
||||
pylint vllm
|
||||
pylint vllm tests
|
||||
|
||||
if ! git diff --quiet &>/dev/null; then
|
||||
echo 'Reformatted files. Please review and stage the changes.'
|
||||
|
@ -14,6 +14,7 @@ app = vllm.entrypoints.api_server.app
|
||||
|
||||
class AsyncLLMEngineWithStats(AsyncLLMEngine):
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._num_aborts = 0
|
||||
|
@ -24,6 +24,7 @@ def _query_server(prompt: str) -> dict:
|
||||
def api_server():
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
# pylint: disable=consider-using-with
|
||||
uvicorn_process = subprocess.Popen([
|
||||
sys.executable, "-u",
|
||||
str(script_path), "--model", "facebook/opt-125m"
|
||||
@ -32,6 +33,7 @@ def api_server():
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
# pylint: disable=redefined-outer-name, unused-argument
|
||||
def test_api_server(api_server):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
@ -47,6 +49,7 @@ def test_api_server(api_server):
|
||||
prompts = ["Hello world"] * 1
|
||||
result = None
|
||||
while not result:
|
||||
# pylint: disable=bare-except
|
||||
try:
|
||||
for result in pool.map(_query_server, prompts):
|
||||
break
|
||||
|
@ -32,12 +32,12 @@ class MockEngine:
|
||||
self.request_id = None
|
||||
|
||||
def add_request(self, **kwargs):
|
||||
del kwargs # Unused
|
||||
self.add_request_calls += 1
|
||||
return
|
||||
|
||||
def abort_request(self, request_id):
|
||||
del request_id # Unused
|
||||
self.abort_request_calls += 1
|
||||
return
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||
|
@ -7,22 +7,22 @@ from vllm.outputs import RequestOutput
|
||||
class DummyEvent:
|
||||
|
||||
def __init__(self):
|
||||
self._flag = False
|
||||
self.flag = False
|
||||
|
||||
def set(self):
|
||||
self._flag = True
|
||||
self.flag = True
|
||||
|
||||
def clear(self):
|
||||
self._flag = False
|
||||
self.flag = False
|
||||
|
||||
|
||||
def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
tracker.new_requests_event = DummyEvent()
|
||||
stream_1 = tracker.add_request("1")
|
||||
assert tracker.new_requests_event._flag
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event._flag
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "1"
|
||||
assert not finished
|
||||
@ -30,9 +30,9 @@ def test_request_tracker():
|
||||
|
||||
stream_2 = tracker.add_request("2")
|
||||
stream_3 = tracker.add_request("3")
|
||||
assert tracker.new_requests_event._flag
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event._flag
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == "2"
|
||||
assert new[1]["request_id"] == "3"
|
||||
@ -43,7 +43,7 @@ def test_request_tracker():
|
||||
# request_ids must be unique
|
||||
with pytest.raises(KeyError):
|
||||
tracker.add_request("1")
|
||||
assert not tracker.new_requests_event._flag
|
||||
assert not tracker.new_requests_event.flag
|
||||
|
||||
tracker.abort_request("1")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
@ -54,7 +54,7 @@ def test_request_tracker():
|
||||
|
||||
stream_4 = tracker.add_request("4")
|
||||
tracker.abort_request("4")
|
||||
assert tracker.new_requests_event._flag
|
||||
assert tracker.new_requests_event.flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "4" in finished
|
||||
@ -62,11 +62,11 @@ def test_request_tracker():
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request("5")
|
||||
assert tracker.new_requests_event._flag
|
||||
assert tracker.new_requests_event.flag
|
||||
tracker.process_request_output(
|
||||
RequestOutput("2", "output", [], [], finished=True))
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event._flag
|
||||
assert not tracker.new_requests_event.flag
|
||||
assert len(finished) == 1
|
||||
assert "2" in finished
|
||||
assert len(new) == 1
|
||||
|
@ -8,6 +8,7 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
_TEST_PROMPTS = [
|
||||
# pylint: disable=line-too-long
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
||||
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
||||
|
82
tests/distributed/test_comm_ops.py
Normal file
82
tests/distributed/test_comm_ops.py
Normal file
@ -0,0 +1,82 @@
|
||||
"""Test the communication operators.
|
||||
|
||||
Run `pytest tests/distributed/test_comm_ops.py --forked`.
|
||||
"""
|
||||
from multiprocessing import Process
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.engine.ray_utils import get_open_port
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_all_gather,
|
||||
)
|
||||
from vllm.worker.worker import _init_distributed_environment
|
||||
|
||||
|
||||
def init_test_distributed_environment(pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
parallel_config = ParallelConfig(pipeline_parallel_size,
|
||||
tensor_parallel_size,
|
||||
worker_use_ray=True)
|
||||
distributed_init_method = f"tcp://localhost:{distributed_init_port}"
|
||||
torch.cuda.set_device(rank)
|
||||
_init_distributed_environment(parallel_config, rank,
|
||||
distributed_init_method)
|
||||
|
||||
|
||||
def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
distributed_init_port)
|
||||
num_elements = 8
|
||||
all_tensors = [
|
||||
torch.arange(num_elements, dtype=torch.float32, device="cuda") *
|
||||
(r + 1) for r in range(tensor_parallel_size)
|
||||
]
|
||||
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||
t = all_tensors[rank]
|
||||
t = tensor_model_parallel_all_reduce(t)
|
||||
assert torch.allclose(t, expected)
|
||||
|
||||
|
||||
def all_gather_test_worker(tensor_parallel_size: int, rank: int,
|
||||
distributed_init_port: str):
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
distributed_init_port)
|
||||
num_dimensions = 3
|
||||
tensor_size = list(range(2, num_dimensions + 2))
|
||||
total_size = 1
|
||||
for s in tensor_size:
|
||||
total_size *= s
|
||||
for all_gather_dimension in range(num_dimensions):
|
||||
all_tensors = [
|
||||
torch.arange(total_size, dtype=torch.float32,
|
||||
device="cuda").reshape(tensor_size) * (r + 1)
|
||||
for r in range(tensor_parallel_size)
|
||||
]
|
||||
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
||||
t = all_tensors[rank]
|
||||
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
|
||||
assert torch.allclose(t, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [2])
|
||||
@pytest.mark.parametrize("test_target",
|
||||
[all_reduce_test_worker, all_gather_test_worker])
|
||||
def test_multi_process_tensor_parallel(tensor_parallel_size, test_target):
|
||||
distributed_init_port = get_open_port()
|
||||
processes = []
|
||||
for rank in range(tensor_parallel_size):
|
||||
p = Process(target=test_target,
|
||||
args=(tensor_parallel_size, rank, distributed_init_port))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
for p in processes:
|
||||
p.join()
|
||||
assert all(p.exitcode == 0 for p in processes)
|
@ -5,6 +5,7 @@ from transformers import AutoTokenizer
|
||||
from vllm.transformers_utils.tokenizer import detokenize_incrementally
|
||||
|
||||
TRUTH = [
|
||||
# pylint: disable=line-too-long
|
||||
"Hello here, this is a simple test",
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving",
|
||||
"我很感谢你的热情"
|
||||
|
@ -29,8 +29,8 @@ def test_silu_and_mul(
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||
activation_ops.silu_and_mul(out, x)
|
||||
ref_out = ref_silu_and_mul(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
@ -49,8 +49,8 @@ def test_gelu_new(
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||
activation_ops.gelu_new(out, x)
|
||||
ref_out = get_activation("gelu_new")(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
@ -68,8 +68,8 @@ def test_gelu_fast(
|
||||
) -> None:
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
|
||||
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
|
||||
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda")
|
||||
activation_ops.gelu_fast(out, x)
|
||||
ref_out = get_activation("gelu_fast")(x)
|
||||
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
|
||||
|
@ -106,14 +106,14 @@ def test_reshape_and_cache(
|
||||
# Create a random slot mapping.
|
||||
num_slots = block_size * num_blocks
|
||||
slot_mapping = random.sample(range(num_slots), num_tokens)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device='cuda')
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device="cuda")
|
||||
|
||||
qkv = torch.randn(num_tokens,
|
||||
3,
|
||||
num_heads,
|
||||
head_size,
|
||||
dtype=dtype,
|
||||
device='cuda')
|
||||
device="cuda")
|
||||
_, key, value = qkv.unbind(dim=1)
|
||||
|
||||
# Create the KV caches.
|
||||
@ -132,7 +132,7 @@ def test_reshape_and_cache(
|
||||
|
||||
# Run the reference implementation.
|
||||
reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape)
|
||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode='floor')
|
||||
block_indicies = torch.div(slot_mapping, block_size, rounding_mode="floor")
|
||||
block_indicies = block_indicies.cpu().tolist()
|
||||
block_offsets = slot_mapping % block_size
|
||||
block_offsets = block_offsets.cpu().tolist()
|
||||
|
@ -140,7 +140,7 @@ def test_rotary_embedding(
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cos_sin_cache = torch.cat((cos, sin), dim=-1)
|
||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda')
|
||||
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
|
||||
|
||||
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
|
||||
out_query = query.clone()
|
||||
|
@ -1,3 +1,4 @@
|
||||
# pylint: disable=protected-access
|
||||
import pytest
|
||||
import random
|
||||
from typing import Tuple
|
||||
@ -108,7 +109,7 @@ def test_sampler_all_random(seed: int):
|
||||
def test_sampler_all_beam(seed: int):
|
||||
set_random_seed(seed)
|
||||
batch_size = random.randint(1, 256)
|
||||
input_tensor, fake_logits, sampler, worker = _prepare_test(batch_size)
|
||||
input_tensor, _, sampler, worker = _prepare_test(batch_size)
|
||||
|
||||
seq_group_metadata_list = []
|
||||
for i in range(batch_size):
|
||||
|
@ -1,7 +1,7 @@
|
||||
from vllm.model_executor.layers.quantized_linear.awq import (
|
||||
AWQColumnParallelLinear, AWQRowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
|
||||
_QUANTIZED_LINEAR_REGISTRY = {
|
||||
"awq": (AWQColumnParallelLinear, AWQRowParallelLinear),
|
||||
|
@ -4,8 +4,8 @@ import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import quantization_ops
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel.layers import (
|
||||
ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
|
||||
|
||||
class AWQColumnParallelLinear(ColumnParallelLinear):
|
||||
|
@ -5,8 +5,8 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
gather_from_tensor_model_parallel_region)
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceOutputs
|
||||
|
||||
@ -92,7 +92,7 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = gather_from_tensor_model_parallel_region(logits)
|
||||
logits = tensor_model_parallel_all_gather(logits)
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :vocab_size]
|
||||
return logits
|
||||
|
@ -39,8 +39,9 @@ from vllm.model_executor.weight_utils import (
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.aquila import AquilaConfig
|
||||
|
||||
@ -56,16 +57,18 @@ class AquilaMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.gate_up_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -130,14 +133,12 @@ class AquilaAttention(nn.Module):
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
@ -230,7 +231,7 @@ class AquilaModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
@ -270,11 +271,12 @@ class AquilaForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.model = AquilaModel(config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
|
@ -39,8 +39,9 @@ from vllm.model_executor.weight_utils import (
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.baichuan import BaiChuanConfig
|
||||
|
||||
@ -81,16 +82,18 @@ class BaiChuanMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.gate_up_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -133,14 +136,12 @@ class BaiChuanAttention(nn.Module):
|
||||
3 * hidden_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
# Create the alibi slopes and slice them.
|
||||
if self.postion_embedding == "ALIBI":
|
||||
@ -249,7 +250,7 @@ class BaiChuanModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
BaiChuanDecoderLayer(config, position_embedding)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
@ -288,11 +289,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = BaiChuanModel(config, position_embedding)
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
|
@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -85,14 +86,12 @@ class BloomAttention(nn.Module):
|
||||
3 * self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
|
||||
# Create the alibi slopes and slice them.
|
||||
@ -129,15 +128,17 @@ class BloomMLP(nn.Module):
|
||||
def __init__(self, config: BloomConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
|
||||
4 * hidden_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
4 * hidden_size,
|
||||
gather_output=False,
|
||||
)
|
||||
self.act = get_act_fn("gelu")
|
||||
self.dense_4h_to_h = RowParallelLinear(4 * hidden_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
4 * hidden_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, _ = self.dense_h_to_4h(x)
|
||||
@ -208,7 +209,9 @@ class BloomModel(nn.Module):
|
||||
|
||||
# Embedding + LN Embedding
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
config.vocab_size, self.embed_dim, perform_initialization=False)
|
||||
config.vocab_size,
|
||||
self.embed_dim,
|
||||
)
|
||||
self.word_embeddings_layernorm = nn.LayerNorm(
|
||||
self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
|
||||
|
@ -36,9 +36,11 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear,
|
||||
reduce_from_tensor_model_parallel_region)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs import RWConfig
|
||||
|
||||
@ -109,7 +111,6 @@ class FalconAttention(nn.Module):
|
||||
self.head_dim,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
elif self.multi_query:
|
||||
@ -120,7 +121,6 @@ class FalconAttention(nn.Module):
|
||||
self.total_num_heads * self.head_dim,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
self.key_value = FalconLinear(self.hidden_size,
|
||||
@ -135,7 +135,6 @@ class FalconAttention(nn.Module):
|
||||
self.head_dim,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
)
|
||||
|
||||
@ -151,7 +150,6 @@ class FalconAttention(nn.Module):
|
||||
self.hidden_size,
|
||||
bias=config.bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
reduce_results=self.reduce_row_parallel_results)
|
||||
|
||||
@ -231,7 +229,6 @@ class FalconMLP(nn.Module):
|
||||
4 * hidden_size,
|
||||
bias=config.bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True)
|
||||
self.act = nn.GELU()
|
||||
self.reduce_row_parallel_results = not (config.new_decoder_architecture
|
||||
@ -241,7 +238,6 @@ class FalconMLP(nn.Module):
|
||||
hidden_size,
|
||||
bias=config.bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
skip_bias_add=True,
|
||||
reduce_results=self.reduce_row_parallel_results)
|
||||
|
||||
@ -325,7 +321,7 @@ class FalconDecoderLayer(nn.Module):
|
||||
# only one all-reduce operator to reduce the results from
|
||||
# both MLP and Attention layers.
|
||||
mlp_output += attention_output
|
||||
mlp_output = reduce_from_tensor_model_parallel_region(mlp_output)
|
||||
mlp_output = tensor_model_parallel_all_reduce(mlp_output)
|
||||
if attention_bias is not None:
|
||||
mlp_output += attention_bias
|
||||
if mlp_bias is not None:
|
||||
@ -347,7 +343,9 @@ class FalconModel(nn.Module):
|
||||
|
||||
# Embedding + LN Embedding
|
||||
self.word_embeddings = VocabParallelEmbedding(
|
||||
config.vocab_size, self.embed_dim, perform_initialization=False)
|
||||
config.vocab_size,
|
||||
self.embed_dim,
|
||||
)
|
||||
|
||||
# Transformer blocks
|
||||
self.h = nn.ModuleList([
|
||||
@ -389,11 +387,12 @@ class FalconForCausalLM(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.transformer = FalconModel(config)
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
|
@ -36,8 +36,9 @@ from vllm.model_executor.weight_utils import (
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -56,16 +57,18 @@ class GPT2Attention(nn.Module):
|
||||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.c_attn = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
3 * self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scale)
|
||||
@ -95,16 +98,18 @@ class GPT2MLP(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.c_fc = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -37,8 +37,9 @@ from vllm.model_executor.weight_utils import (
|
||||
load_padded_tensor_parallel_vocab, load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -62,29 +63,31 @@ class GPTBigCodeAttention(nn.Module):
|
||||
if self.multi_query:
|
||||
self.num_kv_heads = 1
|
||||
self.kv_dim = self.head_dim
|
||||
self.c_attn_q = ColumnParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_attn_q = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
)
|
||||
self.c_attn_kv = nn.Linear(self.hidden_size,
|
||||
2 * self.kv_dim,
|
||||
bias=True)
|
||||
else:
|
||||
self.num_kv_heads = self.num_heads
|
||||
self.kv_dim = self.num_kv_heads * self.head_dim
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size,
|
||||
self.hidden_size +
|
||||
2 * self.kv_dim,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_attn = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size + 2 * self.kv_dim,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
)
|
||||
|
||||
self.c_proj = RowParallelLinear(self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.hidden_size,
|
||||
self.hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scale,
|
||||
@ -124,16 +127,18 @@ class GPTBigMLP(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
hidden_size = config.hidden_size
|
||||
self.c_fc = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.c_fc = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -49,16 +50,18 @@ class GPTJAttention(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_size = self.hidden_size // self.total_num_heads
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(config.hidden_size,
|
||||
3 * config.hidden_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.out_proj = RowParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
3 * config.hidden_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
assert self.total_num_heads % tp_world_size == 0
|
||||
@ -102,14 +105,16 @@ class GPTJMLP(nn.Module):
|
||||
def __init__(self, intermediate_size: int, config: GPTJConfig):
|
||||
super().__init__()
|
||||
hidden_size = config.n_embd
|
||||
self.fc_in = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.fc_out = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.fc_in = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
gather_output=False,
|
||||
)
|
||||
self.fc_out = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@ -159,9 +164,10 @@ class GPTJModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.n_embd
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size,
|
||||
self.embed_dim,
|
||||
perform_initialization=False)
|
||||
self.wte = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
self.embed_dim,
|
||||
)
|
||||
self.h = nn.ModuleList(
|
||||
[GPTJBlock(config) for _ in range(config.n_layer)])
|
||||
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
|
||||
@ -199,10 +205,11 @@ class GPTJForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
assert not config.tie_word_embeddings
|
||||
self.transformer = GPTJModel(config)
|
||||
self.lm_head = ColumnParallelLinear(config.n_embd,
|
||||
config.vocab_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.n_embd,
|
||||
config.vocab_size,
|
||||
gather_output=False,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
|
@ -34,8 +34,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -59,11 +60,12 @@ class GPTNeoXAttention(nn.Module):
|
||||
config.hidden_size,
|
||||
3 * config.hidden_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.dense = RowParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
)
|
||||
self.dense = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
|
||||
scaling = self.head_size**-0.5
|
||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||
@ -100,14 +102,16 @@ class GPTNeoXMLP(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.dense_h_to_4h = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.dense_h_to_4h = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
gather_output=False,
|
||||
)
|
||||
self.dense_4h_to_h = RowParallelLinear(
|
||||
config.intermediate_size,
|
||||
config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@ -169,9 +173,10 @@ class GPTNeoXModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embed_in = VocabParallelEmbedding(config.vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.embed_in = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size,
|
||||
@ -209,11 +214,12 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.gpt_neox = GPTNeoXModel(config)
|
||||
self.embed_out = ColumnParallelLinear(config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.embed_out = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
config.vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
|
@ -12,8 +12,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
ColumnParallelLinear, RowParallelLinear, VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.weight_utils import (
|
||||
hf_model_weights_iterator, load_padded_tensor_parallel_vocab,
|
||||
load_tensor_parallel_weights)
|
||||
@ -31,16 +32,18 @@ class InternLMMLP(nn.Module):
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = ColumnParallelLinear(hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.gate_up_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
@ -80,14 +83,12 @@ class InternLMAttention(nn.Module):
|
||||
3 * self.total_num_heads * self.head_dim,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=True,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
self.num_heads,
|
||||
@ -176,7 +177,9 @@ class InternLMModel(nn.Module):
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
vocab_size, config.hidden_size, perform_initialization=False)
|
||||
vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
InternLMDecoderLayer(config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
@ -216,11 +219,12 @@ class InternLMForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.model = InternLMModel(config)
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.lm_head = ColumnParallelLinear(config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.lm_head = ColumnParallelLinear(
|
||||
config.hidden_size,
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
|
@ -39,8 +39,7 @@ from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.quantized_linear import ParallelLinear
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
|
||||
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
@ -64,13 +63,11 @@ class LlamaMLP(nn.Module):
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = ParallelLinear.row(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
@ -127,7 +124,6 @@ class LlamaAttention(nn.Module):
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = ParallelLinear.row(
|
||||
@ -135,7 +131,6 @@ class LlamaAttention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
@ -241,7 +236,9 @@ class LlamaModel(nn.Module):
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
vocab_size, config.hidden_size, perform_initialization=False)
|
||||
vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
LlamaDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
@ -291,7 +288,6 @@ class LlamaForCausalLM(nn.Module):
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=None)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
|
@ -38,8 +38,7 @@ from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.quantized_linear import ParallelLinear
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding
|
||||
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||
from vllm.model_executor.weight_utils import (
|
||||
convert_pyslice_to_tensor, hf_model_weights_iterator,
|
||||
@ -64,13 +63,11 @@ class MistralMLP(nn.Module):
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config)
|
||||
self.down_proj = ParallelLinear.row(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
@ -116,7 +113,6 @@ class MistralAttention(nn.Module):
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.o_proj = ParallelLinear.row(
|
||||
@ -124,7 +120,6 @@ class MistralAttention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads,
|
||||
@ -225,7 +220,9 @@ class MistralModel(nn.Module):
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
vocab_size, config.hidden_size, perform_initialization=False)
|
||||
vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
MistralDecoderLayer(config, quant_config)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
@ -275,7 +272,6 @@ class MistralForCausalLM(nn.Module):
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
quant_config=None)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
|
@ -15,8 +15,9 @@ from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.transformers_utils.configs.mpt import MPTConfig
|
||||
|
||||
@ -53,7 +54,6 @@ class MPTAttention(nn.Module):
|
||||
3 * self.d_model,
|
||||
bias=not config.no_bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
if self.qk_ln:
|
||||
self.q_ln = nn.LayerNorm(self.d_model)
|
||||
@ -63,7 +63,6 @@ class MPTAttention(nn.Module):
|
||||
self.d_model,
|
||||
bias=not config.no_bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
@ -113,17 +112,19 @@ class MPTMLP(nn.Module):
|
||||
hidden_size = config.d_model
|
||||
expansion_ratio = config.expansion_ratio
|
||||
intermediate_size = expansion_ratio * hidden_size
|
||||
self.up_proj = ColumnParallelLinear(hidden_size,
|
||||
intermediate_size,
|
||||
bias=not config.no_bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.up_proj = ColumnParallelLinear(
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
bias=not config.no_bias,
|
||||
gather_output=False,
|
||||
)
|
||||
self.act = get_act_fn("gelu")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=not config.no_bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.down_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=not config.no_bias,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, _ = self.up_proj(x)
|
||||
@ -172,9 +173,10 @@ class MPTModel(nn.Module):
|
||||
assert config.embedding_fraction == 1.0
|
||||
assert config.norm_type == "low_precision_layernorm"
|
||||
|
||||
self.wte = VocabParallelEmbedding(config.vocab_size,
|
||||
config.d_model,
|
||||
perform_initialization=False)
|
||||
self.wte = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.d_model,
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[MPTBlock(config) for _ in range(config.n_layers)])
|
||||
self.norm_f = nn.LayerNorm(config.d_model)
|
||||
|
@ -35,8 +35,9 @@ from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -73,16 +74,18 @@ class OPTAttention(nn.Module):
|
||||
self.head_dim = embed_dim // total_num_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.qkv_proj = ColumnParallelLinear(embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.out_proj = RowParallelLinear(embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.qkv_proj = ColumnParallelLinear(
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
gather_output=False,
|
||||
)
|
||||
self.out_proj = RowParallelLinear(
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scaling)
|
||||
@ -120,16 +123,18 @@ class OPTDecoderLayer(nn.Module):
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.fc1 = ColumnParallelLinear(self.embed_dim,
|
||||
config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.fc2 = RowParallelLinear(config.ffn_dim,
|
||||
self.embed_dim,
|
||||
bias=config.enable_bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.fc1 = ColumnParallelLinear(
|
||||
self.embed_dim,
|
||||
config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
gather_output=False,
|
||||
)
|
||||
self.fc2 = RowParallelLinear(
|
||||
config.ffn_dim,
|
||||
self.embed_dim,
|
||||
bias=config.enable_bias,
|
||||
input_is_parallel=True,
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim,
|
||||
elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
@ -182,7 +187,7 @@ class OPTDecoder(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.word_embed_proj_dim,
|
||||
perform_initialization=False)
|
||||
)
|
||||
# Positional embeddings are replicated (not sharded).
|
||||
self.embed_positions = OPTLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size)
|
||||
|
@ -28,7 +28,7 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
from vllm.model_executor.parallel_utils.layers import (
|
||||
VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
@ -53,14 +53,12 @@ class QWenMLP(nn.Module):
|
||||
2 * intermediate_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
@ -98,14 +96,12 @@ class QWenAttention(nn.Module):
|
||||
3 * hidden_size,
|
||||
bias=True,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.c_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = PagedAttentionWithRoPE(
|
||||
@ -190,9 +186,10 @@ class QWenModel(nn.Module):
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
||||
self.wte = VocabParallelEmbedding(vocab_size,
|
||||
config.hidden_size,
|
||||
perform_initialization=False)
|
||||
self.wte = VocabParallelEmbedding(
|
||||
vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.h = nn.ModuleList(
|
||||
[QWenBlock(config) for _ in range(config.num_hidden_layers)])
|
||||
self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
|
||||
@ -235,7 +232,6 @@ class QWenLMHeadModel(nn.Module):
|
||||
vocab_size,
|
||||
bias=False,
|
||||
gather_output=False,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
|
@ -1,7 +0,0 @@
|
||||
import vllm.model_executor.parallel_utils.parallel_state
|
||||
import vllm.model_executor.parallel_utils.tensor_parallel
|
||||
|
||||
__all__ = [
|
||||
"parallel_state",
|
||||
"tensor_parallel",
|
||||
]
|
47
vllm/model_executor/parallel_utils/communication_op.py
Normal file
47
vllm/model_executor/parallel_utils/communication_op.py
Normal file
@ -0,0 +1,47 @@
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_group,
|
||||
)
|
||||
|
||||
|
||||
def tensor_model_parallel_all_reduce(input_):
|
||||
"""All-reduce the input tensor across model parallel group.
|
||||
|
||||
Note: This operation is applied in-place on the input tensor.
|
||||
"""
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
return input_
|
||||
# All-reduce.
|
||||
torch.distributed.all_reduce(input_,
|
||||
group=get_tensor_model_parallel_group())
|
||||
return input_
|
||||
|
||||
|
||||
def tensor_model_parallel_all_gather(input_, dim=-1):
|
||||
"""All-gather the input tensor across model parallel group."""
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= dim < input_.dim(), (
|
||||
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
|
||||
if dim < 0:
|
||||
# Convert negative dim to positive.
|
||||
dim += input_.dim()
|
||||
input_size = input_.size()
|
||||
# Allocate output tensor.
|
||||
output_tensor = torch.empty((world_size, ) + input_size,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device)
|
||||
# All-gather.
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
output_tensor, input_, group=get_tensor_model_parallel_group())
|
||||
# Reshape
|
||||
output_tensor = output_tensor.movedim(0, dim)
|
||||
output_tensor = output_tensor.reshape(input_size[:dim] +
|
||||
(world_size * input_size[dim], ) +
|
||||
input_size[dim + 1:])
|
||||
return output_tensor
|
303
vllm/model_executor/parallel_utils/layers.py
Normal file
303
vllm/model_executor/parallel_utils/layers.py
Normal file
@ -0,0 +1,303 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
# repo: https://github.com/pytorch/pytorch
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.quantization_utils import QuantizationConfig
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather)
|
||||
|
||||
from vllm.model_executor.parallel_utils.utils import (
|
||||
divide,
|
||||
VocabUtility,
|
||||
split_tensor_along_last_dim,
|
||||
)
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
This is mainly adapted from torch.nn.Embedding and all the default
|
||||
values are kept.
|
||||
Arguments:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
params_dtype: type of the parameters.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None):
|
||||
super().__init__()
|
||||
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
# TODO: Handle vocab padding here.
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = (
|
||||
VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings, get_tensor_model_parallel_rank(),
|
||||
self.tp_size))
|
||||
self.num_embeddings_per_partition = (self.vocab_end_index -
|
||||
self.vocab_start_index)
|
||||
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.num_embeddings_per_partition,
|
||||
self.embedding_dim,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tp_size > 1:
|
||||
# Build the mask.
|
||||
input_mask = ((input_ < self.vocab_start_index) |
|
||||
(input_ >= self.vocab_end_index))
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(masked_input, self.weight)
|
||||
# Mask the output embedding.
|
||||
if self.tp_size > 1:
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
|
||||
Keyword Arguments
|
||||
bias: If true, add bias
|
||||
gather_output: If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is Y_i = XA_i
|
||||
skip_bias_add: This was added to enable performance optimizations where
|
||||
bias can be fused with other element-wise operations. we
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.gather_output = gather_output
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.quant_config = quant_config
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
self.create_weights(params_dtype)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
self.input_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=dtype))
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
return F.linear(x, self.weight, bias)
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of ColumnParallelLinear
|
||||
|
||||
Args:
|
||||
input_: Tensor whose last dimension is `input_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
input_parallel = input_
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply_weights(input_parallel, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class RowParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| A_1 |
|
||||
| . |
|
||||
A = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| A_p |
|
||||
- -
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
|
||||
Keyword Arguments:
|
||||
bias: If true, add bias. Note that bias is not parallelized.
|
||||
input_is_parallel: If true, we assume that the input is already
|
||||
split across the GPUs and we do not split
|
||||
again.
|
||||
skip_bias_add: This was added to enable performance optimization where
|
||||
bias can be fused with other element-wise operations.
|
||||
We skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configuration.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.create_weights(params_dtype)
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError('When not reduce the results, adding bias to the '
|
||||
'results can lead to incorrect results')
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
self.weight = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
self.input_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=dtype))
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(x, self.weight)
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of RowParallelLinear
|
||||
|
||||
Args:
|
||||
input_: tensor whose last dimension is `input_size`. If
|
||||
`input_is_parallel` is set, then the last dimension
|
||||
is `input_size // tp_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
# Set up backprop all-reduce.
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
# TODO: simplify code below
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply_weights(input_parallel)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.skip_bias_add:
|
||||
output = output_ + self.bias if self.bias is not None else output_
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.bias
|
||||
return output, output_bias
|
@ -1,78 +1,42 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
"""Model and data parallel groups."""
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
# Intra-layer model parallel group that the current rank belongs to.
|
||||
# Tensor model parallel group that the current rank belongs to.
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
# Inter-layer model parallel group that the current rank belongs to.
|
||||
# Pipeline model parallel group that the current rank belongs to.
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
|
||||
_MODEL_PARALLEL_GROUP = None
|
||||
# Embedding group.
|
||||
_EMBEDDING_GROUP = None
|
||||
# Position embedding group.
|
||||
_POSITION_EMBEDDING_GROUP = None
|
||||
# Data parallel group that the current rank belongs to.
|
||||
_DATA_PARALLEL_GROUP = None
|
||||
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
|
||||
|
||||
# These values enable us to change the mpu sizes on the fly.
|
||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
|
||||
# A list of ranks that have a copy of the embedding.
|
||||
_EMBEDDING_GLOBAL_RANKS = None
|
||||
|
||||
# A list of ranks that have a copy of the position embedding.
|
||||
_POSITION_EMBEDDING_GLOBAL_RANKS = None
|
||||
|
||||
# A list of global ranks for each pipeline group to ease calculation of the source
|
||||
# rank when broadcasting from the first or last pipeline stage.
|
||||
# A list of global ranks for each pipeline group to ease calculation of the
|
||||
# source rank when broadcasting from the first or last pipeline stage.
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
|
||||
# A list of global ranks for each data parallel group to ease calculation of the source
|
||||
# rank when broadcasting weights from src to all other data parallel ranks
|
||||
_DATA_PARALLEL_GLOBAL_RANKS = None
|
||||
|
||||
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
virtual_pipeline_model_parallel_size: Optional[int] = None,
|
||||
pipeline_model_parallel_split_rank: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize model data parallel groups.
|
||||
Initialize model parallel groups.
|
||||
|
||||
Arguments:
|
||||
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
|
||||
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
|
||||
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
|
||||
pipeline).
|
||||
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
|
||||
rank in pipeline with split point.
|
||||
tensor_model_parallel_size: number of GPUs used for tensor model
|
||||
parallelism.
|
||||
pipeline_model_parallel_size: number of GPUs used for pipeline model
|
||||
parallelism.
|
||||
|
||||
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
|
||||
Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
|
||||
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
||||
the model pipeline. The present function will
|
||||
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
|
||||
and 8 data-parallel groups as:
|
||||
8 data_parallel groups:
|
||||
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
|
||||
8 tensor model-parallel groups:
|
||||
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
|
||||
4 pipeline model-parallel groups:
|
||||
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
|
||||
create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
|
||||
4 tensor model-parallel groups:
|
||||
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
|
||||
2 pipeline model-parallel groups:
|
||||
[g0, g2, g4, g6], [g1, g3, g5, g7]
|
||||
Note that for efficiency, the caller should make sure adjacent ranks
|
||||
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
||||
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
||||
@ -82,64 +46,23 @@ def initialize_model_parallel(
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
|
||||
if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
|
||||
if (world_size !=
|
||||
tensor_model_parallel_size * pipeline_model_parallel_size):
|
||||
raise RuntimeError(
|
||||
f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
|
||||
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
|
||||
)
|
||||
|
||||
data_parallel_size: int = world_size // (tensor_model_parallel_size *
|
||||
pipeline_model_parallel_size)
|
||||
|
||||
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
|
||||
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
||||
num_data_parallel_groups: int = world_size // data_parallel_size
|
||||
|
||||
if virtual_pipeline_model_parallel_size is not None:
|
||||
if not pipeline_model_parallel_size > 2:
|
||||
raise RuntimeError("pipeline-model-parallel size should be greater than 2 with "
|
||||
"interleaved schedule")
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
|
||||
|
||||
if pipeline_model_parallel_split_rank is not None:
|
||||
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
||||
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
|
||||
f"world_size ({world_size}) is not equal to "
|
||||
f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
|
||||
f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")
|
||||
|
||||
num_tensor_model_parallel_groups: int = (world_size //
|
||||
tensor_model_parallel_size)
|
||||
num_pipeline_model_parallel_groups: int = (world_size //
|
||||
pipeline_model_parallel_size)
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
# Build the data-parallel groups.
|
||||
global _DATA_PARALLEL_GROUP
|
||||
global _DATA_PARALLEL_GLOBAL_RANKS
|
||||
assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
|
||||
all_data_parallel_group_ranks = []
|
||||
for i in range(pipeline_model_parallel_size):
|
||||
start_rank = i * num_pipeline_model_parallel_groups
|
||||
end_rank = (i + 1) * num_pipeline_model_parallel_groups
|
||||
for j in range(tensor_model_parallel_size):
|
||||
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
|
||||
all_data_parallel_group_ranks.append(list(ranks))
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_DATA_PARALLEL_GROUP = group
|
||||
_DATA_PARALLEL_GLOBAL_RANKS = ranks
|
||||
|
||||
# Build the model-parallel groups.
|
||||
global _MODEL_PARALLEL_GROUP
|
||||
assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
|
||||
for i in range(data_parallel_size):
|
||||
ranks = [data_parallel_group_ranks[i]
|
||||
for data_parallel_group_ranks in all_data_parallel_group_ranks]
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_MODEL_PARALLEL_GROUP = group
|
||||
|
||||
# Build the tensor model-parallel groups.
|
||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
|
||||
'tensor model parallel group is already initialized'
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
|
||||
"tensor model parallel group is already initialized")
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
ranks = range(i * tensor_model_parallel_size,
|
||||
(i + 1) * tensor_model_parallel_size)
|
||||
@ -147,268 +70,60 @@ def initialize_model_parallel(
|
||||
if rank in ranks:
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = group
|
||||
|
||||
# Build the pipeline model-parallel groups and embedding groups
|
||||
# (first and last rank in each pipeline model-parallel group).
|
||||
# Build the pipeline model-parallel groups.
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
|
||||
'pipeline model parallel group is already initialized'
|
||||
global _EMBEDDING_GROUP
|
||||
global _EMBEDDING_GLOBAL_RANKS
|
||||
assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
|
||||
global _POSITION_EMBEDDING_GROUP
|
||||
global _POSITION_EMBEDDING_GLOBAL_RANKS
|
||||
assert _POSITION_EMBEDDING_GROUP is None, \
|
||||
'position embedding group is already initialized'
|
||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
|
||||
"pipeline model parallel group is already initialized")
|
||||
for i in range(num_pipeline_model_parallel_groups):
|
||||
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = group
|
||||
_PIPELINE_GLOBAL_RANKS = ranks
|
||||
# Setup embedding group (to exchange gradients between
|
||||
# first and last stages).
|
||||
if len(ranks) > 1:
|
||||
embedding_ranks = [ranks[0], ranks[-1]]
|
||||
position_embedding_ranks = [ranks[0]]
|
||||
if pipeline_model_parallel_split_rank is not None:
|
||||
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
|
||||
embedding_ranks = [ranks[0],
|
||||
ranks[pipeline_model_parallel_split_rank],
|
||||
ranks[-1]]
|
||||
if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
|
||||
position_embedding_ranks = [ranks[0],
|
||||
ranks[pipeline_model_parallel_split_rank]]
|
||||
else:
|
||||
embedding_ranks = ranks
|
||||
position_embedding_ranks = ranks
|
||||
|
||||
group = torch.distributed.new_group(embedding_ranks)
|
||||
if rank in embedding_ranks:
|
||||
_EMBEDDING_GROUP = group
|
||||
if rank in ranks:
|
||||
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
|
||||
|
||||
group = torch.distributed.new_group(position_embedding_ranks)
|
||||
if rank in position_embedding_ranks:
|
||||
_POSITION_EMBEDDING_GROUP = group
|
||||
if rank in ranks:
|
||||
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if model and data parallel groups are initialized."""
|
||||
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP is None or \
|
||||
_DATA_PARALLEL_GROUP is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_model_parallel_group():
|
||||
"""Get the model parallel group the caller rank belongs to."""
|
||||
assert _MODEL_PARALLEL_GROUP is not None, \
|
||||
'model parallel group is not initialized'
|
||||
return _MODEL_PARALLEL_GROUP
|
||||
return (_TENSOR_MODEL_PARALLEL_GROUP is not None
|
||||
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
|
||||
|
||||
|
||||
def get_tensor_model_parallel_group():
|
||||
"""Get the tensor model parallel group the caller rank belongs to."""
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
|
||||
'intra_layer_model parallel group is not initialized'
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
|
||||
"tenosr model parallel group is not initialized")
|
||||
return _TENSOR_MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_group():
|
||||
"""Get the pipeline model parallel group the caller rank belongs to."""
|
||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
|
||||
'pipeline_model parallel group is not initialized'
|
||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
|
||||
"pipeline model parallel group is not initialized")
|
||||
return _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_data_parallel_group():
|
||||
"""Get the data parallel group the caller rank belongs to."""
|
||||
assert _DATA_PARALLEL_GROUP is not None, \
|
||||
'data parallel group is not initialized'
|
||||
return _DATA_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_embedding_group():
|
||||
"""Get the embedding group the caller rank belongs to."""
|
||||
assert _EMBEDDING_GROUP is not None, \
|
||||
'embedding group is not initialized'
|
||||
return _EMBEDDING_GROUP
|
||||
|
||||
|
||||
def get_position_embedding_group():
|
||||
"""Get the position embedding group the caller rank belongs to."""
|
||||
assert _POSITION_EMBEDDING_GROUP is not None, \
|
||||
'position embedding group is not initialized'
|
||||
return _POSITION_EMBEDDING_GROUP
|
||||
|
||||
|
||||
def set_tensor_model_parallel_world_size(world_size):
|
||||
"""Set the tensor model parallel size"""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
|
||||
|
||||
|
||||
def set_pipeline_model_parallel_world_size(world_size):
|
||||
"""Set the pipeline model parallel size"""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
|
||||
|
||||
|
||||
def get_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
|
||||
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
|
||||
return torch.distributed.get_world_size(
|
||||
group=get_tensor_model_parallel_group())
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_world_size():
|
||||
"""Return world size for the pipeline model parallel group."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
|
||||
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
|
||||
|
||||
|
||||
def set_tensor_model_parallel_rank(rank):
|
||||
"""Set tensor model parallel rank."""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
|
||||
|
||||
|
||||
def set_pipeline_model_parallel_rank(rank):
|
||||
"""Set pipeline model parallel rank."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
|
||||
|
||||
|
||||
def set_pipeline_model_parallel_split_rank(rank):
|
||||
"""Set pipeline model parallel split rank."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
|
||||
return torch.distributed.get_world_size(
|
||||
group=get_pipeline_model_parallel_group())
|
||||
|
||||
|
||||
def get_tensor_model_parallel_rank():
|
||||
"""Return my rank for the tensor model parallel group."""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
|
||||
return _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_rank():
|
||||
"""Return my rank for the pipeline model parallel group."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
|
||||
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
|
||||
|
||||
|
||||
|
||||
def is_pipeline_first_stage(ignore_virtual=False):
|
||||
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
|
||||
if not ignore_virtual:
|
||||
if get_virtual_pipeline_model_parallel_world_size() is not None and \
|
||||
get_virtual_pipeline_model_parallel_rank() != 0:
|
||||
return False
|
||||
return get_pipeline_model_parallel_rank() == 0
|
||||
|
||||
|
||||
def is_pipeline_last_stage(ignore_virtual=False):
|
||||
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
|
||||
if not ignore_virtual:
|
||||
virtual_pipeline_model_parallel_world_size = \
|
||||
get_virtual_pipeline_model_parallel_world_size()
|
||||
if virtual_pipeline_model_parallel_world_size is not None and \
|
||||
get_virtual_pipeline_model_parallel_rank() != (
|
||||
virtual_pipeline_model_parallel_world_size - 1):
|
||||
return False
|
||||
return get_pipeline_model_parallel_rank() == (
|
||||
get_pipeline_model_parallel_world_size() - 1)
|
||||
|
||||
|
||||
def is_rank_in_embedding_group(ignore_virtual=False):
|
||||
"""Return true if current rank is in embedding group, False otherwise."""
|
||||
rank = torch.distributed.get_rank()
|
||||
global _EMBEDDING_GLOBAL_RANKS
|
||||
if ignore_virtual:
|
||||
return rank in _EMBEDDING_GLOBAL_RANKS
|
||||
if rank in _EMBEDDING_GLOBAL_RANKS:
|
||||
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
|
||||
return is_pipeline_first_stage(ignore_virtual=False)
|
||||
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
|
||||
return is_pipeline_last_stage(ignore_virtual=False)
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_rank_in_position_embedding_group():
|
||||
"""Return true if current rank is in position embedding group, False otherwise."""
|
||||
rank = torch.distributed.get_rank()
|
||||
global _POSITION_EMBEDDING_GLOBAL_RANKS
|
||||
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
|
||||
|
||||
|
||||
def is_pipeline_stage_before_split(rank=None):
|
||||
"""Return True if pipeline stage executes encoder block for a model
|
||||
with both encoder and decoder."""
|
||||
if get_pipeline_model_parallel_world_size() == 1:
|
||||
return True
|
||||
if rank is None:
|
||||
rank = get_pipeline_model_parallel_rank()
|
||||
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
||||
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
|
||||
return True
|
||||
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_pipeline_stage_after_split(rank=None):
|
||||
"""Return True if pipeline stage executes decoder block for a model
|
||||
with both encoder and decoder."""
|
||||
if get_pipeline_model_parallel_world_size() == 1:
|
||||
return True
|
||||
if rank is None:
|
||||
rank = get_pipeline_model_parallel_rank()
|
||||
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
||||
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
|
||||
return True
|
||||
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_pipeline_stage_at_split():
|
||||
"""Return true if pipeline stage executes decoder block and next
|
||||
stage executes encoder block for a model with both encoder and
|
||||
decoder."""
|
||||
rank = get_pipeline_model_parallel_rank()
|
||||
return is_pipeline_stage_before_split(rank) and \
|
||||
is_pipeline_stage_after_split(rank+1)
|
||||
|
||||
|
||||
def get_virtual_pipeline_model_parallel_rank():
|
||||
"""Return the virtual pipeline-parallel rank."""
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
|
||||
|
||||
def set_virtual_pipeline_model_parallel_rank(rank):
|
||||
"""Set the virtual pipeline-parallel rank."""
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
|
||||
|
||||
|
||||
def get_virtual_pipeline_model_parallel_world_size():
|
||||
"""Return the virtual pipeline-parallel world size."""
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
return torch.distributed.get_rank(
|
||||
group=get_pipeline_model_parallel_group())
|
||||
|
||||
|
||||
def get_tensor_model_parallel_src_rank():
|
||||
@ -419,35 +134,27 @@ def get_tensor_model_parallel_src_rank():
|
||||
return (global_rank // local_world_size) * local_world_size
|
||||
|
||||
|
||||
def get_data_parallel_src_rank():
|
||||
"""Calculate the global rank corresponding to the first local rank
|
||||
in the data parallel group."""
|
||||
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
|
||||
"Data parallel group is not initialized"
|
||||
return _DATA_PARALLEL_GLOBAL_RANKS[0]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_first_rank():
|
||||
"""Return the global rank of the first process in the pipeline for the
|
||||
current tensor parallel group"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||
"Pipeline parallel group is not initialized"
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
||||
"Pipeline parallel group is not initialized")
|
||||
return _PIPELINE_GLOBAL_RANKS[0]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_last_rank():
|
||||
"""Return the global rank of the last process in the pipeline for the
|
||||
current tensor parallel group"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||
"Pipeline parallel group is not initialized"
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
||||
"Pipeline parallel group is not initialized")
|
||||
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
||||
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_next_rank():
|
||||
"""Return the global rank that follows the caller in the pipeline"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||
"Pipeline parallel group is not initialized"
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
||||
"Pipeline parallel group is not initialized")
|
||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||
world_size = get_pipeline_model_parallel_world_size()
|
||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
||||
@ -455,45 +162,18 @@ def get_pipeline_model_parallel_next_rank():
|
||||
|
||||
def get_pipeline_model_parallel_prev_rank():
|
||||
"""Return the global rank that preceeds the caller in the pipeline"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||
"Pipeline parallel group is not initialized"
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, (
|
||||
"Pipeline parallel group is not initialized")
|
||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||
world_size = get_pipeline_model_parallel_world_size()
|
||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
||||
|
||||
|
||||
def get_data_parallel_world_size():
|
||||
"""Return world size for the data parallel group."""
|
||||
return torch.distributed.get_world_size(group=get_data_parallel_group())
|
||||
|
||||
|
||||
def get_data_parallel_rank():
|
||||
"""Return my rank for the data parallel group."""
|
||||
return torch.distributed.get_rank(group=get_data_parallel_group())
|
||||
|
||||
def destroy_model_parallel():
|
||||
"""Set the groups to none."""
|
||||
global _MODEL_PARALLEL_GROUP
|
||||
_MODEL_PARALLEL_GROUP = None
|
||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
global _DATA_PARALLEL_GROUP
|
||||
_DATA_PARALLEL_GROUP = None
|
||||
global _EMBEDDING_GROUP
|
||||
_EMBEDDING_GROUP = None
|
||||
global _POSITION_EMBEDDING_GROUP
|
||||
_POSITION_EMBEDDING_GROUP = None
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
|
@ -1,50 +0,0 @@
|
||||
from .layers import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
VocabParallelEmbedding,
|
||||
set_tensor_model_parallel_attributes,
|
||||
set_defaults_if_not_set_tensor_model_parallel_attributes,
|
||||
copy_tensor_model_parallel_attributes,
|
||||
param_is_not_tensor_parallel_duplicate,
|
||||
)
|
||||
|
||||
from .mappings import (
|
||||
copy_to_tensor_model_parallel_region,
|
||||
gather_from_tensor_model_parallel_region,
|
||||
gather_from_sequence_parallel_region,
|
||||
reduce_from_tensor_model_parallel_region,
|
||||
scatter_to_tensor_model_parallel_region,
|
||||
scatter_to_sequence_parallel_region,
|
||||
)
|
||||
|
||||
from .random import (
|
||||
get_cuda_rng_tracker,
|
||||
model_parallel_cuda_manual_seed,
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
split_tensor_along_last_dim,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
#layers.py
|
||||
"ColumnParallelLinear",
|
||||
"RowParallelLinear",
|
||||
"VocabParallelEmbedding",
|
||||
"set_tensor_model_parallel_attributes",
|
||||
"set_defaults_if_not_set_tensor_model_parallel_attributes",
|
||||
"copy_tensor_model_parallel_attributes",
|
||||
"param_is_not_tensor_parallel_duplicate",
|
||||
# mappings.py
|
||||
"copy_to_tensor_model_parallel_region",
|
||||
"gather_from_tensor_model_parallel_region",
|
||||
"gather_from_sequence_parallel_region",
|
||||
"reduce_from_tensor_model_parallel_region",
|
||||
"scatter_to_tensor_model_parallel_region",
|
||||
"scatter_to_sequence_parallel_region",
|
||||
# random.py
|
||||
"get_cuda_rng_tracker",
|
||||
"model_parallel_cuda_manual_seed",
|
||||
# utils.py
|
||||
"split_tensor_along_last_dim",
|
||||
]
|
@ -1,366 +0,0 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
# repo: https://github.com/pytorch/pytorch
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from .mappings import (
|
||||
gather_from_tensor_model_parallel_region,
|
||||
reduce_from_tensor_model_parallel_region,
|
||||
scatter_to_tensor_model_parallel_region,
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
divide,
|
||||
VocabUtility,
|
||||
)
|
||||
|
||||
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
|
||||
'partition_dim': -1,
|
||||
'partition_stride': 1}
|
||||
|
||||
def param_is_not_tensor_parallel_duplicate(param):
|
||||
return (hasattr(param, 'tensor_model_parallel') and
|
||||
param.tensor_model_parallel) or (
|
||||
get_tensor_model_parallel_rank() == 0)
|
||||
|
||||
|
||||
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
|
||||
# Make sure the attributes are not set.
|
||||
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
||||
assert not hasattr(tensor, attribute)
|
||||
# Set the attributes.
|
||||
setattr(tensor, 'tensor_model_parallel', is_parallel)
|
||||
setattr(tensor, 'partition_dim', dim)
|
||||
setattr(tensor, 'partition_stride', stride)
|
||||
|
||||
|
||||
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
|
||||
def maybe_set(attribute, value):
|
||||
if not hasattr(tensor, attribute):
|
||||
setattr(tensor, attribute, value)
|
||||
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
||||
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
|
||||
|
||||
|
||||
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
|
||||
def maybe_copy(attribute):
|
||||
if hasattr(source_tensor, attribute):
|
||||
setattr(destination_tensor, attribute,
|
||||
getattr(source_tensor, attribute))
|
||||
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
||||
maybe_copy(attribute)
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
This is mainly adapted from torch.nn.Embedding and all the default
|
||||
values are kept.
|
||||
Arguments:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
|
||||
Keyword Arguments:
|
||||
init_method: method to initialize weights.
|
||||
params_dtype
|
||||
use_cpu_initialization
|
||||
perform_initialization
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, *,
|
||||
init_method=init.xavier_normal_,
|
||||
params_dtype: torch.dtype=None,
|
||||
use_cpu_initialization: bool=False,
|
||||
perform_initialization: bool=False):
|
||||
super(VocabParallelEmbedding, self).__init__()
|
||||
assert not perform_initialization
|
||||
assert not use_cpu_initialization
|
||||
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Set the defaults for compatibility.
|
||||
self.padding_idx = None
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = \
|
||||
VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings, get_tensor_model_parallel_rank(),
|
||||
self.tensor_model_parallel_size)
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - \
|
||||
self.vocab_start_index
|
||||
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.num_embeddings_per_partition, self.embedding_dim,
|
||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
# Build the mask.
|
||||
input_mask = (input_ < self.vocab_start_index) | \
|
||||
(input_ >= self.vocab_end_index)
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(masked_input, self.weight,
|
||||
self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq,
|
||||
self.sparse)
|
||||
# Mask the output embedding.
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
|
||||
Keyword Arguments
|
||||
bias: If true, add bias
|
||||
gather_output: If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is Y_i = XA_i
|
||||
init_method: method to initialize weights. Note that bias is always set
|
||||
to zero.
|
||||
stride: For the strided linear layers.
|
||||
keep_master_weight_for_test: This was added for testing and should be
|
||||
set to False. It returns the master weights
|
||||
used for initialization.
|
||||
skip_bias_add: This was added to enable performance optimations where bias
|
||||
can be fused with other elementwise operations. we skip
|
||||
adding bias but instead return it.
|
||||
params_dtype:
|
||||
use_cpu_initialization:
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, output_size, *,
|
||||
bias=True, gather_output=True,
|
||||
init_method=init.xavier_normal_, stride=1,
|
||||
keep_master_weight_for_test=False,
|
||||
skip_bias_add=False,
|
||||
params_dtype=None,
|
||||
use_cpu_initialization=False,
|
||||
perform_initialization=False,
|
||||
quant_config=None,
|
||||
):
|
||||
super(ColumnParallelLinear, self).__init__()
|
||||
assert not perform_initialization
|
||||
assert not use_cpu_initialization
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.gather_output = gather_output
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, self.world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.quant_config = quant_config
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
self.create_weights(params_dtype)
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size_per_partition, self.input_size,
|
||||
device=torch.cuda.current_device(), dtype=dtype))
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
return F.linear(x, self.weight, bias)
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of ColumnParallelLinear
|
||||
|
||||
Args:
|
||||
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
input_parallel = input_
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply_weights(input_parallel, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = gather_from_tensor_model_parallel_region(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class RowParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| A_1 |
|
||||
| . |
|
||||
A = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| A_p |
|
||||
- -
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
|
||||
Keyword Arguments:
|
||||
bias: If true, add bias. Note that bias is not parallelized.
|
||||
input_is_parallel: If true, we assume that the input is already
|
||||
split across the GPUs and we do not split
|
||||
again.
|
||||
init_method: method to initialize weights. Note that bias is always set
|
||||
to zero.
|
||||
stride: For the strided linear layers.
|
||||
keep_master_weight_for_test: This was added for testing and should be
|
||||
set to False. It returns the master weights
|
||||
used for initialization.
|
||||
skip_bias_add: This was added to enable performance optimization where bias
|
||||
can be fused with other elementwise operations. We skip
|
||||
adding bias but instead return it.
|
||||
params_dtype:
|
||||
use_cpu_initialization:
|
||||
perform_initialization:
|
||||
reduce_results:
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, output_size, *,
|
||||
bias=True, input_is_parallel=False,
|
||||
init_method=init.xavier_normal_, stride=1,
|
||||
keep_master_weight_for_test=False,
|
||||
skip_bias_add=False,
|
||||
params_dtype=None,
|
||||
use_cpu_initialization=False,
|
||||
perform_initialization=False,
|
||||
reduce_results=True,
|
||||
quant_config=None,
|
||||
):
|
||||
super(RowParallelLinear, self).__init__()
|
||||
assert not perform_initialization
|
||||
assert not use_cpu_initialization
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.world_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.create_weights(params_dtype)
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
"results can lead to incorrect results")
|
||||
|
||||
if bias:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size, device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
def create_weights(self, dtype: torch.dtype) -> None:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size, self.input_size_per_partition,
|
||||
device=torch.cuda.current_device(), dtype=dtype))
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(x, self.weight)
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of RowParallelLinear
|
||||
|
||||
Args:
|
||||
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
# Set up backprop all-reduce.
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply_weights(input_parallel)
|
||||
if self.reduce_results and self.world_size > 1:
|
||||
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.skip_bias_add:
|
||||
output = output_ + self.bias if self.bias is not None else output_
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.bias
|
||||
return output, output_bias
|
@ -1,281 +0,0 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/mappings.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_group,
|
||||
)
|
||||
from .utils import split_tensor_along_last_dim
|
||||
|
||||
|
||||
def _reduce(input_):
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if get_tensor_model_parallel_world_size()==1:
|
||||
return input_
|
||||
|
||||
# All-reduce.
|
||||
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
|
||||
|
||||
return input_
|
||||
|
||||
|
||||
def _split_along_last_dim(input_):
|
||||
"""Split the tensor along its last dimension and keep the
|
||||
corresponding slice."""
|
||||
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Split along last dimension.
|
||||
input_list = split_tensor_along_last_dim(input_, world_size)
|
||||
|
||||
# Note: torch.split does not create contiguous tensors by default.
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
output = input_list[rank].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _split_along_first_dim(input_):
|
||||
"""Split the tensor along its first dimension and keep the
|
||||
corresponding slice."""
|
||||
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Split along first dimension.
|
||||
dim_size = input_.size()[0]
|
||||
assert dim_size % world_size == 0, \
|
||||
"First dimension of the tensor should be divisible by tensor parallel size"
|
||||
local_dim_size = dim_size // world_size
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
dim_offset = rank * local_dim_size
|
||||
|
||||
output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather_along_last_dim(input_):
|
||||
"""Gather tensors and concatinate along the last dimension."""
|
||||
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Size and dimension.
|
||||
last_dim = input_.dim() - 1
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
|
||||
|
||||
# Note: torch.cat already creates a contiguous tensor.
|
||||
output = torch.cat(tensor_list, dim=last_dim).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather_along_first_dim(input_):
|
||||
"""Gather tensors and concatinate along the first dimension."""
|
||||
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size, dtype=input_.dtype,
|
||||
device=torch.cuda.current_device())
|
||||
torch.distributed._all_gather_base(output, input_.contiguous(),
|
||||
group=get_tensor_model_parallel_group())
|
||||
|
||||
return output
|
||||
|
||||
def _reduce_scatter_along_first_dim(input_):
|
||||
"""Reduce-scatter the input tensor across model parallel group."""
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
assert dim_size[0] % world_size == 0, \
|
||||
"First dimension of the tensor should be divisible by tensor parallel size"
|
||||
|
||||
dim_size[0] = dim_size[0] // world_size
|
||||
|
||||
output = torch.empty(dim_size, dtype=input_.dtype,
|
||||
device=torch.cuda.current_device())
|
||||
torch.distributed._reduce_scatter_base(output, input_.contiguous(),
|
||||
group=get_tensor_model_parallel_group())
|
||||
return output
|
||||
|
||||
|
||||
class _CopyToModelParallelRegion(torch.autograd.Function):
|
||||
"""Pass the input to the model parallel region."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _reduce(grad_output)
|
||||
|
||||
|
||||
class _ReduceFromModelParallelRegion(torch.autograd.Function):
|
||||
"""All-reduce the input from the model parallel region."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _reduce(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _reduce(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
|
||||
class _ScatterToModelParallelRegion(torch.autograd.Function):
|
||||
"""Split the input and keep only the corresponding chuck to the rank."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split_along_last_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _split_along_last_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather_along_last_dim(grad_output)
|
||||
|
||||
|
||||
class _GatherFromModelParallelRegion(torch.autograd.Function):
|
||||
"""Gather the input from model parallel region and concatinate."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _gather_along_last_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _gather_along_last_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split_along_last_dim(grad_output)
|
||||
|
||||
|
||||
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
|
||||
"""Split the input and keep only the corresponding chuck to the rank."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split_along_first_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _split_along_first_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather_along_first_dim(grad_output)
|
||||
|
||||
|
||||
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
|
||||
"""Gather the input from sequence parallel region and concatinate."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_, tensor_parallel_output_grad=True):
|
||||
return _gather_along_first_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, tensor_parallel_output_grad=True):
|
||||
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
|
||||
return _gather_along_first_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
|
||||
|
||||
# If the computation graph after the gather operation is
|
||||
# in the tensor parallel mode, output gradients need to reduce
|
||||
# scattered and whereas if the computation is duplicated,
|
||||
# output gradients need to be scattered.
|
||||
if tensor_parallel_output_grad:
|
||||
return _reduce_scatter_along_first_dim(grad_output), None
|
||||
else:
|
||||
return _split_along_first_dim(grad_output), None
|
||||
|
||||
|
||||
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
|
||||
"""Reduce scatter the input from the model parallel region."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _reduce_scatter_along_first_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _reduce_scatter_along_first_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather_along_first_dim(grad_output)
|
||||
|
||||
|
||||
# -----------------
|
||||
# Helper functions.
|
||||
# -----------------
|
||||
|
||||
def copy_to_tensor_model_parallel_region(input_):
|
||||
return _CopyToModelParallelRegion.apply(input_)
|
||||
|
||||
|
||||
def reduce_from_tensor_model_parallel_region(input_):
|
||||
return _ReduceFromModelParallelRegion.apply(input_)
|
||||
|
||||
|
||||
def scatter_to_tensor_model_parallel_region(input_):
|
||||
return _ScatterToModelParallelRegion.apply(input_)
|
||||
|
||||
|
||||
def gather_from_tensor_model_parallel_region(input_):
|
||||
return _GatherFromModelParallelRegion.apply(input_)
|
||||
|
||||
|
||||
def scatter_to_sequence_parallel_region(input_):
|
||||
return _ScatterToSequenceParallelRegion.apply(input_)
|
||||
|
||||
|
||||
def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
|
||||
return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
|
||||
|
||||
|
||||
def reduce_scatter_to_sequence_parallel_region(input_):
|
||||
return _ReduceScatterToSequenceParallelRegion.apply(input_)
|
||||
|
@ -1,164 +0,0 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
# repo: https://github.com/pytorch/pytorch
|
||||
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.cuda import _lazy_call, device as device_ctx_manager
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
)
|
||||
|
||||
# Default name for the model parallel rng tracker.
|
||||
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
|
||||
|
||||
|
||||
def _set_cuda_rng_state(new_state, device=-1):
|
||||
"""Sets the random number generator state of the current GPU.
|
||||
|
||||
Argumentss:
|
||||
new_state (torch.ByteTensor): The desired state
|
||||
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
|
||||
with a single change: the input state is not cloned. Cloning caused
|
||||
major performance issues for +4 GPU cases.
|
||||
"""
|
||||
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
|
||||
# older PyTorch
|
||||
def cb():
|
||||
with device_ctx_manager(device):
|
||||
_C._cuda_setRNGState(new_state)
|
||||
else:
|
||||
# newer PyTorch
|
||||
if device == -1:
|
||||
device = torch.device('cuda')
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device('cuda', device)
|
||||
|
||||
def cb():
|
||||
idx = device.index
|
||||
if idx is None:
|
||||
idx = torch.cuda.current_device()
|
||||
default_generator = torch.cuda.default_generators[idx]
|
||||
default_generator.set_state(new_state)
|
||||
|
||||
_lazy_call(cb)
|
||||
|
||||
|
||||
|
||||
class CudaRNGStatesTracker:
|
||||
"""Tracker for the cuda RNG states.
|
||||
|
||||
Using the `add` method, a cuda rng state is initialized based on
|
||||
the input `seed` and is assigned to `name`. Later, by forking the
|
||||
rng state, we can perform operations and return to our starting
|
||||
cuda state.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Map from a string name to the cuda rng state.
|
||||
self.states_ = {}
|
||||
# Seeds are just for book keeping and ensure no seed is set twice.
|
||||
self.seeds_ = set()
|
||||
|
||||
def reset(self):
|
||||
"""Set to the initial state (no tracker)."""
|
||||
self.states_ = {}
|
||||
self.seeds_ = set()
|
||||
|
||||
def get_states(self):
|
||||
"""Get rng states. Copy the dictionary so we have direct
|
||||
pointers to the states, not just a pointer to the dictionary."""
|
||||
states = {}
|
||||
for name in self.states_:
|
||||
states[name] = self.states_[name]
|
||||
return states
|
||||
|
||||
def set_states(self, states):
|
||||
"""Set the rng states. For efficiency purposes, we do not check
|
||||
the size of seed for compatibility."""
|
||||
self.states_ = states
|
||||
|
||||
def add(self, name, seed):
|
||||
"""Track the rng state."""
|
||||
# Check seed is not already used.
|
||||
if seed in self.seeds_:
|
||||
raise Exception('seed {} already exists'.format(seed))
|
||||
self.seeds_.add(seed)
|
||||
# Check that state is not already defined.
|
||||
if name in self.states_:
|
||||
raise Exception('cuda rng state {} already exists'.format(name))
|
||||
# Get the current rng state.
|
||||
orig_rng_state = torch.cuda.get_rng_state()
|
||||
# Set the new state and store it.
|
||||
torch.cuda.manual_seed(seed)
|
||||
self.states_[name] = torch.cuda.get_rng_state()
|
||||
# Reset rng state to what it was.
|
||||
_set_cuda_rng_state(orig_rng_state)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
|
||||
"""Fork the cuda rng state, perform operations, and exit with
|
||||
the original state."""
|
||||
# Check if we have added the state
|
||||
if name not in self.states_:
|
||||
raise Exception('cuda rng state {} is not added'.format(name))
|
||||
# Store current rng state.
|
||||
orig_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
# Set rng state to the desired one
|
||||
_set_cuda_rng_state(self.states_[name])
|
||||
# Do the stuff we wanted to do.
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Update the current rng state for later use.
|
||||
self.states_[name] = torch.cuda.get_rng_state()
|
||||
# And set the state to the original state we started with.
|
||||
_set_cuda_rng_state(orig_cuda_rng_state)
|
||||
|
||||
|
||||
# RNG tracker object.
|
||||
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
|
||||
|
||||
|
||||
def get_cuda_rng_tracker():
|
||||
"""Get cuda rng tracker."""
|
||||
return _CUDA_RNG_STATE_TRACKER
|
||||
|
||||
|
||||
def model_parallel_cuda_manual_seed(seed):
|
||||
"""Initialize model parallel cuda seed.
|
||||
|
||||
This function should be called after the model parallel is
|
||||
initialized. Also, no torch.cuda.manual_seed should be called
|
||||
after this function. Basically, this is replacement for that
|
||||
function.
|
||||
Two set of RNG states are tracked:
|
||||
default state: This is for data parallelism and is the same among a
|
||||
set of model parallel GPUs but different across
|
||||
different model paralle groups. This is used for
|
||||
example for dropout in the non-tensor-model-parallel regions.
|
||||
tensor-model-parallel state: This state is different among a set of model
|
||||
parallel GPUs, but the same across data parallel
|
||||
groups. This is used for example for dropout in
|
||||
model parallel regions.
|
||||
"""
|
||||
# 2718 is just for fun and any POSITIVE value will work.
|
||||
offset = seed + 2718
|
||||
tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
|
||||
# Data parallel gets the original seed.
|
||||
data_parallel_seed = seed
|
||||
|
||||
_CUDA_RNG_STATE_TRACKER.reset()
|
||||
# Set the default state.
|
||||
torch.cuda.manual_seed(data_parallel_seed)
|
||||
# and model parallel state.
|
||||
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
|
||||
tensor_model_parallel_seed)
|
@ -1,15 +1,16 @@
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Adapted from
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
from typing import List, Sequence
|
||||
|
||||
import torch
|
||||
from typing import List, Sequence
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
numerator, denominator
|
||||
)
|
||||
numerator, denominator)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
@ -56,15 +57,14 @@ class VocabUtility:
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size: int, rank, world_size: int
|
||||
) -> Sequence[int]:
|
||||
per_partition_vocab_size: int, rank: int) -> Sequence[int]:
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int,
|
||||
world_size: int) -> Sequence[int]:
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return VocabUtility.vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size, rank, world_size
|
||||
)
|
||||
per_partition_vocab_size, rank)
|
@ -4,9 +4,6 @@ import random
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.parallel_utils.parallel_state import model_parallel_is_initialized
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
@ -14,6 +11,3 @@ def set_random_seed(seed: int) -> None:
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if model_parallel_is_initialized():
|
||||
model_parallel_cuda_manual_seed(seed)
|
||||
|
Loading…
x
Reference in New Issue
Block a user