[BugFix] Fix GC bug for LLM
class (#2882)
This commit is contained in:
parent
31348dff03
commit
d7afab6d3a
@ -4,6 +4,10 @@ It should include tests that are reported by users and making sure they
|
|||||||
will never happen again.
|
will never happen again.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import gc
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
@ -35,6 +39,20 @@ def test_max_tokens_none():
|
|||||||
assert len(prompts) == len(outputs)
|
assert len(prompts) == len(outputs)
|
||||||
|
|
||||||
|
|
||||||
|
def test_gc():
|
||||||
|
llm = LLM("facebook/opt-125m", enforce_eager=True)
|
||||||
|
del llm
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# The memory allocated for model and KV cache should be released.
|
||||||
|
# The memory allocated for PyTorch and others should be less than 50MB.
|
||||||
|
# Usually, it's around 10MB.
|
||||||
|
allocated = torch.cuda.memory_allocated()
|
||||||
|
assert allocated < 50 * 1024 * 1024
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import pytest
|
import pytest
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
@ -4,173 +4,167 @@ from typing import Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import_exc = None
|
|
||||||
|
|
||||||
try:
|
def _raise_import_error(e):
|
||||||
import vllm._punica_C as punica_kernels
|
if torch.cuda.get_device_capability() < (8, 0):
|
||||||
except ImportError as e:
|
raise ImportError(
|
||||||
import_exc = e
|
"punica LoRA kernels require compute capability >= 8.0") from e
|
||||||
|
else:
|
||||||
|
raise ImportError(
|
||||||
|
"punica LoRA kernels could not be imported. If you built vLLM "
|
||||||
|
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
|
||||||
|
"was set.") from e
|
||||||
|
|
||||||
if import_exc is None:
|
|
||||||
|
|
||||||
def bgmv(
|
def bgmv(
|
||||||
y: torch.Tensor,
|
y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
w_t_all: torch.Tensor,
|
w_t_all: torch.Tensor,
|
||||||
indicies: torch.LongTensor,
|
indicies: torch.LongTensor,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Semantics:
|
Semantics:
|
||||||
y[i] += (
|
y[i] += (
|
||||||
x[i].unsqueeze(0)
|
x[i].unsqueeze(0)
|
||||||
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
* scale
|
* scale
|
||||||
).squeeze(0)
|
).squeeze(0)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||||
x: Shape: `[B, H1]`. Input vectors.
|
x: Shape: `[B, H1]`. Input vectors.
|
||||||
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
|
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
|
||||||
matrices.
|
matrices.
|
||||||
indicies: Shape: `[B]`. Indices of the weight matrices.
|
indicies: Shape: `[B]`. Indices of the weight matrices.
|
||||||
layer_idx: Layer index of the weight matrices.
|
layer_idx: Layer index of the weight matrices.
|
||||||
scale: Scaling factor.
|
scale: Scaling factor.
|
||||||
"""
|
"""
|
||||||
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
|
try:
|
||||||
|
import vllm._punica_C as punica_kernels
|
||||||
|
except ImportError as e:
|
||||||
|
_raise_import_error(e)
|
||||||
|
|
||||||
def add_lora(y: torch.Tensor,
|
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
|
||||||
x: torch.Tensor,
|
|
||||||
wa_t_all: torch.Tensor,
|
|
||||||
wb_t_all: torch.Tensor,
|
|
||||||
indicies: torch.LongTensor,
|
|
||||||
layer_idx: int,
|
|
||||||
scale: float,
|
|
||||||
*,
|
|
||||||
buffer: Optional[torch.Tensor] = None):
|
|
||||||
"""
|
|
||||||
Semantics:
|
|
||||||
y[i] += (
|
|
||||||
x[i].unsqueeze(0)
|
|
||||||
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
|
||||||
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
|
||||||
* scale
|
|
||||||
).squeeze(0)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
|
||||||
x: Shape: `[B, H1]`. Input vectors.
|
|
||||||
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
|
||||||
LoRA A matrices.
|
|
||||||
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
|
||||||
LoRA B matrices.
|
|
||||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
|
||||||
layer_idx: Layer index of LoRA weights.
|
|
||||||
scale: Scaling factor.
|
|
||||||
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
|
|
||||||
"""
|
|
||||||
r = wb_t_all.size(-1)
|
|
||||||
if buffer is None:
|
|
||||||
# We set the buffer to be float32 by default to avoid
|
|
||||||
# numerical innacuracies that would otherwise happen
|
|
||||||
# due to downcasting.
|
|
||||||
buffer = torch.zeros((x.size(0), r),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=x.device)
|
|
||||||
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx,
|
|
||||||
1.0)
|
|
||||||
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
|
|
||||||
scale)
|
|
||||||
|
|
||||||
def add_lora_slice(y: torch.Tensor,
|
def add_lora(y: torch.Tensor,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
wa_t_all: torch.Tensor,
|
wa_t_all: torch.Tensor,
|
||||||
wb_t_all: torch.Tensor,
|
wb_t_all: torch.Tensor,
|
||||||
indicies: torch.LongTensor,
|
indicies: torch.LongTensor,
|
||||||
layer_idx: int,
|
layer_idx: int,
|
||||||
scale: float,
|
scale: float,
|
||||||
y_offset: int,
|
*,
|
||||||
y_slice_size: int,
|
buffer: Optional[torch.Tensor] = None):
|
||||||
*,
|
"""
|
||||||
buffer: Optional[torch.Tensor] = None):
|
Semantics:
|
||||||
"""
|
y[i] += (
|
||||||
Same as `add_lora` but you can operate on slices of y.
|
x[i].unsqueeze(0)
|
||||||
Pass whole y, define y_offset and y_slice_size.
|
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
|
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
|
* scale
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
Semantics:
|
Args:
|
||||||
y[i] += (
|
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||||
x[i].unsqueeze(0)
|
x: Shape: `[B, H1]`. Input vectors.
|
||||||
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
||||||
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
LoRA A matrices.
|
||||||
* scale
|
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
||||||
).squeeze(0)
|
LoRA B matrices.
|
||||||
|
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||||
|
layer_idx: Layer index of LoRA weights.
|
||||||
|
scale: Scaling factor.
|
||||||
|
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import vllm._punica_C as punica_kernels
|
||||||
|
except ImportError as e:
|
||||||
|
_raise_import_error(e)
|
||||||
|
|
||||||
Args:
|
r = wb_t_all.size(-1)
|
||||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
if buffer is None:
|
||||||
x: Shape: `[B, H1]`. Input vectors.
|
# We set the buffer to be float32 by default to avoid
|
||||||
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
# numerical innacuracies that would otherwise happen
|
||||||
LoRA A matrices.
|
# due to downcasting.
|
||||||
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
buffer = torch.zeros((x.size(0), r),
|
||||||
LoRA B matrices.
|
dtype=torch.float32,
|
||||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
device=x.device)
|
||||||
layer_idx: Layer index of LoRA weights.
|
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
|
||||||
scale: Scaling factor.
|
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
|
||||||
y_offset: Offset to apply to the starting column of y.
|
scale)
|
||||||
y_slice_size: Size of the y column slice.
|
|
||||||
"""
|
|
||||||
r = wb_t_all.size(-1)
|
|
||||||
if buffer is None:
|
|
||||||
# We set the buffer to be float32 by default to avoid
|
|
||||||
# numerical inaccuracies that would otherwise happen
|
|
||||||
# due to downcasting.
|
|
||||||
buffer = torch.zeros((x.size(0), r),
|
|
||||||
dtype=torch.float32,
|
|
||||||
device=x.device)
|
|
||||||
punica_kernels.dispatch_bgmv_low_level(
|
|
||||||
buffer,
|
|
||||||
x,
|
|
||||||
wa_t_all,
|
|
||||||
indicies,
|
|
||||||
layer_idx,
|
|
||||||
1.0,
|
|
||||||
x.size(1),
|
|
||||||
buffer.size(1),
|
|
||||||
0,
|
|
||||||
)
|
|
||||||
punica_kernels.dispatch_bgmv_low_level(
|
|
||||||
y,
|
|
||||||
buffer,
|
|
||||||
wb_t_all,
|
|
||||||
indicies,
|
|
||||||
layer_idx,
|
|
||||||
scale,
|
|
||||||
buffer.size(1),
|
|
||||||
y_slice_size,
|
|
||||||
y_offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def _raise_exc(
|
def add_lora_slice(y: torch.Tensor,
|
||||||
*args, # pylint: disable=unused-argument
|
x: torch.Tensor,
|
||||||
**kwargs # pylint: disable=unused-argument
|
wa_t_all: torch.Tensor,
|
||||||
):
|
wb_t_all: torch.Tensor,
|
||||||
if torch.cuda.get_device_capability() < (8, 0):
|
indicies: torch.LongTensor,
|
||||||
raise ImportError("punica LoRA kernels require compute "
|
layer_idx: int,
|
||||||
"capability>=8.0") from import_exc
|
scale: float,
|
||||||
else:
|
y_offset: int,
|
||||||
raise ImportError(
|
y_slice_size: int,
|
||||||
"punica LoRA kernels could not be imported. If you built vLLM "
|
*,
|
||||||
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
|
buffer: Optional[torch.Tensor] = None):
|
||||||
"was set.") from import_exc
|
"""
|
||||||
|
Same as `add_lora` but you can operate on slices of y.
|
||||||
|
Pass whole y, define y_offset and y_slice_size.
|
||||||
|
|
||||||
bgmv = _raise_exc
|
Semantics:
|
||||||
add_lora = _raise_exc
|
y[i] += (
|
||||||
add_lora_slice = _raise_exc
|
x[i].unsqueeze(0)
|
||||||
|
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
|
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||||
|
* scale
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
__all__ = [
|
Args:
|
||||||
"bgmv",
|
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||||
"add_lora",
|
x: Shape: `[B, H1]`. Input vectors.
|
||||||
"add_lora_slice",
|
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
||||||
]
|
LoRA A matrices.
|
||||||
|
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
||||||
|
LoRA B matrices.
|
||||||
|
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||||
|
layer_idx: Layer index of LoRA weights.
|
||||||
|
scale: Scaling factor.
|
||||||
|
y_offset: Offset to apply to the starting column of y.
|
||||||
|
y_slice_size: Size of the y column slice.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import vllm._punica_C as punica_kernels
|
||||||
|
except ImportError as e:
|
||||||
|
_raise_import_error(e)
|
||||||
|
|
||||||
|
r = wb_t_all.size(-1)
|
||||||
|
if buffer is None:
|
||||||
|
# We set the buffer to be float32 by default to avoid
|
||||||
|
# numerical inaccuracies that would otherwise happen
|
||||||
|
# due to downcasting.
|
||||||
|
buffer = torch.zeros((x.size(0), r),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=x.device)
|
||||||
|
punica_kernels.dispatch_bgmv_low_level(
|
||||||
|
buffer,
|
||||||
|
x,
|
||||||
|
wa_t_all,
|
||||||
|
indicies,
|
||||||
|
layer_idx,
|
||||||
|
1.0,
|
||||||
|
x.size(1),
|
||||||
|
buffer.size(1),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
punica_kernels.dispatch_bgmv_low_level(
|
||||||
|
y,
|
||||||
|
buffer,
|
||||||
|
wb_t_all,
|
||||||
|
indicies,
|
||||||
|
layer_idx,
|
||||||
|
scale,
|
||||||
|
buffer.size(1),
|
||||||
|
y_slice_size,
|
||||||
|
y_offset,
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user