[Core] Adding Priority Scheduling (#5958)

This commit is contained in:
Archit Patke 2024-09-24 21:50:50 -05:00 committed by GitHub
parent 01b6f9e1f0
commit 6da1ab6b41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 410 additions and 8 deletions

View File

@ -0,0 +1,295 @@
"""Benchmark offline prioritization."""
import argparse
import json
import random
import time
from typing import List, Optional, Tuple
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
def sample_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int],
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [(data["conversations"][0]["value"],
data["conversations"][1]["value"]) for data in dataset]
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
filtered_dataset: List[Tuple[str, int, int]] = []
for i in range(len(dataset)):
if len(filtered_dataset) == num_requests:
break
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
completion = dataset[i][1]
completion_token_ids = tokenizer(completion).input_ids
prompt_len = len(prompt_token_ids)
output_len = len(completion_token_ids
) if fixed_output_len is None else fixed_output_len
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
#Select a equi-probable random priority
priority = 0 if random.random() < 0.5 else 1
filtered_dataset.append((prompt, prompt_len, output_len, priority))
return filtered_dataset
def run_vllm(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
use_beam_search: bool,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
gpu_memory_utilization: float = 0.9,
download_dir: Optional[str] = None,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
disable_log_stats=False,
)
# Add the requests to the engine.
prompts = []
sampling_params = []
priority = []
for prompt, _, output_len, _priority in requests:
prompts.append(prompt)
priority.append(_priority)
sampling_params.append(
SamplingParams(
n=n,
temperature=0.0 if use_beam_search else 1.0,
top_p=1.0,
use_beam_search=use_beam_search,
ignore_eos=True,
max_tokens=output_len,
))
start = time.perf_counter()
llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True)
end = time.perf_counter()
return end - start
def main(args: argparse.Namespace):
print(args)
random.seed(args.seed)
# Sample the requests.
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code)
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
else:
requests = sample_requests(args.dataset, args.num_prompts, tokenizer,
args.output_len)
if args.backend == "vllm":
elapsed_time = run_vllm(
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n, args.use_beam_search,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.gpu_memory_utilization,
args.download_dir)
else:
raise ValueError(f"Unknown backend: {args.backend}")
total_num_tokens = sum(prompt_len + output_len
for _, prompt_len, output_len, priority in requests)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
# Output JSON results if specified
if args.output_json:
results = {
"elapsed_time": elapsed_time,
"num_requests": len(requests),
"total_num_tokens": total_num_tokens,
"requests_per_second": len(requests) / elapsed_time,
"tokens_per_second": total_num_tokens / elapsed_time,
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
parser.add_argument("--backend",
type=str,
choices=["vllm", "hf", "mii"],
default="vllm")
parser.add_argument("--dataset",
type=str,
default=None,
help="Path to the dataset.")
parser.add_argument("--input-len",
type=int,
default=None,
help="Input prompt length for each request")
parser.add_argument("--output-len",
type=int,
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=[*QUANTIZATION_METHODS, None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
type=int,
default=1,
help="Number of generated sequences per prompt.")
parser.add_argument("--use-beam-search", action="store_true")
parser.add_argument("--num-prompts",
type=int,
default=200,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--max-model-len',
type=int,
default=None,
help='Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
help="enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill",
action='store_true',
help="enable chunked prefill for vLLM backend.")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=None,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--download-dir',
type=str,
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
if args.dataset is None:
assert args.input_len is not None
assert args.output_len is not None
else:
assert args.input_len is None
main(args)

View File

@ -961,7 +961,7 @@ class SchedulerConfig:
workers instead of an entire data. It should be enabled only workers instead of an entire data. It should be enabled only
when SPMD worker architecture is enabled. I.e., when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_SPMD_WORKER=1
policy: The scheduling policy to use. "fcfs" (default) or "priority".
""" """
def __init__(self, def __init__(self,
@ -977,7 +977,8 @@ class SchedulerConfig:
preemption_mode: Optional[str] = None, preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1, num_scheduler_steps: int = 1,
multi_step_stream_outputs: bool = False, multi_step_stream_outputs: bool = False,
send_delta_data: bool = False) -> None: send_delta_data: bool = False,
policy: str = "fcfs") -> None:
if max_num_batched_tokens is None: if max_num_batched_tokens is None:
if enable_chunked_prefill: if enable_chunked_prefill:
# It is the values that have the best balance between ITL # It is the values that have the best balance between ITL
@ -1019,6 +1020,7 @@ class SchedulerConfig:
self.num_scheduler_steps = num_scheduler_steps self.num_scheduler_steps = num_scheduler_steps
self.multi_step_stream_outputs = multi_step_stream_outputs self.multi_step_stream_outputs = multi_step_stream_outputs
self.send_delta_data = send_delta_data self.send_delta_data = send_delta_data
self.policy = policy
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:

View File

@ -766,6 +766,79 @@ class Scheduler:
else: else:
return prompt_limit return prompt_limit
def _get_priority(self,
seq_group: SequenceGroup) -> Tuple[Optional[int], float]:
""" Get the priority of the sequence group.
Highest preference to user-defined priority, followed by arrival time.
Args:
seq_group: The sequence group input.
Returns:
The priority of the sequence group.
"""
return seq_group.priority, seq_group.arrival_time
def _schedule_priority_preemption(
self,
budget: SchedulingBudget,
) -> int:
"""Sorts waiting and running queue. Also, force preempt requests
from the running queue if their priority is lower.
Priority-based preemption is used with the priority policy.
Args:
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
Returns:
A count of priority-based preemptions.
"""
waiting_queue = self.waiting
running_queue = deque(sorted(self.running, key=self._get_priority))
blocks_to_swap_out: List[Tuple[int, int]] = []
force_preemption_count = 0
if waiting_queue:
seq_group = waiting_queue.popleft()
num_new_seqs = seq_group.get_max_num_running_seqs()
num_new_tokens = self._get_num_new_tokens(seq_group,
SequenceStatus.WAITING,
False, budget)
#Only preempt if priority inversion exists
while running_queue and self._get_priority(
running_queue[-1]) > self._get_priority(seq_group):
#Only preempt if waiting sequence cannot be allocated
can_allocate = self.block_manager.can_allocate(seq_group)
if (num_new_tokens and can_allocate == AllocStatus.OK
and budget.can_schedule(num_new_tokens=num_new_tokens,
num_new_seqs=num_new_seqs)):
break
#Adjust budget to remove the victim sequence group
vseq_group = running_queue.pop()
num_running_tokens = self._get_num_new_tokens(
vseq_group, SequenceStatus.RUNNING, False, budget)
budget.subtract_num_batched_tokens(vseq_group.request_id,
num_running_tokens)
num_running_seqs = vseq_group.get_max_num_running_seqs()
budget.subtract_num_seqs(vseq_group.request_id,
num_running_seqs)
#Preempt out the victim sequence group
self._preempt(vseq_group, blocks_to_swap_out,
PreemptionMode.RECOMPUTE)
waiting_queue.appendleft(vseq_group)
force_preemption_count += 1
#Put the sequence back into the waiting queue
waiting_queue.appendleft(seq_group)
waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))
self.waiting = waiting_queue
self.running = running_queue
return force_preemption_count
def _schedule_prefills( def _schedule_prefills(
self, self,
budget: SchedulingBudget, budget: SchedulingBudget,
@ -917,6 +990,10 @@ class Scheduler:
curr_loras, curr_loras,
enable_chunking=False) enable_chunking=False)
if len(prefills.seq_groups
) == 0 and self.scheduler_config.policy == "priority":
self._schedule_priority_preemption(budget)
# Don't schedule decodes if prefills are scheduled. # Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills. # only contains decode requests, not chunked prefills.

View File

@ -631,6 +631,7 @@ class LLMEngine:
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
) -> None: ) -> None:
self._validate_model_inputs(processed_inputs) self._validate_model_inputs(processed_inputs)
# Create the sequences. # Create the sequences.
@ -661,7 +662,8 @@ class LLMEngine:
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq) encoder_seq=encoder_seq,
priority=priority)
elif isinstance(params, PoolingParams): elif isinstance(params, PoolingParams):
seq_group = self._create_sequence_group_with_pooling( seq_group = self._create_sequence_group_with_pooling(
request_id, request_id,
@ -670,7 +672,8 @@ class LLMEngine:
arrival_time=arrival_time, arrival_time=arrival_time,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq) encoder_seq=encoder_seq,
priority=priority)
else: else:
raise ValueError( raise ValueError(
"Either SamplingParams or PoolingParams must be provided.") "Either SamplingParams or PoolingParams must be provided.")
@ -695,6 +698,7 @@ class LLMEngine:
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None: ) -> None:
"""Add a request to the engine's request pool. """Add a request to the engine's request pool.
@ -713,6 +717,8 @@ class LLMEngine:
arrival_time: The arrival time of the request. If None, we use arrival_time: The arrival time of the request. If None, we use
the current monotonic time. the current monotonic time.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
priority: The priority of the request.
Only applicable with priority scheduling.
Details: Details:
- Set arrival_time to the current time if it is None. - Set arrival_time to the current time if it is None.
@ -741,6 +747,11 @@ class LLMEngine:
if lora_request is not None and not self.lora_config: if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is " raise ValueError(f"Got lora_request {lora_request} but LoRA is "
"not enabled!") "not enabled!")
if priority > 0 and not self.scheduler_config.policy == "priority":
raise ValueError(f"Got priority {priority} but "
"Priority scheduling is not enabled.")
if arrival_time is None: if arrival_time is None:
arrival_time = time.time() arrival_time = time.time()
@ -760,6 +771,7 @@ class LLMEngine:
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers, trace_headers=trace_headers,
priority=priority,
) )
def _create_sequence_group_with_sampling( def _create_sequence_group_with_sampling(
@ -772,6 +784,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
priority: int = 0,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams.""" """Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs max_logprobs = self.get_model_config().max_logprobs
@ -798,7 +811,8 @@ class LLMEngine:
lora_request=lora_request, lora_request=lora_request,
trace_headers=trace_headers, trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq) encoder_seq=encoder_seq,
priority=priority)
return seq_group return seq_group
@ -811,6 +825,7 @@ class LLMEngine:
lora_request: Optional[LoRARequest], lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
priority: int = 0,
) -> SequenceGroup: ) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams.""" """Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler # Defensive copy of PoolingParams, which are used by the pooler
@ -823,7 +838,8 @@ class LLMEngine:
lora_request=lora_request, lora_request=lora_request,
pooling_params=pooling_params, pooling_params=pooling_params,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq) encoder_seq=encoder_seq,
priority=priority)
return seq_group return seq_group
def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: def abort_request(self, request_id: Union[str, Iterable[str]]) -> None:

View File

@ -320,7 +320,8 @@ class LLM:
lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
guided_options_request: Optional[Union[LLMGuidedOptions, guided_options_request: Optional[Union[LLMGuidedOptions,
GuidedDecodingRequest]] = None GuidedDecodingRequest]] = None,
priority: Optional[List[int]] = None,
) -> List[RequestOutput]: ) -> List[RequestOutput]:
"""Generates the completions for the input prompts. """Generates the completions for the input prompts.
@ -339,6 +340,8 @@ class LLM:
lora_request: LoRA request to use for generation, if any. lora_request: LoRA request to use for generation, if any.
prompt_adapter_request: Prompt Adapter request to use for prompt_adapter_request: Prompt Adapter request to use for
generation, if any. generation, if any.
priority: The priority of the requests, if any.
Only applicable when priority scheduling policy is enabled.
Returns: Returns:
A list of ``RequestOutput`` objects containing the A list of ``RequestOutput`` objects containing the
@ -379,7 +382,8 @@ class LLM:
params=sampling_params, params=sampling_params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
guided_options=guided_options_request) guided_options=guided_options_request,
priority=priority)
outputs = self._run_engine(use_tqdm=use_tqdm) outputs = self._run_engine(use_tqdm=use_tqdm)
return LLMEngine.validate_outputs(outputs, RequestOutput) return LLMEngine.validate_outputs(outputs, RequestOutput)
@ -782,6 +786,7 @@ class LLM:
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
prompt_adapter_request: Optional[PromptAdapterRequest], prompt_adapter_request: Optional[PromptAdapterRequest],
guided_options: Optional[GuidedDecodingRequest] = None, guided_options: Optional[GuidedDecodingRequest] = None,
priority: Optional[List[int]] = None,
) -> None: ) -> None:
if isinstance(inputs, (str, dict)): if isinstance(inputs, (str, dict)):
# Convert a single prompt to a list. # Convert a single prompt to a list.
@ -811,6 +816,7 @@ class LLM:
lora_request=lora_request[i] if isinstance( lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request, lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=priority[i] if priority else 0,
) )
def _add_request( def _add_request(
@ -819,6 +825,7 @@ class LLM:
params: Union[SamplingParams, PoolingParams], params: Union[SamplingParams, PoolingParams],
lora_request: Optional[LoRARequest] = None, lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None: ) -> None:
request_id = str(next(self.request_counter)) request_id = str(next(self.request_counter))
self.llm_engine.add_request( self.llm_engine.add_request(
@ -827,6 +834,7 @@ class LLM:
params, params,
lora_request=lora_request, lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request, prompt_adapter_request=prompt_adapter_request,
priority=priority,
) )
def _add_guided_processor( def _add_guided_processor(

View File

@ -646,6 +646,7 @@ class SequenceGroup:
unless you are working with an encoder/decoder model. unless you are working with an encoder/decoder model.
trace_headers: OpenTelemetry trace headers. trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request. prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request.
""" """
def __init__( def __init__(
@ -660,9 +661,11 @@ class SequenceGroup:
encoder_seq: Optional[Sequence] = None, encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Mapping[str, str]] = None, trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
) -> None: ) -> None:
self.request_id = request_id self.request_id = request_id
self.seqs = seqs self.seqs = seqs
self.arrival_time = arrival_time
self.is_single_seq = len(seqs) == 1 self.is_single_seq = len(seqs) == 1
self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.seqs_dict = {seq.seq_id: seq for seq in seqs}
@ -680,6 +683,7 @@ class SequenceGroup:
self.prompt_adapter_request = prompt_adapter_request self.prompt_adapter_request = prompt_adapter_request
self.encoder_seq = encoder_seq self.encoder_seq = encoder_seq
self.trace_headers = trace_headers self.trace_headers = trace_headers
self.priority = priority
self.cached_request_output = None self.cached_request_output = None