[BugFix] Fix GC bug for LLM class (#2882)

This commit is contained in:
Woosuk Kwon 2024-02-14 22:17:44 -08:00 committed by GitHub
parent 31348dff03
commit d7afab6d3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 169 additions and 157 deletions

View File

@ -4,6 +4,10 @@ It should include tests that are reported by users and making sure they
will never happen again. will never happen again.
""" """
import gc
import torch
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
@ -35,6 +39,20 @@ def test_max_tokens_none():
assert len(prompts) == len(outputs) assert len(prompts) == len(outputs)
def test_gc():
llm = LLM("facebook/opt-125m", enforce_eager=True)
del llm
gc.collect()
torch.cuda.empty_cache()
# The memory allocated for model and KV cache should be released.
# The memory allocated for PyTorch and others should be less than 50MB.
# Usually, it's around 10MB.
allocated = torch.cuda.memory_allocated()
assert allocated < 50 * 1024 * 1024
if __name__ == "__main__": if __name__ == "__main__":
import pytest import pytest
pytest.main([__file__]) pytest.main([__file__])

View File

@ -4,173 +4,167 @@ from typing import Optional
import torch import torch
import_exc = None
try: def _raise_import_error(e):
import vllm._punica_C as punica_kernels if torch.cuda.get_device_capability() < (8, 0):
except ImportError as e: raise ImportError(
import_exc = e "punica LoRA kernels require compute capability >= 8.0") from e
else:
raise ImportError(
"punica LoRA kernels could not be imported. If you built vLLM "
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
"was set.") from e
if import_exc is None:
def bgmv( def bgmv(
y: torch.Tensor, y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
w_t_all: torch.Tensor, w_t_all: torch.Tensor,
indicies: torch.LongTensor, indicies: torch.LongTensor,
layer_idx: int, layer_idx: int,
scale: float, scale: float,
): ):
""" """
Semantics: Semantics:
y[i] += ( y[i] += (
x[i].unsqueeze(0) x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale * scale
).squeeze(0) ).squeeze(0)
Args: Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors. x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
matrices. matrices.
indicies: Shape: `[B]`. Indices of the weight matrices. indicies: Shape: `[B]`. Indices of the weight matrices.
layer_idx: Layer index of the weight matrices. layer_idx: Layer index of the weight matrices.
scale: Scaling factor. scale: Scaling factor.
""" """
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
def add_lora(y: torch.Tensor, punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
x: torch.Tensor,
wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor,
indicies: torch.LongTensor,
layer_idx: int,
scale: float,
*,
buffer: Optional[torch.Tensor] = None):
"""
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
"""
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default to avoid
# numerical innacuracies that would otherwise happen
# due to downcasting.
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx,
1.0)
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
scale)
def add_lora_slice(y: torch.Tensor, def add_lora(y: torch.Tensor,
x: torch.Tensor, x: torch.Tensor,
wa_t_all: torch.Tensor, wa_t_all: torch.Tensor,
wb_t_all: torch.Tensor, wb_t_all: torch.Tensor,
indicies: torch.LongTensor, indicies: torch.LongTensor,
layer_idx: int, layer_idx: int,
scale: float, scale: float,
y_offset: int, *,
y_slice_size: int, buffer: Optional[torch.Tensor] = None):
*, """
buffer: Optional[torch.Tensor] = None): Semantics:
""" y[i] += (
Same as `add_lora` but you can operate on slices of y. x[i].unsqueeze(0)
Pass whole y, define y_offset and y_slice_size. @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Semantics: Args:
y[i] += ( y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x[i].unsqueeze(0) x: Shape: `[B, H1]`. Input vectors.
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) LoRA A matrices.
* scale wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
).squeeze(0) LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
"""
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
Args: r = wb_t_all.size(-1)
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. if buffer is None:
x: Shape: `[B, H1]`. Input vectors. # We set the buffer to be float32 by default to avoid
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed # numerical innacuracies that would otherwise happen
LoRA A matrices. # due to downcasting.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed buffer = torch.zeros((x.size(0), r),
LoRA B matrices. dtype=torch.float32,
indicies: Shape: `[B]`. Indices of the LoRA weights. device=x.device)
layer_idx: Layer index of LoRA weights. punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
scale: Scaling factor. punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
y_offset: Offset to apply to the starting column of y. scale)
y_slice_size: Size of the y column slice.
"""
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default to avoid
# numerical inaccuracies that would otherwise happen
# due to downcasting.
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
punica_kernels.dispatch_bgmv_low_level(
buffer,
x,
wa_t_all,
indicies,
layer_idx,
1.0,
x.size(1),
buffer.size(1),
0,
)
punica_kernels.dispatch_bgmv_low_level(
y,
buffer,
wb_t_all,
indicies,
layer_idx,
scale,
buffer.size(1),
y_slice_size,
y_offset,
)
else:
def _raise_exc( def add_lora_slice(y: torch.Tensor,
*args, # pylint: disable=unused-argument x: torch.Tensor,
**kwargs # pylint: disable=unused-argument wa_t_all: torch.Tensor,
): wb_t_all: torch.Tensor,
if torch.cuda.get_device_capability() < (8, 0): indicies: torch.LongTensor,
raise ImportError("punica LoRA kernels require compute " layer_idx: int,
"capability>=8.0") from import_exc scale: float,
else: y_offset: int,
raise ImportError( y_slice_size: int,
"punica LoRA kernels could not be imported. If you built vLLM " *,
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var " buffer: Optional[torch.Tensor] = None):
"was set.") from import_exc """
Same as `add_lora` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
bgmv = _raise_exc Semantics:
add_lora = _raise_exc y[i] += (
add_lora_slice = _raise_exc x[i].unsqueeze(0)
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
__all__ = [ Args:
"bgmv", y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
"add_lora", x: Shape: `[B, H1]`. Input vectors.
"add_lora_slice", wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
] LoRA A matrices.
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
LoRA B matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
try:
import vllm._punica_C as punica_kernels
except ImportError as e:
_raise_import_error(e)
r = wb_t_all.size(-1)
if buffer is None:
# We set the buffer to be float32 by default to avoid
# numerical inaccuracies that would otherwise happen
# due to downcasting.
buffer = torch.zeros((x.size(0), r),
dtype=torch.float32,
device=x.device)
punica_kernels.dispatch_bgmv_low_level(
buffer,
x,
wa_t_all,
indicies,
layer_idx,
1.0,
x.size(1),
buffer.size(1),
0,
)
punica_kernels.dispatch_bgmv_low_level(
y,
buffer,
wb_t_all,
indicies,
layer_idx,
scale,
buffer.size(1),
y_slice_size,
y_offset,
)