[Core] Implement sharded state loader (#4690)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
52f8107cf2
commit
30e754390c
75
examples/save_sharded_state.py
Normal file
75
examples/save_sharded_state.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
"""
|
||||||
|
Saves each worker's model state dict directly to a checkpoint, which enables a
|
||||||
|
fast load path for large tensor-parallel models where each worker only needs to
|
||||||
|
read its own shard rather than the entire checkpoint.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
|
||||||
|
python save_sharded_state.py \
|
||||||
|
--model /path/to/load \
|
||||||
|
--quantization deepspeedfp \
|
||||||
|
--tensor-parallel-size 8 \
|
||||||
|
--output /path/to/save
|
||||||
|
|
||||||
|
Then, the model can be loaded with
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model="/path/to/save",
|
||||||
|
load_format="sharded_state",
|
||||||
|
quantization="deepspeedfp",
|
||||||
|
tensor_parallel_size=8,
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
import argparse
|
||||||
|
import dataclasses
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from vllm import LLM, EngineArgs
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
EngineArgs.add_cli_args(parser)
|
||||||
|
parser.add_argument("--output",
|
||||||
|
"-o",
|
||||||
|
required=True,
|
||||||
|
type=str,
|
||||||
|
help="path to output checkpoint")
|
||||||
|
parser.add_argument("--file-pattern",
|
||||||
|
type=str,
|
||||||
|
help="string pattern of saved filenames")
|
||||||
|
parser.add_argument("--max-file-size",
|
||||||
|
type=str,
|
||||||
|
default=5 * 1024**3,
|
||||||
|
help="max size (in bytes) of each safetensors file")
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
engine_args = EngineArgs.from_cli_args(args)
|
||||||
|
if engine_args.enable_lora:
|
||||||
|
raise ValueError("Saving with enable_lora=True is not supported!")
|
||||||
|
model_path = engine_args.model
|
||||||
|
if not Path(model_path).is_dir():
|
||||||
|
raise ValueError("model path must be a local directory")
|
||||||
|
# Create LLM instance from arguments
|
||||||
|
llm = LLM(**dataclasses.asdict(engine_args))
|
||||||
|
# Prepare output directory
|
||||||
|
Path(args.output).mkdir(exist_ok=True)
|
||||||
|
# Dump worker states to output directory
|
||||||
|
model_executor = llm.llm_engine.model_executor
|
||||||
|
model_executor.save_sharded_state(path=args.output,
|
||||||
|
pattern=args.file_pattern,
|
||||||
|
max_size=args.max_file_size)
|
||||||
|
# Copy metadata files to output directory
|
||||||
|
for file in os.listdir(model_path):
|
||||||
|
if os.path.splitext(file)[1] not in (".bin", ".pt", ".safetensors"):
|
||||||
|
if os.path.isdir(os.path.join(model_path, file)):
|
||||||
|
shutil.copytree(os.path.join(model_path, file),
|
||||||
|
os.path.join(args.output, file))
|
||||||
|
else:
|
||||||
|
shutil.copy(os.path.join(model_path, file), args.output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
90
tests/test_sharded_state_loader.py
Normal file
90
tests/test_sharded_state_loader.py
Normal file
@ -0,0 +1,90 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
from vllm.model_executor.model_loader.loader import ShardedStateLoader
|
||||||
|
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is",
|
||||||
|
"The president of the United States is",
|
||||||
|
"The capital of France is",
|
||||||
|
"The future of AI is",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create a sampling params object.
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.8,
|
||||||
|
top_p=0.95,
|
||||||
|
seed=0,
|
||||||
|
max_tokens=256,
|
||||||
|
ignore_eos=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_subtensors():
|
||||||
|
state_dict = {
|
||||||
|
"a": torch.empty(2),
|
||||||
|
"b": torch.empty((2, 4)),
|
||||||
|
"c": torch.empty((2, 4, 8)),
|
||||||
|
}
|
||||||
|
state_dict.update({
|
||||||
|
"x": state_dict["b"],
|
||||||
|
"y": state_dict["c"][1, 2, :],
|
||||||
|
"z": state_dict["c"][1, :, 4],
|
||||||
|
})
|
||||||
|
filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict)
|
||||||
|
assert tuple(filtered_state_dict.keys()) == ("a", "b", "c")
|
||||||
|
for key, tensor in filtered_state_dict.items():
|
||||||
|
assert tensor.equal(state_dict[key])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("enable_lora", [False, True])
|
||||||
|
def test_sharded_state_loader(enable_lora):
|
||||||
|
weights_patterns = ("*.bin", "*.pt", "*.safetensors")
|
||||||
|
|
||||||
|
with TemporaryDirectory() as cache_dir, TemporaryDirectory() as output_dir:
|
||||||
|
input_dir = snapshot_download("meta-llama/Llama-2-7b-hf",
|
||||||
|
cache_dir=cache_dir)
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model=input_dir,
|
||||||
|
worker_use_ray=True,
|
||||||
|
gpu_memory_utilization=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dump worker states to output directory
|
||||||
|
model_executor = llm.llm_engine.model_executor
|
||||||
|
model_executor.save_sharded_state(path=output_dir)
|
||||||
|
# Copy metadata files to output directory
|
||||||
|
for file in os.listdir(input_dir):
|
||||||
|
if not any(file.endswith(ext) for ext in weights_patterns):
|
||||||
|
shutil.copy(f"{input_dir}/{file}", output_dir)
|
||||||
|
del llm.llm_engine.model_executor
|
||||||
|
|
||||||
|
llm_before = LLM(
|
||||||
|
model=input_dir,
|
||||||
|
worker_use_ray=True,
|
||||||
|
enable_lora=enable_lora,
|
||||||
|
gpu_memory_utilization=0.3,
|
||||||
|
)
|
||||||
|
gen_before = llm_before.generate(prompts, sampling_params)
|
||||||
|
out_before = [gen.outputs[0].__dict__ for gen in gen_before]
|
||||||
|
del llm_before.llm_engine.model_executor
|
||||||
|
|
||||||
|
llm_after = LLM(
|
||||||
|
model=output_dir,
|
||||||
|
worker_use_ray=True,
|
||||||
|
enable_lora=enable_lora,
|
||||||
|
gpu_memory_utilization=0.3,
|
||||||
|
load_format="sharded_state",
|
||||||
|
)
|
||||||
|
gen_after = llm_after.generate(prompts, sampling_params)
|
||||||
|
out_after = [gen.outputs[0].__dict__ for gen in gen_after]
|
||||||
|
del llm_after.llm_engine.model_executor
|
||||||
|
|
||||||
|
assert out_before == out_after
|
@ -463,6 +463,7 @@ class LoadFormat(str, enum.Enum):
|
|||||||
NPCACHE = "npcache"
|
NPCACHE = "npcache"
|
||||||
DUMMY = "dummy"
|
DUMMY = "dummy"
|
||||||
TENSORIZER = "tensorizer"
|
TENSORIZER = "tensorizer"
|
||||||
|
SHARDED_STATE = "sharded_state"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -77,6 +77,17 @@ class DistributedGPUExecutor(GPUExecutor):
|
|||||||
def list_loras(self) -> Set[int]:
|
def list_loras(self) -> Set[int]:
|
||||||
return self._run_workers("list_loras")
|
return self._run_workers("list_loras")
|
||||||
|
|
||||||
|
def save_sharded_state(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
max_size: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
self._run_workers("save_sharded_state",
|
||||||
|
path=path,
|
||||||
|
pattern=pattern,
|
||||||
|
max_size=max_size)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _run_workers(
|
def _run_workers(
|
||||||
self,
|
self,
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
# ruff: noqa: SIM117
|
# ruff: noqa: SIM117
|
||||||
|
import collections
|
||||||
import copy
|
import copy
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
@ -366,6 +367,150 @@ class TensorizerLoader(BaseModelLoader):
|
|||||||
cache_config)
|
cache_config)
|
||||||
|
|
||||||
|
|
||||||
|
class ShardedStateLoader(BaseModelLoader):
|
||||||
|
"""
|
||||||
|
Model loader that directly loads each worker's model state dict, which
|
||||||
|
enables a fast load path for large tensor-parallel models where each worker
|
||||||
|
only needs to read its own shard rather than the entire checkpoint. See
|
||||||
|
`examples/save_sharded_states.py` for creating a sharded checkpoint.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
||||||
|
|
||||||
|
def __init__(self, load_config: LoadConfig):
|
||||||
|
super().__init__(load_config)
|
||||||
|
extra_config = ({} if load_config.model_loader_extra_config is None
|
||||||
|
else load_config.model_loader_extra_config.copy())
|
||||||
|
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
|
||||||
|
if extra_config:
|
||||||
|
raise ValueError(f"Unexpected extra config keys for load format "
|
||||||
|
f"{load_config.load_format}: "
|
||||||
|
f"{load_config.model_loader_extra_config.keys()}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _filter_subtensors(
|
||||||
|
tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Filter out all tensors that share the same memory or a subset of the
|
||||||
|
memory of another tensor.
|
||||||
|
"""
|
||||||
|
same_storage_groups = collections.defaultdict(list)
|
||||||
|
for key, tensor in tensors.items():
|
||||||
|
if tensor.numel():
|
||||||
|
ptr = tensor.untyped_storage().data_ptr()
|
||||||
|
same_storage_groups[tensor.device, ptr].append((key, tensor))
|
||||||
|
|
||||||
|
def get_end_ptr(tensor: torch.Tensor) -> int:
|
||||||
|
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
|
||||||
|
|
||||||
|
result = {}
|
||||||
|
for group in same_storage_groups.values():
|
||||||
|
for k, t in group:
|
||||||
|
a, b = t.data_ptr(), get_end_ptr(t)
|
||||||
|
for k2, t2 in group:
|
||||||
|
if not t2.is_contiguous():
|
||||||
|
continue
|
||||||
|
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
|
||||||
|
if a < a2 or b2 < b:
|
||||||
|
continue
|
||||||
|
if a2 < a or b < b2 or not t.is_contiguous():
|
||||||
|
break # t2 covers strictly more memory than t.
|
||||||
|
if k2 < k:
|
||||||
|
# Same tensors, keep the one with the smaller key.
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
result[k] = t
|
||||||
|
return result
|
||||||
|
|
||||||
|
def load_model(self, *, model_config: ModelConfig,
|
||||||
|
device_config: DeviceConfig,
|
||||||
|
lora_config: Optional[LoRAConfig],
|
||||||
|
vision_language_config: Optional[VisionLanguageConfig],
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
scheduler_config: SchedulerConfig,
|
||||||
|
cache_config: CacheConfig) -> nn.Module:
|
||||||
|
from safetensors.torch import safe_open
|
||||||
|
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
|
with set_default_torch_dtype(model_config.dtype):
|
||||||
|
with torch.device(device_config.device):
|
||||||
|
model = _initialize_model(model_config, self.load_config,
|
||||||
|
lora_config, vision_language_config,
|
||||||
|
cache_config)
|
||||||
|
rank = get_tensor_model_parallel_rank()
|
||||||
|
pattern = os.path.join(
|
||||||
|
model_config.model,
|
||||||
|
self.pattern.format(rank=rank, part="*"),
|
||||||
|
)
|
||||||
|
filepaths = glob.glob(pattern)
|
||||||
|
if not filepaths:
|
||||||
|
# TODO: support un-sharded checkpoints too
|
||||||
|
raise ValueError(
|
||||||
|
f"Could not find checkpoint files '{pattern}', only "
|
||||||
|
f"pre-sharded checkpoints are currently supported!")
|
||||||
|
state_dict = self._filter_subtensors(model.state_dict())
|
||||||
|
for path in filepaths:
|
||||||
|
with safe_open(path, framework="pt") as f:
|
||||||
|
for key in f.keys(): # noqa: SIM118
|
||||||
|
tensor = f.get_tensor(key)
|
||||||
|
# If loading with LoRA enabled, additional padding may
|
||||||
|
# be added to certain parameters. We only load into a
|
||||||
|
# narrowed view of the parameter data.
|
||||||
|
param_data = state_dict[key].data
|
||||||
|
param_shape = state_dict[key].shape
|
||||||
|
for dim, size in enumerate(tensor.shape):
|
||||||
|
if size < param_shape[dim]:
|
||||||
|
param_data = param_data.narrow(dim, 0, size)
|
||||||
|
if tensor.shape != param_shape:
|
||||||
|
logger.warning(
|
||||||
|
"loading tensor of shape %s into "
|
||||||
|
"parameter '%s' of shape %s", tensor.shape,
|
||||||
|
key, param_shape)
|
||||||
|
param_data.copy_(tensor)
|
||||||
|
state_dict.pop(key)
|
||||||
|
if state_dict:
|
||||||
|
raise ValueError(
|
||||||
|
f"Missing keys {tuple(state_dict)} in loaded state!")
|
||||||
|
return model.eval()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def save_model(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
path: str,
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
max_size: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
|
if pattern is None:
|
||||||
|
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
||||||
|
rank = get_tensor_model_parallel_rank()
|
||||||
|
part_idx = 0
|
||||||
|
total_size = 0
|
||||||
|
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
||||||
|
state_dict_part: Dict[str, torch.Tensor] = {}
|
||||||
|
for key, tensor in state_dict.items():
|
||||||
|
param_size = tensor.nelement() * tensor.element_size()
|
||||||
|
if max_size is not None and total_size + param_size > max_size:
|
||||||
|
filename = pattern.format(rank=rank, part=part_idx)
|
||||||
|
save_file(
|
||||||
|
state_dict_part,
|
||||||
|
os.path.join(path, filename),
|
||||||
|
)
|
||||||
|
part_idx += 1
|
||||||
|
total_size = 0
|
||||||
|
state_dict_part = {}
|
||||||
|
state_dict_part[key] = tensor
|
||||||
|
total_size += param_size
|
||||||
|
if len(state_dict_part) > 0:
|
||||||
|
filename = pattern.format(rank=rank, part=part_idx)
|
||||||
|
save_file(
|
||||||
|
state_dict_part,
|
||||||
|
os.path.join(path, filename),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||||
"""Get a model loader based on the load format."""
|
"""Get a model loader based on the load format."""
|
||||||
|
|
||||||
@ -378,4 +523,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|||||||
if load_config.load_format == LoadFormat.TENSORIZER:
|
if load_config.load_format == LoadFormat.TENSORIZER:
|
||||||
return TensorizerLoader(load_config)
|
return TensorizerLoader(load_config)
|
||||||
|
|
||||||
|
if load_config.load_format == LoadFormat.SHARDED_STATE:
|
||||||
|
return ShardedStateLoader(load_config)
|
||||||
|
|
||||||
return DefaultModelLoader(load_config)
|
return DefaultModelLoader(load_config)
|
||||||
|
@ -182,6 +182,20 @@ class ModelRunner:
|
|||||||
"but the KV cache data type is not FP8. "
|
"but the KV cache data type is not FP8. "
|
||||||
"KV cache scaling factors will not be used.")
|
"KV cache scaling factors will not be used.")
|
||||||
|
|
||||||
|
def save_sharded_state(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
max_size: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
from vllm.model_executor.model_loader.loader import ShardedStateLoader
|
||||||
|
ShardedStateLoader.save_model(
|
||||||
|
self.model,
|
||||||
|
path,
|
||||||
|
pattern=pattern,
|
||||||
|
max_size=max_size,
|
||||||
|
)
|
||||||
|
|
||||||
def get_max_block_per_batch(self) -> int:
|
def get_max_block_per_batch(self) -> int:
|
||||||
block_size = self.block_size
|
block_size = self.block_size
|
||||||
return (self.max_seq_len_to_capture + block_size - 1) // block_size
|
return (self.max_seq_len_to_capture + block_size - 1) // block_size
|
||||||
|
@ -119,6 +119,18 @@ class Worker(WorkerBase):
|
|||||||
def load_model(self):
|
def load_model(self):
|
||||||
self.model_runner.load_model()
|
self.model_runner.load_model()
|
||||||
|
|
||||||
|
def save_sharded_state(
|
||||||
|
self,
|
||||||
|
path: str,
|
||||||
|
pattern: Optional[str] = None,
|
||||||
|
max_size: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
self.model_runner.save_sharded_state(
|
||||||
|
path,
|
||||||
|
pattern=pattern,
|
||||||
|
max_size=max_size,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
"""Profiles the peak memory usage of the model to determine how many
|
"""Profiles the peak memory usage of the model to determine how many
|
||||||
|
Loading…
x
Reference in New Issue
Block a user