Add contributing guideline and mypy config (#122)
This commit is contained in:
parent
3f942acfe1
commit
a283ec2eec
74
CONTRIBUTING.md
Normal file
74
CONTRIBUTING.md
Normal file
@ -0,0 +1,74 @@
|
||||
# Contributing to CacheFlow
|
||||
|
||||
Thank you for your interest in contributing to CacheFlow!
|
||||
Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large.
|
||||
There are several ways you can contribute to the project:
|
||||
|
||||
- Identify and report any issues or bugs.
|
||||
- Request or add a new model.
|
||||
- Suggest or implement new features.
|
||||
|
||||
However, remember that contributions aren't just about code.
|
||||
We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions.
|
||||
|
||||
Finally, one of the most impactful ways to support us is by raising awareness about CacheFlow.
|
||||
Talk about it in your blog posts, highlighting how it's driving your incredible projects.
|
||||
Express your support on Twitter if CacheFlow aids you, or simply offer your appreciation by starring our repository.
|
||||
|
||||
|
||||
## Setup for development
|
||||
|
||||
### Build from source
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip install -e . # This may take several minutes.
|
||||
```
|
||||
|
||||
### Testing
|
||||
|
||||
```bash
|
||||
pip install -r requirements-dev.txt
|
||||
|
||||
# Static type checking
|
||||
mypy
|
||||
# Unit tests
|
||||
pytest tests/
|
||||
```
|
||||
**Note:** Currently, the repository does not pass the mypy tests.
|
||||
|
||||
|
||||
## Contributing Guidelines
|
||||
|
||||
### Issue Reporting
|
||||
|
||||
If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it.
|
||||
If not, please file a new issue, providing as much relevant information as possible.
|
||||
|
||||
### Coding Style Guide
|
||||
|
||||
In general, we adhere to [Google Python style guide](https://google.github.io/styleguide/pyguide.html) and [Google C++ style guide](https://google.github.io/styleguide/cppguide.html).
|
||||
|
||||
### Pull Requests
|
||||
|
||||
When submitting a pull request:
|
||||
|
||||
1. Make sure your code has been rebased on top of the latest commit on the main branch.
|
||||
2. Include a detailed description of the changes in the pull request.
|
||||
Explain why you made the changes you did.
|
||||
If your pull request fixes an open issue, please include a reference to it in the description.
|
||||
|
||||
### Code Reviews
|
||||
|
||||
All submissions, including submissions by project members, require a code review.
|
||||
To make the review process as smooth as possible, please:
|
||||
|
||||
1. Keep your changes as concise as possible.
|
||||
If your pull request involves multiple unrelated changes, consider splitting it into separate pull requests.
|
||||
2. Respond to all comments within a reasonable time frame.
|
||||
If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.
|
||||
|
||||
### Thank You
|
||||
|
||||
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to CacheFlow.
|
||||
Your contributions make CacheFlow a great tool for everyone!
|
@ -87,7 +87,7 @@ class Scheduler:
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
return self.waiting or self.running or self.swapped
|
||||
|
||||
def _schedule(self) -> Tuple[SchedulerOutputs, List[int]]:
|
||||
def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
|
||||
# Blocks that need to be swaped or copied before model execution.
|
||||
blocks_to_swap_in: Dict[int, int] = {}
|
||||
blocks_to_swap_out: Dict[int, int] = {}
|
||||
|
@ -61,7 +61,7 @@ class GPTCacheFlowAttention(nn.Module):
|
||||
key: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
value: torch.Tensor, # [num_prompt_tokens, num_heads, head_size]
|
||||
attn_bias: xops.AttentionBias,
|
||||
) -> None:
|
||||
) -> torch.Tensor:
|
||||
# TODO(woosuk): The unsqueeze op may incur some CPU overhead. Optimize.
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query.unsqueeze(0),
|
||||
@ -197,7 +197,7 @@ class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.LongTensor, # [num_tokens]
|
||||
positions: torch.Tensor, # [num_tokens]
|
||||
query: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
value: torch.Tensor, # [num_tokens, num_heads * head_size]
|
||||
|
@ -347,7 +347,7 @@ def _sample_from_generation_tokens(
|
||||
# Greedy sampling.
|
||||
assert len(seq_ids) == 1
|
||||
next_token_id = torch.argmax(probs, dim=-1)
|
||||
next_token_ids = [next_token_id.item()]
|
||||
next_token_ids = [int(next_token_id.item())]
|
||||
parent_seq_ids = seq_ids
|
||||
else:
|
||||
# Random sampling.
|
||||
|
@ -1,4 +1,6 @@
|
||||
"""Utilities for selecting and loading models."""
|
||||
from typing import Type
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
@ -17,7 +19,7 @@ _MODEL_REGISTRY = {
|
||||
}
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> nn.Module:
|
||||
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _MODEL_REGISTRY:
|
||||
|
@ -168,8 +168,8 @@ class GPT2Model(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
@ -204,8 +204,8 @@ class GPT2LMHeadModel(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
|
@ -67,7 +67,7 @@ class GPTNeoXAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.LongTensor,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
@ -118,7 +118,7 @@ class GPTNeoXLayer(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.LongTensor,
|
||||
position_ids: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
@ -162,8 +162,8 @@ class GPTNeoXModel(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
@ -199,8 +199,8 @@ class GPTNeoXForCausalLM(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
|
@ -109,7 +109,7 @@ class LlamaAttention(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
@ -143,7 +143,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.LongTensor,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
@ -184,8 +184,8 @@ class LlamaModel(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
@ -222,8 +222,8 @@ class LlamaForCausalLM(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
|
@ -47,7 +47,7 @@ class OPTLearnedPositionalEmbedding(nn.Embedding):
|
||||
self.offset = 2
|
||||
super().__init__(num_embeddings + self.offset, embedding_dim)
|
||||
|
||||
def forward(self, positions: torch.LongTensor):
|
||||
def forward(self, positions: torch.Tensor):
|
||||
return super().forward(positions + self.offset)
|
||||
|
||||
|
||||
@ -199,8 +199,8 @@ class OPTDecoder(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
@ -235,8 +235,8 @@ class OPTModel(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
@ -258,8 +258,8 @@ class OPTForCausalLM(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
|
@ -31,7 +31,7 @@ class RequestOutput:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
request_id: int,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
prompt_token_ids: List[int],
|
||||
outputs: List[CompletionOutput],
|
||||
|
@ -116,10 +116,11 @@ class Sequence:
|
||||
def get_cumulative_logprob(self) -> float:
|
||||
return self.data.cumulative_logprob
|
||||
|
||||
def fork(self, child_seq: 'Sequence') -> 'Sequence':
|
||||
def fork(self, child_seq: 'Sequence') -> None:
|
||||
child_seq.logical_token_blocks = copy.deepcopy(self.logical_token_blocks)
|
||||
child_seq.output_logprobs = copy.deepcopy(self.output_logprobs)
|
||||
child_seq.data = copy.deepcopy(self.data)
|
||||
return None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'Sequence(seq_id={self.seq_id}, '
|
||||
@ -205,7 +206,9 @@ class SequenceOutputs:
|
||||
f'output_token={self.output_token}), '
|
||||
f'logprobs={self.logprobs}')
|
||||
|
||||
def __eq__(self, other: 'SequenceOutputs') -> bool:
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, SequenceOutputs):
|
||||
return NotImplemented
|
||||
return (self.seq_id == other.seq_id and
|
||||
self.parent_seq_id == other.parent_seq_id and
|
||||
self.output_token == other.output_token and
|
||||
|
@ -8,7 +8,7 @@ except ImportError:
|
||||
|
||||
from cacheflow.config import ParallelConfig
|
||||
|
||||
DeviceID = Tuple[int, str, int] # rank, node resource (node IP), device id
|
||||
DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id
|
||||
|
||||
|
||||
def initialize_cluster(
|
||||
|
@ -132,7 +132,7 @@ class Worker:
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.LongTensor, torch.LongTensor, InputMetadata]:
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]:
|
||||
seq_groups: List[Tuple[List[int], SamplingParams]] = []
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
@ -216,19 +216,14 @@ class Worker:
|
||||
input_positions = _pad_to_alignment(input_positions, multiple_of=8)
|
||||
|
||||
# Convert to tensors.
|
||||
tokens_tensor = torch.tensor(
|
||||
input_tokens, dtype=torch.long, device='cuda')
|
||||
positions_tensor = torch.tensor(
|
||||
input_positions, dtype=torch.long, device='cuda')
|
||||
slot_mapping_tensor = torch.tensor(
|
||||
slot_mapping, dtype=torch.int, device='cuda')
|
||||
context_lens_tensor = torch.tensor(
|
||||
context_lens, dtype=torch.int, device='cuda')
|
||||
tokens_tensor = torch.cuda.LongTensor(input_tokens)
|
||||
positions_tensor = torch.cuda.LongTensor(input_positions)
|
||||
slot_mapping_tensor = torch.cuda.IntTensor(slot_mapping)
|
||||
context_lens_tensor = torch.cuda.IntTensor(context_lens)
|
||||
padded_block_tables = [
|
||||
_pad_to_max(block_table, max_num_blocks_per_seq)
|
||||
for block_table in generation_block_tables]
|
||||
block_tables_tensor = torch.tensor(
|
||||
padded_block_tables, dtype=torch.int, device='cuda')
|
||||
block_tables_tensor = torch.cuda.IntTensor(padded_block_tables)
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
|
8
mypy.ini
Normal file
8
mypy.ini
Normal file
@ -0,0 +1,8 @@
|
||||
[mypy]
|
||||
python_version = 3.8
|
||||
|
||||
ignore_missing_imports = True
|
||||
|
||||
files = cacheflow
|
||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||
exclude = cacheflow/model_executor/parallel_utils/|cacheflow/model_executor/models/
|
2
requirements-dev.txt
Normal file
2
requirements-dev.txt
Normal file
@ -0,0 +1,2 @@
|
||||
mypy
|
||||
pytest
|
@ -49,7 +49,7 @@ class RefRotaryEmbeddingNeox(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.LongTensor, # [num_tokens]
|
||||
positions: torch.Tensor, # [num_tokens]
|
||||
query: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user