[BugFix] Fix GC bug for LLM class (#2882)

This commit is contained in:
Woosuk Kwon 2024-02-14 22:17:44 -08:00 committed by GitHub
parent 31348dff03
commit d7afab6d3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 169 additions and 157 deletions

View File

@ -4,6 +4,10 @@ It should include tests that are reported by users and making sure they
will never happen again. will never happen again.
""" """
import gc
import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
@ -35,6 +39,20 @@ def test_max_tokens_none():
assert len(prompts) == len(outputs) assert len(prompts) == len(outputs)
def test_gc():
llm = LLM("facebook/opt-125m", enforce_eager=True)
del llm
gc.collect()
torch.cuda.empty_cache()
# The memory allocated for model and KV cache should be released.
# The memory allocated for PyTorch and others should be less than 50MB.
# Usually, it's around 10MB.
allocated = torch.cuda.memory_allocated()
assert allocated < 50 * 1024 * 1024
if __name__ == "__main__": if __name__ == "__main__":
import pytest import pytest
pytest.main([__file__]) pytest.main([__file__])

View File

@ -4,23 +4,26 @@ from typing import Optional
import torch import torch
import_exc = None
try: def _raise_import_error(e):
import vllm._punica_C as punica_kernels if torch.cuda.get_device_capability() < (8, 0):
except ImportError as e: raise ImportError(
import_exc = e "punica LoRA kernels require compute capability >= 8.0") from e
else:
raise ImportError(
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set.") from e
if import_exc is None:
def bgmv( def bgmv(
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, w_t_all: torch.Tensor,
indicies: torch.LongTensor, indicies: torch.LongTensor,
layer_idx: int, layer_idx: int,
scale: float, scale: float,
): ):
""" """
Semantics: Semantics:
y[i] += ( y[i] += (
@ -38,9 +41,15 @@ if import_exc is None:
layer_idx: Layer index of the weight matrices. layer_idx: Layer index of the weight matrices.
scale: Scaling factor. scale: Scaling factor.
""" """
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
def add_lora(y: torch.Tensor,
def add_lora(y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
wa_t_all: torch.Tensor, wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor, wb_t_all: torch.Tensor,
@ -70,6 +79,11 @@ if import_exc is None:
scale: Scaling factor. scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer. buffer: Optional. Shape: `[B, R]`. Temporary buffer.
""" """
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
r = wb_t_all.size(-1) r = wb_t_all.size(-1)
if buffer is None: if buffer is None:
# We set the buffer to be float32 by default to avoid # We set the buffer to be float32 by default to avoid
@ -78,12 +92,12 @@ if import_exc is None:
buffer = torch.zeros((x.size(0), r), buffer = torch.zeros((x.size(0), r),
dtype=torch.float32, dtype=torch.float32,
device=x.device) device=x.device)
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
1.0)
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
scale) scale)
def add_lora_slice(y: torch.Tensor,
def add_lora_slice(y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
wa_t_all: torch.Tensor, wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor, wb_t_all: torch.Tensor,
@ -119,6 +133,11 @@ if import_exc is None:
y_offset: Offset to apply to the starting column of y. y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice. y_slice_size: Size of the y column slice.
""" """
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
r = wb_t_all.size(-1) r = wb_t_all.size(-1)
if buffer is None: if buffer is None:
# We set the buffer to be float32 by default to avoid # We set the buffer to be float32 by default to avoid
@ -149,28 +168,3 @@ if import_exc is None:
y_slice_size, y_slice_size,
y_offset, y_offset,
) )
else:
def _raise_exc(
*args, # pylint: disable=unused-argument
**kwargs # pylint: disable=unused-argument
):
if torch.cuda.get_device_capability() < (8, 0):
raise ImportError("punica LoRA kernels require compute "
"capability>=8.0") from import_exc
else:
raise ImportError(
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set.") from import_exc
bgmv = _raise_exc
add_lora = _raise_exc
add_lora_slice = _raise_exc
__all__ = [
"bgmv",
"add_lora",
"add_lora_slice",
]