Add contributing guideline and mypy config (#122)

This commit is contained in:
Woosuk Kwon 2023-05-23 17:58:51 -07:00 committed by GitHub
parent 3f942acfe1
commit a283ec2eec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 128 additions and 44 deletions

74
CONTRIBUTING.md Normal file
View File

@ -0,0 +1,74 @@
# Contributing to CacheFlow
Thank you for your interest in contributing to CacheFlow!
Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large.
There are several ways you can contribute to the project:
- Identify and report any issues or bugs.
- Request or add a new model.
- Suggest or implement new features.
However, remember that contributions aren't just about code.
We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions.
Finally, one of the most impactful ways to support us is by raising awareness about CacheFlow.
Talk about it in your blog posts, highlighting how it's driving your incredible projects.
Express your support on Twitter if CacheFlow aids you, or simply offer your appreciation by starring our repository.
## Setup for development
### Build from source
```bash
pip install -r requirements.txt
pip install -e . # This may take several minutes.
```
### Testing
```bash
pip install -r requirements-dev.txt
# Static type checking
mypy
# Unit tests
pytest tests/
```
**Note:** Currently, the repository does not pass the mypy tests.
## Contributing Guidelines
### Issue Reporting
If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it.
If not, please file a new issue, providing as much relevant information as possible.
### Coding Style Guide
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
### Pull Requests
When submitting a pull request:
1. Make sure your code has been rebased on top of the latest commit on the main branch.
2. Include a detailed description of the changes in the pull request.
Explain why you made the changes you did.
If your pull request fixes an open issue, please include a reference to it in the description.
### Code Reviews
All submissions, including submissions by project members, require a code review.
To make the review process as smooth as possible, please:
1. Keep your changes as concise as possible.
If your pull request involves multiple unrelated changes, consider splitting it into separate pull requests.
2. Respond to all comments within a reasonable time frame.
If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
### Thank You
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to CacheFlow.
Your contributions make CacheFlow a great tool for everyone!

View File

@ -87,7 +87,7 @@ class Scheduler:
def has_unfinished_seqs(self) -> bool: def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped return self.waiting or self.running or self.swapped
def _schedule(self) -> Tuple[SchedulerOutputs, List[int]]: def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
# Blocks that need to be swaped or copied before model execution. # Blocks that need to be swaped or copied before model execution.
blocks_to_swap_in: Dict[int, int] = {} blocks_to_swap_in: Dict[int, int] = {}
blocks_to_swap_out: Dict[int, int] = {} blocks_to_swap_out: Dict[int, int] = {}

View File

@ -61,7 +61,7 @@ class GPTCacheFlowAttention(nn.Module):
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size] value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
attn_bias: xops.AttentionBias, attn_bias: xops.AttentionBias,
) -> None: ) -> torch.Tensor:
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize. # TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
out = xops.memory_efficient_attention_forward( out = xops.memory_efficient_attention_forward(
query.unsqueeze(0), query.unsqueeze(0),
@ -197,7 +197,7 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
def forward( def forward(
self, self,
positions: torch.LongTensor, # [num_tokens] positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads * head_size] query: torch.Tensor, # [num_tokens, num_heads * head_size]
key: torch.Tensor, # [num_tokens, num_heads * head_size] key: torch.Tensor, # [num_tokens, num_heads * head_size]
value: torch.Tensor, # [num_tokens, num_heads * head_size] value: torch.Tensor, # [num_tokens, num_heads * head_size]

View File

@ -347,7 +347,7 @@ def _sample_from_generation_tokens(
# Greedy sampling. # Greedy sampling.
assert len(seq_ids) == 1 assert len(seq_ids) == 1
next_token_id = torch.argmax(probs, dim=-1) next_token_id = torch.argmax(probs, dim=-1)
next_token_ids = [next_token_id.item()] next_token_ids = [int(next_token_id.item())]
parent_seq_ids = seq_ids parent_seq_ids = seq_ids
else: else:
# Random sampling. # Random sampling.

View File

@ -1,4 +1,6 @@
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
from typing import Type
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -17,7 +19,7 @@ _MODEL_REGISTRY = {
} }
def _get_model_architecture(config: PretrainedConfig) -> nn.Module: def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", []) architectures = getattr(config, "architectures", [])
for arch in architectures: for arch in architectures:
if arch in _MODEL_REGISTRY: if arch in _MODEL_REGISTRY:

View File

@ -168,8 +168,8 @@ class GPT2Model(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.Tensor,
position_ids: torch.LongTensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
@ -204,8 +204,8 @@ class GPT2LMHeadModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.Tensor,
positions: torch.LongTensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],

View File

@ -67,7 +67,7 @@ class GPTNeoXAttention(nn.Module):
def forward( def forward(
self, self,
position_ids: torch.LongTensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
@ -118,7 +118,7 @@ class GPTNeoXLayer(nn.Module):
def forward( def forward(
self, self,
position_ids: torch.LongTensor, position_ids: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
@ -162,8 +162,8 @@ class GPTNeoXModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.Tensor,
position_ids: torch.LongTensor, position_ids: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
@ -199,8 +199,8 @@ class GPTNeoXForCausalLM(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.Tensor,
positions: torch.LongTensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],

View File

@ -109,7 +109,7 @@ class LlamaAttention(nn.Module):
def forward( def forward(
self, self,
positions: torch.LongTensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
@ -143,7 +143,7 @@ class LlamaDecoderLayer(nn.Module):
def forward( def forward(
self, self,
positions: torch.LongTensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
@ -184,8 +184,8 @@ class LlamaModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.Tensor,
positions: torch.LongTensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
@ -222,8 +222,8 @@ class LlamaForCausalLM(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.Tensor,
positions: torch.LongTensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],

View File

@ -47,7 +47,7 @@ class OPTLearnedPositionalEmbedding(nn.Embedding):
self.offset = 2 self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim) super().__init__(num_embeddings + self.offset, embedding_dim)
def forward(self, positions: torch.LongTensor): def forward(self, positions: torch.Tensor):
return super().forward(positions + self.offset) return super().forward(positions + self.offset)
@ -199,8 +199,8 @@ class OPTDecoder(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.Tensor,
positions: torch.LongTensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
@ -235,8 +235,8 @@ class OPTModel(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.Tensor,
positions: torch.LongTensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],
@ -258,8 +258,8 @@ class OPTForCausalLM(nn.Module):
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.Tensor,
positions: torch.LongTensor, positions: torch.Tensor,
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
cache_events: Optional[List[torch.cuda.Event]], cache_events: Optional[List[torch.cuda.Event]],

View File

@ -31,7 +31,7 @@ class RequestOutput:
def __init__( def __init__(
self, self,
request_id: int, request_id: str,
prompt: str, prompt: str,
prompt_token_ids: List[int], prompt_token_ids: List[int],
outputs: List[CompletionOutput], outputs: List[CompletionOutput],

View File

@ -116,10 +116,11 @@ class Sequence:
def get_cumulative_logprob(self) -> float: def get_cumulative_logprob(self) -> float:
return self.data.cumulative_logprob return self.data.cumulative_logprob
def fork(self, child_seq: 'Sequence') -> 'Sequence': def fork(self, child_seq: 'Sequence') -> None:
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks) child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs) child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
child_seq.data = copy.deepcopy(self.data) child_seq.data = copy.deepcopy(self.data)
return None
def __repr__(self) -> str: def __repr__(self) -> str:
return (f'Sequence(seq_id={self.seq_id}, ' return (f'Sequence(seq_id={self.seq_id}, '
@ -205,7 +206,9 @@ class SequenceOutputs:
f'output_token={self.output_token}), ' f'output_token={self.output_token}), '
f'logprobs={self.logprobs}') f'logprobs={self.logprobs}')
def __eq__(self, other: 'SequenceOutputs') -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, SequenceOutputs):
return NotImplemented
return (self.seq_id == other.seq_id and return (self.seq_id == other.seq_id and
self.parent_seq_id == other.parent_seq_id and self.parent_seq_id == other.parent_seq_id and
self.output_token == other.output_token and self.output_token == other.output_token and

View File

@ -8,7 +8,7 @@ except ImportError:
from cacheflow.config import ParallelConfig from cacheflow.config import ParallelConfig
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id
def initialize_cluster( def initialize_cluster(

View File

@ -132,7 +132,7 @@ class Worker:
def _prepare_inputs( def _prepare_inputs(
self, self,
seq_group_metadata_list: List[SequenceGroupMetadata], seq_group_metadata_list: List[SequenceGroupMetadata],
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]: ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
seq_groups: List[Tuple[List[int], SamplingParams]] = [] seq_groups: List[Tuple[List[int], SamplingParams]] = []
input_tokens: List[int] = [] input_tokens: List[int] = []
input_positions: List[int] = [] input_positions: List[int] = []
@ -216,19 +216,14 @@ class Worker:
input_positions = _pad_to_alignment(input_positions, multiple_of=8) input_positions = _pad_to_alignment(input_positions, multiple_of=8)
# Convert to tensors. # Convert to tensors.
tokens_tensor = torch.tensor( tokens_tensor = torch.cuda.LongTensor(input_tokens)
input_tokens, dtype=torch.long, device='cuda') positions_tensor = torch.cuda.LongTensor(input_positions)
positions_tensor = torch.tensor( slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping)
input_positions, dtype=torch.long, device='cuda') context_lens_tensor = torch.cuda.IntTensor(context_lens)
slot_mapping_tensor = torch.tensor(
slot_mapping, dtype=torch.int, device='cuda')
context_lens_tensor = torch.tensor(
context_lens, dtype=torch.int, device='cuda')
padded_block_tables = [ padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq) _pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in generation_block_tables] for block_table in generation_block_tables]
block_tables_tensor = torch.tensor( block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
padded_block_tables, dtype=torch.int, device='cuda')
seq_data: Dict[int, SequenceData] = {} seq_data: Dict[int, SequenceData] = {}
for seq_group_metadata in seq_group_metadata_list: for seq_group_metadata in seq_group_metadata_list:

8
mypy.ini Normal file
View File

@ -0,0 +1,8 @@
[mypy]
python_version = 3.8
ignore_missing_imports = True
files = cacheflow
# TODO(woosuk): Include the code from Megatron and HuggingFace.
exclude = cacheflow/model_executor/parallel_utils/|cacheflow/model_executor/models/

2
requirements-dev.txt Normal file
View File

@ -0,0 +1,2 @@
mypy
pytest

View File

@ -49,7 +49,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
def forward( def forward(
self, self,
positions: torch.LongTensor, # [num_tokens] positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads, head_size] query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size] key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]: