diff --git a/benchmark/benchmark_latency.py b/benchmark/benchmark_latency.py new file mode 100644 index 00000000..a18ef98f --- /dev/null +++ b/benchmark/benchmark_latency.py @@ -0,0 +1,99 @@ +import argparse +import time +from typing import List + +from tqdm import tqdm +import numpy as np +import torch + +from cacheflow.master.simple_frontend import SimpleFrontend +from cacheflow.master.server import (Server, add_server_arguments, + initialize_ray_cluster) +from cacheflow.sampling_params import SamplingParams +from cacheflow.utils import get_gpu_memory, get_cpu_memory + + +def main(args: argparse.Namespace): + # TODO(zhuohan): Support pipeline parallelism. + assert args.pipeline_parallel_size == 1, ( + 'Pipeline parallelism is not supported yet.') + + (num_nodes, num_devices_per_node, distributed_init_method, + all_stage_devices) = ( + initialize_ray_cluster( + address='local', + pipeline_parallel_size=args.pipeline_parallel_size, + tensor_parallel_size=args.tensor_parallel_size)) + + # Create a server. + server = Server( + model=args.model, + model_path=args.model_path, + pipeline_parallel_size=args.pipeline_parallel_size, + tensor_parallel_size=args.tensor_parallel_size, + block_size=args.block_size, + dtype=args.dtype, + seed=args.seed, + swap_space=args.swap_space, + max_batch_size=args.max_batch_size, + num_nodes=num_nodes, + num_devices_per_node=num_devices_per_node, + distributed_init_method=distributed_init_method, + all_stage_devices=all_stage_devices, + gpu_memory=get_gpu_memory(), + cpu_memory=get_cpu_memory(), + ) + + # Create a frontend. + frontend = SimpleFrontend( + model_name=args.model, + block_size=args.block_size, + ) + sampling_params_dict = { + 'n': 1, + 'temperature': 0.0, + 'top_p': 1.0, + 'use_beam_search': False, + 'stop_token_ids': set(), + 'max_num_steps': args.output_len, + } + sampling_params = SamplingParams.from_dict(sampling_params_dict) + input_token_ids = [0] * args.input_len + + def profile_step(profile=False): + if profile: + torch.cuda.cudart().cudaProfilerStart() + for _ in range(args.batch_size): + frontend._add_query(input_token_ids, sampling_params) + server.add_sequence_groups(frontend.get_inputs()) + start_time = time.time() + while True: + server.step() + if not server.has_unfinished_requests(): + break + end_time = time.time() + latency = end_time - start_time + if profile: + torch.cuda.cudart().cudaProfilerStop() + return latency + + print("Warm up step") + profile_step() + + # Benchmark. + latencies = [] + for _ in tqdm(range(3), desc="Profile step"): + latencies.append(profile_step()) + print(f'Avg latency: {np.mean(latencies)} seconds') + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='CacheFlow simple server.') + parser = add_server_arguments(parser) + parser.add_argument('--input-len', type=int, default=32) + parser.add_argument('--output-len', type=int, default=128) + parser.add_argument('--batch-size', type=int, default=8) + args = parser.parse_args() + args.max_batch_size = max(args.max_batch_size, args.batch_size * args.input_len) + print(args) + main(args) diff --git a/cacheflow/parallel_utils/tensor_parallel/__init__.py b/cacheflow/parallel_utils/tensor_parallel/__init__.py index fba4f9ab..246f5f6f 100644 --- a/cacheflow/parallel_utils/tensor_parallel/__init__.py +++ b/cacheflow/parallel_utils/tensor_parallel/__init__.py @@ -6,8 +6,6 @@ from .layers import ( set_defaults_if_not_set_tensor_model_parallel_attributes, copy_tensor_model_parallel_attributes, param_is_not_tensor_parallel_duplicate, - linear_with_grad_accumulation_and_async_allreduce - ) from .mappings import ( @@ -39,7 +37,6 @@ __all__ = [ "set_defaults_if_not_set_tensor_model_parallel_attributes", "copy_tensor_model_parallel_attributes", "param_is_not_tensor_parallel_duplicate", - "linear_with_grad_accumulation_and_async_allreduce", # mappings.py "copy_to_tensor_model_parallel_region", "gather_from_tensor_model_parallel_region", diff --git a/cacheflow/parallel_utils/tensor_parallel/layers.py b/cacheflow/parallel_utils/tensor_parallel/layers.py index 978ca04e..f9ba8385 100644 --- a/cacheflow/parallel_utils/tensor_parallel/layers.py +++ b/cacheflow/parallel_utils/tensor_parallel/layers.py @@ -3,10 +3,6 @@ # Parts of the code here are adapted from PyTorch # repo: https://github.com/pytorch/pytorch -import math -import os -from typing import Optional -import warnings import torch import torch.nn.functional as F @@ -16,31 +12,20 @@ from torch.nn.parameter import Parameter from cacheflow.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, - get_tensor_model_parallel_group, - get_global_memory_buffer, ) from .mappings import ( copy_to_tensor_model_parallel_region, gather_from_tensor_model_parallel_region, - gather_from_sequence_parallel_region, reduce_from_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region, - reduce_scatter_to_sequence_parallel_region, ) from .random import get_cuda_rng_tracker from .utils import ( divide, - split_tensor_along_last_dim, VocabUtility, ) -_grad_accum_fusion_available = True -try: - import fused_weight_gradient_mlp_cuda -except ImportError: - _grad_accum_fusion_available = False - _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, 'partition_dim': -1, 'partition_stride': 1} @@ -216,202 +201,6 @@ class VocabParallelEmbedding(torch.nn.Module): return output -class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): - """See linear_with_grad_accumulation_and_async_allreduce""" - - @staticmethod - def forward(ctx, input, weight, bias, gradient_accumulation_fusion, - async_grad_allreduce, sequence_parallel): - ctx.save_for_backward(input, weight) - ctx.use_bias = bias is not None - ctx.gradient_accumulation_fusion = gradient_accumulation_fusion - ctx.async_grad_allreduce = async_grad_allreduce - ctx.sequence_parallel = sequence_parallel - - if sequence_parallel: - world_size = get_tensor_model_parallel_world_size() - dim_size = list(input.size()) - dim_size[0] = dim_size[0] * world_size - - all_gather_buffer = \ - get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") - torch.distributed._all_gather_base( - all_gather_buffer, - input, - group=get_tensor_model_parallel_group()) - total_input = all_gather_buffer - else: - total_input = input - - output = torch.matmul(total_input, weight.t()) - if bias is not None: - output = output + bias - return output - - @staticmethod - def backward(ctx, grad_output): - input, weight = ctx.saved_tensors - use_bias = ctx.use_bias - - if ctx.sequence_parallel: - world_size = get_tensor_model_parallel_world_size() - dim_size = list(input.size()) - dim_size[0] = dim_size[0] * world_size - - all_gather_buffer = \ - get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") - handle = torch.distributed._all_gather_base( - all_gather_buffer, - input, - group=get_tensor_model_parallel_group(), async_op=True) - - # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the - # gather is scheduled before the input gradient computation - total_input = all_gather_buffer - else: - total_input = input - grad_input = grad_output.matmul(weight) - - if ctx.sequence_parallel: - handle.wait() - - # Convert the tensor shapes to 2D for execution compatibility - grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], - grad_output.shape[2]) - total_input = total_input.view(total_input.shape[0] * total_input.shape[1], - total_input.shape[2]) - - if ctx.async_grad_allreduce: - # Asynchronous all-reduce - handle = torch.distributed.all_reduce( - grad_input, group=get_tensor_model_parallel_group(), async_op=True) - # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the - # all-reduce is scheduled before the weight gradient computation - - if ctx.sequence_parallel: - assert not ctx.async_grad_allreduce - dim_size = list(input.size()) - sub_grad_input = torch.empty(dim_size, dtype=input.dtype, - device=torch.cuda.current_device(), - requires_grad=False) - # reduce_scatter - handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, - group=get_tensor_model_parallel_group(), - async_op=True) - # Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the - # reduce scatter is scheduled before the weight gradient computation - - - if ctx.gradient_accumulation_fusion: - if weight.main_grad.dtype == torch.float32: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad) - elif weight.main_grad.dtype == torch.float16: - fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, weight.main_grad) - else: - raise RuntimeError("Unsupported gradient type for gradient accumulation fusion") - grad_weight = None - else: - grad_weight = grad_output.t().matmul(total_input) - grad_bias = grad_output.sum(dim=0) if use_bias else None - - if ctx.sequence_parallel: - handle.wait() - return sub_grad_input, grad_weight, grad_bias, None, None, None - - if ctx.async_grad_allreduce: - handle.wait() - - return grad_input, grad_weight, grad_bias, None, None, None - -def linear_with_grad_accumulation_and_async_allreduce( - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor], - gradient_accumulation_fusion: bool, - async_grad_allreduce: bool, - sequence_parallel_enabled: bool, -) -> torch.Tensor: - """Linear layer execution with asynchronous communication and - gradient accumulation fusion in backprop. - - This has the option to accumulate the result of backprop - calculation into an existing gradient buffer, preventing the need - to do an additional addition kernel after the gradient - calculation. - - Additionally, the tensor parallel all reduce of the input - gradients can be done asynchronously with the calculation of - the weight gradients. - - In the case of sequence parallelism, the reduce scatter of the - input gradients is done asynchronously with the calcluation of the - weight gradients. - - Use of this module requires that the environment variable - CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective - operations, noted in the code, that should be scheduled before - compute kernels to overlap the communication with the computation, - which is necessary for a speedup but not for correctness so that - ordering isn't imposed by the scheduler. Setting - CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled - in the order they are called. - - Arguments: - - input (torch.Tensor required): input like torch.nn.functional.linear - - weight (torch.Tensor required): weight like torch.nn.functional.linear - - bias (torch.Tensor optional): bias like torch.nn.functional.linear - - gradient_accumulation_fusion (bool required): Perform the gradient - accumulation fusion, requires the custom CUDA extension - fused_weight_gradient_mlp_cuda module. To use - gradient_accumulation_fusion you must install APEX with - --cpp_ext and --cuda_ext. For example: "pip install - --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" - " Note that the extension requires CUDA>=11. Otherwise, you - must turn off gradient accumulation fusion." - - async_grad_allreduce (bool required): Do the allreduce of input - gradients asyncronously with the computation of weight - gradients. If sequence_parallel_enabled is True, this must be - False, as no all reduce is performed. - - sequence_parallel_enabled (bool required): Indicates that sequence - parallelism is used and thus in the forward pass the input is - all gathered, and the backward pass the input gradients are - reduce scattered. - """ - args = [ - input, - weight, - bias, - gradient_accumulation_fusion, - async_grad_allreduce, - sequence_parallel_enabled, - ] - - if not linear_with_grad_accumulation_and_async_allreduce.warned: - if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1": - if sequence_parallel_enabled: - warnings.warn( - "When using sequence parallelism it is recommended to set the " - "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " - "maximum speedup") - linear_with_grad_accumulation_and_async_allreduce.warned = True - - if async_grad_allreduce: - warnings.warn( - "When using async grad allreduce it is recommended to set the " - "environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for " - "maximum speedup") - linear_with_grad_accumulation_and_async_allreduce.warned = True - - with torch.cuda.amp.autocast(enabled=False): - return LinearWithGradAccumulationAndAsyncCommunication.apply(*args) -linear_with_grad_accumulation_and_async_allreduce.warned = False - class ColumnParallelLinear(torch.nn.Module): """Linear layer with column parallelism. @@ -436,11 +225,8 @@ class ColumnParallelLinear(torch.nn.Module): skip_bias_add: This was added to enable performance optimations where bias can be fused with other elementwise operations. we skip adding bias but instead return it. - async_tensor_model_parallel_allreduce: params_dtype: use_cpu_initialization: - gradient_accumulation_fusion: - sequence_parallel_enabled: """ def __init__(self, input_size, output_size, *, @@ -448,12 +234,9 @@ class ColumnParallelLinear(torch.nn.Module): init_method=init.xavier_normal_, stride=1, keep_master_weight_for_test=False, skip_bias_add=False, - async_tensor_model_parallel_allreduce=True, params_dtype=None, use_cpu_initialization=False, perform_initialization=True, - gradient_accumulation_fusion=False, - sequence_parallel_enabled: bool = False, ): super(ColumnParallelLinear, self).__init__() @@ -506,37 +289,6 @@ class ColumnParallelLinear(torch.nn.Module): else: self.register_parameter('bias', None) - self.async_tensor_model_parallel_allreduce = ( - async_tensor_model_parallel_allreduce and - world_size > 1) - if sequence_parallel_enabled: - if world_size <= 1: - warnings.warn( - f"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is {world_size}. " - f"Disabling sequence parallel." - ) - sequence_parallel_enabled = False - self.sequence_parallel_enabled = sequence_parallel_enabled - - if gradient_accumulation_fusion: - if not _grad_accum_fusion_available: - raise RuntimeError( - "ColumnParallelLinear was called with gradient_accumulation_fusion set " - "to True but the custom CUDA extension fused_weight_gradient_mlp_cuda " - "module is not found. To use gradient_accumulation_fusion you must " - "install APEX with --cpp_ext and --cuda_ext. For example: " - "pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" " - "Note that the extension requires CUDA>=11. Otherwise, you must turn off " - "gradient accumulation fusion." - ) - self.gradient_accumulation_fusion = gradient_accumulation_fusion - - if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled: - raise RuntimeError( - "`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` " - "cannot be enabled at the same time." - ) - def forward(self, input_): """Forward of ColumnParallelLinear @@ -550,23 +302,11 @@ class ColumnParallelLinear(torch.nn.Module): """ bias = self.bias if not self.skip_bias_add else None - if self.async_tensor_model_parallel_allreduce or \ - self.sequence_parallel_enabled: - input_parallel = input_ - else: - input_parallel = copy_to_tensor_model_parallel_region(input_) + input_parallel = copy_to_tensor_model_parallel_region(input_) # Matrix multiply. - output_parallel = linear_with_grad_accumulation_and_async_allreduce( - input=input_parallel, - weight=self.weight, - bias=bias, - gradient_accumulation_fusion=self.gradient_accumulation_fusion, - async_grad_allreduce=self.async_tensor_model_parallel_allreduce, - sequence_parallel_enabled=self.sequence_parallel_enabled, - ) + output_parallel = F.linear(input_parallel, self.weight, bias) if self.gather_output: # All-gather across the partitions. - assert not self.sequence_parallel_enabled output = gather_from_tensor_model_parallel_region(output_parallel) else: output = output_parallel @@ -607,8 +347,6 @@ class RowParallelLinear(torch.nn.Module): params_dtype: use_cpu_initialization: perform_initialization: - gradient_accumulation_fusion: - sequence_parallel_enabled: """ def __init__(self, input_size, output_size, *, @@ -619,8 +357,6 @@ class RowParallelLinear(torch.nn.Module): params_dtype=None, use_cpu_initialization=False, perform_initialization=True, - gradient_accumulation_fusion=False, - sequence_parallel_enabled: bool = False, ): super(RowParallelLinear, self).__init__() @@ -635,10 +371,6 @@ class RowParallelLinear(torch.nn.Module): world_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, world_size) self.skip_bias_add = skip_bias_add - self.gradient_accumulation_fusion = gradient_accumulation_fusion - self.sequence_parallel_enabled = sequence_parallel_enabled - if self.sequence_parallel_enabled and not self.input_is_parallel: - raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`") # Parameters. # Note: torch.nn.functional.linear performs XA^T + b and as a result @@ -669,7 +401,6 @@ class RowParallelLinear(torch.nn.Module): self.bias = Parameter(torch.empty( self.output_size, device=torch.cuda.current_device(), dtype=params_dtype)) - setattr(self.bias, 'sequence_parallel', sequence_parallel_enabled) # Always initialize bias to zero. with torch.no_grad(): @@ -693,23 +424,12 @@ class RowParallelLinear(torch.nn.Module): if self.input_is_parallel: input_parallel = input_ else: - assert not self.sequence_parallel_enabled input_parallel = scatter_to_tensor_model_parallel_region(input_) # Matrix multiply. - output_parallel = linear_with_grad_accumulation_and_async_allreduce( - input=input_parallel, - weight=self.weight, - bias=None, - gradient_accumulation_fusion=self.gradient_accumulation_fusion, - async_grad_allreduce=False, - sequence_parallel_enabled=False, - ) + output_parallel = F.linear(input_parallel, self.weight) # All-reduce across all the partitions. - if self.sequence_parallel_enabled: - output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) - else: - output_ = reduce_from_tensor_model_parallel_region(output_parallel) + output_ = reduce_from_tensor_model_parallel_region(output_parallel) if not self.skip_bias_add: output = output_ + self.bias if self.bias is not None else output_ output_bias = None