[V1][BugFix] Fix Generator construction in greedy + seed case (#10097)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
e7b84c394d
commit
1fa020c539
@ -146,7 +146,7 @@ class GPUModelRunner:
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = req_data.req_id
|
||||
sampling_params = req_data.sampling_params
|
||||
if sampling_params.seed is not None:
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
else:
|
||||
@ -382,7 +382,8 @@ class GPUModelRunner:
|
||||
# Rewind the generator state as if the token was not sampled.
|
||||
generator = self.input_batch.generators.get(i)
|
||||
if generator is not None:
|
||||
generator.set_offset(generator.get_offset() - 1)
|
||||
# This relies on cuda-specific torch-internal impl details
|
||||
generator.set_offset(generator.get_offset() - 4)
|
||||
|
||||
if sampler_output.logprob_token_ids is None:
|
||||
logprob_token_ids = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user