[Misc][LoRA] Move the implementation of lora bias to punica.py (#10829)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
parent
a4c4daf364
commit
b45f0d7946
@ -55,15 +55,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
|
||||
return generated_texts
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_llama_lora(sql_lora_files):
|
||||
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=1)
|
||||
|
||||
def generate_and_test(llm, sql_lora_files):
|
||||
print("lora adapter created")
|
||||
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
||||
|
||||
@ -79,6 +71,17 @@ def test_llama_lora(sql_lora_files):
|
||||
print("removing lora")
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_llama_lora(sql_lora_files):
|
||||
|
||||
llm = vllm.LLM(MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=1)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
def test_llama_lora_warmup(sql_lora_files):
|
||||
"""Test that the LLM initialization works with a warmup LORA path and
|
||||
@ -118,20 +121,7 @@ def test_llama_lora_tp4(sql_lora_files):
|
||||
max_loras=4,
|
||||
tensor_parallel_size=4,
|
||||
)
|
||||
|
||||
print("lora adapter created")
|
||||
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
||||
|
||||
print("lora 1")
|
||||
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT
|
||||
|
||||
print("no lora")
|
||||
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
||||
|
||||
print("lora 2")
|
||||
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT
|
||||
|
||||
print("removing lora")
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
@ -146,16 +136,20 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
|
||||
tensor_parallel_size=4,
|
||||
fully_sharded_loras=True,
|
||||
)
|
||||
print("lora adapter created")
|
||||
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
||||
print("lora 1")
|
||||
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT
|
||||
|
||||
print("no lora")
|
||||
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
@fork_new_process_for_each_test
|
||||
def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):
|
||||
|
||||
print("lora 2")
|
||||
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT
|
||||
|
||||
print("removing lora")
|
||||
llm = vllm.LLM(
|
||||
MODEL_PATH,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=4,
|
||||
fully_sharded_loras=True,
|
||||
enable_lora_bias=True,
|
||||
)
|
||||
generate_and_test(llm, sql_lora_files)
|
||||
|
@ -73,6 +73,7 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
||||
self.punica_wrapper.add_expand(output,
|
||||
buffer,
|
||||
self.lora_b_stacked,
|
||||
self.bias_stacked,
|
||||
add_input=True)
|
||||
# now have column partitioned output
|
||||
|
||||
@ -131,27 +132,14 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
|
||||
layer.lora_a_stacked[idx], 1.0)
|
||||
|
||||
buffers = tensor_model_parallel_all_gather(buffers)
|
||||
left_offset = 0
|
||||
for idx in range(n):
|
||||
shard_size = layer.lora_b_stacked[idx].shape[2]
|
||||
|
||||
if layer.bias_stacked is not None:
|
||||
bias = layer.bias_stacked[idx]
|
||||
if bias is not None:
|
||||
bias = bias.view(-1, bias.shape[-1])
|
||||
bias = bias[layer.punica_wrapper.token_lora_indices]
|
||||
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
|
||||
output[:, left_offset:left_offset + shard_size] += bias
|
||||
|
||||
layer.punica_wrapper.add_expand_slice(
|
||||
layer.punica_wrapper.add_expand_packed_nslice(
|
||||
output,
|
||||
buffers[idx],
|
||||
layer.lora_b_stacked[idx],
|
||||
left_offset,
|
||||
shard_size,
|
||||
add_input=True,
|
||||
buffers,
|
||||
layer.lora_b_stacked,
|
||||
layer.bias_stacked,
|
||||
1.0,
|
||||
layer.output_slices,
|
||||
)
|
||||
left_offset += shard_size
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
# now have column partitioned and packed output
|
||||
@ -234,6 +222,7 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
|
||||
self.punica_wrapper.add_expand(output,
|
||||
buffer,
|
||||
self.lora_b_stacked,
|
||||
self.bias_all,
|
||||
add_input=True)
|
||||
# now have column partitioned output
|
||||
output = output.view(*out_orig_shape)
|
||||
@ -350,15 +339,9 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
# reduced before being used
|
||||
shard_size = self.lora_b_stacked.shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
|
||||
if self.bias_stacked is not None:
|
||||
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
|
||||
bias = bias[self.punica_wrapper.token_lora_indices]
|
||||
bias[self.punica_wrapper.token_lora_indices == -1] = 0
|
||||
output += bias
|
||||
|
||||
self.punica_wrapper.add_expand_slice(output, buffer,
|
||||
self.lora_b_stacked, start_idx,
|
||||
self.lora_b_stacked,
|
||||
self.bias_stacked, start_idx,
|
||||
shard_size)
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
@ -67,63 +67,6 @@ def _not_fully_sharded_can_replace(can_replace):
|
||||
return dec
|
||||
|
||||
|
||||
def apply_bias(
|
||||
indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
bias_stacked: torch.Tensor,
|
||||
):
|
||||
"""Applies bias to output
|
||||
|
||||
Input shapes:
|
||||
bias_stacked: (num_loras, output_dim)
|
||||
indices: (batch_size)
|
||||
output: (batch_size, output_dim)
|
||||
"""
|
||||
org_output = output
|
||||
output = output.view(-1, output.shape[-1])
|
||||
indices = indices.view(-1)
|
||||
|
||||
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
|
||||
bias_stacked = bias_stacked[indices]
|
||||
bias_stacked[indices == -1] = 0
|
||||
output += bias_stacked
|
||||
|
||||
return output.view_as(org_output)
|
||||
|
||||
|
||||
def apply_bias_packed_nslice(
|
||||
indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
output_slices: Tuple[int, ...],
|
||||
bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
):
|
||||
"""Applies bias to output
|
||||
|
||||
Input shapes:
|
||||
bias_stacked: 3 element tuple of (num_loras, output_dim)
|
||||
indices: (batch_size)
|
||||
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
||||
output_slices: n-1 element tuple of (slice_size...),
|
||||
where n is number of slices
|
||||
"""
|
||||
org_output = output
|
||||
output = output.view(-1, output.shape[-1])
|
||||
indices = indices.view(-1)
|
||||
|
||||
offset_left = 0
|
||||
for slice_idx, slice in enumerate(output_slices):
|
||||
bias = bias_stacked[slice_idx]
|
||||
if bias is not None:
|
||||
bias = bias.view(-1, bias.shape[-1])
|
||||
bias = bias[indices]
|
||||
bias[indices == -1] = 0
|
||||
output[:, offset_left:offset_left + slice] += bias
|
||||
|
||||
offset_left += slice
|
||||
|
||||
return output.view_as(org_output)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAMapping(AdapterMapping):
|
||||
is_prefill: bool = False
|
||||
@ -311,6 +254,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
self.punica_wrapper.add_expand(full_output,
|
||||
full_lora_a_embeddings,
|
||||
self.lora_b_stacked,
|
||||
bias_all=None,
|
||||
add_input=True)
|
||||
return full_output.view_as(full_output_org)
|
||||
|
||||
@ -399,15 +343,9 @@ class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
if self.bias_stacked is not None:
|
||||
self.indices = self.punica_wrapper.token_lora_indices
|
||||
output = apply_bias(
|
||||
self.indices,
|
||||
output,
|
||||
self.bias_stacked,
|
||||
)
|
||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0)
|
||||
self.lora_b_stacked, self.bias_stacked,
|
||||
1.0)
|
||||
return output
|
||||
|
||||
def forward(self, input_):
|
||||
@ -576,15 +514,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
if self.bias_stacked is not None:
|
||||
self.indices = self.punica_wrapper.token_lora_indices
|
||||
output = apply_bias(
|
||||
self.indices,
|
||||
output,
|
||||
self.bias_stacked,
|
||||
)
|
||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0)
|
||||
self.lora_b_stacked, self.bias_stacked,
|
||||
1.0)
|
||||
return output
|
||||
|
||||
def forward(self, input_):
|
||||
@ -687,8 +619,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
) for _ in range(n_slices))
|
||||
else:
|
||||
self.bias_stacked = None
|
||||
|
||||
self.output_dim = self.lora_b_stacked[0].shape[2]
|
||||
self.output_slices = (self.output_dim, self.output_dim)
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[0][index] = 0
|
||||
@ -772,17 +704,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
if self.bias_stacked is not None:
|
||||
self.indices = self.punica_wrapper.token_lora_indices
|
||||
output = apply_bias_packed_nslice(
|
||||
self.indices,
|
||||
output,
|
||||
(self.output_dim, self.output_dim),
|
||||
self.bias_stacked,
|
||||
)
|
||||
self.punica_wrapper.add_lora_packed_nslice(
|
||||
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
|
||||
(self.output_dim, self.output_dim))
|
||||
output, x, self.lora_a_stacked, self.lora_b_stacked,
|
||||
self.bias_stacked, 1.0, (self.output_dim, self.output_dim))
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@ -1129,17 +1053,10 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
if self.bias_stacked is not None:
|
||||
self.indices = self.punica_wrapper.token_lora_indices
|
||||
output = apply_bias_packed_nslice(
|
||||
self.indices,
|
||||
output,
|
||||
self.output_slices,
|
||||
self.bias_stacked,
|
||||
)
|
||||
self.punica_wrapper.add_lora_packed_nslice(output, x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0,
|
||||
self.lora_b_stacked,
|
||||
self.bias_stacked, 1.0,
|
||||
self.output_slices)
|
||||
return output
|
||||
|
||||
@ -1264,15 +1181,9 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
||||
if self.bias_stacked is not None:
|
||||
self.indices = self.punica_wrapper.token_lora_indices
|
||||
output = apply_bias(
|
||||
self.indices,
|
||||
output,
|
||||
self.bias_stacked,
|
||||
)
|
||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0)
|
||||
self.lora_b_stacked, self.bias_stacked,
|
||||
1.0)
|
||||
return output
|
||||
|
||||
def forward(self, input_):
|
||||
|
@ -450,6 +450,62 @@ class PunicaWrapper:
|
||||
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
||||
y_slice_size, add_input)
|
||||
|
||||
def apply_bias(
|
||||
self,
|
||||
indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
bias_stacked: torch.Tensor,
|
||||
):
|
||||
"""Applies bias to output
|
||||
|
||||
Input shapes:
|
||||
bias_stacked: (num_loras, output_dim)
|
||||
indices: (batch_size)
|
||||
output: (batch_size, output_dim)
|
||||
"""
|
||||
org_output = output
|
||||
output = output.view(-1, output.shape[-1])
|
||||
indices = indices.view(-1)
|
||||
|
||||
bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
|
||||
bias_stacked = bias_stacked[indices]
|
||||
bias_stacked[indices == -1] = 0
|
||||
output += bias_stacked
|
||||
|
||||
return output.view_as(org_output)
|
||||
|
||||
def apply_bias_packed_nslice(
|
||||
self,
|
||||
indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
output_slices: Tuple[int, ...],
|
||||
bias_stacked: Tuple[Optional[torch.Tensor], ...],
|
||||
):
|
||||
"""Applies bias to output
|
||||
|
||||
Input shapes:
|
||||
bias_stacked: 3 element tuple of (num_loras, output_dim)
|
||||
indices: (batch_size)
|
||||
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
||||
output_slices: n-1 element tuple of (slice_size...),
|
||||
where n is number of slices
|
||||
"""
|
||||
org_output = output
|
||||
output = output.view(-1, output.shape[-1])
|
||||
indices = indices.view(-1)
|
||||
|
||||
offset_left = 0
|
||||
for slice_idx, slice in enumerate(output_slices):
|
||||
bias = bias_stacked[slice_idx]
|
||||
if bias is not None:
|
||||
bias = bias.view(-1, bias.shape[-1])
|
||||
bias = bias[indices]
|
||||
bias[indices == -1] = 0
|
||||
output[:, offset_left:offset_left + slice] += bias
|
||||
offset_left += slice
|
||||
|
||||
return output.view_as(org_output)
|
||||
|
||||
def add_shrink(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
@ -474,16 +530,19 @@ class PunicaWrapper:
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
bias_all: Optional[torch.Tensor],
|
||||
add_input: bool = True,
|
||||
):
|
||||
"""
|
||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||
Perform the ` y+=x@w_t_all+bias` computation, which is suitable for the
|
||||
GEMM of lora'b.
|
||||
When `is_prefill` is true, it indicates that it is currently the
|
||||
prefill stage, and the `expand_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the expand_decode function
|
||||
should be called.
|
||||
"""
|
||||
if bias_all is not None:
|
||||
y = self.apply_bias(self.token_lora_indices, y, bias_all)
|
||||
|
||||
expand_fun: Callable = (self.expand_prefill
|
||||
if self.is_prefill else self.expand_decode)
|
||||
@ -493,23 +552,54 @@ class PunicaWrapper:
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
bias_all: Optional[torch.Tensor],
|
||||
y_offset: Optional[int],
|
||||
y_slice_size: Optional[int],
|
||||
add_input: bool = True):
|
||||
"""
|
||||
Similar to `add_expand`
|
||||
"""
|
||||
if bias_all is not None:
|
||||
y = self.apply_bias(self.token_lora_indices, y, bias_all)
|
||||
|
||||
expand_slice_fun: Callable = (self.expand_slice_prefill
|
||||
if self.is_prefill else
|
||||
self.expand_slice_decode)
|
||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
|
||||
|
||||
def add_expand_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
bias_stacked: Optional[Tuple[torch.Tensor,
|
||||
...]],
|
||||
scale: float,
|
||||
output_slices: Tuple[int, ...]) -> None:
|
||||
"""
|
||||
Similar to `add_expand`
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = 0
|
||||
if bias_stacked is not None:
|
||||
self.apply_bias_packed_nslice(self.token_lora_indices, y,
|
||||
output_slices, bias_stacked)
|
||||
for slice_idx in range(len(lora_b_stacked)):
|
||||
self.add_expand_slice(y,
|
||||
x[slice_idx],
|
||||
lora_b_stacked[slice_idx],
|
||||
None,
|
||||
offset_left,
|
||||
output_slices[slice_idx],
|
||||
add_input=True)
|
||||
offset_left += output_slices[slice_idx]
|
||||
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
wb_t_all: torch.Tensor,
|
||||
bias_all: Optional[torch.Tensor],
|
||||
scale: float,
|
||||
y_offset: Optional[int] = None,
|
||||
y_slice_size: Optional[int] = None,
|
||||
@ -522,12 +612,13 @@ class PunicaWrapper:
|
||||
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
).squeeze(0)+bias[i]
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||
x (torch.Tensor): Input tensor
|
||||
wa_t_all (torch.Tensor): lora_a's weight
|
||||
wb_t_all (torch.Tensor): lora_b's weight
|
||||
bias_all: (torch.Tensor): lora's bias
|
||||
scale (float): Scaling factor.
|
||||
y_offset (Optional[int], optional): Offset to apply to the starting
|
||||
column of y.
|
||||
@ -544,27 +635,26 @@ class PunicaWrapper:
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
if bias_all is not None:
|
||||
y = self.apply_bias(self.token_lora_indices, y, bias_all)
|
||||
self.add_shrink(buffer, x, wa_t_all, scale)
|
||||
if y_offset is None and y_slice_size is None:
|
||||
self.add_expand(y, buffer, wb_t_all, add_input=True)
|
||||
self.add_expand(y, buffer, wb_t_all, bias_all=None, add_input=True)
|
||||
else:
|
||||
self.add_expand_slice(y,
|
||||
buffer,
|
||||
wb_t_all,
|
||||
None,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_input=True)
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor],
|
||||
lora_b_stacked: Tuple[torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor],
|
||||
scale: float,
|
||||
lora_a_stacked: Tuple[torch.Tensor, ...],
|
||||
lora_b_stacked: Tuple[torch.Tensor, ...],
|
||||
bias_all: Tuple[Optional[torch.Tensor],
|
||||
...], scale: float,
|
||||
output_slices: Tuple[int, ...]) -> None:
|
||||
"""
|
||||
Applies lora to each input. Similar to add_lora, This method is
|
||||
@ -575,10 +665,13 @@ class PunicaWrapper:
|
||||
x = x.view(-1, x.shape[-1])
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = 0
|
||||
if bias_all is not None:
|
||||
y = self.apply_bias_packed_nslice(self.token_lora_indices, y,
|
||||
output_slices, bias_all)
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(output_slices)):
|
||||
self.add_lora(y, x, lora_a_stacked[slice_idx],
|
||||
lora_b_stacked[slice_idx], scale, offset_left,
|
||||
lora_b_stacked[slice_idx], None, scale, offset_left,
|
||||
output_slices[slice_idx])
|
||||
offset_left += output_slices[slice_idx]
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user