[Bugfix] Fix pickle of input when async output processing is on (#9931)
Signed-off-by: Wallas Santos <wallashss@ibm.com>
This commit is contained in:
parent
43300bd98a
commit
966e31697b
@ -156,3 +156,29 @@ def test_model_with_failure(vllm_runner) -> None:
|
|||||||
ModelInputForGPUWithSamplingMetadata)
|
ModelInputForGPUWithSamplingMetadata)
|
||||||
finally:
|
finally:
|
||||||
os.remove(filename)
|
os.remove(filename)
|
||||||
|
|
||||||
|
|
||||||
|
def test_failure_with_async_out_proc(vllm_runner) -> None:
|
||||||
|
|
||||||
|
filename = None
|
||||||
|
try:
|
||||||
|
with vllm_runner("facebook/opt-125m",
|
||||||
|
dtype="half",
|
||||||
|
enforce_eager=False,
|
||||||
|
gpu_memory_utilization=0.7) as vllm_model,\
|
||||||
|
patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
|
||||||
|
side_effect=ValueError()):
|
||||||
|
model_config = vllm_model.model.llm_engine.model_config
|
||||||
|
assert model_config.use_async_output_proc
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
vllm_model.generate_greedy('how to make pizza?', 250)
|
||||||
|
matches = re.search(r"input dumped to (.+).pkl",
|
||||||
|
str(exc_info.value))
|
||||||
|
assert matches is not None
|
||||||
|
|
||||||
|
filename = f"{matches.group(1)}.pkl"
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
if filename is not None:
|
||||||
|
os.remove(filename)
|
||||||
|
pass
|
||||||
|
@ -136,6 +136,18 @@ class ModelInputForGPU(ModelRunnerInputBase):
|
|||||||
attn_backend, tensor_dict)
|
attn_backend, tensor_dict)
|
||||||
return cls(**tensor_dict)
|
return cls(**tensor_dict)
|
||||||
|
|
||||||
|
# Exclude `async_callback` to be able to pickle this object
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
del state["async_callback"]
|
||||||
|
return state
|
||||||
|
|
||||||
|
# TODO: What happens when we depickle this object?
|
||||||
|
# How can we update this callback to properly pass it to the engine?
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self.__dict__.update({'async_callback': None})
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user