Optimize tensor parallel execution speed (#17)
This commit is contained in:
parent
7a7929abe8
commit
c45f3c3ab6
99
benchmark/benchmark_latency.py
Normal file
99
benchmark/benchmark_latency.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from cacheflow.master.simple_frontend import SimpleFrontend
|
||||||
|
from cacheflow.master.server import (Server, add_server_arguments,
|
||||||
|
initialize_ray_cluster)
|
||||||
|
from cacheflow.sampling_params import SamplingParams
|
||||||
|
from cacheflow.utils import get_gpu_memory, get_cpu_memory
|
||||||
|
|
||||||
|
|
||||||
|
def main(args: argparse.Namespace):
|
||||||
|
# TODO(zhuohan): Support pipeline parallelism.
|
||||||
|
assert args.pipeline_parallel_size == 1, (
|
||||||
|
'Pipeline parallelism is not supported yet.')
|
||||||
|
|
||||||
|
(num_nodes, num_devices_per_node, distributed_init_method,
|
||||||
|
all_stage_devices) = (
|
||||||
|
initialize_ray_cluster(
|
||||||
|
address='local',
|
||||||
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||||
|
tensor_parallel_size=args.tensor_parallel_size))
|
||||||
|
|
||||||
|
# Create a server.
|
||||||
|
server = Server(
|
||||||
|
model=args.model,
|
||||||
|
model_path=args.model_path,
|
||||||
|
pipeline_parallel_size=args.pipeline_parallel_size,
|
||||||
|
tensor_parallel_size=args.tensor_parallel_size,
|
||||||
|
block_size=args.block_size,
|
||||||
|
dtype=args.dtype,
|
||||||
|
seed=args.seed,
|
||||||
|
swap_space=args.swap_space,
|
||||||
|
max_batch_size=args.max_batch_size,
|
||||||
|
num_nodes=num_nodes,
|
||||||
|
num_devices_per_node=num_devices_per_node,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
|
all_stage_devices=all_stage_devices,
|
||||||
|
gpu_memory=get_gpu_memory(),
|
||||||
|
cpu_memory=get_cpu_memory(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a frontend.
|
||||||
|
frontend = SimpleFrontend(
|
||||||
|
model_name=args.model,
|
||||||
|
block_size=args.block_size,
|
||||||
|
)
|
||||||
|
sampling_params_dict = {
|
||||||
|
'n': 1,
|
||||||
|
'temperature': 0.0,
|
||||||
|
'top_p': 1.0,
|
||||||
|
'use_beam_search': False,
|
||||||
|
'stop_token_ids': set(),
|
||||||
|
'max_num_steps': args.output_len,
|
||||||
|
}
|
||||||
|
sampling_params = SamplingParams.from_dict(sampling_params_dict)
|
||||||
|
input_token_ids = [0] * args.input_len
|
||||||
|
|
||||||
|
def profile_step(profile=False):
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStart()
|
||||||
|
for _ in range(args.batch_size):
|
||||||
|
frontend._add_query(input_token_ids, sampling_params)
|
||||||
|
server.add_sequence_groups(frontend.get_inputs())
|
||||||
|
start_time = time.time()
|
||||||
|
while True:
|
||||||
|
server.step()
|
||||||
|
if not server.has_unfinished_requests():
|
||||||
|
break
|
||||||
|
end_time = time.time()
|
||||||
|
latency = end_time - start_time
|
||||||
|
if profile:
|
||||||
|
torch.cuda.cudart().cudaProfilerStop()
|
||||||
|
return latency
|
||||||
|
|
||||||
|
print("Warm up step")
|
||||||
|
profile_step()
|
||||||
|
|
||||||
|
# Benchmark.
|
||||||
|
latencies = []
|
||||||
|
for _ in tqdm(range(3), desc="Profile step"):
|
||||||
|
latencies.append(profile_step())
|
||||||
|
print(f'Avg latency: {np.mean(latencies)} seconds')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='CacheFlow simple server.')
|
||||||
|
parser = add_server_arguments(parser)
|
||||||
|
parser.add_argument('--input-len', type=int, default=32)
|
||||||
|
parser.add_argument('--output-len', type=int, default=128)
|
||||||
|
parser.add_argument('--batch-size', type=int, default=8)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.max_batch_size = max(args.max_batch_size, args.batch_size * args.input_len)
|
||||||
|
print(args)
|
||||||
|
main(args)
|
@ -6,8 +6,6 @@ from .layers import (
|
|||||||
set_defaults_if_not_set_tensor_model_parallel_attributes,
|
set_defaults_if_not_set_tensor_model_parallel_attributes,
|
||||||
copy_tensor_model_parallel_attributes,
|
copy_tensor_model_parallel_attributes,
|
||||||
param_is_not_tensor_parallel_duplicate,
|
param_is_not_tensor_parallel_duplicate,
|
||||||
linear_with_grad_accumulation_and_async_allreduce
|
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .mappings import (
|
from .mappings import (
|
||||||
@ -39,7 +37,6 @@ __all__ = [
|
|||||||
"set_defaults_if_not_set_tensor_model_parallel_attributes",
|
"set_defaults_if_not_set_tensor_model_parallel_attributes",
|
||||||
"copy_tensor_model_parallel_attributes",
|
"copy_tensor_model_parallel_attributes",
|
||||||
"param_is_not_tensor_parallel_duplicate",
|
"param_is_not_tensor_parallel_duplicate",
|
||||||
"linear_with_grad_accumulation_and_async_allreduce",
|
|
||||||
# mappings.py
|
# mappings.py
|
||||||
"copy_to_tensor_model_parallel_region",
|
"copy_to_tensor_model_parallel_region",
|
||||||
"gather_from_tensor_model_parallel_region",
|
"gather_from_tensor_model_parallel_region",
|
||||||
|
@ -3,10 +3,6 @@
|
|||||||
# Parts of the code here are adapted from PyTorch
|
# Parts of the code here are adapted from PyTorch
|
||||||
# repo: https://github.com/pytorch/pytorch
|
# repo: https://github.com/pytorch/pytorch
|
||||||
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from typing import Optional
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -16,31 +12,20 @@ from torch.nn.parameter import Parameter
|
|||||||
from cacheflow.parallel_utils.parallel_state import (
|
from cacheflow.parallel_utils.parallel_state import (
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
get_tensor_model_parallel_group,
|
|
||||||
get_global_memory_buffer,
|
|
||||||
)
|
)
|
||||||
from .mappings import (
|
from .mappings import (
|
||||||
copy_to_tensor_model_parallel_region,
|
copy_to_tensor_model_parallel_region,
|
||||||
gather_from_tensor_model_parallel_region,
|
gather_from_tensor_model_parallel_region,
|
||||||
gather_from_sequence_parallel_region,
|
|
||||||
reduce_from_tensor_model_parallel_region,
|
reduce_from_tensor_model_parallel_region,
|
||||||
scatter_to_tensor_model_parallel_region,
|
scatter_to_tensor_model_parallel_region,
|
||||||
reduce_scatter_to_sequence_parallel_region,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from .random import get_cuda_rng_tracker
|
from .random import get_cuda_rng_tracker
|
||||||
from .utils import (
|
from .utils import (
|
||||||
divide,
|
divide,
|
||||||
split_tensor_along_last_dim,
|
|
||||||
VocabUtility,
|
VocabUtility,
|
||||||
)
|
)
|
||||||
|
|
||||||
_grad_accum_fusion_available = True
|
|
||||||
try:
|
|
||||||
import fused_weight_gradient_mlp_cuda
|
|
||||||
except ImportError:
|
|
||||||
_grad_accum_fusion_available = False
|
|
||||||
|
|
||||||
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
|
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
|
||||||
'partition_dim': -1,
|
'partition_dim': -1,
|
||||||
'partition_stride': 1}
|
'partition_stride': 1}
|
||||||
@ -216,202 +201,6 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
|
|
||||||
"""See linear_with_grad_accumulation_and_async_allreduce"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
|
|
||||||
async_grad_allreduce, sequence_parallel):
|
|
||||||
ctx.save_for_backward(input, weight)
|
|
||||||
ctx.use_bias = bias is not None
|
|
||||||
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
|
|
||||||
ctx.async_grad_allreduce = async_grad_allreduce
|
|
||||||
ctx.sequence_parallel = sequence_parallel
|
|
||||||
|
|
||||||
if sequence_parallel:
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
dim_size = list(input.size())
|
|
||||||
dim_size[0] = dim_size[0] * world_size
|
|
||||||
|
|
||||||
all_gather_buffer = \
|
|
||||||
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
|
|
||||||
torch.distributed._all_gather_base(
|
|
||||||
all_gather_buffer,
|
|
||||||
input,
|
|
||||||
group=get_tensor_model_parallel_group())
|
|
||||||
total_input = all_gather_buffer
|
|
||||||
else:
|
|
||||||
total_input = input
|
|
||||||
|
|
||||||
output = torch.matmul(total_input, weight.t())
|
|
||||||
if bias is not None:
|
|
||||||
output = output + bias
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, grad_output):
|
|
||||||
input, weight = ctx.saved_tensors
|
|
||||||
use_bias = ctx.use_bias
|
|
||||||
|
|
||||||
if ctx.sequence_parallel:
|
|
||||||
world_size = get_tensor_model_parallel_world_size()
|
|
||||||
dim_size = list(input.size())
|
|
||||||
dim_size[0] = dim_size[0] * world_size
|
|
||||||
|
|
||||||
all_gather_buffer = \
|
|
||||||
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
|
|
||||||
handle = torch.distributed._all_gather_base(
|
|
||||||
all_gather_buffer,
|
|
||||||
input,
|
|
||||||
group=get_tensor_model_parallel_group(), async_op=True)
|
|
||||||
|
|
||||||
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
|
||||||
# gather is scheduled before the input gradient computation
|
|
||||||
total_input = all_gather_buffer
|
|
||||||
else:
|
|
||||||
total_input = input
|
|
||||||
grad_input = grad_output.matmul(weight)
|
|
||||||
|
|
||||||
if ctx.sequence_parallel:
|
|
||||||
handle.wait()
|
|
||||||
|
|
||||||
# Convert the tensor shapes to 2D for execution compatibility
|
|
||||||
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
|
|
||||||
grad_output.shape[2])
|
|
||||||
total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
|
|
||||||
total_input.shape[2])
|
|
||||||
|
|
||||||
if ctx.async_grad_allreduce:
|
|
||||||
# Asynchronous all-reduce
|
|
||||||
handle = torch.distributed.all_reduce(
|
|
||||||
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
|
|
||||||
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
|
||||||
# all-reduce is scheduled before the weight gradient computation
|
|
||||||
|
|
||||||
if ctx.sequence_parallel:
|
|
||||||
assert not ctx.async_grad_allreduce
|
|
||||||
dim_size = list(input.size())
|
|
||||||
sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
|
|
||||||
device=torch.cuda.current_device(),
|
|
||||||
requires_grad=False)
|
|
||||||
# reduce_scatter
|
|
||||||
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
|
|
||||||
group=get_tensor_model_parallel_group(),
|
|
||||||
async_op=True)
|
|
||||||
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
|
||||||
# reduce scatter is scheduled before the weight gradient computation
|
|
||||||
|
|
||||||
|
|
||||||
if ctx.gradient_accumulation_fusion:
|
|
||||||
if weight.main_grad.dtype == torch.float32:
|
|
||||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
|
|
||||||
elif weight.main_grad.dtype == torch.float16:
|
|
||||||
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, weight.main_grad)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
|
|
||||||
grad_weight = None
|
|
||||||
else:
|
|
||||||
grad_weight = grad_output.t().matmul(total_input)
|
|
||||||
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
|
||||||
|
|
||||||
if ctx.sequence_parallel:
|
|
||||||
handle.wait()
|
|
||||||
return sub_grad_input, grad_weight, grad_bias, None, None, None
|
|
||||||
|
|
||||||
if ctx.async_grad_allreduce:
|
|
||||||
handle.wait()
|
|
||||||
|
|
||||||
return grad_input, grad_weight, grad_bias, None, None, None
|
|
||||||
|
|
||||||
def linear_with_grad_accumulation_and_async_allreduce(
|
|
||||||
input: torch.Tensor,
|
|
||||||
weight: torch.Tensor,
|
|
||||||
bias: Optional[torch.Tensor],
|
|
||||||
gradient_accumulation_fusion: bool,
|
|
||||||
async_grad_allreduce: bool,
|
|
||||||
sequence_parallel_enabled: bool,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Linear layer execution with asynchronous communication and
|
|
||||||
gradient accumulation fusion in backprop.
|
|
||||||
|
|
||||||
This has the option to accumulate the result of backprop
|
|
||||||
calculation into an existing gradient buffer, preventing the need
|
|
||||||
to do an additional addition kernel after the gradient
|
|
||||||
calculation.
|
|
||||||
|
|
||||||
Additionally, the tensor parallel all reduce of the input
|
|
||||||
gradients can be done asynchronously with the calculation of
|
|
||||||
the weight gradients.
|
|
||||||
|
|
||||||
In the case of sequence parallelism, the reduce scatter of the
|
|
||||||
input gradients is done asynchronously with the calcluation of the
|
|
||||||
weight gradients.
|
|
||||||
|
|
||||||
Use of this module requires that the environment variable
|
|
||||||
CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
|
|
||||||
operations, noted in the code, that should be scheduled before
|
|
||||||
compute kernels to overlap the communication with the computation,
|
|
||||||
which is necessary for a speedup but not for correctness so that
|
|
||||||
ordering isn't imposed by the scheduler. Setting
|
|
||||||
CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
|
|
||||||
in the order they are called.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
|
|
||||||
input (torch.Tensor required): input like torch.nn.functional.linear
|
|
||||||
|
|
||||||
weight (torch.Tensor required): weight like torch.nn.functional.linear
|
|
||||||
|
|
||||||
bias (torch.Tensor optional): bias like torch.nn.functional.linear
|
|
||||||
|
|
||||||
gradient_accumulation_fusion (bool required): Perform the gradient
|
|
||||||
accumulation fusion, requires the custom CUDA extension
|
|
||||||
fused_weight_gradient_mlp_cuda module. To use
|
|
||||||
gradient_accumulation_fusion you must install APEX with
|
|
||||||
--cpp_ext and --cuda_ext. For example: "pip install
|
|
||||||
--global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\"
|
|
||||||
" Note that the extension requires CUDA>=11. Otherwise, you
|
|
||||||
must turn off gradient accumulation fusion."
|
|
||||||
|
|
||||||
async_grad_allreduce (bool required): Do the allreduce of input
|
|
||||||
gradients asyncronously with the computation of weight
|
|
||||||
gradients. If sequence_parallel_enabled is True, this must be
|
|
||||||
False, as no all reduce is performed.
|
|
||||||
|
|
||||||
sequence_parallel_enabled (bool required): Indicates that sequence
|
|
||||||
parallelism is used and thus in the forward pass the input is
|
|
||||||
all gathered, and the backward pass the input gradients are
|
|
||||||
reduce scattered.
|
|
||||||
"""
|
|
||||||
args = [
|
|
||||||
input,
|
|
||||||
weight,
|
|
||||||
bias,
|
|
||||||
gradient_accumulation_fusion,
|
|
||||||
async_grad_allreduce,
|
|
||||||
sequence_parallel_enabled,
|
|
||||||
]
|
|
||||||
|
|
||||||
if not linear_with_grad_accumulation_and_async_allreduce.warned:
|
|
||||||
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
|
|
||||||
if sequence_parallel_enabled:
|
|
||||||
warnings.warn(
|
|
||||||
"When using sequence parallelism it is recommended to set the "
|
|
||||||
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
|
|
||||||
"maximum speedup")
|
|
||||||
linear_with_grad_accumulation_and_async_allreduce.warned = True
|
|
||||||
|
|
||||||
if async_grad_allreduce:
|
|
||||||
warnings.warn(
|
|
||||||
"When using async grad allreduce it is recommended to set the "
|
|
||||||
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
|
|
||||||
"maximum speedup")
|
|
||||||
linear_with_grad_accumulation_and_async_allreduce.warned = True
|
|
||||||
|
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
|
||||||
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
|
|
||||||
linear_with_grad_accumulation_and_async_allreduce.warned = False
|
|
||||||
|
|
||||||
class ColumnParallelLinear(torch.nn.Module):
|
class ColumnParallelLinear(torch.nn.Module):
|
||||||
"""Linear layer with column parallelism.
|
"""Linear layer with column parallelism.
|
||||||
|
|
||||||
@ -436,11 +225,8 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
skip_bias_add: This was added to enable performance optimations where bias
|
skip_bias_add: This was added to enable performance optimations where bias
|
||||||
can be fused with other elementwise operations. we skip
|
can be fused with other elementwise operations. we skip
|
||||||
adding bias but instead return it.
|
adding bias but instead return it.
|
||||||
async_tensor_model_parallel_allreduce:
|
|
||||||
params_dtype:
|
params_dtype:
|
||||||
use_cpu_initialization:
|
use_cpu_initialization:
|
||||||
gradient_accumulation_fusion:
|
|
||||||
sequence_parallel_enabled:
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_size, output_size, *,
|
def __init__(self, input_size, output_size, *,
|
||||||
@ -448,12 +234,9 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
init_method=init.xavier_normal_, stride=1,
|
init_method=init.xavier_normal_, stride=1,
|
||||||
keep_master_weight_for_test=False,
|
keep_master_weight_for_test=False,
|
||||||
skip_bias_add=False,
|
skip_bias_add=False,
|
||||||
async_tensor_model_parallel_allreduce=True,
|
|
||||||
params_dtype=None,
|
params_dtype=None,
|
||||||
use_cpu_initialization=False,
|
use_cpu_initialization=False,
|
||||||
perform_initialization=True,
|
perform_initialization=True,
|
||||||
gradient_accumulation_fusion=False,
|
|
||||||
sequence_parallel_enabled: bool = False,
|
|
||||||
):
|
):
|
||||||
super(ColumnParallelLinear, self).__init__()
|
super(ColumnParallelLinear, self).__init__()
|
||||||
|
|
||||||
@ -506,37 +289,6 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter('bias', None)
|
self.register_parameter('bias', None)
|
||||||
|
|
||||||
self.async_tensor_model_parallel_allreduce = (
|
|
||||||
async_tensor_model_parallel_allreduce and
|
|
||||||
world_size > 1)
|
|
||||||
if sequence_parallel_enabled:
|
|
||||||
if world_size <= 1:
|
|
||||||
warnings.warn(
|
|
||||||
f"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is {world_size}. "
|
|
||||||
f"Disabling sequence parallel."
|
|
||||||
)
|
|
||||||
sequence_parallel_enabled = False
|
|
||||||
self.sequence_parallel_enabled = sequence_parallel_enabled
|
|
||||||
|
|
||||||
if gradient_accumulation_fusion:
|
|
||||||
if not _grad_accum_fusion_available:
|
|
||||||
raise RuntimeError(
|
|
||||||
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
|
|
||||||
"to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
|
|
||||||
"module is not found. To use gradient_accumulation_fusion you must "
|
|
||||||
"install APEX with --cpp_ext and --cuda_ext. For example: "
|
|
||||||
"pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" "
|
|
||||||
"Note that the extension requires CUDA>=11. Otherwise, you must turn off "
|
|
||||||
"gradient accumulation fusion."
|
|
||||||
)
|
|
||||||
self.gradient_accumulation_fusion = gradient_accumulation_fusion
|
|
||||||
|
|
||||||
if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled:
|
|
||||||
raise RuntimeError(
|
|
||||||
"`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` "
|
|
||||||
"cannot be enabled at the same time."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_):
|
||||||
"""Forward of ColumnParallelLinear
|
"""Forward of ColumnParallelLinear
|
||||||
@ -550,23 +302,11 @@ class ColumnParallelLinear(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
bias = self.bias if not self.skip_bias_add else None
|
bias = self.bias if not self.skip_bias_add else None
|
||||||
|
|
||||||
if self.async_tensor_model_parallel_allreduce or \
|
|
||||||
self.sequence_parallel_enabled:
|
|
||||||
input_parallel = input_
|
|
||||||
else:
|
|
||||||
input_parallel = copy_to_tensor_model_parallel_region(input_)
|
input_parallel = copy_to_tensor_model_parallel_region(input_)
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
output_parallel = linear_with_grad_accumulation_and_async_allreduce(
|
output_parallel = F.linear(input_parallel, self.weight, bias)
|
||||||
input=input_parallel,
|
|
||||||
weight=self.weight,
|
|
||||||
bias=bias,
|
|
||||||
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
|
|
||||||
async_grad_allreduce=self.async_tensor_model_parallel_allreduce,
|
|
||||||
sequence_parallel_enabled=self.sequence_parallel_enabled,
|
|
||||||
)
|
|
||||||
if self.gather_output:
|
if self.gather_output:
|
||||||
# All-gather across the partitions.
|
# All-gather across the partitions.
|
||||||
assert not self.sequence_parallel_enabled
|
|
||||||
output = gather_from_tensor_model_parallel_region(output_parallel)
|
output = gather_from_tensor_model_parallel_region(output_parallel)
|
||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
@ -607,8 +347,6 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
params_dtype:
|
params_dtype:
|
||||||
use_cpu_initialization:
|
use_cpu_initialization:
|
||||||
perform_initialization:
|
perform_initialization:
|
||||||
gradient_accumulation_fusion:
|
|
||||||
sequence_parallel_enabled:
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, input_size, output_size, *,
|
def __init__(self, input_size, output_size, *,
|
||||||
@ -619,8 +357,6 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
params_dtype=None,
|
params_dtype=None,
|
||||||
use_cpu_initialization=False,
|
use_cpu_initialization=False,
|
||||||
perform_initialization=True,
|
perform_initialization=True,
|
||||||
gradient_accumulation_fusion=False,
|
|
||||||
sequence_parallel_enabled: bool = False,
|
|
||||||
):
|
):
|
||||||
super(RowParallelLinear, self).__init__()
|
super(RowParallelLinear, self).__init__()
|
||||||
|
|
||||||
@ -635,10 +371,6 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
world_size = get_tensor_model_parallel_world_size()
|
world_size = get_tensor_model_parallel_world_size()
|
||||||
self.input_size_per_partition = divide(input_size, world_size)
|
self.input_size_per_partition = divide(input_size, world_size)
|
||||||
self.skip_bias_add = skip_bias_add
|
self.skip_bias_add = skip_bias_add
|
||||||
self.gradient_accumulation_fusion = gradient_accumulation_fusion
|
|
||||||
self.sequence_parallel_enabled = sequence_parallel_enabled
|
|
||||||
if self.sequence_parallel_enabled and not self.input_is_parallel:
|
|
||||||
raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`")
|
|
||||||
|
|
||||||
# Parameters.
|
# Parameters.
|
||||||
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
||||||
@ -669,7 +401,6 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
self.bias = Parameter(torch.empty(
|
self.bias = Parameter(torch.empty(
|
||||||
self.output_size, device=torch.cuda.current_device(),
|
self.output_size, device=torch.cuda.current_device(),
|
||||||
dtype=params_dtype))
|
dtype=params_dtype))
|
||||||
setattr(self.bias, 'sequence_parallel', sequence_parallel_enabled)
|
|
||||||
|
|
||||||
# Always initialize bias to zero.
|
# Always initialize bias to zero.
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -693,22 +424,11 @@ class RowParallelLinear(torch.nn.Module):
|
|||||||
if self.input_is_parallel:
|
if self.input_is_parallel:
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
else:
|
else:
|
||||||
assert not self.sequence_parallel_enabled
|
|
||||||
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
||||||
# Matrix multiply.
|
# Matrix multiply.
|
||||||
output_parallel = linear_with_grad_accumulation_and_async_allreduce(
|
output_parallel = F.linear(input_parallel, self.weight)
|
||||||
input=input_parallel,
|
|
||||||
weight=self.weight,
|
|
||||||
bias=None,
|
|
||||||
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
|
|
||||||
async_grad_allreduce=False,
|
|
||||||
sequence_parallel_enabled=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# All-reduce across all the partitions.
|
# All-reduce across all the partitions.
|
||||||
if self.sequence_parallel_enabled:
|
|
||||||
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
|
|
||||||
else:
|
|
||||||
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
||||||
if not self.skip_bias_add:
|
if not self.skip_bias_add:
|
||||||
output = output_ + self.bias if self.bias is not None else output_
|
output = output_ + self.bias if self.bias is not None else output_
|
||||||
|
Loading…
x
Reference in New Issue
Block a user