[HPU][Bugfix] set_forward_context and CI test execution (#12014)

Signed-off-by: Konrad Zawora <kzawora@habana.ai>
This commit is contained in:
Konrad Zawora 2025-01-14 04:04:18 +01:00 committed by GitHub
parent 1a401252b5
commit 078da31903
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 23 additions and 18 deletions

View File

@ -8,9 +8,12 @@ set -ex
docker build -t hpu-test-env -f Dockerfile.hpu . docker build -t hpu-test-env -f Dockerfile.hpu .
# Setup cleanup # Setup cleanup
EXITCODE=1
remove_docker_container() { docker rm -f hpu-test || true; } remove_docker_container() { docker rm -f hpu-test || true; }
trap remove_docker_container EXIT remove_docker_container_and_exit() { remove_docker_container; exit $EXITCODE; }
trap remove_docker_container_and_exit EXIT
remove_docker_container remove_docker_container
# Run the image and launch offline inference # Run the image and launch offline inference
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic.py docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic.py
EXITCODE=$?

View File

@ -1,4 +1,4 @@
FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest FROM vault.habana.ai/gaudi-docker/1.19.1/ubuntu22.04/habanalabs/pytorch-installer-2.5.1:latest
COPY ./ /workspace/vllm COPY ./ /workspace/vllm

View File

@ -289,12 +289,14 @@ def modify_decoder_layer(module: torch.nn.Module, suffix="DecoderLayer"):
class HpuModelAdapter: class HpuModelAdapter:
def __init__(self, model, block_size, dtype, enforce_eager): def __init__(self, model, vllm_config):
self.model = model self.model = model
self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA', self.prefill_use_fusedsdpa = os.getenv('VLLM_PROMPT_USE_FUSEDSDPA',
'0').lower() in ['1', 'true'] '0').lower() in ['1', 'true']
self.block_size = block_size self.vllm_config = vllm_config
self.dtype = dtype self.block_size = vllm_config.cache_config.block_size
self.dtype = vllm_config.model_config.dtype
enforce_eager = vllm_config.model_config.enforce_eager
if not htorch.utils.internal.is_lazy() and not enforce_eager: if not htorch.utils.internal.is_lazy() and not enforce_eager:
self.model = torch.compile(self.model, self.model = torch.compile(self.model,
backend='hpu_backend', backend='hpu_backend',
@ -353,14 +355,20 @@ class HpuModelAdapter:
selected_token_indices = kwargs.pop('selected_token_indices') selected_token_indices = kwargs.pop('selected_token_indices')
if 'warmup_mode' in kwargs: if 'warmup_mode' in kwargs:
kwargs.pop('warmup_mode') kwargs.pop('warmup_mode')
virtual_engine = 0
if 'virtual_engine' in kwargs:
virtual_engine = kwargs.pop('virtual_engine')
input_ids = kwargs['input_ids'] input_ids = kwargs['input_ids']
kwargs['attn_metadata'] = self._update_metadata( kwargs['attn_metadata'] = self._update_metadata(
kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1), kwargs['attn_metadata'], input_ids.size(0), input_ids.size(1),
input_ids.device, self.dtype) input_ids.device, self.dtype)
LoraMask.setLoraMask(kwargs.pop('lora_mask')) LoraMask.setLoraMask(kwargs.pop('lora_mask'))
hidden_states = self.model(*args, **kwargs) with set_forward_context(kwargs['attn_metadata'], self.vllm_config,
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) virtual_engine):
hidden_states = hidden_states.index_select(0, selected_token_indices) hidden_states = self.model(*args, **kwargs)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = hidden_states.index_select(0,
selected_token_indices)
return hidden_states return hidden_states
def compute_logits(self, *args, **kwargs): def compute_logits(self, *args, **kwargs):
@ -660,10 +668,7 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
with HabanaMemoryProfiler() as m_wrap: with HabanaMemoryProfiler() as m_wrap:
self.model = _maybe_wrap_in_hpu_graph( self.model = _maybe_wrap_in_hpu_graph(
self.model, self.model, vllm_config=self.vllm_config)
self.block_size,
dtype=self.model_config.dtype,
enforce_eager=self.enforce_eager)
msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}" msg = f"Wrapping in HPU Graph took {m_wrap.get_summary_string()}"
logger.info(msg) logger.info(msg)
@ -1934,6 +1939,7 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
"attn_metadata": self.trim_attn_metadata(attn_metadata), "attn_metadata": self.trim_attn_metadata(attn_metadata),
"intermediate_tensors": intermediate_tensors, "intermediate_tensors": intermediate_tensors,
"lora_mask": lora_mask, "lora_mask": lora_mask,
"virtual_engine": model_input.virtual_engine,
**(model_input.multi_modal_kwargs or {}), **(model_input.multi_modal_kwargs or {}),
} }
if htorch.utils.internal.is_lazy(): if htorch.utils.internal.is_lazy():
@ -1948,11 +1954,7 @@ class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]):
f"graphs{'T' if use_graphs else 'F'}") f"graphs{'T' if use_graphs else 'F'}")
else: else:
model_event_name = 'model_executable' model_event_name = 'model_executable'
with set_forward_context( with self.profiler.record_event('internal', model_event_name):
model_input.attn_metadata, self.vllm_config,
model_input.virtual_engine), \
self.profiler.record_event(
'internal', model_event_name):
hidden_states = self.model.forward( hidden_states = self.model.forward(
**execute_model_kwargs, **execute_model_kwargs,
selected_token_indices=sampling_metadata.selected_token_indices selected_token_indices=sampling_metadata.selected_token_indices