[core] fix sleep mode and pytorch checkpoint compatibility (#13001)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
parent
44607e07d3
commit
b2496bb07f
@ -115,10 +115,16 @@ def test_cumem_with_cudagraph():
|
|||||||
|
|
||||||
|
|
||||||
@fork_new_process_for_each_test
|
@fork_new_process_for_each_test
|
||||||
def test_end_to_end():
|
@pytest.mark.parametrize(
|
||||||
|
"model",
|
||||||
|
[
|
||||||
|
"meta-llama/Llama-3.2-1B", # sleep mode with safetensors
|
||||||
|
"facebook/opt-125m" # sleep mode with pytorch checkpoint
|
||||||
|
])
|
||||||
|
def test_end_to_end(model):
|
||||||
free, total = torch.cuda.mem_get_info()
|
free, total = torch.cuda.mem_get_info()
|
||||||
used_bytes_baseline = total - free # in case other process is running
|
used_bytes_baseline = total - free # in case other process is running
|
||||||
llm = LLM("meta-llama/Llama-3.2-1B", enable_sleep_mode=True)
|
llm = LLM(model, enable_sleep_mode=True)
|
||||||
prompt = "How are you?"
|
prompt = "How are you?"
|
||||||
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
sampling_params = SamplingParams(temperature=0, max_tokens=10)
|
||||||
output = llm.generate(prompt, sampling_params)
|
output = llm.generate(prompt, sampling_params)
|
||||||
|
@ -462,7 +462,6 @@ def pt_weights_iterator(
|
|||||||
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
||||||
yield from state.items()
|
yield from state.items()
|
||||||
del state
|
del state
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
|
|
||||||
def get_gguf_extra_tensor_names(
|
def get_gguf_extra_tensor_names(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user