[BugFix] Fix GC bug for LLM
class (#2882)
This commit is contained in:
parent
31348dff03
commit
d7afab6d3a
@ -4,6 +4,10 @@ It should include tests that are reported by users and making sure they
|
|||||||
will never happen again.
|
will never happen again.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
@ -35,6 +39,20 @@ def test_max_tokens_none():
|
|||||||
assert len(prompts) == len(outputs)
|
assert len(prompts) == len(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gc():
|
||||||
|
llm = LLM("facebook/opt-125m", enforce_eager=True)
|
||||||
|
del llm
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# The memory allocated for model and KV cache should be released.
|
||||||
|
# The memory allocated for PyTorch and others should be less than 50MB.
|
||||||
|
# Usually, it's around 10MB.
|
||||||
|
allocated = torch.cuda.memory_allocated()
|
||||||
|
assert allocated < 50 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
@ -4,23 +4,26 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import_exc = None
|
|
||||||
|
|
||||||
try:
|
def _raise_import_error(e):
|
||||||
import vllm._punica_C as punica_kernels
|
if torch.cuda.get_device_capability() < (8, 0):
|
||||||
except ImportError as e:
|
raise ImportError(
|
||||||
import_exc = e
|
"punica LoRA kernels require compute capability >= 8.0") from e
|
||||||
|
else:
|
||||||
|
raise ImportError(
|
||||||
|
"punica LoRA kernels could not be imported. If you built vLLM "
|
||||||
|
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
|
||||||
|
"was set.") from e
|
||||||
|
|
||||||
if import_exc is None:
|
|
||||||
|
|
||||||
def bgmv(
|
def bgmv(
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
w_t_all: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
indicies: torch.LongTensor,
|
indicies: torch.LongTensor,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Semantics:
|
Semantics:
|
||||||
y[i] += (
|
y[i] += (
|
||||||
@ -38,9 +41,15 @@ if import_exc is None:
|
|||||||
layer_idx: Layer index of the weight matrices.
|
layer_idx: Layer index of the weight matrices.
|
||||||
scale: Scaling factor.
|
scale: Scaling factor.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
import vllm._punica_C as punica_kernels
|
||||||
|
except ImportError as e:
|
||||||
|
_raise_import_error(e)
|
||||||
|
|
||||||
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
|
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
|
||||||
|
|
||||||
def add_lora(y: torch.Tensor,
|
|
||||||
|
def add_lora(y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
wa_t_all: torch.Tensor,
|
wa_t_all: torch.Tensor,
|
||||||
wb_t_all: torch.Tensor,
|
wb_t_all: torch.Tensor,
|
||||||
@ -70,6 +79,11 @@ if import_exc is None:
|
|||||||
scale: Scaling factor.
|
scale: Scaling factor.
|
||||||
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
|
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
import vllm._punica_C as punica_kernels
|
||||||
|
except ImportError as e:
|
||||||
|
_raise_import_error(e)
|
||||||
|
|
||||||
r = wb_t_all.size(-1)
|
r = wb_t_all.size(-1)
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
# We set the buffer to be float32 by default to avoid
|
# We set the buffer to be float32 by default to avoid
|
||||||
@ -78,12 +92,12 @@ if import_exc is None:
|
|||||||
buffer = torch.zeros((x.size(0), r),
|
buffer = torch.zeros((x.size(0), r),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
device=x.device)
|
device=x.device)
|
||||||
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx,
|
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
|
||||||
1.0)
|
|
||||||
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
|
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
|
||||||
scale)
|
scale)
|
||||||
|
|
||||||
def add_lora_slice(y: torch.Tensor,
|
|
||||||
|
def add_lora_slice(y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
wa_t_all: torch.Tensor,
|
wa_t_all: torch.Tensor,
|
||||||
wb_t_all: torch.Tensor,
|
wb_t_all: torch.Tensor,
|
||||||
@ -119,6 +133,11 @@ if import_exc is None:
|
|||||||
y_offset: Offset to apply to the starting column of y.
|
y_offset: Offset to apply to the starting column of y.
|
||||||
y_slice_size: Size of the y column slice.
|
y_slice_size: Size of the y column slice.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
|
import vllm._punica_C as punica_kernels
|
||||||
|
except ImportError as e:
|
||||||
|
_raise_import_error(e)
|
||||||
|
|
||||||
r = wb_t_all.size(-1)
|
r = wb_t_all.size(-1)
|
||||||
if buffer is None:
|
if buffer is None:
|
||||||
# We set the buffer to be float32 by default to avoid
|
# We set the buffer to be float32 by default to avoid
|
||||||
@ -149,28 +168,3 @@ if import_exc is None:
|
|||||||
y_slice_size,
|
y_slice_size,
|
||||||
y_offset,
|
y_offset,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def _raise_exc(
|
|
||||||
*args, # pylint: disable=unused-argument
|
|
||||||
**kwargs # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
if torch.cuda.get_device_capability() < (8, 0):
|
|
||||||
raise ImportError("punica LoRA kernels require compute "
|
|
||||||
"capability>=8.0") from import_exc
|
|
||||||
else:
|
|
||||||
raise ImportError(
|
|
||||||
"punica LoRA kernels could not be imported. If you built vLLM "
|
|
||||||
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
|
|
||||||
"was set.") from import_exc
|
|
||||||
|
|
||||||
bgmv = _raise_exc
|
|
||||||
add_lora = _raise_exc
|
|
||||||
add_lora_slice = _raise_exc
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"bgmv",
|
|
||||||
"add_lora",
|
|
||||||
"add_lora_slice",
|
|
||||||
]
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user