Support tensor parallel (#2)
This commit is contained in:
parent
cfae35b861
commit
2f49f15585
@ -11,5 +11,6 @@ pip install -e .
|
||||
## Run
|
||||
|
||||
```bash
|
||||
python server.py
|
||||
ray start --head
|
||||
python server.py [--tensor-parallel-size <N>]
|
||||
```
|
||||
|
@ -1,12 +1,10 @@
|
||||
from cacheflow.models.input_metadata import InputMetadata
|
||||
from cacheflow.models.model_utils import get_memory_analyzer
|
||||
from cacheflow.models.model_utils import get_model
|
||||
from cacheflow.models.utils import set_seed
|
||||
|
||||
|
||||
__all__ = [
|
||||
'InputMetadata',
|
||||
'get_memory_analyzer',
|
||||
'get_model',
|
||||
'set_seed',
|
||||
]
|
||||
|
@ -112,7 +112,7 @@ class OPTCacheFlowAttention(nn.Module):
|
||||
output[:num_prompt_tokens],
|
||||
query[:num_prompt_tokens],
|
||||
key[:num_prompt_tokens],
|
||||
value[:num_prompt_tokens],
|
||||
value[:num_prompt_tokens],
|
||||
input_metadata.prompt_lens,
|
||||
)
|
||||
|
||||
|
@ -43,4 +43,8 @@ class InputMetadata:
|
||||
f'num_generation_tokens={self.num_generation_tokens}, '
|
||||
f'num_valid_tokens={self.num_valid_tokens}, '
|
||||
f'max_num_blocks_per_seq={self.max_num_blocks_per_seq}, '
|
||||
f'max_context_len={self.max_context_len})')
|
||||
f'max_context_len={self.max_context_len}), '
|
||||
f'prompt_lens={self.prompt_lens}, '
|
||||
f'slot_mapping={self.slot_mapping}, '
|
||||
f'context_lens={self.context_lens}, '
|
||||
f'block_tables={self.block_tables})')
|
||||
|
@ -31,12 +31,13 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
dtype: torch.dtype,
|
||||
tensor_parallel_size: int,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.block_size = block_size
|
||||
self.dtype = dtype
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
|
||||
# TODO(woosuk): Support tensor parallelism.
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
self.hidden_size = config.hidden_size
|
||||
@ -48,26 +49,25 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
self.max_position = config.max_position_embeddings
|
||||
|
||||
def _get_param_size(self) -> int:
|
||||
# TODO(woosuk): Support tensor parallelism.
|
||||
word_embedding = self.vocab_size * self.embedding_size
|
||||
word_embedding = self.vocab_size * self.embedding_size // self.tensor_parallel_size
|
||||
if self.embedding_size != self.vocab_size:
|
||||
# Project in/out.
|
||||
word_embedding += 2 * self.embedding_size * self.vocab_size
|
||||
position_embedding = self.max_position * self.hidden_size
|
||||
|
||||
ln1 = 2 * self.hidden_size
|
||||
q = self.hidden_size * self.hidden_size + self.hidden_size
|
||||
k = self.hidden_size * self.hidden_size + self.hidden_size
|
||||
v = self.hidden_size * self.hidden_size + self.hidden_size
|
||||
out = self.hidden_size * self.hidden_size + self.hidden_size
|
||||
q = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||
k = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||
v = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||
out = self.hidden_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||
mha = ln1 + q + k + v + out
|
||||
|
||||
ln2 = 2 * self.hidden_size
|
||||
ffn1 = self.hidden_size * self.ffn_size + self.ffn_size
|
||||
ffn2 = self.ffn_size * self.hidden_size + self.hidden_size
|
||||
ffn1 = self.hidden_size * self.ffn_size // self.tensor_parallel_size + self.ffn_size
|
||||
ffn2 = self.ffn_size * self.hidden_size // self.tensor_parallel_size + self.hidden_size
|
||||
ffn = ln2 + ffn1 + ffn2
|
||||
|
||||
total = (word_embedding + position_embedding +
|
||||
total = (word_embedding + position_embedding +
|
||||
self.num_layers * (mha + ffn))
|
||||
dtype_size = get_dtype_size(self.dtype)
|
||||
return dtype_size * total
|
||||
@ -76,15 +76,17 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
||||
self,
|
||||
max_num_batched_tokens: int,
|
||||
) -> int:
|
||||
# TODO(woosuk): Support tensor parallelism.
|
||||
# NOTE: We approxmiately calculate the maximum activation size by
|
||||
# 1) estimating the maximum activation tensor size during inference, and
|
||||
# 2) multiplying it by 4.
|
||||
# estimating
|
||||
# 1) the maximum activation tensor size during inference
|
||||
# 2) the residual tensor size during inference
|
||||
# Here, we assume that FlashAttention is used and
|
||||
# thus the attention maps are never materialized in GPU DRAM.
|
||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size)
|
||||
ffn = max_num_batched_tokens * self.ffn_size
|
||||
max_act = 4 * max(qkv, ffn)
|
||||
residual = max_num_batched_tokens * self.hidden_size
|
||||
qkv = 3 * (max_num_batched_tokens * self.hidden_size) // self.tensor_parallel_size
|
||||
ffn = max_num_batched_tokens * self.ffn_size // self.tensor_parallel_size
|
||||
# Double the activation size for input and output.
|
||||
max_act = 2 * (max(qkv, ffn) + residual)
|
||||
dtype_size = get_dtype_size(self.dtype)
|
||||
return dtype_size * max_act
|
||||
|
||||
|
@ -1,7 +1,9 @@
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoConfig
|
||||
|
||||
from cacheflow.models.memory_analyzer import CacheFlowMemoryAnalyzer
|
||||
from cacheflow.models.memory_analyzer import OPTMemoryAnalyzer
|
||||
@ -21,13 +23,20 @@ _MEMORY_ANALYZERS = {
|
||||
def get_model(
|
||||
model_name: str,
|
||||
dtype: Union[torch.dtype, str],
|
||||
path: str,
|
||||
) -> nn.Module:
|
||||
torch_dtype = get_torch_dtype(dtype)
|
||||
for model_class, hf_model in _MODELS.items():
|
||||
if model_class in model_name:
|
||||
model = hf_model.from_pretrained(
|
||||
model_name, torch_dtype=torch_dtype)
|
||||
return model.eval()
|
||||
torch.set_default_dtype(torch_dtype)
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
for model_class_name, model_class in _MODELS.items():
|
||||
if model_class_name in model_name:
|
||||
# Download model weights if it's not cached.
|
||||
weights_dir = model_class.download_weights(model_name, path=path)
|
||||
# Create a model instance.
|
||||
model = model_class(config)
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(weights_dir)
|
||||
return model.eval(), torch_dtype
|
||||
raise ValueError(f'Unsupported model name: {model_name}')
|
||||
|
||||
|
||||
@ -35,10 +44,11 @@ def get_memory_analyzer(
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
dtype: Union[torch.dtype, str],
|
||||
tensor_parallel_size: int = 1,
|
||||
) -> CacheFlowMemoryAnalyzer:
|
||||
torch_dtype = get_torch_dtype(dtype)
|
||||
for model_class, memory_analyzer in _MEMORY_ANALYZERS.items():
|
||||
if model_class in model_name:
|
||||
return memory_analyzer(
|
||||
model_name, block_size, torch_dtype)
|
||||
model_name, block_size, torch_dtype, tensor_parallel_size)
|
||||
raise ValueError(f'Unsupported model name: {model_name}')
|
||||
|
@ -1,14 +1,24 @@
|
||||
"""1D OPT model compatible with HuggingFace weights."""
|
||||
import os
|
||||
import glob
|
||||
import filelock
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import OPTConfig
|
||||
from transformers import PreTrainedModel
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.models.attention import OPTCacheFlowAttention
|
||||
from cacheflow.models.sample import Sampler
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.parallel_utils.tensor_parallel import (VocabParallelEmbedding,
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
@ -36,15 +46,26 @@ class OPTAttention(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = embed_dim // num_heads
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
||||
total_num_heads = num_heads
|
||||
assert num_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
||||
self.head_dim = embed_dim // total_num_heads
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
# TODO(woosuk): Fuse the three linear layers into one QKV linear layer.
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||
self.k_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.v_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.q_proj = ColumnParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
|
||||
self.attn = OPTCacheFlowAttention(scale=self.scaling)
|
||||
|
||||
@ -55,13 +76,13 @@ class OPTAttention(nn.Module):
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
q = self.q_proj(hidden_states)
|
||||
k = self.k_proj(hidden_states)
|
||||
v = self.v_proj(hidden_states)
|
||||
q, _ = self.q_proj(hidden_states)
|
||||
k, _ = self.k_proj(hidden_states)
|
||||
v, _ = self.v_proj(hidden_states)
|
||||
key_cache, value_cache = kv_cache
|
||||
attn_output = self.attn(
|
||||
q, k, v, key_cache, value_cache, input_metadata, cache_event)
|
||||
output = self.out_proj(attn_output)
|
||||
output, _ = self.out_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
@ -69,6 +90,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.self_attn = OPTAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
@ -81,9 +103,16 @@ class OPTDecoderLayer(nn.Module):
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias)
|
||||
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias)
|
||||
self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
self.fc1 = ColumnParallelLinear(self.embed_dim, config.ffn_dim,
|
||||
bias=config.enable_bias,
|
||||
gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.fc2 = RowParallelLinear(config.ffn_dim, self.embed_dim,
|
||||
bias=config.enable_bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -112,9 +141,9 @@ class OPTDecoderLayer(nn.Module):
|
||||
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
||||
if self.do_layer_norm_before:
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
# 350m applies layer norm AFTER attention
|
||||
if not self.do_layer_norm_before:
|
||||
@ -122,29 +151,23 @@ class OPTDecoderLayer(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OPTPreTrainedModel(PreTrainedModel):
|
||||
config_class = OPTConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = ["OPTDecoderLayer"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
|
||||
|
||||
def _init_weights(self, module) -> None:
|
||||
del module # unused
|
||||
return
|
||||
|
||||
|
||||
class OPTDecoder(OPTPreTrainedModel):
|
||||
class OPTDecoder(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
super().__init__(config)
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.max_target_positions = config.max_position_embeddings
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)
|
||||
self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
|
||||
config.word_embed_proj_dim,
|
||||
perform_initialization=False)
|
||||
# Positional embeddings are replicated (not sharded).
|
||||
self.embed_positions = OPTLearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.hidden_size)
|
||||
|
||||
# Project out & in will be replicated if they exist.
|
||||
if config.word_embed_proj_dim != config.hidden_size:
|
||||
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
|
||||
else:
|
||||
@ -167,9 +190,6 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
|
||||
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@ -200,13 +220,11 @@ class OPTDecoder(OPTPreTrainedModel):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class OPTModel(OPTPreTrainedModel):
|
||||
class OPTModel(nn.Module):
|
||||
|
||||
def __init__(self, config: OPTConfig):
|
||||
super().__init__(config)
|
||||
super().__init__()
|
||||
self.decoder = OPTDecoder(config)
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -220,41 +238,17 @@ class OPTModel(OPTPreTrainedModel):
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
|
||||
|
||||
class OPTForCausalLM(OPTPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]
|
||||
class OPTForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.model = OPTModel(config)
|
||||
# the lm_head weight is automatically tied to the embed tokens weight
|
||||
self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
|
||||
# TODO(zhuohan): create a new weight after implementing pipeline
|
||||
# parallelism
|
||||
self.lm_head_weight = self.model.decoder.embed_tokens.weight
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Initialize weights and apply final processing
|
||||
self.post_init()
|
||||
|
||||
# NOTE(woosuk): While the following methods are not called in the model code,
|
||||
# they may be internally used by the transformers library.
|
||||
# For example, tie_weights() does not work without these methods.
|
||||
# Thus, do not delete these methods.
|
||||
def get_input_embeddings(self):
|
||||
return self.model.decoder.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.model.decoder.embed_tokens = value
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.model.decoder = decoder
|
||||
|
||||
def get_decoder(self):
|
||||
return self.model.decoder
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
@ -266,5 +260,72 @@ class OPTForCausalLM(OPTPreTrainedModel):
|
||||
hidden_states = self.model(
|
||||
input_ids, positions, kv_caches, input_metadata, cache_events)
|
||||
next_tokens = self.sampler(
|
||||
self.lm_head.weight, hidden_states, input_metadata)
|
||||
self.lm_head_weight, hidden_states, input_metadata)
|
||||
return next_tokens
|
||||
|
||||
_column_parallel_weights = ["embed_tokens.weight",
|
||||
"q_proj.weight", "k_proj.weight",
|
||||
"v_proj.weight", "fc1.weight"]
|
||||
_column_parallel_biases = ["q_proj.bias", "k_proj.bias",
|
||||
"v_proj.bias", "fc1.bias"]
|
||||
_row_parallel_weights = ["out_proj.weight", "fc2.weight"]
|
||||
|
||||
def load_weights(self, weights_path: str):
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
state_dict = self.state_dict()
|
||||
for name, param in state_dict.items():
|
||||
if "lm_head_weight" in name:
|
||||
continue
|
||||
loaded_weight = torch.from_numpy(np.load(os.path.join(weights_path,
|
||||
name)))
|
||||
for p in (self._column_parallel_weights
|
||||
+ self._column_parallel_biases):
|
||||
if p in name:
|
||||
shard_size = param.shape[0]
|
||||
loaded_weight = loaded_weight[
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
for p in self._row_parallel_weights:
|
||||
if p in name:
|
||||
shard_size = param.shape[1]
|
||||
loaded_weight = loaded_weight[
|
||||
:,
|
||||
shard_size * tensor_model_parallel_rank
|
||||
:shard_size * (tensor_model_parallel_rank + 1)]
|
||||
break
|
||||
|
||||
assert param.shape == loaded_weight.shape
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
@staticmethod
|
||||
def download_weights(model_name: str, path: str):
|
||||
path = os.path.join(path, f"{model_name}-np")
|
||||
path = os.path.abspath(os.path.expanduser(path))
|
||||
os.makedirs(path, exist_ok=True)
|
||||
lock_path = os.path.join(path, "file_lock")
|
||||
lock = filelock.FileLock(lock_path)
|
||||
|
||||
with lock:
|
||||
test_weight_path = os.path.join(
|
||||
path, "model.decoder.embed_positions.weight")
|
||||
if os.path.exists(test_weight_path):
|
||||
return path
|
||||
|
||||
folder = snapshot_download(model_name, allow_patterns="*.bin",
|
||||
cache_dir=os.path.join(path, "cache"))
|
||||
bin_files = glob.glob(os.path.join(folder, "*.bin"))
|
||||
|
||||
if "/" in model_name:
|
||||
model_name = model_name.split("/")[1].lower()
|
||||
|
||||
for bin_file in tqdm(bin_files, desc="Convert format"):
|
||||
state = torch.load(bin_file)
|
||||
for name, param in tqdm(state.items(), leave=False):
|
||||
if name.startswith("decoder."):
|
||||
name = "model." + name
|
||||
param_path = os.path.join(path, name)
|
||||
with open(param_path, "wb") as f:
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
|
||||
return path
|
||||
|
@ -6,7 +6,7 @@ import torch.nn as nn
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
|
||||
from cacheflow.parallel_utils.tensor_parallel import gather_from_tensor_model_parallel_region
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
@ -24,6 +24,7 @@ class Sampler(nn.Module):
|
||||
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
logits = gather_from_tensor_model_parallel_region(logits)
|
||||
|
||||
# Apply temperature scaling.
|
||||
temperatures = _get_temperatures(input_metadata)
|
||||
|
@ -27,14 +27,6 @@ def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
|
||||
return torch.tensor([], dtype=torch_dtype).element_size()
|
||||
|
||||
|
||||
def set_seed(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def get_gpu_memory(gpu: int = 0) -> int:
|
||||
return torch.cuda.get_device_properties(gpu).total_memory
|
||||
|
||||
|
1
cacheflow/parallel_utils/README.md
Normal file
1
cacheflow/parallel_utils/README.md
Normal file
@ -0,0 +1 @@
|
||||
The files in this folder are ported from [Megatron-LM](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core). We only keep the codes that are used in inference.
|
12
cacheflow/parallel_utils/__init__.py
Normal file
12
cacheflow/parallel_utils/__init__.py
Normal file
@ -0,0 +1,12 @@
|
||||
import cacheflow.parallel_utils.parallel_state
|
||||
import cacheflow.parallel_utils.tensor_parallel
|
||||
import cacheflow.parallel_utils.utils
|
||||
|
||||
# Alias parallel_state as mpu, its legacy name
|
||||
mpu = parallel_state
|
||||
|
||||
__all__ = [
|
||||
"parallel_state",
|
||||
"tensor_parallel",
|
||||
"utils",
|
||||
]
|
522
cacheflow/parallel_utils/parallel_state.py
Normal file
522
cacheflow/parallel_utils/parallel_state.py
Normal file
@ -0,0 +1,522 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
"""Model and data parallel groups."""
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from .utils import GlobalMemoryBuffer
|
||||
|
||||
# Intra-layer model parallel group that the current rank belongs to.
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
# Inter-layer model parallel group that the current rank belongs to.
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
|
||||
_MODEL_PARALLEL_GROUP = None
|
||||
# Embedding group.
|
||||
_EMBEDDING_GROUP = None
|
||||
# Position embedding group.
|
||||
_POSITION_EMBEDDING_GROUP = None
|
||||
# Data parallel group that the current rank belongs to.
|
||||
_DATA_PARALLEL_GROUP = None
|
||||
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
|
||||
|
||||
# These values enable us to change the mpu sizes on the fly.
|
||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
|
||||
# A list of ranks that have a copy of the embedding.
|
||||
_EMBEDDING_GLOBAL_RANKS = None
|
||||
|
||||
# A list of ranks that have a copy of the position embedding.
|
||||
_POSITION_EMBEDDING_GLOBAL_RANKS = None
|
||||
|
||||
# A list of global ranks for each pipeline group to ease calculation of the source
|
||||
# rank when broadcasting from the first or last pipeline stage.
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
|
||||
# A list of global ranks for each data parallel group to ease calculation of the source
|
||||
# rank when broadcasting weights from src to all other data parallel ranks
|
||||
_DATA_PARALLEL_GLOBAL_RANKS = None
|
||||
|
||||
# Memory buffers to avoid dynamic memory allocation
|
||||
_GLOBAL_MEMORY_BUFFER = None
|
||||
|
||||
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
virtual_pipeline_model_parallel_size: Optional[int] = None,
|
||||
pipeline_model_parallel_split_rank: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize model data parallel groups.
|
||||
|
||||
Arguments:
|
||||
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
|
||||
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
|
||||
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
|
||||
pipeline).
|
||||
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
|
||||
rank in pipeline with split point.
|
||||
|
||||
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
|
||||
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
||||
the model pipeline. The present function will
|
||||
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
|
||||
and 8 data-parallel groups as:
|
||||
8 data_parallel groups:
|
||||
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
|
||||
8 tensor model-parallel groups:
|
||||
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
|
||||
4 pipeline model-parallel groups:
|
||||
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
|
||||
Note that for efficiency, the caller should make sure adjacent ranks
|
||||
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
||||
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
||||
ranks 8 to 15 belong to the second box.
|
||||
"""
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
|
||||
if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
|
||||
raise RuntimeError(
|
||||
f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
|
||||
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
|
||||
)
|
||||
|
||||
data_parallel_size: int = world_size // (tensor_model_parallel_size *
|
||||
pipeline_model_parallel_size)
|
||||
|
||||
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
|
||||
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
||||
num_data_parallel_groups: int = world_size // data_parallel_size
|
||||
|
||||
if virtual_pipeline_model_parallel_size is not None:
|
||||
if not pipeline_model_parallel_size > 2:
|
||||
raise RuntimeError("pipeline-model-parallel size should be greater than 2 with "
|
||||
"interleaved schedule")
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
|
||||
|
||||
if pipeline_model_parallel_split_rank is not None:
|
||||
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
||||
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
# Build the data-parallel groups.
|
||||
global _DATA_PARALLEL_GROUP
|
||||
global _DATA_PARALLEL_GLOBAL_RANKS
|
||||
assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
|
||||
all_data_parallel_group_ranks = []
|
||||
for i in range(pipeline_model_parallel_size):
|
||||
start_rank = i * num_pipeline_model_parallel_groups
|
||||
end_rank = (i + 1) * num_pipeline_model_parallel_groups
|
||||
for j in range(tensor_model_parallel_size):
|
||||
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
|
||||
all_data_parallel_group_ranks.append(list(ranks))
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_DATA_PARALLEL_GROUP = group
|
||||
_DATA_PARALLEL_GLOBAL_RANKS = ranks
|
||||
|
||||
# Build the model-parallel groups.
|
||||
global _MODEL_PARALLEL_GROUP
|
||||
assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
|
||||
for i in range(data_parallel_size):
|
||||
ranks = [data_parallel_group_ranks[i]
|
||||
for data_parallel_group_ranks in all_data_parallel_group_ranks]
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_MODEL_PARALLEL_GROUP = group
|
||||
|
||||
# Build the tensor model-parallel groups.
|
||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
|
||||
'tensor model parallel group is already initialized'
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
ranks = range(i * tensor_model_parallel_size,
|
||||
(i + 1) * tensor_model_parallel_size)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = group
|
||||
|
||||
# Build the pipeline model-parallel groups and embedding groups
|
||||
# (first and last rank in each pipeline model-parallel group).
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
|
||||
'pipeline model parallel group is already initialized'
|
||||
global _EMBEDDING_GROUP
|
||||
global _EMBEDDING_GLOBAL_RANKS
|
||||
assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
|
||||
global _POSITION_EMBEDDING_GROUP
|
||||
global _POSITION_EMBEDDING_GLOBAL_RANKS
|
||||
assert _POSITION_EMBEDDING_GROUP is None, \
|
||||
'position embedding group is already initialized'
|
||||
for i in range(num_pipeline_model_parallel_groups):
|
||||
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = group
|
||||
_PIPELINE_GLOBAL_RANKS = ranks
|
||||
# Setup embedding group (to exchange gradients between
|
||||
# first and last stages).
|
||||
if len(ranks) > 1:
|
||||
embedding_ranks = [ranks[0], ranks[-1]]
|
||||
position_embedding_ranks = [ranks[0]]
|
||||
if pipeline_model_parallel_split_rank is not None:
|
||||
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
|
||||
embedding_ranks = [ranks[0],
|
||||
ranks[pipeline_model_parallel_split_rank],
|
||||
ranks[-1]]
|
||||
if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
|
||||
position_embedding_ranks = [ranks[0],
|
||||
ranks[pipeline_model_parallel_split_rank]]
|
||||
else:
|
||||
embedding_ranks = ranks
|
||||
position_embedding_ranks = ranks
|
||||
|
||||
group = torch.distributed.new_group(embedding_ranks)
|
||||
if rank in embedding_ranks:
|
||||
_EMBEDDING_GROUP = group
|
||||
if rank in ranks:
|
||||
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
|
||||
|
||||
group = torch.distributed.new_group(position_embedding_ranks)
|
||||
if rank in position_embedding_ranks:
|
||||
_POSITION_EMBEDDING_GROUP = group
|
||||
if rank in ranks:
|
||||
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
|
||||
|
||||
# Initialize global memory buffer
|
||||
# This isn't really "parallel state" but there isn't another good place to
|
||||
# put this. If we end up with a more generic initialization of megatron-core
|
||||
# we could stick it there
|
||||
_set_global_memory_buffer()
|
||||
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if model and data parallel groups are initialized."""
|
||||
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP is None or \
|
||||
_DATA_PARALLEL_GROUP is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_model_parallel_group():
|
||||
"""Get the model parallel group the caller rank belongs to."""
|
||||
assert _MODEL_PARALLEL_GROUP is not None, \
|
||||
'model parallel group is not initialized'
|
||||
return _MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_tensor_model_parallel_group():
|
||||
"""Get the tensor model parallel group the caller rank belongs to."""
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
|
||||
'intra_layer_model parallel group is not initialized'
|
||||
return _TENSOR_MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_group():
|
||||
"""Get the pipeline model parallel group the caller rank belongs to."""
|
||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
|
||||
'pipeline_model parallel group is not initialized'
|
||||
return _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_data_parallel_group():
|
||||
"""Get the data parallel group the caller rank belongs to."""
|
||||
assert _DATA_PARALLEL_GROUP is not None, \
|
||||
'data parallel group is not initialized'
|
||||
return _DATA_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_embedding_group():
|
||||
"""Get the embedding group the caller rank belongs to."""
|
||||
assert _EMBEDDING_GROUP is not None, \
|
||||
'embedding group is not initialized'
|
||||
return _EMBEDDING_GROUP
|
||||
|
||||
|
||||
def get_position_embedding_group():
|
||||
"""Get the position embedding group the caller rank belongs to."""
|
||||
assert _POSITION_EMBEDDING_GROUP is not None, \
|
||||
'position embedding group is not initialized'
|
||||
return _POSITION_EMBEDDING_GROUP
|
||||
|
||||
|
||||
def set_tensor_model_parallel_world_size(world_size):
|
||||
"""Set the tensor model parallel size"""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
|
||||
|
||||
|
||||
def set_pipeline_model_parallel_world_size(world_size):
|
||||
"""Set the pipeline model parallel size"""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
|
||||
|
||||
|
||||
def get_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
|
||||
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_world_size():
|
||||
"""Return world size for the pipeline model parallel group."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
|
||||
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
|
||||
|
||||
|
||||
def set_tensor_model_parallel_rank(rank):
|
||||
"""Set tensor model parallel rank."""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
|
||||
|
||||
|
||||
def set_pipeline_model_parallel_rank(rank):
|
||||
"""Set pipeline model parallel rank."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
|
||||
|
||||
|
||||
def set_pipeline_model_parallel_split_rank(rank):
|
||||
"""Set pipeline model parallel split rank."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
|
||||
|
||||
|
||||
def get_tensor_model_parallel_rank():
|
||||
"""Return my rank for the tensor model parallel group."""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
|
||||
return _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_rank():
|
||||
"""Return my rank for the pipeline model parallel group."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
|
||||
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
|
||||
|
||||
|
||||
|
||||
def is_pipeline_first_stage(ignore_virtual=False):
|
||||
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
|
||||
if not ignore_virtual:
|
||||
if get_virtual_pipeline_model_parallel_world_size() is not None and \
|
||||
get_virtual_pipeline_model_parallel_rank() != 0:
|
||||
return False
|
||||
return get_pipeline_model_parallel_rank() == 0
|
||||
|
||||
|
||||
def is_pipeline_last_stage(ignore_virtual=False):
|
||||
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
|
||||
if not ignore_virtual:
|
||||
virtual_pipeline_model_parallel_world_size = \
|
||||
get_virtual_pipeline_model_parallel_world_size()
|
||||
if virtual_pipeline_model_parallel_world_size is not None and \
|
||||
get_virtual_pipeline_model_parallel_rank() != (
|
||||
virtual_pipeline_model_parallel_world_size - 1):
|
||||
return False
|
||||
return get_pipeline_model_parallel_rank() == (
|
||||
get_pipeline_model_parallel_world_size() - 1)
|
||||
|
||||
|
||||
def is_rank_in_embedding_group(ignore_virtual=False):
|
||||
"""Return true if current rank is in embedding group, False otherwise."""
|
||||
rank = torch.distributed.get_rank()
|
||||
global _EMBEDDING_GLOBAL_RANKS
|
||||
if ignore_virtual:
|
||||
return rank in _EMBEDDING_GLOBAL_RANKS
|
||||
if rank in _EMBEDDING_GLOBAL_RANKS:
|
||||
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
|
||||
return is_pipeline_first_stage(ignore_virtual=False)
|
||||
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
|
||||
return is_pipeline_last_stage(ignore_virtual=False)
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_rank_in_position_embedding_group():
|
||||
"""Return true if current rank is in position embedding group, False otherwise."""
|
||||
rank = torch.distributed.get_rank()
|
||||
global _POSITION_EMBEDDING_GLOBAL_RANKS
|
||||
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
|
||||
|
||||
|
||||
def is_pipeline_stage_before_split(rank=None):
|
||||
"""Return True if pipeline stage executes encoder block for a model
|
||||
with both encoder and decoder."""
|
||||
if get_pipeline_model_parallel_world_size() == 1:
|
||||
return True
|
||||
if rank is None:
|
||||
rank = get_pipeline_model_parallel_rank()
|
||||
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
||||
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
|
||||
return True
|
||||
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_pipeline_stage_after_split(rank=None):
|
||||
"""Return True if pipeline stage executes decoder block for a model
|
||||
with both encoder and decoder."""
|
||||
if get_pipeline_model_parallel_world_size() == 1:
|
||||
return True
|
||||
if rank is None:
|
||||
rank = get_pipeline_model_parallel_rank()
|
||||
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
||||
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
|
||||
return True
|
||||
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_pipeline_stage_at_split():
|
||||
"""Return true if pipeline stage executes decoder block and next
|
||||
stage executes encoder block for a model with both encoder and
|
||||
decoder."""
|
||||
rank = get_pipeline_model_parallel_rank()
|
||||
return is_pipeline_stage_before_split(rank) and \
|
||||
is_pipeline_stage_after_split(rank+1)
|
||||
|
||||
|
||||
def get_virtual_pipeline_model_parallel_rank():
|
||||
"""Return the virtual pipeline-parallel rank."""
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
|
||||
|
||||
def set_virtual_pipeline_model_parallel_rank(rank):
|
||||
"""Set the virtual pipeline-parallel rank."""
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
|
||||
|
||||
|
||||
def get_virtual_pipeline_model_parallel_world_size():
|
||||
"""Return the virtual pipeline-parallel world size."""
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
|
||||
|
||||
def get_tensor_model_parallel_src_rank():
|
||||
"""Calculate the global rank corresponding to the first local rank
|
||||
in the tensor model parallel group."""
|
||||
global_rank = torch.distributed.get_rank()
|
||||
local_world_size = get_tensor_model_parallel_world_size()
|
||||
return (global_rank // local_world_size) * local_world_size
|
||||
|
||||
|
||||
def get_data_parallel_src_rank():
|
||||
"""Calculate the global rank corresponding to the first local rank
|
||||
in the data parallel group."""
|
||||
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
|
||||
"Data parallel group is not initialized"
|
||||
return _DATA_PARALLEL_GLOBAL_RANKS[0]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_first_rank():
|
||||
"""Return the global rank of the first process in the pipeline for the
|
||||
current tensor parallel group"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||
"Pipeline parallel group is not initialized"
|
||||
return _PIPELINE_GLOBAL_RANKS[0]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_last_rank():
|
||||
"""Return the global rank of the last process in the pipeline for the
|
||||
current tensor parallel group"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||
"Pipeline parallel group is not initialized"
|
||||
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
||||
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
||||
|
||||
def get_pipeline_model_parallel_next_rank():
|
||||
"""Return the global rank that follows the caller in the pipeline"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||
"Pipeline parallel group is not initialized"
|
||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||
world_size = get_pipeline_model_parallel_world_size()
|
||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_prev_rank():
|
||||
"""Return the global rank that preceeds the caller in the pipeline"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||
"Pipeline parallel group is not initialized"
|
||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||
world_size = get_pipeline_model_parallel_world_size()
|
||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
||||
|
||||
|
||||
def get_data_parallel_world_size():
|
||||
"""Return world size for the data parallel group."""
|
||||
return torch.distributed.get_world_size(group=get_data_parallel_group())
|
||||
|
||||
|
||||
def get_data_parallel_rank():
|
||||
"""Return my rank for the data parallel group."""
|
||||
return torch.distributed.get_rank(group=get_data_parallel_group())
|
||||
|
||||
def _set_global_memory_buffer():
|
||||
"""Initialize global buffer"""
|
||||
global _GLOBAL_MEMORY_BUFFER
|
||||
assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
|
||||
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
|
||||
|
||||
def get_global_memory_buffer():
|
||||
"""Return the global GlobalMemoryBuffer object"""
|
||||
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
|
||||
return _GLOBAL_MEMORY_BUFFER
|
||||
|
||||
|
||||
def destroy_model_parallel():
|
||||
"""Set the groups to none."""
|
||||
global _MODEL_PARALLEL_GROUP
|
||||
_MODEL_PARALLEL_GROUP = None
|
||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
global _DATA_PARALLEL_GROUP
|
||||
_DATA_PARALLEL_GROUP = None
|
||||
global _EMBEDDING_GROUP
|
||||
_EMBEDDING_GROUP = None
|
||||
global _POSITION_EMBEDDING_GROUP
|
||||
_POSITION_EMBEDDING_GROUP = None
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
global _GLOBAL_MEMORY_BUFFER
|
||||
_GLOBAL_MEMORY_BUFFER = None
|
58
cacheflow/parallel_utils/tensor_parallel/__init__.py
Normal file
58
cacheflow/parallel_utils/tensor_parallel/__init__.py
Normal file
@ -0,0 +1,58 @@
|
||||
from .layers import (
|
||||
ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
VocabParallelEmbedding,
|
||||
set_tensor_model_parallel_attributes,
|
||||
set_defaults_if_not_set_tensor_model_parallel_attributes,
|
||||
copy_tensor_model_parallel_attributes,
|
||||
param_is_not_tensor_parallel_duplicate,
|
||||
linear_with_grad_accumulation_and_async_allreduce
|
||||
|
||||
)
|
||||
|
||||
from .mappings import (
|
||||
copy_to_tensor_model_parallel_region,
|
||||
gather_from_tensor_model_parallel_region,
|
||||
gather_from_sequence_parallel_region,
|
||||
scatter_to_tensor_model_parallel_region,
|
||||
scatter_to_sequence_parallel_region,
|
||||
)
|
||||
|
||||
from .random import (
|
||||
checkpoint,
|
||||
get_cuda_rng_tracker,
|
||||
model_parallel_cuda_manual_seed,
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
split_tensor_along_last_dim,
|
||||
split_tensor_into_1d_equal_chunks,
|
||||
gather_split_1d_tensor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
#layers.py
|
||||
"ColumnParallelLinear",
|
||||
"RowParallelLinear",
|
||||
"VocabParallelEmbedding",
|
||||
"set_tensor_model_parallel_attributes",
|
||||
"set_defaults_if_not_set_tensor_model_parallel_attributes",
|
||||
"copy_tensor_model_parallel_attributes",
|
||||
"param_is_not_tensor_parallel_duplicate",
|
||||
"linear_with_grad_accumulation_and_async_allreduce",
|
||||
# mappings.py
|
||||
"copy_to_tensor_model_parallel_region",
|
||||
"gather_from_tensor_model_parallel_region",
|
||||
"gather_from_sequence_parallel_region",
|
||||
# "reduce_from_tensor_model_parallel_region",
|
||||
"scatter_to_tensor_model_parallel_region",
|
||||
"scatter_to_sequence_parallel_region",
|
||||
# random.py
|
||||
"checkpoint",
|
||||
"get_cuda_rng_tracker",
|
||||
"model_parallel_cuda_manual_seed",
|
||||
# utils.py
|
||||
"split_tensor_along_last_dim",
|
||||
"split_tensor_into_1d_equal_chunks",
|
||||
"gather_split_1d_tensor",
|
||||
]
|
719
cacheflow/parallel_utils/tensor_parallel/layers.py
Normal file
719
cacheflow/parallel_utils/tensor_parallel/layers.py
Normal file
@ -0,0 +1,719 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
# repo: https://github.com/pytorch/pytorch
|
||||
|
||||
import math
|
||||
import os
|
||||
from typing import Optional
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.init as init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_group,
|
||||
get_global_memory_buffer,
|
||||
)
|
||||
from .mappings import (
|
||||
copy_to_tensor_model_parallel_region,
|
||||
gather_from_tensor_model_parallel_region,
|
||||
gather_from_sequence_parallel_region,
|
||||
reduce_from_tensor_model_parallel_region,
|
||||
scatter_to_tensor_model_parallel_region,
|
||||
reduce_scatter_to_sequence_parallel_region,
|
||||
)
|
||||
|
||||
from .random import get_cuda_rng_tracker
|
||||
from .utils import (
|
||||
divide,
|
||||
split_tensor_along_last_dim,
|
||||
VocabUtility,
|
||||
)
|
||||
|
||||
_grad_accum_fusion_available = True
|
||||
try:
|
||||
import fused_weight_gradient_mlp_cuda
|
||||
except ImportError:
|
||||
_grad_accum_fusion_available = False
|
||||
|
||||
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
|
||||
'partition_dim': -1,
|
||||
'partition_stride': 1}
|
||||
|
||||
def param_is_not_tensor_parallel_duplicate(param):
|
||||
return (hasattr(param, 'tensor_model_parallel') and
|
||||
param.tensor_model_parallel) or (
|
||||
get_tensor_model_parallel_rank() == 0)
|
||||
|
||||
|
||||
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
|
||||
# Make sure the attributes are not set.
|
||||
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
||||
assert not hasattr(tensor, attribute)
|
||||
# Set the attributes.
|
||||
setattr(tensor, 'tensor_model_parallel', is_parallel)
|
||||
setattr(tensor, 'partition_dim', dim)
|
||||
setattr(tensor, 'partition_stride', stride)
|
||||
|
||||
|
||||
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
|
||||
def maybe_set(attribute, value):
|
||||
if not hasattr(tensor, attribute):
|
||||
setattr(tensor, attribute, value)
|
||||
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
||||
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
|
||||
|
||||
|
||||
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
|
||||
def maybe_copy(attribute):
|
||||
if hasattr(source_tensor, attribute):
|
||||
setattr(destination_tensor, attribute,
|
||||
getattr(source_tensor, attribute))
|
||||
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
||||
maybe_copy(attribute)
|
||||
|
||||
|
||||
def _initialize_affine_weight_gpu(weight, init_method,
|
||||
partition_dim, stride=1):
|
||||
"""Initialize affine weight for model parallel on GPU."""
|
||||
|
||||
set_tensor_model_parallel_attributes(tensor=weight,
|
||||
is_parallel=True,
|
||||
dim=partition_dim,
|
||||
stride=stride)
|
||||
|
||||
with get_cuda_rng_tracker().fork():
|
||||
init_method(weight)
|
||||
|
||||
|
||||
def _initialize_affine_weight_cpu(weight, output_size, input_size,
|
||||
per_partition_size, partition_dim,
|
||||
init_method, stride=1,
|
||||
return_master_weight=False,
|
||||
*, params_dtype=None):
|
||||
"""Initialize affine weight for model parallel.
|
||||
|
||||
Build the master weight on all processes and scatter
|
||||
the relevant chunk."""
|
||||
|
||||
set_tensor_model_parallel_attributes(tensor=weight,
|
||||
is_parallel=True,
|
||||
dim=partition_dim,
|
||||
stride=stride)
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Initialize master weight
|
||||
master_weight = torch.empty(output_size, input_size,
|
||||
dtype=torch.float,
|
||||
requires_grad=False)
|
||||
init_method(master_weight)
|
||||
master_weight = master_weight.to(dtype=params_dtype)
|
||||
|
||||
# Split and copy
|
||||
per_partition_per_stride_size = divide(per_partition_size, stride)
|
||||
weight_list = torch.split(master_weight, per_partition_per_stride_size,
|
||||
dim=partition_dim)
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
my_weight_list = weight_list[rank::world_size]
|
||||
|
||||
with torch.no_grad():
|
||||
torch.cat(my_weight_list, dim=partition_dim, out=weight)
|
||||
if return_master_weight:
|
||||
return master_weight
|
||||
return None
|
||||
|
||||
|
||||
class VocabParallelEmbedding(torch.nn.Module):
|
||||
"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
This is mainly adapted from torch.nn.Embedding and all the default
|
||||
values are kept.
|
||||
Arguments:
|
||||
num_embeddings: vocabulary size.
|
||||
embedding_dim: size of hidden state.
|
||||
|
||||
Keyword Arguments:
|
||||
init_method: method to initialize weights.
|
||||
params_dtype
|
||||
use_cpu_initialization
|
||||
perform_initialization
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, *,
|
||||
init_method=init.xavier_normal_,
|
||||
params_dtype: torch.dtype=None,
|
||||
use_cpu_initialization: bool=False,
|
||||
perform_initialization: bool=True):
|
||||
super(VocabParallelEmbedding, self).__init__()
|
||||
# Keep the input dimensions.
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Set the defaults for compatibility.
|
||||
self.padding_idx = None
|
||||
self.max_norm = None
|
||||
self.norm_type = 2.
|
||||
self.scale_grad_by_freq = False
|
||||
self.sparse = False
|
||||
self._weight = None
|
||||
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
|
||||
# Divide the weight matrix along the vocaburaly dimension.
|
||||
self.vocab_start_index, self.vocab_end_index = \
|
||||
VocabUtility.vocab_range_from_global_vocab_size(
|
||||
self.num_embeddings, get_tensor_model_parallel_rank(),
|
||||
self.tensor_model_parallel_size)
|
||||
self.num_embeddings_per_partition = self.vocab_end_index - \
|
||||
self.vocab_start_index
|
||||
|
||||
# Allocate weights and initialize.
|
||||
if use_cpu_initialization:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.num_embeddings_per_partition, self.embedding_dim,
|
||||
dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
_initialize_affine_weight_cpu(
|
||||
self.weight, self.num_embeddings, self.embedding_dim,
|
||||
self.num_embeddings_per_partition, 0, init_method,
|
||||
params_dtype=params_dtype)
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.num_embeddings_per_partition, self.embedding_dim,
|
||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
||||
partition_dim=0, stride=1)
|
||||
|
||||
def forward(self, input_):
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
# Build the mask.
|
||||
input_mask = (input_ < self.vocab_start_index) | \
|
||||
(input_ >= self.vocab_end_index)
|
||||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
else:
|
||||
masked_input = input_
|
||||
# Get the embeddings.
|
||||
output_parallel = F.embedding(masked_input, self.weight,
|
||||
self.padding_idx, self.max_norm,
|
||||
self.norm_type, self.scale_grad_by_freq,
|
||||
self.sparse)
|
||||
# Mask the output embedding.
|
||||
if self.tensor_model_parallel_size > 1:
|
||||
output_parallel[input_mask, :] = 0.0
|
||||
# Reduce across all the model parallel GPUs.
|
||||
output = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||
return output
|
||||
|
||||
|
||||
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
|
||||
"""See linear_with_grad_accumulation_and_async_allreduce"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
|
||||
async_grad_allreduce, sequence_parallel):
|
||||
ctx.save_for_backward(input, weight)
|
||||
ctx.use_bias = bias is not None
|
||||
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
|
||||
ctx.async_grad_allreduce = async_grad_allreduce
|
||||
ctx.sequence_parallel = sequence_parallel
|
||||
|
||||
if sequence_parallel:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
dim_size = list(input.size())
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
all_gather_buffer = \
|
||||
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
|
||||
torch.distributed._all_gather_base(
|
||||
all_gather_buffer,
|
||||
input,
|
||||
group=get_tensor_model_parallel_group())
|
||||
total_input = all_gather_buffer
|
||||
else:
|
||||
total_input = input
|
||||
|
||||
output = torch.matmul(total_input, weight.t())
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.saved_tensors
|
||||
use_bias = ctx.use_bias
|
||||
|
||||
if ctx.sequence_parallel:
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
dim_size = list(input.size())
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
all_gather_buffer = \
|
||||
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
|
||||
handle = torch.distributed._all_gather_base(
|
||||
all_gather_buffer,
|
||||
input,
|
||||
group=get_tensor_model_parallel_group(), async_op=True)
|
||||
|
||||
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
||||
# gather is scheduled before the input gradient computation
|
||||
total_input = all_gather_buffer
|
||||
else:
|
||||
total_input = input
|
||||
grad_input = grad_output.matmul(weight)
|
||||
|
||||
if ctx.sequence_parallel:
|
||||
handle.wait()
|
||||
|
||||
# Convert the tensor shapes to 2D for execution compatibility
|
||||
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
|
||||
grad_output.shape[2])
|
||||
total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
|
||||
total_input.shape[2])
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
# Asynchronous all-reduce
|
||||
handle = torch.distributed.all_reduce(
|
||||
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
|
||||
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
||||
# all-reduce is scheduled before the weight gradient computation
|
||||
|
||||
if ctx.sequence_parallel:
|
||||
assert not ctx.async_grad_allreduce
|
||||
dim_size = list(input.size())
|
||||
sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False)
|
||||
# reduce_scatter
|
||||
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
|
||||
group=get_tensor_model_parallel_group(),
|
||||
async_op=True)
|
||||
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
||||
# reduce scatter is scheduled before the weight gradient computation
|
||||
|
||||
|
||||
if ctx.gradient_accumulation_fusion:
|
||||
if weight.main_grad.dtype == torch.float32:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
|
||||
elif weight.main_grad.dtype == torch.float16:
|
||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, weight.main_grad)
|
||||
else:
|
||||
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
|
||||
grad_weight = None
|
||||
else:
|
||||
grad_weight = grad_output.t().matmul(total_input)
|
||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
||||
|
||||
if ctx.sequence_parallel:
|
||||
handle.wait()
|
||||
return sub_grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
if ctx.async_grad_allreduce:
|
||||
handle.wait()
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
def linear_with_grad_accumulation_and_async_allreduce(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: Optional[torch.Tensor],
|
||||
gradient_accumulation_fusion: bool,
|
||||
async_grad_allreduce: bool,
|
||||
sequence_parallel_enabled: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Linear layer execution with asynchronous communication and
|
||||
gradient accumulation fusion in backprop.
|
||||
|
||||
This has the option to accumulate the result of backprop
|
||||
calculation into an existing gradient buffer, preventing the need
|
||||
to do an additional addition kernel after the gradient
|
||||
calculation.
|
||||
|
||||
Additionally, the tensor parallel all reduce of the input
|
||||
gradients can be done asynchronously with the calculation of
|
||||
the weight gradients.
|
||||
|
||||
In the case of sequence parallelism, the reduce scatter of the
|
||||
input gradients is done asynchronously with the calcluation of the
|
||||
weight gradients.
|
||||
|
||||
Use of this module requires that the environment variable
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
|
||||
operations, noted in the code, that should be scheduled before
|
||||
compute kernels to overlap the communication with the computation,
|
||||
which is necessary for a speedup but not for correctness so that
|
||||
ordering isn't imposed by the scheduler. Setting
|
||||
CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
|
||||
in the order they are called.
|
||||
|
||||
Arguments:
|
||||
|
||||
input (torch.Tensor required): input like torch.nn.functional.linear
|
||||
|
||||
weight (torch.Tensor required): weight like torch.nn.functional.linear
|
||||
|
||||
bias (torch.Tensor optional): bias like torch.nn.functional.linear
|
||||
|
||||
gradient_accumulation_fusion (bool required): Perform the gradient
|
||||
accumulation fusion, requires the custom CUDA extension
|
||||
fused_weight_gradient_mlp_cuda module. To use
|
||||
gradient_accumulation_fusion you must install APEX with
|
||||
--cpp_ext and --cuda_ext. For example: "pip install
|
||||
--global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\"
|
||||
" Note that the extension requires CUDA>=11. Otherwise, you
|
||||
must turn off gradient accumulation fusion."
|
||||
|
||||
async_grad_allreduce (bool required): Do the allreduce of input
|
||||
gradients asyncronously with the computation of weight
|
||||
gradients. If sequence_parallel_enabled is True, this must be
|
||||
False, as no all reduce is performed.
|
||||
|
||||
sequence_parallel_enabled (bool required): Indicates that sequence
|
||||
parallelism is used and thus in the forward pass the input is
|
||||
all gathered, and the backward pass the input gradients are
|
||||
reduce scattered.
|
||||
"""
|
||||
args = [
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
gradient_accumulation_fusion,
|
||||
async_grad_allreduce,
|
||||
sequence_parallel_enabled,
|
||||
]
|
||||
|
||||
if not linear_with_grad_accumulation_and_async_allreduce.warned:
|
||||
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
|
||||
if sequence_parallel_enabled:
|
||||
warnings.warn(
|
||||
"When using sequence parallelism it is recommended to set the "
|
||||
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
|
||||
"maximum speedup")
|
||||
linear_with_grad_accumulation_and_async_allreduce.warned = True
|
||||
|
||||
if async_grad_allreduce:
|
||||
warnings.warn(
|
||||
"When using async grad allreduce it is recommended to set the "
|
||||
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
|
||||
"maximum speedup")
|
||||
linear_with_grad_accumulation_and_async_allreduce.warned = True
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
|
||||
linear_with_grad_accumulation_and_async_allreduce.warned = False
|
||||
|
||||
class ColumnParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its second dimension as A = [A_1, ..., A_p].
|
||||
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
|
||||
Keyword Arguments
|
||||
bias: If true, add bias
|
||||
gather_output: If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is Y_i = XA_i
|
||||
init_method: method to initialize weights. Note that bias is always set
|
||||
to zero.
|
||||
stride: For the strided linear layers.
|
||||
keep_master_weight_for_test: This was added for testing and should be
|
||||
set to False. It returns the master weights
|
||||
used for initialization.
|
||||
skip_bias_add: This was added to enable performance optimations where bias
|
||||
can be fused with other elementwise operations. we skip
|
||||
adding bias but instead return it.
|
||||
async_tensor_model_parallel_allreduce:
|
||||
params_dtype:
|
||||
use_cpu_initialization:
|
||||
gradient_accumulation_fusion:
|
||||
sequence_parallel_enabled:
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, output_size, *,
|
||||
bias=True, gather_output=True,
|
||||
init_method=init.xavier_normal_, stride=1,
|
||||
keep_master_weight_for_test=False,
|
||||
skip_bias_add=False,
|
||||
async_tensor_model_parallel_allreduce=True,
|
||||
params_dtype=None,
|
||||
use_cpu_initialization=False,
|
||||
perform_initialization=True,
|
||||
gradient_accumulation_fusion=False,
|
||||
sequence_parallel_enabled: bool = False,
|
||||
):
|
||||
super(ColumnParallelLinear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.gather_output = gather_output
|
||||
# Divide the weight matrix along the last dimension.
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
self.output_size_per_partition = divide(output_size, world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
# Initialize weight.
|
||||
if use_cpu_initialization:
|
||||
self.weight = Parameter(torch.empty(self.output_size_per_partition,
|
||||
self.input_size,
|
||||
dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
self.master_weight = _initialize_affine_weight_cpu(
|
||||
self.weight, self.output_size, self.input_size,
|
||||
self.output_size_per_partition, 0, init_method,
|
||||
stride=stride, return_master_weight=keep_master_weight_for_test)
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size_per_partition, self.input_size,
|
||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
||||
partition_dim=0, stride=stride)
|
||||
|
||||
if bias:
|
||||
if use_cpu_initialization:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size_per_partition, dtype=params_dtype))
|
||||
else:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
self.async_tensor_model_parallel_allreduce = (
|
||||
async_tensor_model_parallel_allreduce and
|
||||
world_size > 1)
|
||||
if sequence_parallel_enabled:
|
||||
if world_size <= 1:
|
||||
warnings.warn(
|
||||
f"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is {world_size}. "
|
||||
f"Disabling sequence parallel."
|
||||
)
|
||||
sequence_parallel_enabled = False
|
||||
self.sequence_parallel_enabled = sequence_parallel_enabled
|
||||
|
||||
if gradient_accumulation_fusion:
|
||||
if not _grad_accum_fusion_available:
|
||||
raise RuntimeError(
|
||||
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
|
||||
"to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
|
||||
"module is not found. To use gradient_accumulation_fusion you must "
|
||||
"install APEX with --cpp_ext and --cuda_ext. For example: "
|
||||
"pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" "
|
||||
"Note that the extension requires CUDA>=11. Otherwise, you must turn off "
|
||||
"gradient accumulation fusion."
|
||||
)
|
||||
self.gradient_accumulation_fusion = gradient_accumulation_fusion
|
||||
|
||||
if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled:
|
||||
raise RuntimeError(
|
||||
"`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` "
|
||||
"cannot be enabled at the same time."
|
||||
)
|
||||
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of ColumnParallelLinear
|
||||
|
||||
Args:
|
||||
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
if self.async_tensor_model_parallel_allreduce or \
|
||||
self.sequence_parallel_enabled:
|
||||
input_parallel = input_
|
||||
else:
|
||||
input_parallel = copy_to_tensor_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
output_parallel = linear_with_grad_accumulation_and_async_allreduce(
|
||||
input=input_parallel,
|
||||
weight=self.weight,
|
||||
bias=bias,
|
||||
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
|
||||
async_grad_allreduce=self.async_tensor_model_parallel_allreduce,
|
||||
sequence_parallel_enabled=self.sequence_parallel_enabled,
|
||||
)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
assert not self.sequence_parallel_enabled
|
||||
output = gather_from_tensor_model_parallel_region(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
|
||||
class RowParallelLinear(torch.nn.Module):
|
||||
"""Linear layer with row parallelism.
|
||||
|
||||
The linear layer is defined as Y = XA + b. A is parallelized along
|
||||
its first dimension and X along its second dimension as:
|
||||
- -
|
||||
| A_1 |
|
||||
| . |
|
||||
A = | . | X = [X_1, ..., X_p]
|
||||
| . |
|
||||
| A_p |
|
||||
- -
|
||||
Arguments:
|
||||
input_size: first dimension of matrix A.
|
||||
output_size: second dimension of matrix A.
|
||||
|
||||
Keyword Arguments:
|
||||
bias: If true, add bias. Note that bias is not parallelized.
|
||||
input_is_parallel: If true, we assume that the input is already
|
||||
split across the GPUs and we do not split
|
||||
again.
|
||||
init_method: method to initialize weights. Note that bias is always set
|
||||
to zero.
|
||||
stride: For the strided linear layers.
|
||||
keep_master_weight_for_test: This was added for testing and should be
|
||||
set to False. It returns the master weights
|
||||
used for initialization.
|
||||
skip_bias_add: This was added to enable performance optimization where bias
|
||||
can be fused with other elementwise operations. We skip
|
||||
adding bias but instead return it.
|
||||
params_dtype:
|
||||
use_cpu_initialization:
|
||||
perform_initialization:
|
||||
gradient_accumulation_fusion:
|
||||
sequence_parallel_enabled:
|
||||
"""
|
||||
|
||||
def __init__(self, input_size, output_size, *,
|
||||
bias=True, input_is_parallel=False,
|
||||
init_method=init.xavier_normal_, stride=1,
|
||||
keep_master_weight_for_test=False,
|
||||
skip_bias_add=False,
|
||||
params_dtype=None,
|
||||
use_cpu_initialization=False,
|
||||
perform_initialization=True,
|
||||
gradient_accumulation_fusion=False,
|
||||
sequence_parallel_enabled: bool = False,
|
||||
):
|
||||
super(RowParallelLinear, self).__init__()
|
||||
|
||||
# Keep input parameters
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.input_is_parallel = input_is_parallel
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
# Divide the weight matrix along the last dimension.
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, world_size)
|
||||
self.skip_bias_add = skip_bias_add
|
||||
self.gradient_accumulation_fusion = gradient_accumulation_fusion
|
||||
self.sequence_parallel_enabled = sequence_parallel_enabled
|
||||
if self.sequence_parallel_enabled and not self.input_is_parallel:
|
||||
raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`")
|
||||
|
||||
# Parameters.
|
||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||
# we allocate the transpose.
|
||||
# Initialize weight.
|
||||
if use_cpu_initialization:
|
||||
self.weight = Parameter(torch.empty(self.output_size,
|
||||
self.input_size_per_partition,
|
||||
dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
self.master_weight = _initialize_affine_weight_cpu(
|
||||
self.weight, self.output_size, self.input_size,
|
||||
self.input_size_per_partition, 1, init_method,
|
||||
stride=stride, return_master_weight=keep_master_weight_for_test,
|
||||
params_dtype=params_dtype)
|
||||
else:
|
||||
self.weight = Parameter(torch.empty(
|
||||
self.output_size, self.input_size_per_partition,
|
||||
device=torch.cuda.current_device(), dtype=params_dtype))
|
||||
if perform_initialization:
|
||||
_initialize_affine_weight_gpu(self.weight, init_method,
|
||||
partition_dim=1, stride=stride)
|
||||
if bias:
|
||||
if use_cpu_initialization:
|
||||
self.bias = Parameter(torch.empty(self.output_size,
|
||||
dtype=params_dtype))
|
||||
else:
|
||||
self.bias = Parameter(torch.empty(
|
||||
self.output_size, device=torch.cuda.current_device(),
|
||||
dtype=params_dtype))
|
||||
setattr(self.bias, 'sequence_parallel', sequence_parallel_enabled)
|
||||
|
||||
# Always initialize bias to zero.
|
||||
with torch.no_grad():
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of RowParallelLinear
|
||||
|
||||
Args:
|
||||
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
# Set up backprop all-reduce.
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
assert not self.sequence_parallel_enabled
|
||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
output_parallel = linear_with_grad_accumulation_and_async_allreduce(
|
||||
input=input_parallel,
|
||||
weight=self.weight,
|
||||
bias=None,
|
||||
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
|
||||
async_grad_allreduce=False,
|
||||
sequence_parallel_enabled=False,
|
||||
)
|
||||
|
||||
# All-reduce across all the partitions.
|
||||
if self.sequence_parallel_enabled:
|
||||
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
|
||||
else:
|
||||
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||
if not self.skip_bias_add:
|
||||
output = output_ + self.bias if self.bias is not None else output_
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.bias
|
||||
return output, output_bias
|
279
cacheflow/parallel_utils/tensor_parallel/mappings.py
Normal file
279
cacheflow/parallel_utils/tensor_parallel/mappings.py
Normal file
@ -0,0 +1,279 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
import torch
|
||||
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_group,
|
||||
)
|
||||
from .utils import split_tensor_along_last_dim
|
||||
|
||||
|
||||
def _reduce(input_):
|
||||
"""All-reduce the input tensor across model parallel group."""
|
||||
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if get_tensor_model_parallel_world_size()==1:
|
||||
return input_
|
||||
|
||||
# All-reduce.
|
||||
torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
|
||||
|
||||
return input_
|
||||
|
||||
|
||||
def _split_along_last_dim(input_):
|
||||
"""Split the tensor along its last dimension and keep the
|
||||
corresponding slice."""
|
||||
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Split along last dimension.
|
||||
input_list = split_tensor_along_last_dim(input_, world_size)
|
||||
|
||||
# Note: torch.split does not create contiguous tensors by default.
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
output = input_list[rank].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _split_along_first_dim(input_):
|
||||
"""Split the tensor along its first dimension and keep the
|
||||
corresponding slice."""
|
||||
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Split along first dimension.
|
||||
dim_size = input_.size()[0]
|
||||
assert dim_size % world_size == 0, \
|
||||
"First dimension of the tensor should be divisible by tensor parallel size"
|
||||
local_dim_size = dim_size // world_size
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
dim_offset = rank * local_dim_size
|
||||
|
||||
output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather_along_last_dim(input_):
|
||||
"""Gather tensors and concatinate along the last dimension."""
|
||||
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
# Size and dimension.
|
||||
last_dim = input_.dim() - 1
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
|
||||
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
||||
tensor_list[rank] = input_
|
||||
torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
|
||||
|
||||
# Note: torch.cat already creates a contiguous tensor.
|
||||
output = torch.cat(tensor_list, dim=last_dim).contiguous()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _gather_along_first_dim(input_):
|
||||
"""Gather tensors and concatinate along the first dimension."""
|
||||
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size, dtype=input_.dtype,
|
||||
device=torch.cuda.current_device())
|
||||
torch.distributed._all_gather_base(output, input_.contiguous(),
|
||||
group=get_tensor_model_parallel_group())
|
||||
|
||||
return output
|
||||
|
||||
def _reduce_scatter_along_first_dim(input_):
|
||||
"""Reduce-scatter the input tensor across model parallel group."""
|
||||
world_size = get_tensor_model_parallel_world_size()
|
||||
# Bypass the function if we are using only 1 GPU.
|
||||
if world_size == 1:
|
||||
return input_
|
||||
|
||||
dim_size = list(input_.size())
|
||||
assert dim_size[0] % world_size == 0, \
|
||||
"First dimension of the tensor should be divisible by tensor parallel size"
|
||||
|
||||
dim_size[0] = dim_size[0] // world_size
|
||||
|
||||
output = torch.empty(dim_size, dtype=input_.dtype,
|
||||
device=torch.cuda.current_device())
|
||||
torch.distributed._reduce_scatter_base(output, input_.contiguous(),
|
||||
group=get_tensor_model_parallel_group())
|
||||
return output
|
||||
|
||||
|
||||
class _CopyToModelParallelRegion(torch.autograd.Function):
|
||||
"""Pass the input to the model parallel region."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return input_
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _reduce(grad_output)
|
||||
|
||||
|
||||
class _ReduceFromModelParallelRegion(torch.autograd.Function):
|
||||
"""All-reduce the input from the model parallel region."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _reduce(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _reduce(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
|
||||
class _ScatterToModelParallelRegion(torch.autograd.Function):
|
||||
"""Split the input and keep only the corresponding chuck to the rank."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split_along_last_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _split_along_last_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather_along_last_dim(grad_output)
|
||||
|
||||
|
||||
class _GatherFromModelParallelRegion(torch.autograd.Function):
|
||||
"""Gather the input from model parallel region and concatinate."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _gather_along_last_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _gather_along_last_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _split_along_last_dim(grad_output)
|
||||
|
||||
|
||||
class _ScatterToSequenceParallelRegion(torch.autograd.Function):
|
||||
"""Split the input and keep only the corresponding chuck to the rank."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _split_along_first_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _split_along_first_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather_along_first_dim(grad_output)
|
||||
|
||||
|
||||
class _GatherFromSequenceParallelRegion(torch.autograd.Function):
|
||||
"""Gather the input from sequence parallel region and concatinate."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_, tensor_parallel_output_grad=True):
|
||||
return _gather_along_first_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_, tensor_parallel_output_grad=True):
|
||||
ctx.tensor_parallel_output_grad = tensor_parallel_output_grad
|
||||
return _gather_along_first_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
tensor_parallel_output_grad = ctx.tensor_parallel_output_grad
|
||||
|
||||
# If the computation graph after the gather operation is
|
||||
# in the tensor parallel mode, output gradients need to reduce
|
||||
# scattered and whereas if the computation is duplicated,
|
||||
# output gradients need to be scattered.
|
||||
if tensor_parallel_output_grad:
|
||||
return _reduce_scatter_along_first_dim(grad_output), None
|
||||
else:
|
||||
return _split_along_first_dim(grad_output), None
|
||||
|
||||
|
||||
class _ReduceScatterToSequenceParallelRegion(torch.autograd.Function):
|
||||
"""Reduce scatter the input from the model parallel region."""
|
||||
|
||||
@staticmethod
|
||||
def symbolic(graph, input_):
|
||||
return _reduce_scatter_along_first_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_):
|
||||
return _reduce_scatter_along_first_dim(input_)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return _gather_along_first_dim(grad_output)
|
||||
|
||||
|
||||
# -----------------
|
||||
# Helper functions.
|
||||
# -----------------
|
||||
|
||||
def copy_to_tensor_model_parallel_region(input_):
|
||||
return _CopyToModelParallelRegion.apply(input_)
|
||||
|
||||
|
||||
def reduce_from_tensor_model_parallel_region(input_):
|
||||
return _ReduceFromModelParallelRegion.apply(input_)
|
||||
|
||||
|
||||
def scatter_to_tensor_model_parallel_region(input_):
|
||||
return _ScatterToModelParallelRegion.apply(input_)
|
||||
|
||||
|
||||
def gather_from_tensor_model_parallel_region(input_):
|
||||
return _GatherFromModelParallelRegion.apply(input_)
|
||||
|
||||
|
||||
def scatter_to_sequence_parallel_region(input_):
|
||||
return _ScatterToSequenceParallelRegion.apply(input_)
|
||||
|
||||
|
||||
def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True):
|
||||
return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad)
|
||||
|
||||
|
||||
def reduce_scatter_to_sequence_parallel_region(input_):
|
||||
return _ReduceScatterToSequenceParallelRegion.apply(input_)
|
||||
|
253
cacheflow/parallel_utils/tensor_parallel/random.py
Normal file
253
cacheflow/parallel_utils/tensor_parallel/random.py
Normal file
@ -0,0 +1,253 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
# Parts of the code here are adapted from PyTorch
|
||||
# repo: https://github.com/pytorch/pytorch
|
||||
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from torch import _C
|
||||
from torch.cuda import _lazy_call, device as device_ctx_manager
|
||||
from torch.utils.checkpoint import detach_variable
|
||||
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
get_data_parallel_rank,
|
||||
get_tensor_model_parallel_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
split_tensor_into_1d_equal_chunks,
|
||||
gather_split_1d_tensor,
|
||||
)
|
||||
|
||||
from cacheflow.parallel_utils.utils import safely_set_viewless_tensor_data
|
||||
|
||||
# Default name for the model parallel rng tracker.
|
||||
_MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'
|
||||
|
||||
|
||||
def _set_cuda_rng_state(new_state, device=-1):
|
||||
"""Sets the random number generator state of the current GPU.
|
||||
|
||||
Argumentss:
|
||||
new_state (torch.ByteTensor): The desired state
|
||||
This function is adapted from PyTorch repo (torch.cuda.set_rng_state)
|
||||
with a single change: the input state is not cloned. Cloning caused
|
||||
major performance issues for +4 GPU cases.
|
||||
"""
|
||||
if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):
|
||||
# older PyTorch
|
||||
def cb():
|
||||
with device_ctx_manager(device):
|
||||
_C._cuda_setRNGState(new_state)
|
||||
else:
|
||||
# newer PyTorch
|
||||
if device == -1:
|
||||
device = torch.device('cuda')
|
||||
elif isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
elif isinstance(device, int):
|
||||
device = torch.device('cuda', device)
|
||||
|
||||
def cb():
|
||||
idx = device.index
|
||||
if idx is None:
|
||||
idx = torch.cuda.current_device()
|
||||
default_generator = torch.cuda.default_generators[idx]
|
||||
default_generator.set_state(new_state)
|
||||
|
||||
_lazy_call(cb)
|
||||
|
||||
|
||||
|
||||
class CudaRNGStatesTracker:
|
||||
"""Tracker for the cuda RNG states.
|
||||
|
||||
Using the `add` method, a cuda rng state is initialized based on
|
||||
the input `seed` and is assigned to `name`. Later, by forking the
|
||||
rng state, we can perform operations and return to our starting
|
||||
cuda state.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Map from a string name to the cuda rng state.
|
||||
self.states_ = {}
|
||||
# Seeds are just for book keeping and ensure no seed is set twice.
|
||||
self.seeds_ = set()
|
||||
|
||||
def reset(self):
|
||||
"""Set to the initial state (no tracker)."""
|
||||
self.states_ = {}
|
||||
self.seeds_ = set()
|
||||
|
||||
def get_states(self):
|
||||
"""Get rng states. Copy the dictionary so we have direct
|
||||
pointers to the states, not just a pointer to the dictionary."""
|
||||
states = {}
|
||||
for name in self.states_:
|
||||
states[name] = self.states_[name]
|
||||
return states
|
||||
|
||||
def set_states(self, states):
|
||||
"""Set the rng states. For efficiency purposes, we do not check
|
||||
the size of seed for compatibility."""
|
||||
self.states_ = states
|
||||
|
||||
def add(self, name, seed):
|
||||
"""Track the rng state."""
|
||||
# Check seed is not already used.
|
||||
if seed in self.seeds_:
|
||||
raise Exception('seed {} already exists'.format(seed))
|
||||
self.seeds_.add(seed)
|
||||
# Check that state is not already defined.
|
||||
if name in self.states_:
|
||||
raise Exception('cuda rng state {} already exists'.format(name))
|
||||
# Get the current rng state.
|
||||
orig_rng_state = torch.cuda.get_rng_state()
|
||||
# Set the new state and store it.
|
||||
torch.cuda.manual_seed(seed)
|
||||
self.states_[name] = torch.cuda.get_rng_state()
|
||||
# Reset rng state to what it was.
|
||||
_set_cuda_rng_state(orig_rng_state)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):
|
||||
"""Fork the cuda rng state, perform operations, and exit with
|
||||
the original state."""
|
||||
# Check if we have added the state
|
||||
if name not in self.states_:
|
||||
raise Exception('cuda rng state {} is not added'.format(name))
|
||||
# Store current rng state.
|
||||
orig_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
# Set rng state to the desired one
|
||||
_set_cuda_rng_state(self.states_[name])
|
||||
# Do the stuff we wanted to do.
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Update the current rng state for later use.
|
||||
self.states_[name] = torch.cuda.get_rng_state()
|
||||
# And set the state to the original state we started with.
|
||||
_set_cuda_rng_state(orig_cuda_rng_state)
|
||||
|
||||
|
||||
# RNG tracker object.
|
||||
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()
|
||||
|
||||
|
||||
def get_cuda_rng_tracker():
|
||||
"""Get cuda rng tracker."""
|
||||
return _CUDA_RNG_STATE_TRACKER
|
||||
|
||||
|
||||
def model_parallel_cuda_manual_seed(seed):
|
||||
"""Initialize model parallel cuda seed.
|
||||
|
||||
This function should be called after the model parallel is
|
||||
initialized. Also, no torch.cuda.manual_seed should be called
|
||||
after this function. Basically, this is replacement for that
|
||||
function.
|
||||
Two set of RNG states are tracked:
|
||||
default state: This is for data parallelism and is the same among a
|
||||
set of model parallel GPUs but different across
|
||||
different model paralle groups. This is used for
|
||||
example for dropout in the non-tensor-model-parallel regions.
|
||||
tensor-model-parallel state: This state is different among a set of model
|
||||
parallel GPUs, but the same across data parallel
|
||||
groups. This is used for example for dropout in
|
||||
model parallel regions.
|
||||
"""
|
||||
# 2718 is just for fun and any POSITIVE value will work.
|
||||
offset = seed + 2718
|
||||
tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()
|
||||
# Data parallel gets the original seed.
|
||||
data_parallel_seed = seed
|
||||
|
||||
_CUDA_RNG_STATE_TRACKER.reset()
|
||||
# Set the default state.
|
||||
torch.cuda.manual_seed(data_parallel_seed)
|
||||
# and model parallel state.
|
||||
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,
|
||||
tensor_model_parallel_seed)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
"""This function is adapted from torch.utils.checkpoint with
|
||||
two main changes:
|
||||
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
|
||||
2) the states in the model parallel tracker are also properly
|
||||
tracked/set/reset.
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, distribute_saved_activations, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.distribute_saved_activations \
|
||||
= distribute_saved_activations
|
||||
|
||||
# Copy the rng states.
|
||||
ctx.fwd_cpu_rng_state = torch.get_rng_state()
|
||||
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = run_function(*args)
|
||||
|
||||
# Divide hidden states across model parallel group and only keep
|
||||
# the chunk corresponding to the current rank.
|
||||
if distribute_saved_activations:
|
||||
ctx.input_0_shape = args[0].data.shape
|
||||
safely_set_viewless_tensor_data(
|
||||
args[0],
|
||||
split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True))
|
||||
|
||||
# Store everything.
|
||||
ctx.save_for_backward(*args)
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *args):
|
||||
if not torch.autograd._is_checkpoint_valid():
|
||||
raise RuntimeError("Checkpointing is not compatible with .grad(), "
|
||||
"please use .backward() if possible")
|
||||
inputs = ctx.saved_tensors
|
||||
if ctx.distribute_saved_activations:
|
||||
safely_set_viewless_tensor_data(
|
||||
inputs[0],
|
||||
gather_split_1d_tensor(inputs[0].data).view(ctx.input_0_shape))
|
||||
|
||||
# Store the current states.
|
||||
bwd_cpu_rng_state = torch.get_rng_state()
|
||||
bwd_cuda_rng_state = torch.cuda.get_rng_state()
|
||||
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()
|
||||
|
||||
# Set the states to what it used to be before the forward pass.
|
||||
torch.set_rng_state(ctx.fwd_cpu_rng_state)
|
||||
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
|
||||
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)
|
||||
|
||||
# Compute the forward pass.
|
||||
detached_inputs = detach_variable(inputs)
|
||||
with torch.enable_grad():
|
||||
outputs = ctx.run_function(*detached_inputs)
|
||||
|
||||
# Set the states back to what it was at the start of this function.
|
||||
torch.set_rng_state(bwd_cpu_rng_state)
|
||||
_set_cuda_rng_state(bwd_cuda_rng_state)
|
||||
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)
|
||||
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
torch.autograd.backward(outputs, args)
|
||||
grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp
|
||||
for inp in detached_inputs)
|
||||
return (None, None) + grads
|
||||
|
||||
|
||||
def checkpoint(function, distribute_saved_activations, *args):
|
||||
"""Checkpoint a model or part of the model.
|
||||
This has been directly copied from torch.utils.checkpoint."""
|
||||
return CheckpointFunction.apply(function,
|
||||
distribute_saved_activations, *args)
|
108
cacheflow/parallel_utils/tensor_parallel/utils.py
Normal file
108
cacheflow/parallel_utils/tensor_parallel/utils.py
Normal file
@ -0,0 +1,108 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
import torch
|
||||
from typing import List, Sequence
|
||||
|
||||
from cacheflow.parallel_utils.utils import divide
|
||||
from cacheflow.parallel_utils import parallel_state
|
||||
|
||||
def split_tensor_along_last_dim(
|
||||
tensor: torch.Tensor,
|
||||
num_partitions: int,
|
||||
contiguous_split_chunks: bool = False,
|
||||
) -> List[torch.Tensor]:
|
||||
""" Split a tensor along its last dimension.
|
||||
|
||||
Arguments:
|
||||
tensor: input tensor.
|
||||
num_partitions: number of partitions to split the tensor
|
||||
contiguous_split_chunks: If True, make each chunk contiguous
|
||||
in memory.
|
||||
|
||||
Returns:
|
||||
A list of Tensors
|
||||
"""
|
||||
# Get the size and dimension.
|
||||
last_dim = tensor.dim() - 1
|
||||
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
||||
# Split.
|
||||
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
||||
# Note: torch.split does not create contiguous tensors by default.
|
||||
if contiguous_split_chunks:
|
||||
return tuple(chunk.contiguous() for chunk in tensor_list)
|
||||
|
||||
return tensor_list
|
||||
|
||||
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
|
||||
""" Break a tensor into equal 1D chunks across tensor parallel ranks.
|
||||
|
||||
Returns a Tensor or View with this rank's portion of the data.
|
||||
|
||||
Arguments:
|
||||
tensor: The tensor to split
|
||||
|
||||
Keyword Arguments:
|
||||
new_buffer (bool): If True, returns a new Tensor.
|
||||
If False, returns a view into the existing Tensor.
|
||||
Default is False
|
||||
|
||||
"""
|
||||
partition_size = torch.numel(tensor) // \
|
||||
parallel_state.get_tensor_model_parallel_world_size()
|
||||
start_index = partition_size * parallel_state.get_tensor_model_parallel_rank()
|
||||
end_index = start_index + partition_size
|
||||
if new_buffer:
|
||||
data = torch.empty(partition_size, dtype=tensor.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False)
|
||||
data.copy_(tensor.view(-1)[start_index:end_index])
|
||||
else:
|
||||
data = tensor.view(-1)[start_index:end_index]
|
||||
return data
|
||||
|
||||
|
||||
def gather_split_1d_tensor(tensor):
|
||||
""" Opposite of split_tensor_into_1d_equal_chunks. Gather values from tensor
|
||||
model parallel ranks.
|
||||
|
||||
Returns a new Tensor with the gathered data.
|
||||
|
||||
Arguments:
|
||||
tensor: A Tensor or view of this rank's portion of the data.
|
||||
"""
|
||||
numel_gathered = torch.numel(tensor) * \
|
||||
parallel_state.get_tensor_model_parallel_world_size()
|
||||
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False)
|
||||
# TODO: This API is experimental in pytorch (as of Feb 2022) and
|
||||
# this might break in future pytorch releases. We chose this API
|
||||
# as opposed to torch.distributed.all_gather for efficiency reasons.
|
||||
# This API calls directly NCCL all-gather versus the former does
|
||||
# internal copies and can potentially cause slow down.
|
||||
torch.distributed._all_gather_base(gathered, tensor,
|
||||
group=parallel_state.get_tensor_model_parallel_group())
|
||||
return gathered
|
||||
|
||||
|
||||
class VocabUtility:
|
||||
""" Split the vocabulary into `world_size` chunks and return the first
|
||||
and last index of the vocabulary belonging to the `rank`
|
||||
partition: Note that indices in [fist, last)
|
||||
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size: int, rank, world_size: int
|
||||
) -> Sequence[int]:
|
||||
index_f = rank * per_partition_vocab_size
|
||||
index_l = index_f + per_partition_vocab_size
|
||||
return index_f, index_l
|
||||
|
||||
@staticmethod
|
||||
def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, world_size: int) -> Sequence[int]:
|
||||
per_partition_vocab_size = divide(global_vocab_size, world_size)
|
||||
return VocabUtility.vocab_range_from_per_partition_vocab_size(
|
||||
per_partition_vocab_size, rank, world_size
|
||||
)
|
120
cacheflow/parallel_utils/utils.py
Normal file
120
cacheflow/parallel_utils/utils.py
Normal file
@ -0,0 +1,120 @@
|
||||
"""Utility functions used throughout Megatron core"""
|
||||
from functools import reduce
|
||||
import operator
|
||||
|
||||
import torch
|
||||
|
||||
from cacheflow.parallel_utils import parallel_state
|
||||
|
||||
|
||||
def ensure_divisibility(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator."""
|
||||
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
||||
numerator, denominator
|
||||
)
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
"""Ensure that numerator is divisible by the denominator and return
|
||||
the division value."""
|
||||
ensure_divisibility(numerator, denominator)
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
class GlobalMemoryBuffer:
|
||||
"""Global buffer to avoid dynamic memory allocations.
|
||||
Caller should ensure that buffers of the same name
|
||||
are not used concurrently."""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = {}
|
||||
|
||||
def get_tensor(self, tensor_shape, dtype, name):
|
||||
required_len = reduce(operator.mul, tensor_shape, 1)
|
||||
if self.buffer.get((name, dtype), None) is None or \
|
||||
self.buffer[(name, dtype)].numel() < required_len:
|
||||
self.buffer[(name, dtype)] = \
|
||||
torch.empty(required_len,
|
||||
dtype=dtype,
|
||||
device=torch.cuda.current_device(),
|
||||
requires_grad=False)
|
||||
|
||||
return self.buffer[(name, dtype)][0:required_len].view(*tensor_shape)
|
||||
|
||||
def _kernel_make_viewless_tensor(inp, requires_grad):
|
||||
'''Make a viewless tensor.
|
||||
|
||||
View tensors have the undesirable side-affect of retaining a reference
|
||||
to the originally-viewed tensor, even after manually setting the '.data'
|
||||
field. This method creates a new tensor that links to the old tensor's
|
||||
data, without linking the viewed tensor, referenced via the '._base'
|
||||
field.
|
||||
'''
|
||||
out = torch.empty(
|
||||
(1,),
|
||||
dtype = inp.dtype,
|
||||
device = inp.device,
|
||||
requires_grad = requires_grad,
|
||||
)
|
||||
out.data = inp.data
|
||||
return out
|
||||
|
||||
class MakeViewlessTensor(torch.autograd.Function):
|
||||
'''
|
||||
Autograd function to make a viewless tensor.
|
||||
|
||||
This function should be used in cases where the computation graph needs
|
||||
to be propagated, but we only want a viewless tensor (e.g.,
|
||||
ParallelTransformer's hidden_states). Call this function by passing
|
||||
'keep_graph = True' to 'make_viewless_tensor()'.
|
||||
'''
|
||||
@staticmethod
|
||||
def forward(ctx, inp, requires_grad):
|
||||
return _kernel_make_viewless_tensor(inp, requires_grad)
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, None
|
||||
|
||||
def make_viewless_tensor(inp, requires_grad, keep_graph):
|
||||
'''
|
||||
Entry-point for creating viewless tensors.
|
||||
|
||||
This method should be used, rather than calling 'MakeViewlessTensor'
|
||||
or '_kernel_make_viewless_tensor' directly. This method acts as a
|
||||
switch for determining if an autograd function or a regular method
|
||||
should be used to create the tensor.
|
||||
'''
|
||||
|
||||
# return tensor as-is, if not a 'view'
|
||||
if inp._base is None:
|
||||
return inp
|
||||
|
||||
# create viewless tensor
|
||||
if keep_graph:
|
||||
return MakeViewlessTensor.apply(inp, requires_grad)
|
||||
else:
|
||||
return _kernel_make_viewless_tensor(inp, requires_grad)
|
||||
|
||||
def assert_viewless_tensor(tensor, extra_msg = None):
|
||||
'''Assert that a tensor is not a view (i.e., its '._base' field is
|
||||
not set).'''
|
||||
if isinstance(tensor, list):
|
||||
[ assert_viewless_tensor(t) for t in tensor ]
|
||||
return tensor
|
||||
if not isinstance(tensor, torch.Tensor):
|
||||
return tensor
|
||||
assert tensor._base is None, (
|
||||
"Ensure tensor._base is None before setting tensor.data or storing "
|
||||
"tensor to memory buffer. Otherwise, a memory leak will occur (and "
|
||||
"likely accumulate over iterations). %s"
|
||||
) % extra_msg
|
||||
return tensor
|
||||
|
||||
def safely_set_viewless_tensor_data(tensor, new_data_tensor):
|
||||
'''Safely set tensor's '.data' field.
|
||||
|
||||
Check first that the tensor is viewless (i.e., '._base' not set). If not,
|
||||
raise an exception.
|
||||
'''
|
||||
assert_viewless_tensor(tensor, extra_msg = "FYI, tensor._base has shape %s, and new_data_tensor has shape %s." % ("--" if tensor._base is None else tensor._base.shape, new_data_tensor.shape))
|
||||
tensor.data = new_data_tensor
|
@ -158,3 +158,9 @@ class SequenceOutputs:
|
||||
f'parent_seq_id={self.parent_seq_id}, '
|
||||
f'output_token={self.output_token}), '
|
||||
f'logprobs={self.logprobs}')
|
||||
|
||||
def __eq__(self, other: 'SequenceOutputs') -> bool:
|
||||
return (self.seq_id == other.seq_id and
|
||||
self.parent_seq_id == other.parent_seq_id and
|
||||
self.output_token == other.output_token and
|
||||
self.logprobs == other.logprobs)
|
||||
|
@ -1,4 +1,11 @@
|
||||
import enum
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from cacheflow.parallel_utils.parallel_state import model_parallel_is_initialized
|
||||
from cacheflow.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
|
||||
|
||||
|
||||
class Device(enum.Enum):
|
||||
@ -18,3 +25,13 @@ class Counter:
|
||||
|
||||
def reset(self) -> None:
|
||||
self.counter = 0
|
||||
|
||||
def set_random_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
if model_parallel_is_initialized():
|
||||
model_parallel_cuda_manual_seed(seed)
|
||||
|
@ -11,7 +11,6 @@ class CacheEngine:
|
||||
def __init__(
|
||||
self,
|
||||
worker_id: int,
|
||||
gpu_id: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
@ -25,7 +24,6 @@ class CacheEngine:
|
||||
f'head_size ({head_size}) must be a multiple of 16.')
|
||||
|
||||
self.worker_id = worker_id
|
||||
self.gpu_id = gpu_id
|
||||
self.num_layers = num_layers
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
@ -39,8 +37,8 @@ class CacheEngine:
|
||||
self.cpu_cache = self.allocate_cpu_cache()
|
||||
|
||||
# Initialize the stream for caching operations.
|
||||
self.cache_stream = torch.cuda.Stream(device=gpu_id)
|
||||
assert self.cache_stream != torch.cuda.current_stream(device=gpu_id)
|
||||
self.cache_stream = torch.cuda.Stream()
|
||||
assert self.cache_stream != torch.cuda.current_stream()
|
||||
# Initialize the events for stream synchronization.
|
||||
self.events = [torch.cuda.Event() for _ in range(num_layers)]
|
||||
|
||||
@ -69,12 +67,12 @@ class CacheEngine:
|
||||
key_blocks = torch.empty(
|
||||
size=(self.num_gpu_blocks, *key_block_shape),
|
||||
dtype=self.dtype,
|
||||
device=self.gpu_id,
|
||||
device="cuda",
|
||||
)
|
||||
value_blocks = torch.empty(
|
||||
size=(self.num_gpu_blocks, *value_block_shape),
|
||||
dtype=self.dtype,
|
||||
device=self.gpu_id,
|
||||
device="cuda",
|
||||
)
|
||||
gpu_cache.append((key_blocks, value_blocks))
|
||||
return gpu_cache
|
||||
|
@ -1,45 +1,62 @@
|
||||
from typing import Dict, List, Union
|
||||
from typing import Dict, List, Union, Tuple
|
||||
|
||||
import ray
|
||||
|
||||
from cacheflow.master.scheduler import Scheduler
|
||||
from cacheflow.sequence import SequenceGroupInputs
|
||||
from cacheflow.worker.worker import Worker
|
||||
|
||||
|
||||
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id
|
||||
|
||||
|
||||
class Controller:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_id: int,
|
||||
num_workers: int,
|
||||
stage_id: int,
|
||||
stage_devices: List[DeviceID],
|
||||
world_size: int,
|
||||
tensor_parallel_size: int,
|
||||
pipeline_parallel_size: int,
|
||||
distributed_init_method: str,
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
dtype: str,
|
||||
seed: int,
|
||||
model_path: str,
|
||||
) -> None:
|
||||
self.node_id = node_id
|
||||
self.num_workers = num_workers
|
||||
self.stage_id = stage_id
|
||||
self.stage_devices = stage_devices
|
||||
self.model_name = model_name
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
# Which pipeline stage is this node assigned to?
|
||||
self.is_first_stage = node_id == 0
|
||||
self.is_first_stage = stage_id == 0
|
||||
self.is_last_stage = False
|
||||
|
||||
self.workers: List[Worker] = []
|
||||
for i in range(num_workers):
|
||||
worker = Worker(
|
||||
worker_id=node_id + i,
|
||||
gpu_id=i,
|
||||
for rank, node_resource, device_id in stage_devices:
|
||||
worker_cls = ray.remote(num_cpus=0,
|
||||
num_gpus=1,
|
||||
resources={node_resource: 1e-5})(Worker)
|
||||
worker = worker_cls.remote(
|
||||
model_name=model_name,
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
distributed_init_method=distributed_init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
model_path=model_path,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
|
||||
@ -57,15 +74,21 @@ class Controller:
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
# FIXME: Support tensor parallelism.
|
||||
assert len(self.workers) == 1
|
||||
worker = self.workers[0]
|
||||
output = worker.execute_stage(
|
||||
input_seq_groups,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
)
|
||||
futures = []
|
||||
for worker in self.workers:
|
||||
future = worker.execute_stage.remote(
|
||||
input_seq_groups,
|
||||
blocks_to_swap_in,
|
||||
blocks_to_swap_out,
|
||||
blocks_to_copy,
|
||||
)
|
||||
futures.append(future)
|
||||
|
||||
all_outputs = ray.get(futures)
|
||||
# Make sure all workers have the same results.
|
||||
output = all_outputs[0]
|
||||
for other_output in all_outputs[1:]:
|
||||
assert output == other_output
|
||||
|
||||
if self.is_last_stage:
|
||||
self.next_node.post_step(output)
|
||||
|
@ -3,49 +3,58 @@ from typing import Dict, List, Tuple
|
||||
import torch
|
||||
|
||||
from cacheflow.models import get_model
|
||||
from cacheflow.models import set_seed
|
||||
from cacheflow.models import InputMetadata
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import SequenceGroupInputs
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from cacheflow.worker.cache_engine import CacheEngine
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.utils import set_random_seed
|
||||
|
||||
|
||||
class Worker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker_id: int,
|
||||
gpu_id: int,
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
dtype: str,
|
||||
seed: int,
|
||||
distributed_init_method: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
model_path: str,
|
||||
tensor_parallel_size: int = 1,
|
||||
pipeline_parallel_size: int = 1,
|
||||
) -> None:
|
||||
self.worker_id = worker_id
|
||||
self.gpu_id = gpu_id
|
||||
self.init_distributed_environment(distributed_init_method,
|
||||
rank,
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
pipeline_parallel_size)
|
||||
self.worker_id = rank
|
||||
self.block_size = block_size
|
||||
|
||||
self.device = torch.device('cuda', index=gpu_id)
|
||||
set_random_seed(seed)
|
||||
|
||||
# Initialize the model.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
self.model = get_model(model_name, dtype=dtype).to(device=self.device)
|
||||
self.model, self.dtype = get_model(model_name, dtype=dtype, path=model_path)
|
||||
self.model = self.model.cuda()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
self.num_layers = self.model.config.num_hidden_layers
|
||||
self.num_heads = self.model.config.num_attention_heads
|
||||
self.head_size = self.model.config.hidden_size // self.num_heads
|
||||
self.dtype = self.model.dtype
|
||||
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
|
||||
self.head_size = self.model.config.hidden_size // (self.num_heads * tensor_model_parallel_world_size)
|
||||
|
||||
# Set the seed.
|
||||
# We set the seed after initializing the model to ensure that
|
||||
# We reset the seed after initializing the model to ensure that
|
||||
# the random state is not affected by the model initialization.
|
||||
set_seed(seed)
|
||||
set_random_seed(seed)
|
||||
|
||||
self.cache_engine = CacheEngine(
|
||||
worker_id=worker_id,
|
||||
gpu_id=gpu_id,
|
||||
worker_id=self.worker_id,
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads,
|
||||
head_size=self.head_size,
|
||||
@ -57,6 +66,26 @@ class Worker:
|
||||
self.cache_events = self.cache_engine.events
|
||||
self.gpu_cache = self.cache_engine.gpu_cache
|
||||
|
||||
|
||||
def init_distributed_environment(self,
|
||||
distributed_init_method: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
pipeline_parallel_size: int = 1) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
torch.distributed.init_process_group(
|
||||
backend='nccl',
|
||||
init_method=distributed_init_method,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
)
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cuda())
|
||||
initialize_model_parallel(tensor_parallel_size,
|
||||
pipeline_parallel_size)
|
||||
|
||||
|
||||
def prepare_inputs(
|
||||
self,
|
||||
input_seq_groups: List[SequenceGroupInputs],
|
||||
@ -142,18 +171,18 @@ class Worker:
|
||||
|
||||
# Convert to tensors.
|
||||
tokens_tensor = torch.tensor(
|
||||
input_tokens, dtype=torch.long, device=self.device)
|
||||
input_tokens, dtype=torch.long, device='cuda')
|
||||
positions_tensor = torch.tensor(
|
||||
input_positions, dtype=torch.long, device=self.device)
|
||||
input_positions, dtype=torch.long, device='cuda')
|
||||
slot_mapping_tensor = torch.tensor(
|
||||
slot_mapping, dtype=torch.int, device=self.device)
|
||||
slot_mapping, dtype=torch.int, device='cuda')
|
||||
context_lens_tensor = torch.tensor(
|
||||
context_lens, dtype=torch.int, device=self.device)
|
||||
context_lens, dtype=torch.int, device='cuda')
|
||||
padded_block_tables = [
|
||||
_pad_to_max(block_table, max_num_blocks_per_seq)
|
||||
for block_table in generation_block_tables]
|
||||
block_tables_tensor = torch.tensor(
|
||||
padded_block_tables, dtype=torch.int, device=self.device)
|
||||
padded_block_tables, dtype=torch.int, device='cuda')
|
||||
|
||||
input_metadata = InputMetadata(
|
||||
seq_groups=seq_groups,
|
||||
|
134
server.py
134
server.py
@ -1,30 +1,99 @@
|
||||
import argparse
|
||||
from typing import List
|
||||
import random
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
import ray
|
||||
|
||||
from cacheflow.master.frontend import Frontend
|
||||
from cacheflow.master.scheduler import Scheduler
|
||||
from cacheflow.models import get_memory_analyzer
|
||||
from cacheflow.worker.controller import Controller
|
||||
|
||||
parser = argparse.ArgumentParser(description='CacheFlow server')
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
|
||||
parser.add_argument('--num-nodes', type=int, default=1, help='number of nodes')
|
||||
parser.add_argument('--num-workers', type=int, default=1, help='number of workers per node')
|
||||
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
|
||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
||||
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens')
|
||||
args = parser.parse_args()
|
||||
from cacheflow.worker.controller import Controller, DeviceID
|
||||
|
||||
|
||||
def main():
|
||||
def initialize_ray_cluster(
|
||||
address: str = 'auto',
|
||||
pipeline_parallel_size: int = 1,
|
||||
tensor_parallel_size: int = 1,
|
||||
) -> Tuple[int, int, str, List[List[DeviceID]]]:
|
||||
# Connect to a ray cluster.
|
||||
ray.init(address=address)
|
||||
|
||||
# Assume we have a uniform cluster that each node has the same number of
|
||||
# GPUs for now.
|
||||
valid_node_resources = []
|
||||
num_devices_per_node = None
|
||||
for node in ray.nodes():
|
||||
if (not node['Alive']) or node['Resources']['GPU'] <= 0:
|
||||
continue
|
||||
if num_devices_per_node is None:
|
||||
num_devices_per_node = node['Resources']['GPU']
|
||||
else:
|
||||
assert num_devices_per_node == node['Resources']['GPU'], (
|
||||
"The number of GPUs per node is not uniform.")
|
||||
for key in node['Resources']:
|
||||
if key.startswith('node:'):
|
||||
valid_node_resources.append(key)
|
||||
|
||||
num_nodes = len(valid_node_resources)
|
||||
|
||||
assert (pipeline_parallel_size * tensor_parallel_size
|
||||
<= num_nodes * num_devices_per_node), (
|
||||
"The number of required GPUs exceeds the total number of "
|
||||
"available GPUs.")
|
||||
if tensor_parallel_size >= num_devices_per_node:
|
||||
assert tensor_parallel_size % num_devices_per_node == 0, (
|
||||
"The number of tensor parallelism is not divisible by the "
|
||||
"number of GPUs per node.")
|
||||
else:
|
||||
assert num_devices_per_node % tensor_parallel_size == 0, (
|
||||
"The number of GPUs per node is not divisible by the number "
|
||||
"of tensor parallelism.")
|
||||
|
||||
# Assign GPUs to pipeline stages.
|
||||
rank = 0
|
||||
current_node_id = 0
|
||||
current_device_id = 0
|
||||
distributed_init_method = None
|
||||
all_stage_devices = []
|
||||
|
||||
for i in range(pipeline_parallel_size):
|
||||
stage_devices = []
|
||||
for j in range(tensor_parallel_size):
|
||||
node_resource = valid_node_resources[current_node_id]
|
||||
stage_devices.append((rank, node_resource, current_device_id))
|
||||
if distributed_init_method is None:
|
||||
ip = node_resource.split("node:")[-1]
|
||||
port = random.randint(10000, 20000)
|
||||
distributed_init_method = f"tcp://{ip}:{port}"
|
||||
rank += 1
|
||||
current_device_id += 1
|
||||
if current_device_id >= num_devices_per_node:
|
||||
current_node_id += 1
|
||||
current_device_id = 0
|
||||
all_stage_devices.append(stage_devices)
|
||||
|
||||
return (num_nodes, num_devices_per_node, distributed_init_method,
|
||||
all_stage_devices)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
# TODO(zhuohan): Support pipeline parallelism.
|
||||
assert args.pipeline_parallel_size == 1, (
|
||||
'Pipeline parallelism is not supported yet.')
|
||||
|
||||
(num_nodes, num_devices_per_node, distributed_init_method,
|
||||
all_stage_devices) = (
|
||||
initialize_ray_cluster(
|
||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||
tensor_parallel_size=args.tensor_parallel_size))
|
||||
|
||||
world_size = args.pipeline_parallel_size * args.tensor_parallel_size
|
||||
|
||||
memory_analyzer = get_memory_analyzer(
|
||||
model_name=args.model,
|
||||
block_size=args.block_size,
|
||||
dtype=args.dtype,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
)
|
||||
num_gpu_blocks = memory_analyzer.get_max_num_gpu_blocks(
|
||||
max_num_batched_tokens=args.max_batch_size)
|
||||
@ -32,18 +101,23 @@ def main():
|
||||
swap_space=args.swap_space)
|
||||
print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}')
|
||||
|
||||
# Create a controller for each node.
|
||||
# Create a controller for each pipeline stage.
|
||||
controllers: List[Controller] = []
|
||||
for i in range(args.num_nodes):
|
||||
for i in range(args.pipeline_parallel_size):
|
||||
controller = Controller(
|
||||
node_id=i,
|
||||
num_workers=args.num_workers,
|
||||
stage_id=i,
|
||||
stage_devices=all_stage_devices[i],
|
||||
world_size=world_size,
|
||||
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||
tensor_parallel_size=args.tensor_parallel_size,
|
||||
distributed_init_method=distributed_init_method,
|
||||
model_name=args.model,
|
||||
block_size=args.block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
dtype=args.dtype,
|
||||
seed=args.seed,
|
||||
model_path=args.model_path,
|
||||
)
|
||||
controllers.append(controller)
|
||||
|
||||
@ -83,4 +157,22 @@ def main():
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
parser = argparse.ArgumentParser(description='CacheFlow server')
|
||||
# Model arguments
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name')
|
||||
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
|
||||
help='model path to download and load the weights')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--pipeline-parallel-size', type=int, default=1, help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size', type=int, default=1, help='number of tensor parallel replicas')
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
|
||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
||||
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens')
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
|
Loading…
x
Reference in New Issue
Block a user