[Bugfix] fix automatic prefix args and add log info (#3608)

This commit is contained in:
TianYu GUO 2024-03-25 20:35:22 +08:00 committed by GitHub
parent 925f3332ca
commit e67c295b0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 1 deletions

View File

@ -9,6 +9,9 @@ from vllm.block import BlockTable, PhysicalTokenBlock
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device from vllm.utils import Device
from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor from vllm.core.evictor import Evictor, EvictionPolicy, make_evictor
from vllm.logger import init_logger
logger = init_logger(__name__)
class BlockAllocatorBase(ABC): class BlockAllocatorBase(ABC):
@ -241,11 +244,13 @@ class BlockSpaceManager:
self.watermark_blocks = int(watermark * num_gpu_blocks) self.watermark_blocks = int(watermark * num_gpu_blocks)
if self.enable_caching: if self.enable_caching:
logger.info("enable automatic prefix caching")
self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size, self.gpu_allocator = CachedBlockAllocator(Device.GPU, block_size,
num_gpu_blocks) num_gpu_blocks)
self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size, self.cpu_allocator = CachedBlockAllocator(Device.CPU, block_size,
num_cpu_blocks) num_cpu_blocks)
else: else:
logger.info("disable automatic prefix caching")
self.gpu_allocator = UncachedBlockAllocator( self.gpu_allocator = UncachedBlockAllocator(
Device.GPU, block_size, num_gpu_blocks) Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = UncachedBlockAllocator( self.cpu_allocator = UncachedBlockAllocator(

View File

@ -337,7 +337,8 @@ class EngineArgs:
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space, self.kv_cache_dtype, self.swap_space, self.kv_cache_dtype,
model_config.get_sliding_window()) model_config.get_sliding_window(),
self.enable_prefix_caching)
parallel_config = ParallelConfig( parallel_config = ParallelConfig(
self.pipeline_parallel_size, self.tensor_parallel_size, self.pipeline_parallel_size, self.tensor_parallel_size,
self.worker_use_ray, self.max_parallel_loading_workers, self.worker_use_ray, self.max_parallel_loading_workers,