[BugFix] Fix GC bug for LLM
class (#2882)
This commit is contained in:
parent
31348dff03
commit
d7afab6d3a
@ -4,6 +4,10 @@ It should include tests that are reported by users and making sure they
|
||||
will never happen again.
|
||||
|
||||
"""
|
||||
import gc
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
@ -35,6 +39,20 @@ def test_max_tokens_none():
|
||||
assert len(prompts) == len(outputs)
|
||||
|
||||
|
||||
def test_gc():
|
||||
llm = LLM("facebook/opt-125m", enforce_eager=True)
|
||||
del llm
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# The memory allocated for model and KV cache should be released.
|
||||
# The memory allocated for PyTorch and others should be less than 50MB.
|
||||
# Usually, it's around 10MB.
|
||||
allocated = torch.cuda.memory_allocated()
|
||||
assert allocated < 50 * 1024 * 1024
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__])
|
||||
|
@ -4,23 +4,26 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
import_exc = None
|
||||
|
||||
try:
|
||||
import vllm._punica_C as punica_kernels
|
||||
except ImportError as e:
|
||||
import_exc = e
|
||||
def _raise_import_error(e):
|
||||
if torch.cuda.get_device_capability() < (8, 0):
|
||||
raise ImportError(
|
||||
"punica LoRA kernels require compute capability >= 8.0") from e
|
||||
else:
|
||||
raise ImportError(
|
||||
"punica LoRA kernels could not be imported. If you built vLLM "
|
||||
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
|
||||
"was set.") from e
|
||||
|
||||
if import_exc is None:
|
||||
|
||||
def bgmv(
|
||||
def bgmv(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
scale: float,
|
||||
):
|
||||
):
|
||||
"""
|
||||
Semantics:
|
||||
y[i] += (
|
||||
@ -38,9 +41,15 @@ if import_exc is None:
|
||||
layer_idx: Layer index of the weight matrices.
|
||||
scale: Scaling factor.
|
||||
"""
|
||||
try:
|
||||
import vllm._punica_C as punica_kernels
|
||||
except ImportError as e:
|
||||
_raise_import_error(e)
|
||||
|
||||
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
|
||||
|
||||
def add_lora(y: torch.Tensor,
|
||||
|
||||
def add_lora(y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
wb_t_all: torch.Tensor,
|
||||
@ -70,6 +79,11 @@ if import_exc is None:
|
||||
scale: Scaling factor.
|
||||
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
|
||||
"""
|
||||
try:
|
||||
import vllm._punica_C as punica_kernels
|
||||
except ImportError as e:
|
||||
_raise_import_error(e)
|
||||
|
||||
r = wb_t_all.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default to avoid
|
||||
@ -78,12 +92,12 @@ if import_exc is None:
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx,
|
||||
1.0)
|
||||
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
|
||||
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
|
||||
scale)
|
||||
|
||||
def add_lora_slice(y: torch.Tensor,
|
||||
|
||||
def add_lora_slice(y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
wb_t_all: torch.Tensor,
|
||||
@ -119,6 +133,11 @@ if import_exc is None:
|
||||
y_offset: Offset to apply to the starting column of y.
|
||||
y_slice_size: Size of the y column slice.
|
||||
"""
|
||||
try:
|
||||
import vllm._punica_C as punica_kernels
|
||||
except ImportError as e:
|
||||
_raise_import_error(e)
|
||||
|
||||
r = wb_t_all.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default to avoid
|
||||
@ -149,28 +168,3 @@ if import_exc is None:
|
||||
y_slice_size,
|
||||
y_offset,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
def _raise_exc(
|
||||
*args, # pylint: disable=unused-argument
|
||||
**kwargs # pylint: disable=unused-argument
|
||||
):
|
||||
if torch.cuda.get_device_capability() < (8, 0):
|
||||
raise ImportError("punica LoRA kernels require compute "
|
||||
"capability>=8.0") from import_exc
|
||||
else:
|
||||
raise ImportError(
|
||||
"punica LoRA kernels could not be imported. If you built vLLM "
|
||||
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
|
||||
"was set.") from import_exc
|
||||
|
||||
bgmv = _raise_exc
|
||||
add_lora = _raise_exc
|
||||
add_lora_slice = _raise_exc
|
||||
|
||||
__all__ = [
|
||||
"bgmv",
|
||||
"add_lora",
|
||||
"add_lora_slice",
|
||||
]
|
||||
|
Loading…
x
Reference in New Issue
Block a user