[BugFix] fix wrong output when using lora and num_scheduler_steps=8 (#11161)
FIX issue https://github.com/vllm-project/vllm/issues/9688 https://github.com/vllm-project/vllm/issues/11086 #12487 --------- Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Co-authored-by: weilong.yu <weilong.yu@shopee.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
b1340f9d55
commit
cb3e73e4c8
@ -1346,6 +1346,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||||||
|
|
||||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
if self.lora_config:
|
||||||
|
# Remove dummy loras.
|
||||||
|
assert self.lora_manager is not None
|
||||||
|
self.remove_all_loras()
|
||||||
return
|
return
|
||||||
|
|
||||||
def remove_all_loras(self):
|
def remove_all_loras(self):
|
||||||
|
@ -264,10 +264,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
|||||||
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
|
f"{(available_kv_cache_memory / GiB_bytes):.2f}GiB.")
|
||||||
|
|
||||||
logger.info(msg)
|
logger.info(msg)
|
||||||
|
|
||||||
# Final cleanup
|
# Final cleanup
|
||||||
if self.model_runner.lora_manager:
|
|
||||||
self.model_runner.remove_all_loras()
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
return num_gpu_blocks, num_cpu_blocks
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
Loading…
x
Reference in New Issue
Block a user