[Misc] Use torch.Tensor for type annotation (#6505)

This commit is contained in:
Woosuk Kwon 2024-07-17 06:01:10 -07:00 committed by GitHub
parent e09ce759aa
commit a9a2e74d21
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 18 deletions

View File

@ -20,18 +20,18 @@ DEFAULT_TP_SIZES = [1]
# helpers # helpers
def to_fp8(tensor: torch.tensor) -> torch.tensor: def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn)
return torch.round(tensor.clamp( return torch.round(tensor.clamp(
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
def to_int8(tensor: torch.tensor) -> torch.tensor: def to_int8(tensor: torch.Tensor) -> torch.Tensor:
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8) return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
def make_rand_tensors(dtype: torch.dtype, m: int, n: int, def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.tensor, torch.tensor]: k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5 a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5 b = torch.randn((n, k), device='cuda').t() * 5
@ -47,15 +47,15 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
# impl # impl
def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.tensor: out_dtype: torch.dtype) -> torch.Tensor:
return torch.mm(a, b) return torch.mm(a, b)
def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.tensor: out_dtype: torch.dtype) -> torch.Tensor:
return torch._scaled_mm(a, return torch._scaled_mm(a,
b, b,
scale_a=scale_a, scale_a=scale_a,
@ -63,9 +63,9 @@ def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
out_dtype=out_dtype) out_dtype=out_dtype)
def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor, def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor,
scale_a: torch.tensor, scale_b: torch.tensor, scale_a: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.tensor: out_dtype: torch.dtype) -> torch.Tensor:
return torch._scaled_mm(a, return torch._scaled_mm(a,
b, b,
scale_a=scale_a, scale_a=scale_a,
@ -74,15 +74,15 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
use_fast_accum=True) use_fast_accum=True)
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.tensor: out_dtype: torch.dtype) -> torch.Tensor:
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype) return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
# bench # bench
def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor, def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.tensor, out_dtype: torch.dtype, label: str, scale_b: torch.Tensor, out_dtype: torch.dtype, label: str,
sub_label: str, fn: Callable, description: str) -> TMeasurement: sub_label: str, fn: Callable, description: str) -> TMeasurement:
min_run_time = 1 min_run_time = 1

View File

@ -105,7 +105,7 @@ class Worker(LocalOrDistributedWorkerBase):
# initialize_cache. # initialize_cache.
self.cache_engine: List[CacheEngine] self.cache_engine: List[CacheEngine]
# Initialize gpu_cache as embedding models don't initialize kv_caches # Initialize gpu_cache as embedding models don't initialize kv_caches
self.gpu_cache: Optional[List[List[torch.tensor]]] = None self.gpu_cache: Optional[List[List[torch.Tensor]]] = None
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":