42 lines
1.1 KiB
Python
42 lines
1.1 KiB
Python
![]() |
import random
|
||
|
from typing import Union
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
|
||
|
from cacheflow.model_executor.parallel_utils.parallel_state import model_parallel_is_initialized
|
||
|
from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
|
||
|
|
||
|
|
||
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||
|
'half': torch.half,
|
||
|
'float': torch.float,
|
||
|
'float16': torch.float16,
|
||
|
'float32': torch.float32,
|
||
|
'bfloat16': torch.bfloat16,
|
||
|
}
|
||
|
|
||
|
|
||
|
def get_torch_dtype(dtype: Union[torch.dtype, str]) -> torch.dtype:
|
||
|
if isinstance(dtype, str):
|
||
|
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype.lower()]
|
||
|
else:
|
||
|
torch_dtype = dtype
|
||
|
return torch_dtype
|
||
|
|
||
|
|
||
|
def get_dtype_size(dtype: Union[torch.dtype, str]) -> int:
|
||
|
torch_dtype = get_torch_dtype(dtype)
|
||
|
return torch.tensor([], dtype=torch_dtype).element_size()
|
||
|
|
||
|
|
||
|
def set_random_seed(seed: int) -> None:
|
||
|
random.seed(seed)
|
||
|
np.random.seed(seed)
|
||
|
torch.manual_seed(seed)
|
||
|
if torch.cuda.is_available():
|
||
|
torch.cuda.manual_seed_all(seed)
|
||
|
|
||
|
if model_parallel_is_initialized():
|
||
|
model_parallel_cuda_manual_seed(seed)
|