[V1][BugFix] Fix Generator construction in greedy + seed case (#10097)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
e7b84c394d
commit
1fa020c539
@ -146,7 +146,7 @@ class GPUModelRunner:
|
|||||||
for req_data in scheduler_output.scheduled_new_reqs:
|
for req_data in scheduler_output.scheduled_new_reqs:
|
||||||
req_id = req_data.req_id
|
req_id = req_data.req_id
|
||||||
sampling_params = req_data.sampling_params
|
sampling_params = req_data.sampling_params
|
||||||
if sampling_params.seed is not None:
|
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||||
generator = torch.Generator(device=self.device)
|
generator = torch.Generator(device=self.device)
|
||||||
generator.manual_seed(sampling_params.seed)
|
generator.manual_seed(sampling_params.seed)
|
||||||
else:
|
else:
|
||||||
@ -382,7 +382,8 @@ class GPUModelRunner:
|
|||||||
# Rewind the generator state as if the token was not sampled.
|
# Rewind the generator state as if the token was not sampled.
|
||||||
generator = self.input_batch.generators.get(i)
|
generator = self.input_batch.generators.get(i)
|
||||||
if generator is not None:
|
if generator is not None:
|
||||||
generator.set_offset(generator.get_offset() - 1)
|
# This relies on cuda-specific torch-internal impl details
|
||||||
|
generator.set_offset(generator.get_offset() - 4)
|
||||||
|
|
||||||
if sampler_output.logprob_token_ids is None:
|
if sampler_output.logprob_token_ids is None:
|
||||||
logprob_token_ids = None
|
logprob_token_ids = None
|
||||||
|
Loading…
x
Reference in New Issue
Block a user