720 lines
29 KiB
Python
720 lines
29 KiB
Python
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
|
|
|
# Parts of the code here are adapted from PyTorch
|
|
# repo: https://github.com/pytorch/pytorch
|
|
|
|
import math
|
|
import os
|
|
from typing import Optional
|
|
import warnings
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.nn.init as init
|
|
from torch.nn.parameter import Parameter
|
|
|
|
from cacheflow.parallel_utils.parallel_state import (
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
get_tensor_model_parallel_group,
|
|
get_global_memory_buffer,
|
|
)
|
|
from .mappings import (
|
|
copy_to_tensor_model_parallel_region,
|
|
gather_from_tensor_model_parallel_region,
|
|
gather_from_sequence_parallel_region,
|
|
reduce_from_tensor_model_parallel_region,
|
|
scatter_to_tensor_model_parallel_region,
|
|
reduce_scatter_to_sequence_parallel_region,
|
|
)
|
|
|
|
from .random import get_cuda_rng_tracker
|
|
from .utils import (
|
|
divide,
|
|
split_tensor_along_last_dim,
|
|
VocabUtility,
|
|
)
|
|
|
|
_grad_accum_fusion_available = True
|
|
try:
|
|
import fused_weight_gradient_mlp_cuda
|
|
except ImportError:
|
|
_grad_accum_fusion_available = False
|
|
|
|
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
|
|
'partition_dim': -1,
|
|
'partition_stride': 1}
|
|
|
|
def param_is_not_tensor_parallel_duplicate(param):
|
|
return (hasattr(param, 'tensor_model_parallel') and
|
|
param.tensor_model_parallel) or (
|
|
get_tensor_model_parallel_rank() == 0)
|
|
|
|
|
|
def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):
|
|
# Make sure the attributes are not set.
|
|
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
|
assert not hasattr(tensor, attribute)
|
|
# Set the attributes.
|
|
setattr(tensor, 'tensor_model_parallel', is_parallel)
|
|
setattr(tensor, 'partition_dim', dim)
|
|
setattr(tensor, 'partition_stride', stride)
|
|
|
|
|
|
def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):
|
|
def maybe_set(attribute, value):
|
|
if not hasattr(tensor, attribute):
|
|
setattr(tensor, attribute, value)
|
|
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
|
maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])
|
|
|
|
|
|
def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):
|
|
def maybe_copy(attribute):
|
|
if hasattr(source_tensor, attribute):
|
|
setattr(destination_tensor, attribute,
|
|
getattr(source_tensor, attribute))
|
|
for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:
|
|
maybe_copy(attribute)
|
|
|
|
|
|
def _initialize_affine_weight_gpu(weight, init_method,
|
|
partition_dim, stride=1):
|
|
"""Initialize affine weight for model parallel on GPU."""
|
|
|
|
set_tensor_model_parallel_attributes(tensor=weight,
|
|
is_parallel=True,
|
|
dim=partition_dim,
|
|
stride=stride)
|
|
|
|
with get_cuda_rng_tracker().fork():
|
|
init_method(weight)
|
|
|
|
|
|
def _initialize_affine_weight_cpu(weight, output_size, input_size,
|
|
per_partition_size, partition_dim,
|
|
init_method, stride=1,
|
|
return_master_weight=False,
|
|
*, params_dtype=None):
|
|
"""Initialize affine weight for model parallel.
|
|
|
|
Build the master weight on all processes and scatter
|
|
the relevant chunk."""
|
|
|
|
set_tensor_model_parallel_attributes(tensor=weight,
|
|
is_parallel=True,
|
|
dim=partition_dim,
|
|
stride=stride)
|
|
|
|
if params_dtype is None:
|
|
params_dtype = torch.get_default_dtype()
|
|
|
|
# Initialize master weight
|
|
master_weight = torch.empty(output_size, input_size,
|
|
dtype=torch.float,
|
|
requires_grad=False)
|
|
init_method(master_weight)
|
|
master_weight = master_weight.to(dtype=params_dtype)
|
|
|
|
# Split and copy
|
|
per_partition_per_stride_size = divide(per_partition_size, stride)
|
|
weight_list = torch.split(master_weight, per_partition_per_stride_size,
|
|
dim=partition_dim)
|
|
rank = get_tensor_model_parallel_rank()
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
my_weight_list = weight_list[rank::world_size]
|
|
|
|
with torch.no_grad():
|
|
torch.cat(my_weight_list, dim=partition_dim, out=weight)
|
|
if return_master_weight:
|
|
return master_weight
|
|
return None
|
|
|
|
|
|
class VocabParallelEmbedding(torch.nn.Module):
|
|
"""Embedding parallelized in the vocabulary dimension.
|
|
|
|
This is mainly adapted from torch.nn.Embedding and all the default
|
|
values are kept.
|
|
Arguments:
|
|
num_embeddings: vocabulary size.
|
|
embedding_dim: size of hidden state.
|
|
|
|
Keyword Arguments:
|
|
init_method: method to initialize weights.
|
|
params_dtype
|
|
use_cpu_initialization
|
|
perform_initialization
|
|
"""
|
|
|
|
def __init__(self, num_embeddings: int, embedding_dim: int, *,
|
|
init_method=init.xavier_normal_,
|
|
params_dtype: torch.dtype=None,
|
|
use_cpu_initialization: bool=False,
|
|
perform_initialization: bool=True):
|
|
super(VocabParallelEmbedding, self).__init__()
|
|
# Keep the input dimensions.
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
if params_dtype is None:
|
|
params_dtype = torch.get_default_dtype()
|
|
|
|
# Set the defaults for compatibility.
|
|
self.padding_idx = None
|
|
self.max_norm = None
|
|
self.norm_type = 2.
|
|
self.scale_grad_by_freq = False
|
|
self.sparse = False
|
|
self._weight = None
|
|
self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
|
|
# Divide the weight matrix along the vocaburaly dimension.
|
|
self.vocab_start_index, self.vocab_end_index = \
|
|
VocabUtility.vocab_range_from_global_vocab_size(
|
|
self.num_embeddings, get_tensor_model_parallel_rank(),
|
|
self.tensor_model_parallel_size)
|
|
self.num_embeddings_per_partition = self.vocab_end_index - \
|
|
self.vocab_start_index
|
|
|
|
# Allocate weights and initialize.
|
|
if use_cpu_initialization:
|
|
self.weight = Parameter(torch.empty(
|
|
self.num_embeddings_per_partition, self.embedding_dim,
|
|
dtype=params_dtype))
|
|
if perform_initialization:
|
|
_initialize_affine_weight_cpu(
|
|
self.weight, self.num_embeddings, self.embedding_dim,
|
|
self.num_embeddings_per_partition, 0, init_method,
|
|
params_dtype=params_dtype)
|
|
else:
|
|
self.weight = Parameter(torch.empty(
|
|
self.num_embeddings_per_partition, self.embedding_dim,
|
|
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
if perform_initialization:
|
|
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
partition_dim=0, stride=1)
|
|
|
|
def forward(self, input_):
|
|
if self.tensor_model_parallel_size > 1:
|
|
# Build the mask.
|
|
input_mask = (input_ < self.vocab_start_index) | \
|
|
(input_ >= self.vocab_end_index)
|
|
# Mask the input.
|
|
masked_input = input_.clone() - self.vocab_start_index
|
|
masked_input[input_mask] = 0
|
|
else:
|
|
masked_input = input_
|
|
# Get the embeddings.
|
|
output_parallel = F.embedding(masked_input, self.weight,
|
|
self.padding_idx, self.max_norm,
|
|
self.norm_type, self.scale_grad_by_freq,
|
|
self.sparse)
|
|
# Mask the output embedding.
|
|
if self.tensor_model_parallel_size > 1:
|
|
output_parallel[input_mask, :] = 0.0
|
|
# Reduce across all the model parallel GPUs.
|
|
output = reduce_from_tensor_model_parallel_region(output_parallel)
|
|
return output
|
|
|
|
|
|
class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
|
|
"""See linear_with_grad_accumulation_and_async_allreduce"""
|
|
|
|
@staticmethod
|
|
def forward(ctx, input, weight, bias, gradient_accumulation_fusion,
|
|
async_grad_allreduce, sequence_parallel):
|
|
ctx.save_for_backward(input, weight)
|
|
ctx.use_bias = bias is not None
|
|
ctx.gradient_accumulation_fusion = gradient_accumulation_fusion
|
|
ctx.async_grad_allreduce = async_grad_allreduce
|
|
ctx.sequence_parallel = sequence_parallel
|
|
|
|
if sequence_parallel:
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
dim_size = list(input.size())
|
|
dim_size[0] = dim_size[0] * world_size
|
|
|
|
all_gather_buffer = \
|
|
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
|
|
torch.distributed._all_gather_base(
|
|
all_gather_buffer,
|
|
input,
|
|
group=get_tensor_model_parallel_group())
|
|
total_input = all_gather_buffer
|
|
else:
|
|
total_input = input
|
|
|
|
output = torch.matmul(total_input, weight.t())
|
|
if bias is not None:
|
|
output = output + bias
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
input, weight = ctx.saved_tensors
|
|
use_bias = ctx.use_bias
|
|
|
|
if ctx.sequence_parallel:
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
dim_size = list(input.size())
|
|
dim_size[0] = dim_size[0] * world_size
|
|
|
|
all_gather_buffer = \
|
|
get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu")
|
|
handle = torch.distributed._all_gather_base(
|
|
all_gather_buffer,
|
|
input,
|
|
group=get_tensor_model_parallel_group(), async_op=True)
|
|
|
|
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
|
# gather is scheduled before the input gradient computation
|
|
total_input = all_gather_buffer
|
|
else:
|
|
total_input = input
|
|
grad_input = grad_output.matmul(weight)
|
|
|
|
if ctx.sequence_parallel:
|
|
handle.wait()
|
|
|
|
# Convert the tensor shapes to 2D for execution compatibility
|
|
grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1],
|
|
grad_output.shape[2])
|
|
total_input = total_input.view(total_input.shape[0] * total_input.shape[1],
|
|
total_input.shape[2])
|
|
|
|
if ctx.async_grad_allreduce:
|
|
# Asynchronous all-reduce
|
|
handle = torch.distributed.all_reduce(
|
|
grad_input, group=get_tensor_model_parallel_group(), async_op=True)
|
|
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
|
# all-reduce is scheduled before the weight gradient computation
|
|
|
|
if ctx.sequence_parallel:
|
|
assert not ctx.async_grad_allreduce
|
|
dim_size = list(input.size())
|
|
sub_grad_input = torch.empty(dim_size, dtype=input.dtype,
|
|
device=torch.cuda.current_device(),
|
|
requires_grad=False)
|
|
# reduce_scatter
|
|
handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input,
|
|
group=get_tensor_model_parallel_group(),
|
|
async_op=True)
|
|
# Here we rely on CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that the
|
|
# reduce scatter is scheduled before the weight gradient computation
|
|
|
|
|
|
if ctx.gradient_accumulation_fusion:
|
|
if weight.main_grad.dtype == torch.float32:
|
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad)
|
|
elif weight.main_grad.dtype == torch.float16:
|
|
fused_weight_gradient_mlp_cuda.wgrad_gemm_accum_fp16(total_input, grad_output, weight.main_grad)
|
|
else:
|
|
raise RuntimeError("Unsupported gradient type for gradient accumulation fusion")
|
|
grad_weight = None
|
|
else:
|
|
grad_weight = grad_output.t().matmul(total_input)
|
|
grad_bias = grad_output.sum(dim=0) if use_bias else None
|
|
|
|
if ctx.sequence_parallel:
|
|
handle.wait()
|
|
return sub_grad_input, grad_weight, grad_bias, None, None, None
|
|
|
|
if ctx.async_grad_allreduce:
|
|
handle.wait()
|
|
|
|
return grad_input, grad_weight, grad_bias, None, None, None
|
|
|
|
def linear_with_grad_accumulation_and_async_allreduce(
|
|
input: torch.Tensor,
|
|
weight: torch.Tensor,
|
|
bias: Optional[torch.Tensor],
|
|
gradient_accumulation_fusion: bool,
|
|
async_grad_allreduce: bool,
|
|
sequence_parallel_enabled: bool,
|
|
) -> torch.Tensor:
|
|
"""Linear layer execution with asynchronous communication and
|
|
gradient accumulation fusion in backprop.
|
|
|
|
This has the option to accumulate the result of backprop
|
|
calculation into an existing gradient buffer, preventing the need
|
|
to do an additional addition kernel after the gradient
|
|
calculation.
|
|
|
|
Additionally, the tensor parallel all reduce of the input
|
|
gradients can be done asynchronously with the calculation of
|
|
the weight gradients.
|
|
|
|
In the case of sequence parallelism, the reduce scatter of the
|
|
input gradients is done asynchronously with the calcluation of the
|
|
weight gradients.
|
|
|
|
Use of this module requires that the environment variable
|
|
CUDA_DEVICE_MAX_CONNECTIONS=1. There are a few collective
|
|
operations, noted in the code, that should be scheduled before
|
|
compute kernels to overlap the communication with the computation,
|
|
which is necessary for a speedup but not for correctness so that
|
|
ordering isn't imposed by the scheduler. Setting
|
|
CUDA_DEVICE_MAX_CONNECTIONS=1 forces the kernels to be scheduled
|
|
in the order they are called.
|
|
|
|
Arguments:
|
|
|
|
input (torch.Tensor required): input like torch.nn.functional.linear
|
|
|
|
weight (torch.Tensor required): weight like torch.nn.functional.linear
|
|
|
|
bias (torch.Tensor optional): bias like torch.nn.functional.linear
|
|
|
|
gradient_accumulation_fusion (bool required): Perform the gradient
|
|
accumulation fusion, requires the custom CUDA extension
|
|
fused_weight_gradient_mlp_cuda module. To use
|
|
gradient_accumulation_fusion you must install APEX with
|
|
--cpp_ext and --cuda_ext. For example: "pip install
|
|
--global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\"
|
|
" Note that the extension requires CUDA>=11. Otherwise, you
|
|
must turn off gradient accumulation fusion."
|
|
|
|
async_grad_allreduce (bool required): Do the allreduce of input
|
|
gradients asyncronously with the computation of weight
|
|
gradients. If sequence_parallel_enabled is True, this must be
|
|
False, as no all reduce is performed.
|
|
|
|
sequence_parallel_enabled (bool required): Indicates that sequence
|
|
parallelism is used and thus in the forward pass the input is
|
|
all gathered, and the backward pass the input gradients are
|
|
reduce scattered.
|
|
"""
|
|
args = [
|
|
input,
|
|
weight,
|
|
bias,
|
|
gradient_accumulation_fusion,
|
|
async_grad_allreduce,
|
|
sequence_parallel_enabled,
|
|
]
|
|
|
|
if not linear_with_grad_accumulation_and_async_allreduce.warned:
|
|
if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
|
|
if sequence_parallel_enabled:
|
|
warnings.warn(
|
|
"When using sequence parallelism it is recommended to set the "
|
|
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
|
|
"maximum speedup")
|
|
linear_with_grad_accumulation_and_async_allreduce.warned = True
|
|
|
|
if async_grad_allreduce:
|
|
warnings.warn(
|
|
"When using async grad allreduce it is recommended to set the "
|
|
"environment variable CUDA_DEVICE_MAX_CONNECTIONS to 1 for "
|
|
"maximum speedup")
|
|
linear_with_grad_accumulation_and_async_allreduce.warned = True
|
|
|
|
with torch.cuda.amp.autocast(enabled=False):
|
|
return LinearWithGradAccumulationAndAsyncCommunication.apply(*args)
|
|
linear_with_grad_accumulation_and_async_allreduce.warned = False
|
|
|
|
class ColumnParallelLinear(torch.nn.Module):
|
|
"""Linear layer with column parallelism.
|
|
|
|
The linear layer is defined as Y = XA + b. A is parallelized along
|
|
its second dimension as A = [A_1, ..., A_p].
|
|
|
|
Arguments:
|
|
input_size: first dimension of matrix A.
|
|
output_size: second dimension of matrix A.
|
|
|
|
Keyword Arguments
|
|
bias: If true, add bias
|
|
gather_output: If true, call all-gather on output and make Y available
|
|
to all GPUs, otherwise, every GPU will have its output
|
|
which is Y_i = XA_i
|
|
init_method: method to initialize weights. Note that bias is always set
|
|
to zero.
|
|
stride: For the strided linear layers.
|
|
keep_master_weight_for_test: This was added for testing and should be
|
|
set to False. It returns the master weights
|
|
used for initialization.
|
|
skip_bias_add: This was added to enable performance optimations where bias
|
|
can be fused with other elementwise operations. we skip
|
|
adding bias but instead return it.
|
|
async_tensor_model_parallel_allreduce:
|
|
params_dtype:
|
|
use_cpu_initialization:
|
|
gradient_accumulation_fusion:
|
|
sequence_parallel_enabled:
|
|
"""
|
|
|
|
def __init__(self, input_size, output_size, *,
|
|
bias=True, gather_output=True,
|
|
init_method=init.xavier_normal_, stride=1,
|
|
keep_master_weight_for_test=False,
|
|
skip_bias_add=False,
|
|
async_tensor_model_parallel_allreduce=True,
|
|
params_dtype=None,
|
|
use_cpu_initialization=False,
|
|
perform_initialization=True,
|
|
gradient_accumulation_fusion=False,
|
|
sequence_parallel_enabled: bool = False,
|
|
):
|
|
super(ColumnParallelLinear, self).__init__()
|
|
|
|
# Keep input parameters
|
|
self.input_size = input_size
|
|
self.output_size = output_size
|
|
self.gather_output = gather_output
|
|
# Divide the weight matrix along the last dimension.
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
self.output_size_per_partition = divide(output_size, world_size)
|
|
self.skip_bias_add = skip_bias_add
|
|
|
|
if params_dtype is None:
|
|
params_dtype = torch.get_default_dtype()
|
|
|
|
# Parameters.
|
|
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
|
# we allocate the transpose.
|
|
# Initialize weight.
|
|
if use_cpu_initialization:
|
|
self.weight = Parameter(torch.empty(self.output_size_per_partition,
|
|
self.input_size,
|
|
dtype=params_dtype))
|
|
if perform_initialization:
|
|
self.master_weight = _initialize_affine_weight_cpu(
|
|
self.weight, self.output_size, self.input_size,
|
|
self.output_size_per_partition, 0, init_method,
|
|
stride=stride, return_master_weight=keep_master_weight_for_test)
|
|
else:
|
|
self.weight = Parameter(torch.empty(
|
|
self.output_size_per_partition, self.input_size,
|
|
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
if perform_initialization:
|
|
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
partition_dim=0, stride=stride)
|
|
|
|
if bias:
|
|
if use_cpu_initialization:
|
|
self.bias = Parameter(torch.empty(
|
|
self.output_size_per_partition, dtype=params_dtype))
|
|
else:
|
|
self.bias = Parameter(torch.empty(
|
|
self.output_size_per_partition,
|
|
device=torch.cuda.current_device(),
|
|
dtype=params_dtype))
|
|
set_tensor_model_parallel_attributes(self.bias, True, 0, stride)
|
|
# Always initialize bias to zero.
|
|
with torch.no_grad():
|
|
self.bias.zero_()
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
|
|
self.async_tensor_model_parallel_allreduce = (
|
|
async_tensor_model_parallel_allreduce and
|
|
world_size > 1)
|
|
if sequence_parallel_enabled:
|
|
if world_size <= 1:
|
|
warnings.warn(
|
|
f"`sequence_parallel_enabled` is set to `True`, but tensor model parallel size is {world_size}. "
|
|
f"Disabling sequence parallel."
|
|
)
|
|
sequence_parallel_enabled = False
|
|
self.sequence_parallel_enabled = sequence_parallel_enabled
|
|
|
|
if gradient_accumulation_fusion:
|
|
if not _grad_accum_fusion_available:
|
|
raise RuntimeError(
|
|
"ColumnParallelLinear was called with gradient_accumulation_fusion set "
|
|
"to True but the custom CUDA extension fused_weight_gradient_mlp_cuda "
|
|
"module is not found. To use gradient_accumulation_fusion you must "
|
|
"install APEX with --cpp_ext and --cuda_ext. For example: "
|
|
"pip install --global-option=\"--cpp_ext\" --global-option=\"--cuda_ext .\" "
|
|
"Note that the extension requires CUDA>=11. Otherwise, you must turn off "
|
|
"gradient accumulation fusion."
|
|
)
|
|
self.gradient_accumulation_fusion = gradient_accumulation_fusion
|
|
|
|
if self.async_tensor_model_parallel_allreduce and self.sequence_parallel_enabled:
|
|
raise RuntimeError(
|
|
"`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` "
|
|
"cannot be enabled at the same time."
|
|
)
|
|
|
|
|
|
def forward(self, input_):
|
|
"""Forward of ColumnParallelLinear
|
|
|
|
Args:
|
|
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
|
|
|
|
Returns:
|
|
- output
|
|
- bias
|
|
"""
|
|
bias = self.bias if not self.skip_bias_add else None
|
|
|
|
if self.async_tensor_model_parallel_allreduce or \
|
|
self.sequence_parallel_enabled:
|
|
input_parallel = input_
|
|
else:
|
|
input_parallel = copy_to_tensor_model_parallel_region(input_)
|
|
# Matrix multiply.
|
|
output_parallel = linear_with_grad_accumulation_and_async_allreduce(
|
|
input=input_parallel,
|
|
weight=self.weight,
|
|
bias=bias,
|
|
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
|
|
async_grad_allreduce=self.async_tensor_model_parallel_allreduce,
|
|
sequence_parallel_enabled=self.sequence_parallel_enabled,
|
|
)
|
|
if self.gather_output:
|
|
# All-gather across the partitions.
|
|
assert not self.sequence_parallel_enabled
|
|
output = gather_from_tensor_model_parallel_region(output_parallel)
|
|
else:
|
|
output = output_parallel
|
|
output_bias = self.bias if self.skip_bias_add else None
|
|
return output, output_bias
|
|
|
|
|
|
class RowParallelLinear(torch.nn.Module):
|
|
"""Linear layer with row parallelism.
|
|
|
|
The linear layer is defined as Y = XA + b. A is parallelized along
|
|
its first dimension and X along its second dimension as:
|
|
- -
|
|
| A_1 |
|
|
| . |
|
|
A = | . | X = [X_1, ..., X_p]
|
|
| . |
|
|
| A_p |
|
|
- -
|
|
Arguments:
|
|
input_size: first dimension of matrix A.
|
|
output_size: second dimension of matrix A.
|
|
|
|
Keyword Arguments:
|
|
bias: If true, add bias. Note that bias is not parallelized.
|
|
input_is_parallel: If true, we assume that the input is already
|
|
split across the GPUs and we do not split
|
|
again.
|
|
init_method: method to initialize weights. Note that bias is always set
|
|
to zero.
|
|
stride: For the strided linear layers.
|
|
keep_master_weight_for_test: This was added for testing and should be
|
|
set to False. It returns the master weights
|
|
used for initialization.
|
|
skip_bias_add: This was added to enable performance optimization where bias
|
|
can be fused with other elementwise operations. We skip
|
|
adding bias but instead return it.
|
|
params_dtype:
|
|
use_cpu_initialization:
|
|
perform_initialization:
|
|
gradient_accumulation_fusion:
|
|
sequence_parallel_enabled:
|
|
"""
|
|
|
|
def __init__(self, input_size, output_size, *,
|
|
bias=True, input_is_parallel=False,
|
|
init_method=init.xavier_normal_, stride=1,
|
|
keep_master_weight_for_test=False,
|
|
skip_bias_add=False,
|
|
params_dtype=None,
|
|
use_cpu_initialization=False,
|
|
perform_initialization=True,
|
|
gradient_accumulation_fusion=False,
|
|
sequence_parallel_enabled: bool = False,
|
|
):
|
|
super(RowParallelLinear, self).__init__()
|
|
|
|
# Keep input parameters
|
|
self.input_size = input_size
|
|
self.output_size = output_size
|
|
self.input_is_parallel = input_is_parallel
|
|
if params_dtype is None:
|
|
params_dtype = torch.get_default_dtype()
|
|
|
|
# Divide the weight matrix along the last dimension.
|
|
world_size = get_tensor_model_parallel_world_size()
|
|
self.input_size_per_partition = divide(input_size, world_size)
|
|
self.skip_bias_add = skip_bias_add
|
|
self.gradient_accumulation_fusion = gradient_accumulation_fusion
|
|
self.sequence_parallel_enabled = sequence_parallel_enabled
|
|
if self.sequence_parallel_enabled and not self.input_is_parallel:
|
|
raise RuntimeError("To enable `sequence_parallel_enabled`, `input_is_parallel` must be `True`")
|
|
|
|
# Parameters.
|
|
# Note: torch.nn.functional.linear performs XA^T + b and as a result
|
|
# we allocate the transpose.
|
|
# Initialize weight.
|
|
if use_cpu_initialization:
|
|
self.weight = Parameter(torch.empty(self.output_size,
|
|
self.input_size_per_partition,
|
|
dtype=params_dtype))
|
|
if perform_initialization:
|
|
self.master_weight = _initialize_affine_weight_cpu(
|
|
self.weight, self.output_size, self.input_size,
|
|
self.input_size_per_partition, 1, init_method,
|
|
stride=stride, return_master_weight=keep_master_weight_for_test,
|
|
params_dtype=params_dtype)
|
|
else:
|
|
self.weight = Parameter(torch.empty(
|
|
self.output_size, self.input_size_per_partition,
|
|
device=torch.cuda.current_device(), dtype=params_dtype))
|
|
if perform_initialization:
|
|
_initialize_affine_weight_gpu(self.weight, init_method,
|
|
partition_dim=1, stride=stride)
|
|
if bias:
|
|
if use_cpu_initialization:
|
|
self.bias = Parameter(torch.empty(self.output_size,
|
|
dtype=params_dtype))
|
|
else:
|
|
self.bias = Parameter(torch.empty(
|
|
self.output_size, device=torch.cuda.current_device(),
|
|
dtype=params_dtype))
|
|
setattr(self.bias, 'sequence_parallel', sequence_parallel_enabled)
|
|
|
|
# Always initialize bias to zero.
|
|
with torch.no_grad():
|
|
self.bias.zero_()
|
|
else:
|
|
self.register_parameter('bias', None)
|
|
|
|
|
|
|
|
def forward(self, input_):
|
|
"""Forward of RowParallelLinear
|
|
|
|
Args:
|
|
input_: 3D tensor whose order of dimension is [sequence, batch, hidden]
|
|
|
|
Returns:
|
|
- output
|
|
- bias
|
|
"""
|
|
# Set up backprop all-reduce.
|
|
if self.input_is_parallel:
|
|
input_parallel = input_
|
|
else:
|
|
assert not self.sequence_parallel_enabled
|
|
input_parallel = scatter_to_tensor_model_parallel_region(input_)
|
|
# Matrix multiply.
|
|
output_parallel = linear_with_grad_accumulation_and_async_allreduce(
|
|
input=input_parallel,
|
|
weight=self.weight,
|
|
bias=None,
|
|
gradient_accumulation_fusion=self.gradient_accumulation_fusion,
|
|
async_grad_allreduce=False,
|
|
sequence_parallel_enabled=False,
|
|
)
|
|
|
|
# All-reduce across all the partitions.
|
|
if self.sequence_parallel_enabled:
|
|
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
|
|
else:
|
|
output_ = reduce_from_tensor_model_parallel_region(output_parallel)
|
|
if not self.skip_bias_add:
|
|
output = output_ + self.bias if self.bias is not None else output_
|
|
output_bias = None
|
|
else:
|
|
output = output_
|
|
output_bias = self.bias
|
|
return output, output_bias
|