vllm/cacheflow/decoding.py

34 lines
957 B
Python
Raw Normal View History

2023-02-23 04:57:46 +00:00
from typing import Optional, Set
2023-02-09 11:27:06 +00:00
class DecodingParams:
def __init__(
self,
n: int = 1,
temperature: float = 1.0,
top_p: float = 1.0,
use_beam_search: bool = False,
2023-02-23 04:57:46 +00:00
stop_token_ids: Set[int] = [],
max_context_len: Optional[int] = None,
2023-02-09 11:27:06 +00:00
) -> None:
assert n >= 1
assert temperature >= 0.0
assert 0.0 < top_p <= 1.0
if use_beam_search:
assert n > 1
assert temperature > 0.0
assert top_p == 1.0
elif temperature == 0.0:
# Zero temperature means greedy decoding.
assert n == 1
assert top_p == 1.0
2023-02-23 04:57:46 +00:00
assert max_context_len is None or max_context_len >= 0
2023-02-09 11:27:06 +00:00
self.n = n
self.temperature = temperature
self.top_p = top_p
self.use_beam_search = use_beam_search
self.stop_token_ids = stop_token_ids
2023-02-23 04:57:46 +00:00
self.max_context_len = max_context_len