import pytest from vllm.sequence import SamplerOutput, SequenceGroupOutput, SequenceOutput @pytest.fixture def sample_outputs(): return [ SequenceGroupOutput(samples=[ SequenceOutput(parent_seq_id=0, output_token=i, logprobs={}) ], prompt_logprobs=None) for i in range(5) ] @pytest.fixture def sampler_output(sample_outputs): return SamplerOutput(outputs=sample_outputs) def test_sampler_output_initialization(sampler_output, sample_outputs): assert len(sampler_output) == len(sample_outputs) assert sampler_output.sampled_token_probs is None assert sampler_output.sampled_token_ids is None assert sampler_output.spec_decode_worker_metrics is None def test_sampler_output_getitem(sampler_output, sample_outputs): assert sampler_output[2] == sample_outputs[2] def test_sampler_output_setitem(sampler_output): new_output = SequenceGroupOutput(samples=[ SequenceOutput(parent_seq_id=0, output_token=99, logprobs={}) ], prompt_logprobs=None) sampler_output[2] = new_output assert sampler_output[2] == new_output def test_sampler_output_len(sampler_output, sample_outputs): assert len(sampler_output) == len(sample_outputs) def test_sampler_output_eq(sample_outputs): sampler_output1 = SamplerOutput(outputs=sample_outputs) sampler_output2 = SamplerOutput(outputs=sample_outputs.copy()) sampler_output3 = SamplerOutput(outputs=sample_outputs[:-1]) assert sampler_output1 == sampler_output2 assert sampler_output1 != sampler_output3