2024-08-13 13:33:41 +08:00
|
|
|
from typing import Optional
|
|
|
|
|
2024-04-06 17:11:41 -07:00
|
|
|
import torch
|
|
|
|
|
|
|
|
from vllm import LLM, ModelRegistry, SamplingParams
|
|
|
|
from vllm.model_executor.models.opt import OPTForCausalLM
|
|
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
|
|
|
|
|
|
|
|
|
|
class MyOPTForCausalLM(OPTForCausalLM):
|
|
|
|
|
2024-08-13 13:33:41 +08:00
|
|
|
def compute_logits(
|
|
|
|
self,
|
|
|
|
hidden_states: torch.Tensor,
|
|
|
|
sampling_metadata: SamplingMetadata,
|
|
|
|
) -> Optional[torch.Tensor]:
|
2024-04-06 17:11:41 -07:00
|
|
|
# this dummy model always predicts the first token
|
|
|
|
logits = super().compute_logits(hidden_states, sampling_metadata)
|
|
|
|
logits.zero_()
|
|
|
|
logits[:, 0] += 1.0
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
|
|
|
def test_oot_registration():
|
|
|
|
# register our dummy model
|
|
|
|
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
|
|
|
|
prompts = ["Hello, my name is", "The text does not matter"]
|
|
|
|
sampling_params = SamplingParams(temperature=0)
|
|
|
|
llm = LLM(model="facebook/opt-125m")
|
|
|
|
first_token = llm.get_tokenizer().decode(0)
|
|
|
|
outputs = llm.generate(prompts, sampling_params)
|
|
|
|
|
|
|
|
for output in outputs:
|
|
|
|
generated_text = output.outputs[0].text
|
|
|
|
# make sure only the first token is generated
|
|
|
|
rest = generated_text.replace(first_token, "")
|
|
|
|
assert rest == ""
|