2023-03-22 04:45:42 +08:00
""" Utility functions used throughout Megatron core """
from functools import reduce
import operator
import torch
2023-05-09 15:30:12 -07:00
from cacheflow . model_executor . parallel_utils import parallel_state
2023-03-22 04:45:42 +08:00
def ensure_divisibility ( numerator , denominator ) :
""" Ensure that numerator is divisible by the denominator. """
assert numerator % denominator == 0 , " {} is not divisible by {} " . format (
numerator , denominator
)
def divide ( numerator , denominator ) :
""" Ensure that numerator is divisible by the denominator and return
the division value . """
ensure_divisibility ( numerator , denominator )
return numerator / / denominator
class GlobalMemoryBuffer :
""" Global buffer to avoid dynamic memory allocations.
Caller should ensure that buffers of the same name
are not used concurrently . """
def __init__ ( self ) :
self . buffer = { }
def get_tensor ( self , tensor_shape , dtype , name ) :
required_len = reduce ( operator . mul , tensor_shape , 1 )
if self . buffer . get ( ( name , dtype ) , None ) is None or \
self . buffer [ ( name , dtype ) ] . numel ( ) < required_len :
self . buffer [ ( name , dtype ) ] = \
torch . empty ( required_len ,
dtype = dtype ,
device = torch . cuda . current_device ( ) ,
requires_grad = False )
return self . buffer [ ( name , dtype ) ] [ 0 : required_len ] . view ( * tensor_shape )
def _kernel_make_viewless_tensor ( inp , requires_grad ) :
''' Make a viewless tensor.
View tensors have the undesirable side - affect of retaining a reference
to the originally - viewed tensor , even after manually setting the ' .data '
field . This method creates a new tensor that links to the old tensor ' s
data , without linking the viewed tensor , referenced via the ' ._base '
field .
'''
out = torch . empty (
( 1 , ) ,
dtype = inp . dtype ,
device = inp . device ,
requires_grad = requires_grad ,
)
out . data = inp . data
return out
class MakeViewlessTensor ( torch . autograd . Function ) :
'''
Autograd function to make a viewless tensor .
This function should be used in cases where the computation graph needs
to be propagated , but we only want a viewless tensor ( e . g . ,
ParallelTransformer ' s hidden_states). Call this function by passing
' keep_graph = True ' to ' make_viewless_tensor() ' .
'''
@staticmethod
def forward ( ctx , inp , requires_grad ) :
return _kernel_make_viewless_tensor ( inp , requires_grad )
@staticmethod
def backward ( ctx , grad_output ) :
return grad_output , None
def make_viewless_tensor ( inp , requires_grad , keep_graph ) :
'''
Entry - point for creating viewless tensors .
This method should be used , rather than calling ' MakeViewlessTensor '
or ' _kernel_make_viewless_tensor ' directly . This method acts as a
switch for determining if an autograd function or a regular method
should be used to create the tensor .
'''
# return tensor as-is, if not a 'view'
if inp . _base is None :
return inp
# create viewless tensor
if keep_graph :
return MakeViewlessTensor . apply ( inp , requires_grad )
else :
return _kernel_make_viewless_tensor ( inp , requires_grad )
def assert_viewless_tensor ( tensor , extra_msg = None ) :
''' Assert that a tensor is not a view (i.e., its ' ._base ' field is
not set ) . '''
if isinstance ( tensor , list ) :
[ assert_viewless_tensor ( t ) for t in tensor ]
return tensor
if not isinstance ( tensor , torch . Tensor ) :
return tensor
assert tensor . _base is None , (
" Ensure tensor._base is None before setting tensor.data or storing "
" tensor to memory buffer. Otherwise, a memory leak will occur (and "
" likely accumulate over iterations). %s "
) % extra_msg
return tensor
def safely_set_viewless_tensor_data ( tensor , new_data_tensor ) :
''' Safely set tensor ' s ' .data ' field.
Check first that the tensor is viewless ( i . e . , ' ._base ' not set ) . If not ,
raise an exception .
'''
assert_viewless_tensor ( tensor , extra_msg = " FYI, tensor._base has shape %s , and new_data_tensor has shape %s . " % ( " -- " if tensor . _base is None else tensor . _base . shape , new_data_tensor . shape ) )
tensor . data = new_data_tensor