Add CUDA graph-based all reduce launcher (#26)
This commit is contained in:
parent
21b3671bbc
commit
12659a0bd7
@ -35,7 +35,7 @@ def main(args: argparse.Namespace):
|
||||
dtype=args.dtype,
|
||||
seed=args.seed,
|
||||
swap_space=args.swap_space,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||
num_nodes=num_nodes,
|
||||
num_devices_per_node=num_devices_per_node,
|
||||
distributed_init_method=distributed_init_method,
|
||||
@ -94,6 +94,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--output-len', type=int, default=128)
|
||||
parser.add_argument('--batch-size', type=int, default=8)
|
||||
args = parser.parse_args()
|
||||
args.max_batch_size = max(args.max_batch_size, args.batch_size * args.input_len)
|
||||
args.max_num_batched_tokens = max(
|
||||
args.max_num_batched_tokens, args.batch_size * args.input_len)
|
||||
print(args)
|
||||
main(args)
|
||||
|
@ -22,7 +22,7 @@ class Server:
|
||||
dtype: str,
|
||||
seed: int,
|
||||
swap_space: int,
|
||||
max_batch_size: int,
|
||||
max_num_batched_tokens: int,
|
||||
num_nodes: int,
|
||||
num_devices_per_node: int,
|
||||
distributed_init_method: str,
|
||||
@ -43,7 +43,7 @@ class Server:
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
)
|
||||
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
|
||||
max_num_batched_tokens=max_batch_size)
|
||||
max_num_batched_tokens=max_num_batched_tokens)
|
||||
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
|
||||
swap_space=swap_space)
|
||||
print(f'# GPU blocks: {self.num_gpu_blocks}, '
|
||||
@ -66,6 +66,7 @@ class Server:
|
||||
dtype=dtype,
|
||||
seed=seed,
|
||||
model_path=model_path,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
self.controllers.append(controller)
|
||||
|
||||
@ -75,7 +76,7 @@ class Server:
|
||||
block_size=block_size,
|
||||
num_gpu_blocks=self.num_gpu_blocks,
|
||||
num_cpu_blocks=self.num_cpu_blocks,
|
||||
max_num_batched_tokens=max_batch_size,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
# Connect the controllers.
|
||||
for i in range(len(self.controllers) - 1):
|
||||
@ -168,8 +169,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
|
||||
help='model path to download and load the weights')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--pipeline-parallel-size', type=int, default=1, help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size', type=int, default=1, help='number of tensor parallel replicas')
|
||||
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
|
||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
||||
@ -177,5 +178,5 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens')
|
||||
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens')
|
||||
return parser
|
||||
|
@ -47,6 +47,7 @@ _DATA_PARALLEL_GLOBAL_RANKS = None
|
||||
# Memory buffers to avoid dynamic memory allocation
|
||||
_GLOBAL_MEMORY_BUFFER = None
|
||||
|
||||
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
|
||||
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
@ -205,6 +206,20 @@ def initialize_model_parallel(
|
||||
_set_global_memory_buffer()
|
||||
|
||||
|
||||
def initialize_all_reduce_launcher(
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
disable_graph: bool = False,
|
||||
) -> None:
|
||||
global _ALL_REDUCE_LAUNCHER
|
||||
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
|
||||
max_num_tokens=max_num_tokens,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
disable_graph=disable_graph,
|
||||
)
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if model and data parallel groups are initialized."""
|
||||
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
|
||||
@ -491,6 +506,9 @@ def get_global_memory_buffer():
|
||||
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
|
||||
return _GLOBAL_MEMORY_BUFFER
|
||||
|
||||
def get_all_reduce_launcher() -> 'GraphAllReduce':
|
||||
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
|
||||
return _ALL_REDUCE_LAUNCHER
|
||||
|
||||
def destroy_model_parallel():
|
||||
"""Set the groups to none."""
|
||||
@ -520,3 +538,56 @@ def destroy_model_parallel():
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
global _GLOBAL_MEMORY_BUFFER
|
||||
_GLOBAL_MEMORY_BUFFER = None
|
||||
|
||||
|
||||
class GraphAllReduce:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
disable_graph: bool = False,
|
||||
) -> None:
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.hidden_size = hidden_size
|
||||
self.disable_graph = disable_graph
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
if tp_world_size == 1:
|
||||
return
|
||||
|
||||
self.group = get_tensor_model_parallel_group()
|
||||
self.buffer = torch.empty(
|
||||
size=(max_num_tokens, hidden_size),
|
||||
dtype=dtype,
|
||||
device='cuda',
|
||||
)
|
||||
|
||||
# Build graphs for different number of tokens.
|
||||
if not self.disable_graph:
|
||||
self.graphs = {}
|
||||
for num_tokens in range(8, max_num_tokens + 1, 8):
|
||||
self.graphs[num_tokens] = self._build_graph(num_tokens)
|
||||
|
||||
def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
|
||||
# Warm up.
|
||||
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Build graph.
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
torch.distributed.all_reduce(
|
||||
self.buffer[:num_tokens], group=self.group)
|
||||
torch.cuda.synchronize()
|
||||
return graph
|
||||
|
||||
def launch(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: x must be a slice of self.buffer.
|
||||
num_tokens = x.shape[0]
|
||||
if self.disable_graph:
|
||||
torch.distributed.all_reduce(x, group=self.group)
|
||||
else:
|
||||
self.graphs[num_tokens].replay()
|
||||
return x
|
||||
|
@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_all_reduce_launcher,
|
||||
)
|
||||
from .mappings import (
|
||||
copy_to_tensor_model_parallel_region,
|
||||
@ -407,8 +408,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
self.bias.zero_()
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
|
||||
|
||||
self.weight_t = self.weight.t()
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of RowParallelLinear
|
||||
@ -425,11 +425,18 @@ class RowParallelLinear(torch.nn.Module):
|
||||
input_parallel = input_
|
||||
else:
|
||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||
# Matrix multiply.
|
||||
output_parallel = F.linear(input_parallel, self.weight)
|
||||
if get_tensor_model_parallel_world_size() == 1:
|
||||
# Matrix multiply.
|
||||
output_ = F.linear(input_parallel, self.weight)
|
||||
else:
|
||||
# Matrix multiply.
|
||||
all_reduce_launcher = get_all_reduce_launcher()
|
||||
num_tokens = input_parallel.shape[0]
|
||||
output_buffer = all_reduce_launcher.buffer[:num_tokens]
|
||||
torch.matmul(input_parallel, self.weight_t, out=output_buffer)
|
||||
# All-reduce across all the partitions.
|
||||
output_ = all_reduce_launcher.launch(output_buffer)
|
||||
|
||||
# All-reduce across all the partitions.
|
||||
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||
if not self.skip_bias_add:
|
||||
output = output_ + self.bias if self.bias is not None else output_
|
||||
output_bias = None
|
||||
|
@ -27,6 +27,7 @@ class Controller:
|
||||
dtype: str,
|
||||
seed: int,
|
||||
model_path: str,
|
||||
max_num_batched_tokens: int,
|
||||
) -> None:
|
||||
self.stage_id = stage_id
|
||||
self.stage_devices = stage_devices
|
||||
@ -57,6 +58,7 @@ class Controller:
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
pipeline_parallel_size=pipeline_parallel_size,
|
||||
model_path=model_path,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
)
|
||||
self.workers.append(worker)
|
||||
|
||||
|
@ -9,7 +9,9 @@ from cacheflow.sequence import SequenceGroupInputs
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from cacheflow.worker.cache_engine import CacheEngine
|
||||
from cacheflow.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel, get_tensor_model_parallel_world_size)
|
||||
initialize_model_parallel,
|
||||
initialize_all_reduce_launcher,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from cacheflow.utils import set_random_seed
|
||||
|
||||
|
||||
@ -27,6 +29,7 @@ class Worker:
|
||||
rank: int,
|
||||
world_size: int,
|
||||
model_path: str,
|
||||
max_num_batched_tokens: int,
|
||||
tensor_parallel_size: int = 1,
|
||||
pipeline_parallel_size: int = 1,
|
||||
) -> None:
|
||||
@ -44,6 +47,8 @@ class Worker:
|
||||
self.model = self.model.cuda()
|
||||
tensor_model_parallel_world_size = (
|
||||
get_tensor_model_parallel_world_size())
|
||||
initialize_all_reduce_launcher(
|
||||
max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
|
||||
self.num_layers = self.model.config.num_hidden_layers
|
||||
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
|
||||
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
|
||||
|
@ -28,7 +28,7 @@ def main(args: argparse.Namespace):
|
||||
dtype=args.dtype,
|
||||
seed=args.seed,
|
||||
swap_space=args.swap_space,
|
||||
max_batch_size=args.max_batch_size,
|
||||
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||
num_nodes=num_nodes,
|
||||
num_devices_per_node=num_devices_per_node,
|
||||
distributed_init_method=distributed_init_method,
|
||||
|
Loading…
x
Reference in New Issue
Block a user