33 lines
1.1 KiB
Python
33 lines
1.1 KiB
Python
![]() |
import torch
|
||
|
|
||
|
from vllm import LLM, ModelRegistry, SamplingParams
|
||
|
from vllm.model_executor.models.opt import OPTForCausalLM
|
||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||
|
|
||
|
|
||
|
class MyOPTForCausalLM(OPTForCausalLM):
|
||
|
|
||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||
|
# this dummy model always predicts the first token
|
||
|
logits = super().compute_logits(hidden_states, sampling_metadata)
|
||
|
logits.zero_()
|
||
|
logits[:, 0] += 1.0
|
||
|
return logits
|
||
|
|
||
|
|
||
|
def test_oot_registration():
|
||
|
# register our dummy model
|
||
|
ModelRegistry.register_model("OPTForCausalLM", MyOPTForCausalLM)
|
||
|
prompts = ["Hello, my name is", "The text does not matter"]
|
||
|
sampling_params = SamplingParams(temperature=0)
|
||
|
llm = LLM(model="facebook/opt-125m")
|
||
|
first_token = llm.get_tokenizer().decode(0)
|
||
|
outputs = llm.generate(prompts, sampling_params)
|
||
|
|
||
|
for output in outputs:
|
||
|
generated_text = output.outputs[0].text
|
||
|
# make sure only the first token is generated
|
||
|
rest = generated_text.replace(first_token, "")
|
||
|
assert rest == ""
|