Add CUDA graph-based all reduce launcher (#26)
This commit is contained in:
parent
21b3671bbc
commit
12659a0bd7
@ -35,7 +35,7 @@ def main(args: argparse.Namespace):
|
|||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
swap_space=args.swap_space,
|
swap_space=args.swap_space,
|
||||||
max_batch_size=args.max_batch_size,
|
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||||
num_nodes=num_nodes,
|
num_nodes=num_nodes,
|
||||||
num_devices_per_node=num_devices_per_node,
|
num_devices_per_node=num_devices_per_node,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
@ -94,6 +94,7 @@ if __name__ == '__main__':
|
|||||||
parser.add_argument('--output-len', type=int, default=128)
|
parser.add_argument('--output-len', type=int, default=128)
|
||||||
parser.add_argument('--batch-size', type=int, default=8)
|
parser.add_argument('--batch-size', type=int, default=8)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
args.max_batch_size = max(args.max_batch_size, args.batch_size * args.input_len)
|
args.max_num_batched_tokens = max(
|
||||||
|
args.max_num_batched_tokens, args.batch_size * args.input_len)
|
||||||
print(args)
|
print(args)
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -22,7 +22,7 @@ class Server:
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
swap_space: int,
|
swap_space: int,
|
||||||
max_batch_size: int,
|
max_num_batched_tokens: int,
|
||||||
num_nodes: int,
|
num_nodes: int,
|
||||||
num_devices_per_node: int,
|
num_devices_per_node: int,
|
||||||
distributed_init_method: str,
|
distributed_init_method: str,
|
||||||
@ -43,7 +43,7 @@ class Server:
|
|||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
)
|
)
|
||||||
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
|
self.num_gpu_blocks = self.memory_analyzer.get_max_num_gpu_blocks(
|
||||||
max_num_batched_tokens=max_batch_size)
|
max_num_batched_tokens=max_num_batched_tokens)
|
||||||
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
|
self.num_cpu_blocks = self.memory_analyzer.get_max_num_cpu_blocks(
|
||||||
swap_space=swap_space)
|
swap_space=swap_space)
|
||||||
print(f'# GPU blocks: {self.num_gpu_blocks}, '
|
print(f'# GPU blocks: {self.num_gpu_blocks}, '
|
||||||
@ -66,6 +66,7 @@ class Server:
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
)
|
)
|
||||||
self.controllers.append(controller)
|
self.controllers.append(controller)
|
||||||
|
|
||||||
@ -75,7 +76,7 @@ class Server:
|
|||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_gpu_blocks=self.num_gpu_blocks,
|
num_gpu_blocks=self.num_gpu_blocks,
|
||||||
num_cpu_blocks=self.num_cpu_blocks,
|
num_cpu_blocks=self.num_cpu_blocks,
|
||||||
max_num_batched_tokens=max_batch_size,
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
)
|
)
|
||||||
# Connect the controllers.
|
# Connect the controllers.
|
||||||
for i in range(len(self.controllers) - 1):
|
for i in range(len(self.controllers) - 1):
|
||||||
@ -168,8 +169,8 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
|
parser.add_argument('--model-path', type=str, default='~/.cacheflow/model_weights',
|
||||||
help='model path to download and load the weights')
|
help='model path to download and load the weights')
|
||||||
# Parallel arguments
|
# Parallel arguments
|
||||||
parser.add_argument('--pipeline-parallel-size', type=int, default=1, help='number of pipeline stages')
|
parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages')
|
||||||
parser.add_argument('--tensor-parallel-size', type=int, default=1, help='number of tensor parallel replicas')
|
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas')
|
||||||
# KV cache arguments
|
# KV cache arguments
|
||||||
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
|
parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='token block size')
|
||||||
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
# NOTE(woosuk): If FlashAttention is used, the float data type is not supported.
|
||||||
@ -177,5 +178,5 @@ def add_server_arguments(parser: argparse.ArgumentParser):
|
|||||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||||
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
||||||
parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens')
|
parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens')
|
||||||
return parser
|
return parser
|
||||||
|
@ -47,6 +47,7 @@ _DATA_PARALLEL_GLOBAL_RANKS = None
|
|||||||
# Memory buffers to avoid dynamic memory allocation
|
# Memory buffers to avoid dynamic memory allocation
|
||||||
_GLOBAL_MEMORY_BUFFER = None
|
_GLOBAL_MEMORY_BUFFER = None
|
||||||
|
|
||||||
|
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
|
||||||
|
|
||||||
def initialize_model_parallel(
|
def initialize_model_parallel(
|
||||||
tensor_model_parallel_size: int = 1,
|
tensor_model_parallel_size: int = 1,
|
||||||
@ -205,6 +206,20 @@ def initialize_model_parallel(
|
|||||||
_set_global_memory_buffer()
|
_set_global_memory_buffer()
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_all_reduce_launcher(
|
||||||
|
max_num_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
disable_graph: bool = False,
|
||||||
|
) -> None:
|
||||||
|
global _ALL_REDUCE_LAUNCHER
|
||||||
|
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
|
||||||
|
max_num_tokens=max_num_tokens,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
dtype=dtype,
|
||||||
|
disable_graph=disable_graph,
|
||||||
|
)
|
||||||
|
|
||||||
def model_parallel_is_initialized():
|
def model_parallel_is_initialized():
|
||||||
"""Check if model and data parallel groups are initialized."""
|
"""Check if model and data parallel groups are initialized."""
|
||||||
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
|
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
|
||||||
@ -491,6 +506,9 @@ def get_global_memory_buffer():
|
|||||||
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
|
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
|
||||||
return _GLOBAL_MEMORY_BUFFER
|
return _GLOBAL_MEMORY_BUFFER
|
||||||
|
|
||||||
|
def get_all_reduce_launcher() -> 'GraphAllReduce':
|
||||||
|
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
|
||||||
|
return _ALL_REDUCE_LAUNCHER
|
||||||
|
|
||||||
def destroy_model_parallel():
|
def destroy_model_parallel():
|
||||||
"""Set the groups to none."""
|
"""Set the groups to none."""
|
||||||
@ -520,3 +538,56 @@ def destroy_model_parallel():
|
|||||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||||
global _GLOBAL_MEMORY_BUFFER
|
global _GLOBAL_MEMORY_BUFFER
|
||||||
_GLOBAL_MEMORY_BUFFER = None
|
_GLOBAL_MEMORY_BUFFER = None
|
||||||
|
|
||||||
|
|
||||||
|
class GraphAllReduce:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_num_tokens: int,
|
||||||
|
hidden_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
disable_graph: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.max_num_tokens = max_num_tokens
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.disable_graph = disable_graph
|
||||||
|
|
||||||
|
tp_world_size = get_tensor_model_parallel_world_size()
|
||||||
|
if tp_world_size == 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.group = get_tensor_model_parallel_group()
|
||||||
|
self.buffer = torch.empty(
|
||||||
|
size=(max_num_tokens, hidden_size),
|
||||||
|
dtype=dtype,
|
||||||
|
device='cuda',
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build graphs for different number of tokens.
|
||||||
|
if not self.disable_graph:
|
||||||
|
self.graphs = {}
|
||||||
|
for num_tokens in range(8, max_num_tokens + 1, 8):
|
||||||
|
self.graphs[num_tokens] = self._build_graph(num_tokens)
|
||||||
|
|
||||||
|
def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
|
||||||
|
# Warm up.
|
||||||
|
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Build graph.
|
||||||
|
graph = torch.cuda.CUDAGraph()
|
||||||
|
with torch.cuda.graph(graph):
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
self.buffer[:num_tokens], group=self.group)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
return graph
|
||||||
|
|
||||||
|
def launch(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
# NOTE: x must be a slice of self.buffer.
|
||||||
|
num_tokens = x.shape[0]
|
||||||
|
if self.disable_graph:
|
||||||
|
torch.distributed.all_reduce(x, group=self.group)
|
||||||
|
else:
|
||||||
|
self.graphs[num_tokens].replay()
|
||||||
|
return x
|
||||||
|
@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from cacheflow.parallel_utils.parallel_state import (
|
from cacheflow.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
|
get_all_reduce_launcher,
|
||||||
)
|
)
|
||||||
from .mappings import (
|
from .mappings import (
|
||||||
copy_to_tensor_model_parallel_region,
|
copy_to_tensor_model_parallel_region,
|
||||||
@ -407,8 +408,7 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
self.bias.zero_()
|
self.bias.zero_()
|
||||||
else:
|
else:
|
||||||
self.register_parameter('bias', None)
|
self.register_parameter('bias', None)
|
||||||
|
self.weight_t = self.weight.t()
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
"""Forward of RowParallelLinear
|
"""Forward of RowParallelLinear
|
||||||
@ -425,11 +425,18 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
else:
|
else:
|
||||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||||
# Matrix multiply.
|
if get_tensor_model_parallel_world_size() == 1:
|
||||||
output_parallel = F.linear(input_parallel, self.weight)
|
# Matrix multiply.
|
||||||
|
output_ = F.linear(input_parallel, self.weight)
|
||||||
|
else:
|
||||||
|
# Matrix multiply.
|
||||||
|
all_reduce_launcher = get_all_reduce_launcher()
|
||||||
|
num_tokens = input_parallel.shape[0]
|
||||||
|
output_buffer = all_reduce_launcher.buffer[:num_tokens]
|
||||||
|
torch.matmul(input_parallel, self.weight_t, out=output_buffer)
|
||||||
|
# All-reduce across all the partitions.
|
||||||
|
output_ = all_reduce_launcher.launch(output_buffer)
|
||||||
|
|
||||||
# All-reduce across all the partitions.
|
|
||||||
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
|
||||||
if not self.skip_bias_add:
|
if not self.skip_bias_add:
|
||||||
output = output_ + self.bias if self.bias is not None else output_
|
output = output_ + self.bias if self.bias is not None else output_
|
||||||
output_bias = None
|
output_bias = None
|
||||||
|
@ -27,6 +27,7 @@ class Controller:
|
|||||||
dtype: str,
|
dtype: str,
|
||||||
seed: int,
|
seed: int,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.stage_id = stage_id
|
self.stage_id = stage_id
|
||||||
self.stage_devices = stage_devices
|
self.stage_devices = stage_devices
|
||||||
@ -57,6 +58,7 @@ class Controller:
|
|||||||
tensor_parallel_size=tensor_parallel_size,
|
tensor_parallel_size=tensor_parallel_size,
|
||||||
pipeline_parallel_size=pipeline_parallel_size,
|
pipeline_parallel_size=pipeline_parallel_size,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
|
max_num_batched_tokens=max_num_batched_tokens,
|
||||||
)
|
)
|
||||||
self.workers.append(worker)
|
self.workers.append(worker)
|
||||||
|
|
||||||
|
@ -9,7 +9,9 @@ from cacheflow.sequence import SequenceGroupInputs
|
|||||||
from cacheflow.sequence import SequenceOutputs
|
from cacheflow.sequence import SequenceOutputs
|
||||||
from cacheflow.worker.cache_engine import CacheEngine
|
from cacheflow.worker.cache_engine import CacheEngine
|
||||||
from cacheflow.parallel_utils.parallel_state import (
|
from cacheflow.parallel_utils.parallel_state import (
|
||||||
initialize_model_parallel, get_tensor_model_parallel_world_size)
|
initialize_model_parallel,
|
||||||
|
initialize_all_reduce_launcher,
|
||||||
|
get_tensor_model_parallel_world_size)
|
||||||
from cacheflow.utils import set_random_seed
|
from cacheflow.utils import set_random_seed
|
||||||
|
|
||||||
|
|
||||||
@ -27,6 +29,7 @@ class Worker:
|
|||||||
rank: int,
|
rank: int,
|
||||||
world_size: int,
|
world_size: int,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
|
max_num_batched_tokens: int,
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
pipeline_parallel_size: int = 1,
|
pipeline_parallel_size: int = 1,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -44,6 +47,8 @@ class Worker:
|
|||||||
self.model = self.model.cuda()
|
self.model = self.model.cuda()
|
||||||
tensor_model_parallel_world_size = (
|
tensor_model_parallel_world_size = (
|
||||||
get_tensor_model_parallel_world_size())
|
get_tensor_model_parallel_world_size())
|
||||||
|
initialize_all_reduce_launcher(
|
||||||
|
max_num_batched_tokens, self.model.config.hidden_size, self.dtype)
|
||||||
self.num_layers = self.model.config.num_hidden_layers
|
self.num_layers = self.model.config.num_hidden_layers
|
||||||
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
|
assert self.model.config.num_attention_heads % tensor_model_parallel_world_size == 0
|
||||||
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
|
self.num_heads = self.model.config.num_attention_heads // tensor_model_parallel_world_size
|
||||||
|
@ -28,7 +28,7 @@ def main(args: argparse.Namespace):
|
|||||||
dtype=args.dtype,
|
dtype=args.dtype,
|
||||||
seed=args.seed,
|
seed=args.seed,
|
||||||
swap_space=args.swap_space,
|
swap_space=args.swap_space,
|
||||||
max_batch_size=args.max_batch_size,
|
max_num_batched_tokens=args.max_num_batched_tokens,
|
||||||
num_nodes=num_nodes,
|
num_nodes=num_nodes,
|
||||||
num_devices_per_node=num_devices_per_node,
|
num_devices_per_node=num_devices_per_node,
|
||||||
distributed_init_method=distributed_init_method,
|
distributed_init_method=distributed_init_method,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user