Optimize tensor parallel execution speed (#17)

This commit is contained in:
Zhuohan Li 2023-04-01 00:51:08 +08:00 committed by GitHub
parent 7a7929abe8
commit c45f3c3ab6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 103 additions and 287 deletions

View File

@ -0,0 +1,99 @@
import argparse
import time
from typing import List
from tqdm import tqdm
import numpy as np
import torch
from cacheflow.master.simple_frontend import SimpleFrontend
from cacheflow.master.server import (Server, add_server_arguments,
initialize_ray_cluster)
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import get_gpu_memory, get_cpu_memory
def main(args: argparse.Namespace):
# TODO(zhuohan): Support pipeline parallelism.
assert args.pipeline_parallel_size == 1, (
'Pipeline parallelism is not supported yet.')
(num_nodes, num_devices_per_node, distributed_init_method,
all_stage_devices) = (
initialize_ray_cluster(
address='local',
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size))
# Create a server.
server = Server(
model=args.model,
model_path=args.model_path,
pipeline_parallel_size=args.pipeline_parallel_size,
tensor_parallel_size=args.tensor_parallel_size,
block_size=args.block_size,
dtype=args.dtype,
seed=args.seed,
swap_space=args.swap_space,
max_batch_size=args.max_batch_size,
num_nodes=num_nodes,
num_devices_per_node=num_devices_per_node,
distributed_init_method=distributed_init_method,
all_stage_devices=all_stage_devices,
gpu_memory=get_gpu_memory(),
cpu_memory=get_cpu_memory(),
)
# Create a frontend.
frontend = SimpleFrontend(
model_name=args.model,
block_size=args.block_size,
)
sampling_params_dict = {
'n': 1,
'temperature': 0.0,
'top_p': 1.0,
'use_beam_search': False,
'stop_token_ids': set(),
'max_num_steps': args.output_len,
}
sampling_params = SamplingParams.from_dict(sampling_params_dict)
input_token_ids = [0] * args.input_len
def profile_step(profile=False):
if profile:
torch.cuda.cudart().cudaProfilerStart()
for _ in range(args.batch_size):
frontend._add_query(input_token_ids, sampling_params)
server.add_sequence_groups(frontend.get_inputs())
start_time = time.time()
while True:
server.step()
if not server.has_unfinished_requests():
break
end_time = time.time()
latency = end_time - start_time
if profile:
torch.cuda.cudart().cudaProfilerStop()
return latency
print("Warm up step")
profile_step()
# Benchmark.
latencies = []
for _ in tqdm(range(3), desc="Profile step"):
latencies.append(profile_step())
print(f'Avg latency: {np.mean(latencies)} seconds')
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
parser = add_server_arguments(parser)
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
args = parser.parse_args()
args.max_batch_size = max(args.max_batch_size, args.batch_size * args.input_len)
print(args)
main(args)

View File

@ -6,8 +6,6 @@ from .layers import (
set_defaults_if_not_set_tensor_model_parallel_attributes, set_defaults_if_not_set_tensor_model_parallel_attributes,
copy_tensor_model_parallel_attributes, copy_tensor_model_parallel_attributes,
param_is_not_tensor_parallel_duplicate, param_is_not_tensor_parallel_duplicate,
linear_with_grad_accumulation_and_async_allreduce
) )
from .mappings import ( from .mappings import (
@ -39,7 +37,6 @@ __all__ = [
"set_defaults_if_not_set_tensor_model_parallel_attributes", "set_defaults_if_not_set_tensor_model_parallel_attributes",
"copy_tensor_model_parallel_attributes", "copy_tensor_model_parallel_attributes",
"param_is_not_tensor_parallel_duplicate", "param_is_not_tensor_parallel_duplicate",
"linear_with_grad_accumulation_and_async_allreduce",
# mappings.py # mappings.py
"copy_to_tensor_model_parallel_region", "copy_to_tensor_model_parallel_region",
"gather_from_tensor_model_parallel_region", "gather_from_tensor_model_parallel_region",

View File

@ -3,10 +3,6 @@
# Parts of the code here are adapted from PyTorch # Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch # repo: https://github.com/pytorch/pytorch
import math
import os
from typing import Optional
import warnings
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -16,31 +12,20 @@ from torch.nn.parameter import Parameter
from cacheflow.parallel_utils.parallel_state import ( from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
get_global_memory_buffer,
) )
from .mappings import ( from .mappings import (
copy_to_tensor_model_parallel_region, copy_to_tensor_model_parallel_region,
gather_from_tensor_model_parallel_region, gather_from_tensor_model_parallel_region,
gather_from_sequence_parallel_region,
reduce_from_tensor_model_parallel_region, reduce_from_tensor_model_parallel_region,
scatter_to_tensor_model_parallel_region, scatter_to_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
) )
from .random import get_cuda_rng_tracker from .random import get_cuda_rng_tracker
from .utils import ( from .utils import (
divide, divide,
split_tensor_along_last_dim,
VocabUtility, VocabUtility,
) )
_grad_accum_fusion_available = True
try:
import fused_weight_gradient_mlp_cuda
except ImportError:
_grad_accum_fusion_available = False
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1, 'partition_dim': -1,
'partition_stride': 1} 'partition_stride': 1}
@ -216,202 +201,6 @@ class VocabParallelEmbedding(torch.nn.Module):
return output return output
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
"""See linear_with_grad_accumulation_and_async_allreduce"""
@staticmethod
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
async_grad_allreduce, sequence_parallel):
ctx.save_for_backward(input, weight)
ctx.use_bias = bias is not None
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
ctx.async_grad_allreduce = async_grad_allreduce
ctx.sequence_parallel = sequence_parallel
if sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=get_tensor_model_parallel_group())
total_input = all_gather_buffer
else:
total_input = input
output = torch.matmul(total_input, weight.t())
if bias is not None:
output = output + bias
return output
@staticmethod
def backward(ctx, grad_output):
input, weight = ctx.saved_tensors
use_bias = ctx.use_bias
if ctx.sequence_parallel:
world_size = get_tensor_model_parallel_world_size()
dim_size = list(input.size())
dim_size[0] = dim_size[0] * world_size
all_gather_buffer = \
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
handle = torch.distributed._all_gather_base(
all_gather_buffer,
input,
group=get_tensor_model_parallel_group(), async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# gather is scheduled before the input gradient computation
total_input = all_gather_buffer
else:
total_input = input
grad_input = grad_output.matmul(weight)
if ctx.sequence_parallel:
handle.wait()
# Convert the tensor shapes to 2D for execution compatibility
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
grad_output.shape[2])
total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
total_input.shape[2])
if ctx.async_grad_allreduce:
# Asynchronous all-reduce
handle = torch.distributed.all_reduce(
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# all-reduce is scheduled before the weight gradient computation
if ctx.sequence_parallel:
assert not ctx.async_grad_allreduce
dim_size = list(input.size())
sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
# reduce_scatter
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
group=get_tensor_model_parallel_group(),
async_op=True)
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
# reduce scatter is scheduled before the weight gradient computation
if ctx.gradient_accumulation_fusion:
if weight.main_grad.dtype == torch.float32:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
elif weight.main_grad.dtype == torch.float16:
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, weight.main_grad)
else:
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
grad_weight = None
else:
grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
if ctx.sequence_parallel:
handle.wait()
return sub_grad_input, grad_weight, grad_bias, None, None, None
if ctx.async_grad_allreduce:
handle.wait()
return grad_input, grad_weight, grad_bias, None, None, None
def linear_with_grad_accumulation_and_async_allreduce(
input: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
gradient_accumulation_fusion: bool,
async_grad_allreduce: bool,
sequence_parallel_enabled: bool,
) -> torch.Tensor:
"""Linear layer execution with asynchronous communication and
gradient accumulation fusion in backprop.
This has the option to accumulate the result of backprop
calculation into an existing gradient buffer, preventing the need
to do an additional addition kernel after the gradient
calculation.
Additionally, the tensor parallel all reduce of the input
gradients can be done asynchronously with the calculation of
the weight gradients.
In the case of sequence parallelism, the reduce scatter of the
input gradients is done asynchronously with the calcluation of the
weight gradients.
Use of this module requires that the environment variable
CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
operations, noted in the code, that should be scheduled before
compute kernels to overlap the communication with the computation,
which is necessary for a speedup but not for correctness so that
ordering isn't imposed by the scheduler. Setting
CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
in the order they are called.
Arguments:
input (torch.Tensor required): input like torch.nn.functional.linear
weight (torch.Tensor required): weight like torch.nn.functional.linear
bias (torch.Tensor optional): bias like torch.nn.functional.linear
gradient_accumulation_fusion (bool required): Perform the gradient
accumulation fusion, requires the custom CUDA extension
fused_weight_gradient_mlp_cuda module. To use
gradient_accumulation_fusion you must install APEX with
--cpp_ext and --cuda_ext. For example: "pip install
--global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\"
" Note that the extension requires CUDA>=11. Otherwise, you
must turn off gradient accumulation fusion."
async_grad_allreduce (bool required): Do the allreduce of input
gradients asyncronously with the computation of weight
gradients. If sequence_parallel_enabled is True, this must be
False, as no all reduce is performed.
sequence_parallel_enabled (bool required): Indicates that sequence
parallelism is used and thus in the forward pass the input is
all gathered, and the backward pass the input gradients are
reduce scattered.
"""
args = [
input,
weight,
bias,
gradient_accumulation_fusion,
async_grad_allreduce,
sequence_parallel_enabled,
]
if not linear_with_grad_accumulation_and_async_allreduce.warned:
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if sequence_parallel_enabled:
warnings.warn(
"When using sequence parallelism it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup")
linear_with_grad_accumulation_and_async_allreduce.warned = True
if async_grad_allreduce:
warnings.warn(
"When using async grad allreduce it is recommended to set the "
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
"maximum speedup")
linear_with_grad_accumulation_and_async_allreduce.warned = True
with torch.cuda.amp.autocast(enabled=False):
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
linear_with_grad_accumulation_and_async_allreduce.warned = False
class ColumnParallelLinear(torch.nn.Module): class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
@ -436,11 +225,8 @@ class ColumnParallelLinear(torch.nn.Module):
skip_bias_add: This was added to enable performance optimations where bias skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip can be fused with other elementwise operations. we skip
adding bias but instead return it. adding bias but instead return it.
async_tensor_model_parallel_allreduce:
params_dtype: params_dtype:
use_cpu_initialization: use_cpu_initialization:
gradient_accumulation_fusion:
sequence_parallel_enabled:
""" """
def __init__(self, input_size, output_size, *, def __init__(self, input_size, output_size, *,
@ -448,12 +234,9 @@ class ColumnParallelLinear(torch.nn.Module):
init_method=init.xavier_normal_, stride=1, init_method=init.xavier_normal_, stride=1,
keep_master_weight_for_test=False, keep_master_weight_for_test=False,
skip_bias_add=False, skip_bias_add=False,
async_tensor_model_parallel_allreduce=True,
params_dtype=None, params_dtype=None,
use_cpu_initialization=False, use_cpu_initialization=False,
perform_initialization=True, perform_initialization=True,
gradient_accumulation_fusion=False,
sequence_parallel_enabled: bool = False,
): ):
super(ColumnParallelLinear, self).__init__() super(ColumnParallelLinear, self).__init__()
@ -506,37 +289,6 @@ class ColumnParallelLinear(torch.nn.Module):
else: else:
self.register_parameter('bias', None) self.register_parameter('bias', None)
self.async_tensor_model_parallel_allreduce = (
async_tensor_model_parallel_allreduce and
world_size > 1)
if sequence_parallel_enabled:
if world_size <= 1:
warnings.warn(
f"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is {world_size}. "
f"Disabling sequence parallel."
)
sequence_parallel_enabled = False
self.sequence_parallel_enabled = sequence_parallel_enabled
if gradient_accumulation_fusion:
if not _grad_accum_fusion_available:
raise RuntimeError(
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
"to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
"module is not found. To use gradient_accumulation_fusion you must "
"install APEX with --cpp_ext and --cuda_ext. For example: "
"pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" "
"Note that the extension requires CUDA>=11. Otherwise, you must turn off "
"gradient accumulation fusion."
)
self.gradient_accumulation_fusion = gradient_accumulation_fusion
if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled:
raise RuntimeError(
"`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` "
"cannot be enabled at the same time."
)
def forward(self, input_): def forward(self, input_):
"""Forward of ColumnParallelLinear """Forward of ColumnParallelLinear
@ -550,23 +302,11 @@ class ColumnParallelLinear(torch.nn.Module):
""" """
bias = self.bias if not self.skip_bias_add else None bias = self.bias if not self.skip_bias_add else None
if self.async_tensor_model_parallel_allreduce or \
self.sequence_parallel_enabled:
input_parallel = input_
else:
input_parallel = copy_to_tensor_model_parallel_region(input_) input_parallel = copy_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = linear_with_grad_accumulation_and_async_allreduce( output_parallel = F.linear(input_parallel, self.weight, bias)
input=input_parallel,
weight=self.weight,
bias=bias,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=self.async_tensor_model_parallel_allreduce,
sequence_parallel_enabled=self.sequence_parallel_enabled,
)
if self.gather_output: if self.gather_output:
# All-gather across the partitions. # All-gather across the partitions.
assert not self.sequence_parallel_enabled
output = gather_from_tensor_model_parallel_region(output_parallel) output = gather_from_tensor_model_parallel_region(output_parallel)
else: else:
output = output_parallel output = output_parallel
@ -607,8 +347,6 @@ class RowParallelLinear(torch.nn.Module):
params_dtype: params_dtype:
use_cpu_initialization: use_cpu_initialization:
perform_initialization: perform_initialization:
gradient_accumulation_fusion:
sequence_parallel_enabled:
""" """
def __init__(self, input_size, output_size, *, def __init__(self, input_size, output_size, *,
@ -619,8 +357,6 @@ class RowParallelLinear(torch.nn.Module):
params_dtype=None, params_dtype=None,
use_cpu_initialization=False, use_cpu_initialization=False,
perform_initialization=True, perform_initialization=True,
gradient_accumulation_fusion=False,
sequence_parallel_enabled: bool = False,
): ):
super(RowParallelLinear, self).__init__() super(RowParallelLinear, self).__init__()
@ -635,10 +371,6 @@ class RowParallelLinear(torch.nn.Module):
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size) self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add self.skip_bias_add = skip_bias_add
self.gradient_accumulation_fusion = gradient_accumulation_fusion
self.sequence_parallel_enabled = sequence_parallel_enabled
if self.sequence_parallel_enabled and not self.input_is_parallel:
raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`")
# Parameters. # Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result # Note: torch.nn.functional.linear performs XA^T + b and as a result
@ -669,7 +401,6 @@ class RowParallelLinear(torch.nn.Module):
self.bias = Parameter(torch.empty( self.bias = Parameter(torch.empty(
self.output_size, device=torch.cuda.current_device(), self.output_size, device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
setattr(self.bias, 'sequence_parallel', sequence_parallel_enabled)
# Always initialize bias to zero. # Always initialize bias to zero.
with torch.no_grad(): with torch.no_grad():
@ -693,22 +424,11 @@ class RowParallelLinear(torch.nn.Module):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
assert not self.sequence_parallel_enabled
input_parallel = scatter_to_tensor_model_parallel_region(input_) input_parallel = scatter_to_tensor_model_parallel_region(input_)
# Matrix multiply. # Matrix multiply.
output_parallel = linear_with_grad_accumulation_and_async_allreduce( output_parallel = F.linear(input_parallel, self.weight)
input=input_parallel,
weight=self.weight,
bias=None,
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
async_grad_allreduce=False,
sequence_parallel_enabled=False,
)
# All-reduce across all the partitions. # All-reduce across all the partitions.
if self.sequence_parallel_enabled:
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
else:
output_ = reduce_from_tensor_model_parallel_region(output_parallel) output_ = reduce_from_tensor_model_parallel_region(output_parallel)
if not self.skip_bias_add: if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_ output = output_ + self.bias if self.bias is not None else output_