[V1][BugFix] Fix Generator construction in greedy + seed case (#10097)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2024-11-07 05:06:57 +00:00 committed by GitHub
parent e7b84c394d
commit 1fa020c539
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -146,7 +146,7 @@ class GPUModelRunner:
for req_data in scheduler_output.scheduled_new_reqs:
req_id = req_data.req_id
sampling_params = req_data.sampling_params
if sampling_params.seed is not None:
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
generator = torch.Generator(device=self.device)
generator.manual_seed(sampling_params.seed)
else:
@ -382,7 +382,8 @@ class GPUModelRunner:
# Rewind the generator state as if the token was not sampled.
generator = self.input_batch.generators.get(i)
if generator is not None:
generator.set_offset(generator.get_offset() - 1)
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)
if sampler_output.logprob_token_ids is None:
logprob_token_ids = None