[tpu][misc] fix typo (#8260)
This commit is contained in:
parent
795b662cff
commit
ce2702a923
@ -2,7 +2,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
|
|
||||||
|
|
||||||
class MyMod(torch.nn.Module):
|
class MyMod(torch.nn.Module):
|
||||||
@ -13,7 +13,7 @@ class MyMod(torch.nn.Module):
|
|||||||
return x * 2
|
return x * 2
|
||||||
|
|
||||||
|
|
||||||
class MyWrapper(TorchCompileWrapperWithCustomDispacther):
|
class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
|
||||||
|
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
@ -10,7 +10,7 @@ import torch
|
|||||||
import vllm.envs as envs
|
import vllm.envs as envs
|
||||||
|
|
||||||
|
|
||||||
class TorchCompileWrapperWithCustomDispacther:
|
class TorchCompileWrapperWithCustomDispatcher:
|
||||||
"""
|
"""
|
||||||
A wrapper class for torch.compile, with a custom dispatch logic.
|
A wrapper class for torch.compile, with a custom dispatch logic.
|
||||||
Subclasses should:
|
Subclasses should:
|
||||||
|
@ -11,7 +11,7 @@ import torch_xla.core.xla_model as xm
|
|||||||
import torch_xla.runtime as xr
|
import torch_xla.runtime as xr
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispacther
|
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig,
|
||||||
ParallelConfig, SchedulerConfig)
|
ParallelConfig, SchedulerConfig)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -611,7 +611,7 @@ class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
|||||||
return [SamplerOutput(sampler_outputs)]
|
return [SamplerOutput(sampler_outputs)]
|
||||||
|
|
||||||
|
|
||||||
class ModelWrapper(TorchCompileWrapperWithCustomDispacther):
|
class ModelWrapper(TorchCompileWrapperWithCustomDispatcher):
|
||||||
|
|
||||||
def __init__(self, model: nn.Module):
|
def __init__(self, model: nn.Module):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user