Add miscellaneous updates (#8)
This commit is contained in:
parent
e9d3f2ff77
commit
cfae35b861
@ -158,8 +158,8 @@ class Scheduler:
|
|||||||
# 3. Join new sequences if possible.
|
# 3. Join new sequences if possible.
|
||||||
# NOTE: Here we implicitly assume FCFS scheduling.
|
# NOTE: Here we implicitly assume FCFS scheduling.
|
||||||
# TODO(woosuk): Add a batching policy to control the batch size.
|
# TODO(woosuk): Add a batching policy to control the batch size.
|
||||||
|
self._fetch_inputs()
|
||||||
if not self.swapped:
|
if not self.swapped:
|
||||||
self._fetch_inputs()
|
|
||||||
for i, seq_group in enumerate(self.pending):
|
for i, seq_group in enumerate(self.pending):
|
||||||
num_prompt_tokens = seq_group.seqs[0].get_len()
|
num_prompt_tokens = seq_group.seqs[0].get_len()
|
||||||
if self.block_manager.can_allocate(seq_group):
|
if self.block_manager.can_allocate(seq_group):
|
||||||
@ -211,12 +211,13 @@ class Scheduler:
|
|||||||
input_seq_groups.append(input_seq_group)
|
input_seq_groups.append(input_seq_group)
|
||||||
|
|
||||||
# 5. Execute the first stage of the pipeline.
|
# 5. Execute the first stage of the pipeline.
|
||||||
self.controllers[0].execute_stage(
|
if (input_seq_groups or blocks_to_swap_in or blocks_to_swap_out):
|
||||||
input_seq_groups,
|
self.controllers[0].execute_stage(
|
||||||
blocks_to_swap_in,
|
input_seq_groups,
|
||||||
blocks_to_swap_out,
|
blocks_to_swap_in,
|
||||||
blocks_to_copy,
|
blocks_to_swap_out,
|
||||||
)
|
blocks_to_copy,
|
||||||
|
)
|
||||||
|
|
||||||
def post_step(
|
def post_step(
|
||||||
self,
|
self,
|
||||||
|
@ -12,7 +12,7 @@ from cacheflow.models import InputMetadata
|
|||||||
class OPTCacheFlowAttention(nn.Module):
|
class OPTCacheFlowAttention(nn.Module):
|
||||||
|
|
||||||
def __init__(self, scale: float) -> None:
|
def __init__(self, scale: float) -> None:
|
||||||
super().__init__()
|
super(OPTCacheFlowAttention, self).__init__()
|
||||||
self.scale = float(scale)
|
self.scale = float(scale)
|
||||||
|
|
||||||
self.flash_attn = FlashAttention(softmax_scale=self.scale)
|
self.flash_attn = FlashAttention(softmax_scale=self.scale)
|
||||||
@ -106,8 +106,8 @@ class OPTCacheFlowAttention(nn.Module):
|
|||||||
output = output.view(-1, num_heads, head_size)
|
output = output.view(-1, num_heads, head_size)
|
||||||
|
|
||||||
# Compute the attention op for prompts.
|
# Compute the attention op for prompts.
|
||||||
if input_metadata.num_prompts > 0:
|
num_prompt_tokens = input_metadata.num_prompt_tokens
|
||||||
num_prompt_tokens = sum(input_metadata.prompt_lens)
|
if num_prompt_tokens > 0:
|
||||||
self.multi_query_kv_attention(
|
self.multi_query_kv_attention(
|
||||||
output[:num_prompt_tokens],
|
output[:num_prompt_tokens],
|
||||||
query[:num_prompt_tokens],
|
query[:num_prompt_tokens],
|
||||||
@ -126,10 +126,9 @@ class OPTCacheFlowAttention(nn.Module):
|
|||||||
|
|
||||||
if input_metadata.num_generation_tokens > 0:
|
if input_metadata.num_generation_tokens > 0:
|
||||||
# Compute the attention op for generation tokens.
|
# Compute the attention op for generation tokens.
|
||||||
start_idx = sum(input_metadata.prompt_lens)
|
|
||||||
self.single_query_cached_kv_attention(
|
self.single_query_cached_kv_attention(
|
||||||
output[start_idx:],
|
output[num_prompt_tokens:],
|
||||||
query[start_idx:],
|
query[num_prompt_tokens:],
|
||||||
key_cache,
|
key_cache,
|
||||||
value_cache,
|
value_cache,
|
||||||
input_metadata)
|
input_metadata)
|
||||||
|
@ -5,7 +5,7 @@ from cacheflow.models.utils import get_cpu_memory
|
|||||||
from cacheflow.models.utils import get_dtype_size
|
from cacheflow.models.utils import get_dtype_size
|
||||||
from cacheflow.models.utils import get_gpu_memory
|
from cacheflow.models.utils import get_gpu_memory
|
||||||
|
|
||||||
_GiB = 1 << 30
|
_GiB = 1 << 30
|
||||||
|
|
||||||
|
|
||||||
class CacheFlowMemoryAnalyzer:
|
class CacheFlowMemoryAnalyzer:
|
||||||
@ -117,9 +117,19 @@ class OPTMemoryAnalyzer(CacheFlowMemoryAnalyzer):
|
|||||||
|
|
||||||
def get_max_num_cpu_blocks(
|
def get_max_num_cpu_blocks(
|
||||||
self,
|
self,
|
||||||
memory_utilization: float = 0.25,
|
swap_space: int,
|
||||||
) -> int:
|
) -> int:
|
||||||
|
swap_space = swap_space * _GiB
|
||||||
cpu_memory = get_cpu_memory()
|
cpu_memory = get_cpu_memory()
|
||||||
usable_memory = int(memory_utilization * cpu_memory)
|
if swap_space > 0.8 * cpu_memory:
|
||||||
max_num_blocks = usable_memory // self._get_cache_block_size()
|
raise ValueError(f'The swap space ({swap_space / _GiB:.2f} GiB) '
|
||||||
|
'takes more than 80% of the available memory '
|
||||||
|
f'({cpu_memory / _GiB:.2f} GiB).'
|
||||||
|
'Please check the swap space size.')
|
||||||
|
if swap_space > 0.5 * cpu_memory:
|
||||||
|
print(f'WARNING: The swap space ({swap_space / _GiB:.2f} GiB) '
|
||||||
|
'takes more than 50% of the available memory '
|
||||||
|
f'({cpu_memory / _GiB:.2f} GiB).'
|
||||||
|
'This may slow the system performance.')
|
||||||
|
max_num_blocks = swap_space // self._get_cache_block_size()
|
||||||
return max_num_blocks
|
return max_num_blocks
|
||||||
|
@ -11,7 +11,7 @@ from cacheflow.sequence import SequenceOutputs
|
|||||||
class Sampler(nn.Module):
|
class Sampler(nn.Module):
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super(Sampler, self).__init__()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -191,6 +191,13 @@ class Worker:
|
|||||||
else:
|
else:
|
||||||
cache_events = None
|
cache_events = None
|
||||||
|
|
||||||
|
# If there is no input, we don't need to execute the model.
|
||||||
|
if not input_seq_groups:
|
||||||
|
if cache_events is not None:
|
||||||
|
for event in cache_events:
|
||||||
|
event.wait()
|
||||||
|
return {}
|
||||||
|
|
||||||
# Prepare input tensors.
|
# Prepare input tensors.
|
||||||
input_tokens, input_positions, input_metadata = self.prepare_inputs(
|
input_tokens, input_positions, input_metadata = self.prepare_inputs(
|
||||||
input_seq_groups)
|
input_seq_groups)
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
#include <torch/extension.h>
|
#include <torch/extension.h>
|
||||||
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
@ -73,6 +72,8 @@ void copy_blocks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace cacheflow {
|
||||||
|
|
||||||
template<typename scalar_t>
|
template<typename scalar_t>
|
||||||
__global__ void reshape_and_cache_kernel(
|
__global__ void reshape_and_cache_kernel(
|
||||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||||
@ -112,6 +113,8 @@ __global__ void reshape_and_cache_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace cacheflow
|
||||||
|
|
||||||
void reshape_and_cache(
|
void reshape_and_cache(
|
||||||
torch::Tensor& key,
|
torch::Tensor& key,
|
||||||
torch::Tensor& value,
|
torch::Tensor& value,
|
||||||
@ -131,7 +134,7 @@ void reshape_and_cache(
|
|||||||
key.scalar_type(),
|
key.scalar_type(),
|
||||||
"reshape_and_cache_kernel",
|
"reshape_and_cache_kernel",
|
||||||
[&] {
|
[&] {
|
||||||
reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||||
key.data_ptr<scalar_t>(),
|
key.data_ptr<scalar_t>(),
|
||||||
value.data_ptr<scalar_t>(),
|
value.data_ptr<scalar_t>(),
|
||||||
key_cache.data_ptr<scalar_t>(),
|
key_cache.data_ptr<scalar_t>(),
|
||||||
|
@ -15,7 +15,8 @@ parser.add_argument('--block-size', type=int, default=8, choices=[8, 16], help='
|
|||||||
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
|
parser.add_argument('--dtype', type=str, default='half', choices=['half', 'float'], help='data type')
|
||||||
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
|
||||||
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
parser.add_argument('--seed', type=int, default=0, help='random seed')
|
||||||
parser.add_argument('--max-batch-size', type=int, default=2048, help='maximum number of batched tokens')
|
parser.add_argument('--swap-space', type=int, default=20, help='CPU swap space size (GiB) per GPU')
|
||||||
|
parser.add_argument('--max-batch-size', type=int, default=2560, help='maximum number of batched tokens')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@ -27,7 +28,8 @@ def main():
|
|||||||
)
|
)
|
||||||
num_gpu_blocks = memory_analyzer.get_max_num_gpu_blocks(
|
num_gpu_blocks = memory_analyzer.get_max_num_gpu_blocks(
|
||||||
max_num_batched_tokens=args.max_batch_size)
|
max_num_batched_tokens=args.max_batch_size)
|
||||||
num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks()
|
num_cpu_blocks = memory_analyzer.get_max_num_cpu_blocks(
|
||||||
|
swap_space=args.swap_space)
|
||||||
print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}')
|
print(f'# GPU blocks: {num_gpu_blocks}, # CPU blocks: {num_cpu_blocks}')
|
||||||
|
|
||||||
# Create a controller for each node.
|
# Create a controller for each node.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user