[Misc/Testing] Use torch.testing.assert_close
(#7324)
This commit is contained in:
parent
e165528778
commit
50b8d08dbd
@ -34,7 +34,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
|
|||||||
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||||
t = all_tensors[rank % tp_size]
|
t = all_tensors[rank % tp_size]
|
||||||
t = tensor_model_parallel_all_reduce(t)
|
t = tensor_model_parallel_all_reduce(t)
|
||||||
assert torch.allclose(t, expected)
|
torch.testing.assert_close(t, expected)
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(num_gpus=1, max_calls=1)
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
@ -62,7 +62,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
|
|||||||
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
||||||
t = all_tensors[rank % tp_size]
|
t = all_tensors[rank % tp_size]
|
||||||
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
|
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
|
||||||
assert torch.allclose(t, expected)
|
torch.testing.assert_close(t, expected)
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(num_gpus=1, max_calls=1)
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
@ -96,12 +96,12 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
|||||||
else:
|
else:
|
||||||
recv_dict = broadcast_tensor_dict(src=0)
|
recv_dict = broadcast_tensor_dict(src=0)
|
||||||
assert len(recv_dict) == len(test_dict)
|
assert len(recv_dict) == len(test_dict)
|
||||||
assert torch.allclose(recv_dict["a"], test_dict["a"])
|
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
|
||||||
assert torch.allclose(recv_dict["b"], test_dict["b"])
|
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
|
||||||
assert recv_dict["c"] == test_dict["c"]
|
assert recv_dict["c"] == test_dict["c"]
|
||||||
assert recv_dict["d"] == test_dict["d"]
|
assert recv_dict["d"] == test_dict["d"]
|
||||||
assert recv_dict["e"] == test_dict["e"]
|
assert recv_dict["e"] == test_dict["e"]
|
||||||
assert torch.allclose(recv_dict["f"], test_dict["f"])
|
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(num_gpus=1, max_calls=1)
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
@ -136,12 +136,12 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
|||||||
|
|
||||||
if not get_pp_group().is_first_rank:
|
if not get_pp_group().is_first_rank:
|
||||||
assert len(recv_dict) == len(test_dict)
|
assert len(recv_dict) == len(test_dict)
|
||||||
assert torch.allclose(recv_dict["a"], test_dict["a"])
|
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
|
||||||
assert torch.allclose(recv_dict["b"], test_dict["b"])
|
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
|
||||||
assert recv_dict["c"] == test_dict["c"]
|
assert recv_dict["c"] == test_dict["c"]
|
||||||
assert recv_dict["d"] == test_dict["d"]
|
assert recv_dict["d"] == test_dict["d"]
|
||||||
assert recv_dict["e"] == test_dict["e"]
|
assert recv_dict["e"] == test_dict["e"]
|
||||||
assert torch.allclose(recv_dict["f"], test_dict["f"])
|
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(num_gpus=1, max_calls=1)
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
@ -163,7 +163,7 @@ def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
|
|||||||
get_pp_group().send(test_tensor)
|
get_pp_group().send(test_tensor)
|
||||||
|
|
||||||
if not get_pp_group().is_first_rank:
|
if not get_pp_group().is_first_rank:
|
||||||
assert torch.allclose(test_tensor, recv_tensor)
|
torch.testing.assert_close(test_tensor, recv_tensor)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||||
|
@ -72,8 +72,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
|||||||
out2 = tensor_model_parallel_all_reduce(inp2)
|
out2 = tensor_model_parallel_all_reduce(inp2)
|
||||||
dist.all_reduce(inp2, group=group)
|
dist.all_reduce(inp2, group=group)
|
||||||
graph.replay()
|
graph.replay()
|
||||||
assert torch.allclose(out1, inp1)
|
torch.testing.assert_close(out1, inp1)
|
||||||
assert torch.allclose(out2, inp2)
|
torch.testing.assert_close(out2, inp2)
|
||||||
|
|
||||||
|
|
||||||
@ray.remote(num_gpus=1, max_calls=1)
|
@ray.remote(num_gpus=1, max_calls=1)
|
||||||
@ -96,13 +96,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
|||||||
out = inp
|
out = inp
|
||||||
for _ in range(num_communication):
|
for _ in range(num_communication):
|
||||||
out = fa.all_reduce_unreg(out)
|
out = fa.all_reduce_unreg(out)
|
||||||
assert torch.allclose(out, inp * (tp_size**num_communication))
|
torch.testing.assert_close(out, inp * (tp_size**num_communication))
|
||||||
|
|
||||||
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
|
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
|
||||||
out = inp
|
out = inp
|
||||||
for _ in range(num_communication):
|
for _ in range(num_communication):
|
||||||
out = fa.all_reduce_unreg(out)
|
out = fa.all_reduce_unreg(out)
|
||||||
assert torch.allclose(out, inp * (tp_size**num_communication))
|
torch.testing.assert_close(out, inp * (tp_size**num_communication))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("tp_size", [2])
|
@pytest.mark.parametrize("tp_size", [2])
|
||||||
|
@ -69,4 +69,4 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
|
|||||||
ref_iscale = one / ref_scale
|
ref_iscale = one / ref_scale
|
||||||
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
|
ref_out = (as_float32_tensor(x) * ref_iscale).clamp(
|
||||||
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
|
fp8_traits.min, fp8_traits.max).to(dtype=torch.float8_e4m3fn)
|
||||||
return ref_out, ref_scale
|
return ref_out, ref_scale.view((1, ))
|
||||||
|
@ -47,7 +47,7 @@ def test_act_and_mul(
|
|||||||
ref_out = layer.forward_native(x)
|
ref_out = layer.forward_native(x)
|
||||||
# The SiLU and GELU implementations are equivalent to the native PyTorch
|
# The SiLU and GELU implementations are equivalent to the native PyTorch
|
||||||
# implementations, so we can do exact comparison.
|
# implementations, so we can do exact comparison.
|
||||||
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
|
torch.testing.assert_close(out, ref_out, atol=0.0, rtol=0.0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("activation", [FastGELU, NewGELU])
|
@pytest.mark.parametrize("activation", [FastGELU, NewGELU])
|
||||||
@ -73,7 +73,7 @@ def test_activation(
|
|||||||
layer = activation()
|
layer = activation()
|
||||||
out = layer(x)
|
out = layer(x)
|
||||||
ref_out = layer.forward_native(x)
|
ref_out = layer.forward_native(x)
|
||||||
assert torch.allclose(out,
|
torch.testing.assert_close(out,
|
||||||
ref_out,
|
ref_out,
|
||||||
atol=get_default_atol(out),
|
atol=get_default_atol(out),
|
||||||
rtol=get_default_rtol(out))
|
rtol=get_default_rtol(out))
|
||||||
|
@ -276,7 +276,7 @@ def test_paged_attention(
|
|||||||
atol, rtol = 1e-3, 1e-5
|
atol, rtol = 1e-3, 1e-5
|
||||||
if kv_cache_dtype == "fp8":
|
if kv_cache_dtype == "fp8":
|
||||||
atol, rtol = 1e-2, 1e-5
|
atol, rtol = 1e-2, 1e-5
|
||||||
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
def ref_multi_query_kv_attention(
|
def ref_multi_query_kv_attention(
|
||||||
@ -379,4 +379,4 @@ def test_multi_query_kv_attention(
|
|||||||
)
|
)
|
||||||
atol = get_default_atol(output) if is_hip() else 1e-3
|
atol = get_default_atol(output) if is_hip() else 1e-3
|
||||||
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
||||||
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||||
|
@ -327,7 +327,7 @@ def test_paged_attention(
|
|||||||
atol, rtol = 1e-3, 1e-5
|
atol, rtol = 1e-3, 1e-5
|
||||||
if kv_cache_dtype == "fp8":
|
if kv_cache_dtype == "fp8":
|
||||||
atol, rtol = 1e-2, 1e-5
|
atol, rtol = 1e-2, 1e-5
|
||||||
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||||
|
|
||||||
|
|
||||||
def ref_multi_query_kv_attention(
|
def ref_multi_query_kv_attention(
|
||||||
@ -441,4 +441,4 @@ def test_varlen_blocksparse_attention_prefill(
|
|||||||
scale,
|
scale,
|
||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2)
|
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2)
|
||||||
|
@ -98,10 +98,10 @@ def test_copy_blocks(
|
|||||||
|
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
for key_cache, cloned_key_cache in zip(key_caches, cloned_key_caches):
|
||||||
assert torch.allclose(key_cache, cloned_key_cache)
|
torch.testing.assert_close(key_cache, cloned_key_cache)
|
||||||
for value_cache, cloned_value_cache in zip(value_caches,
|
for value_cache, cloned_value_cache in zip(value_caches,
|
||||||
cloned_value_caches):
|
cloned_value_caches):
|
||||||
assert torch.allclose(value_cache, cloned_value_cache)
|
torch.testing.assert_close(value_cache, cloned_value_cache)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
@ -184,17 +184,17 @@ def test_reshape_and_cache(
|
|||||||
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
cloned_value_cache[block_idx, :, :, block_offset] = value[i]
|
||||||
|
|
||||||
if kv_cache_dtype == "fp8":
|
if kv_cache_dtype == "fp8":
|
||||||
assert torch.allclose(result_key_cache,
|
torch.testing.assert_close(result_key_cache,
|
||||||
cloned_key_cache,
|
cloned_key_cache,
|
||||||
atol=0.001,
|
atol=0.001,
|
||||||
rtol=0.1)
|
rtol=0.1)
|
||||||
assert torch.allclose(result_value_cache,
|
torch.testing.assert_close(result_value_cache,
|
||||||
cloned_value_cache,
|
cloned_value_cache,
|
||||||
atol=0.001,
|
atol=0.001,
|
||||||
rtol=0.1)
|
rtol=0.1)
|
||||||
else:
|
else:
|
||||||
assert torch.allclose(key_cache, cloned_key_cache)
|
torch.testing.assert_close(key_cache, cloned_key_cache)
|
||||||
assert torch.allclose(value_cache, cloned_value_cache)
|
torch.testing.assert_close(value_cache, cloned_value_cache)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
@ -290,17 +290,17 @@ def test_reshape_and_cache_flash(
|
|||||||
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
|
cloned_value_cache[block_idx, block_offset, :, :] = value[i]
|
||||||
|
|
||||||
if kv_cache_dtype == "fp8":
|
if kv_cache_dtype == "fp8":
|
||||||
assert torch.allclose(result_key_cache,
|
torch.testing.assert_close(result_key_cache,
|
||||||
cloned_key_cache,
|
cloned_key_cache,
|
||||||
atol=0.001,
|
atol=0.001,
|
||||||
rtol=0.1)
|
rtol=0.1)
|
||||||
assert torch.allclose(result_value_cache,
|
torch.testing.assert_close(result_value_cache,
|
||||||
cloned_value_cache,
|
cloned_value_cache,
|
||||||
atol=0.001,
|
atol=0.001,
|
||||||
rtol=0.1)
|
rtol=0.1)
|
||||||
else:
|
else:
|
||||||
assert torch.allclose(key_cache, cloned_key_cache)
|
torch.testing.assert_close(key_cache, cloned_key_cache)
|
||||||
assert torch.allclose(value_cache, cloned_value_cache)
|
torch.testing.assert_close(value_cache, cloned_value_cache)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
|
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
|
||||||
@ -372,9 +372,9 @@ def test_swap_blocks(
|
|||||||
block_mapping_tensor)
|
block_mapping_tensor)
|
||||||
|
|
||||||
for src, dst in block_mapping:
|
for src, dst in block_mapping:
|
||||||
assert torch.allclose(src_key_caches_clone[src].cpu(),
|
torch.testing.assert_close(src_key_caches_clone[src].cpu(),
|
||||||
dist_key_caches[0][dst].cpu())
|
dist_key_caches[0][dst].cpu())
|
||||||
assert torch.allclose(src_value_caches_clone[src].cpu(),
|
torch.testing.assert_close(src_value_caches_clone[src].cpu(),
|
||||||
dist_value_caches[0][dst].cpu())
|
dist_value_caches[0][dst].cpu())
|
||||||
|
|
||||||
|
|
||||||
@ -411,4 +411,4 @@ def test_fp8_e4m3_conversion(
|
|||||||
converted_cache = torch.empty_like(cache)
|
converted_cache = torch.empty_like(cache)
|
||||||
ops.convert_fp8(converted_cache, cache_fp8)
|
ops.convert_fp8(converted_cache, cache_fp8)
|
||||||
|
|
||||||
assert torch.allclose(cache, converted_cache, atol=0.001, rtol=0.1)
|
torch.testing.assert_close(cache, converted_cache, atol=0.001, rtol=0.1)
|
||||||
|
@ -74,7 +74,7 @@ def cutlass_fp8_gemm_helper(m: int,
|
|||||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||||
|
|
||||||
assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2)
|
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2)
|
||||||
|
|
||||||
|
|
||||||
def cutlass_int8_gemm_helper(m: int,
|
def cutlass_int8_gemm_helper(m: int,
|
||||||
@ -106,7 +106,7 @@ def cutlass_int8_gemm_helper(m: int,
|
|||||||
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||||
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
|
||||||
|
|
||||||
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
|
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
|
||||||
@ -252,7 +252,7 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
|||||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||||
|
|
||||||
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
|
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
|
||||||
assert torch.allclose(a_dq, scale_a * aq_f32 + azp_a)
|
torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
|
||||||
|
|
||||||
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
|
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
|
||||||
|
|
||||||
@ -271,8 +271,8 @@ def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
|
|||||||
scale_b,
|
scale_b,
|
||||||
out_dtype=out_dtype,
|
out_dtype=out_dtype,
|
||||||
bias=azp_bias[0, :])
|
bias=azp_bias[0, :])
|
||||||
assert torch.allclose(out, baseline_dq, rtol=1e-2, atol=1e0)
|
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
|
||||||
assert torch.allclose(out, baseline_q, rtol=1e-2, atol=1e0)
|
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("m", [32, 64, 128])
|
@pytest.mark.parametrize("m", [32, 64, 128])
|
||||||
@ -302,7 +302,10 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
|||||||
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
|
||||||
|
|
||||||
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
|
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
|
||||||
assert torch.allclose(a_dq, scale_a * aq_f32 - azp_a, rtol=1e-4, atol=1e-3)
|
torch.testing.assert_close(a_dq,
|
||||||
|
scale_a * aq_f32 - azp_a,
|
||||||
|
rtol=1e-4,
|
||||||
|
atol=1e-3)
|
||||||
|
|
||||||
if use_bias:
|
if use_bias:
|
||||||
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
|
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
|
||||||
@ -335,8 +338,8 @@ def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
|
|||||||
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
|
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
|
||||||
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
|
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
|
||||||
atol = 1e-3
|
atol = 1e-3
|
||||||
assert torch.allclose(out, baseline_dq, rtol=rtol, atol=atol)
|
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
|
||||||
assert torch.allclose(out, baseline_q, rtol=rtol, atol=atol)
|
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
# Test working with a subset of A and B
|
# Test working with a subset of A and B
|
||||||
@ -363,7 +366,7 @@ def test_cutlass_subset():
|
|||||||
scale_b,
|
scale_b,
|
||||||
out_dtype=torch.bfloat16)
|
out_dtype=torch.bfloat16)
|
||||||
|
|
||||||
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||||
|
|
||||||
|
|
||||||
# Test to make sure cuda graphs work
|
# Test to make sure cuda graphs work
|
||||||
@ -411,4 +414,4 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
|
|||||||
|
|
||||||
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
|
||||||
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
|
scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
|
||||||
assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
|
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
|
||||||
|
@ -126,7 +126,7 @@ def test_flash_attn_with_paged_kv(
|
|||||||
scale=scale,
|
scale=scale,
|
||||||
soft_cap=soft_cap,
|
soft_cap=soft_cap,
|
||||||
)
|
)
|
||||||
assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \
|
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
|
||||||
|
|
||||||
@ -211,5 +211,5 @@ def test_varlen_with_paged_kv(
|
|||||||
sliding_window=sliding_window,
|
sliding_window=sliding_window,
|
||||||
soft_cap=soft_cap,
|
soft_cap=soft_cap,
|
||||||
)
|
)
|
||||||
assert torch.allclose(output, ref_output, atol=2e-2, rtol=1e-2), \
|
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
@ -144,7 +144,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
soft_cap=soft_cap)
|
soft_cap=soft_cap)
|
||||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
|
||||||
|
|
||||||
@ -244,5 +244,5 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
|
|||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
soft_cap=soft_cap)
|
soft_cap=soft_cap)
|
||||||
assert torch.allclose(output, ref_output, atol=1e-2, rtol=1e-2), \
|
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
|
||||||
f"{torch.max(torch.abs(output - ref_output))}"
|
f"{torch.max(torch.abs(output - ref_output))}"
|
||||||
|
@ -37,8 +37,8 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
|
|||||||
scale_ub=scale_ub,
|
scale_ub=scale_ub,
|
||||||
use_per_token_if_dynamic=True)
|
use_per_token_if_dynamic=True)
|
||||||
|
|
||||||
assert torch.allclose(ref_scales, ops_scales)
|
torch.testing.assert_close(ref_scales, ops_scales)
|
||||||
assert torch.allclose(ref_out.to(dtype=torch.float32),
|
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
|
||||||
ops_out.to(dtype=torch.float32))
|
ops_out.to(dtype=torch.float32))
|
||||||
|
|
||||||
|
|
||||||
@ -57,8 +57,8 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
|
|||||||
ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
|
ref_out, ref_scale = ref_dynamic_per_tensor_fp8_quant(x)
|
||||||
ops_out, ops_scale = ops.scaled_fp8_quant(x)
|
ops_out, ops_scale = ops.scaled_fp8_quant(x)
|
||||||
|
|
||||||
assert torch.allclose(ref_scale, ops_scale)
|
torch.testing.assert_close(ref_scale, ops_scale)
|
||||||
assert torch.allclose(ref_out.to(dtype=torch.float32),
|
torch.testing.assert_close(ref_out.to(dtype=torch.float32),
|
||||||
ops_out.to(dtype=torch.float32))
|
ops_out.to(dtype=torch.float32))
|
||||||
|
|
||||||
|
|
||||||
@ -84,4 +84,4 @@ def test_fp8_quant_large(seed: int) -> None:
|
|||||||
ref_out = ref_out.to(dtype=dtype)
|
ref_out = ref_out.to(dtype=dtype)
|
||||||
ops_out = ops_out.to(dtype=dtype)
|
ops_out = ops_out.to(dtype=dtype)
|
||||||
|
|
||||||
assert torch.allclose(ref_out, ops_out)
|
torch.testing.assert_close(ref_out, ops_out)
|
||||||
|
@ -29,9 +29,10 @@ def test_dynamic_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
|||||||
# kernel
|
# kernel
|
||||||
ops_out, ops_scales = scaled_int8_quant(x)
|
ops_out, ops_scales = scaled_int8_quant(x)
|
||||||
|
|
||||||
assert torch.allclose(ops_scales, ref_scales)
|
torch.testing.assert_close(ops_scales, ref_scales)
|
||||||
assert torch.allclose(ops_out, ref_out,
|
torch.testing.assert_close(
|
||||||
atol=1) # big atol to account for rounding errors
|
ops_out, ref_out, atol=1,
|
||||||
|
rtol=0.0) # big atol to account for rounding errors
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
|
||||||
@ -54,5 +55,6 @@ def test_static_scaled_int8_quant(num_tokens: int, hidden_size: int,
|
|||||||
int8_traits.max).to(torch.int8)
|
int8_traits.max).to(torch.int8)
|
||||||
out2, _ = scaled_int8_quant(x, scale)
|
out2, _ = scaled_int8_quant(x, scale)
|
||||||
|
|
||||||
assert torch.allclose(out1, out2,
|
torch.testing.assert_close(
|
||||||
atol=1) # big atol to account for rounding errors
|
out1, out2, atol=1,
|
||||||
|
rtol=0.0) # big atol to account for rounding errors
|
||||||
|
@ -48,7 +48,7 @@ def test_rms_norm(
|
|||||||
# numerical errors than other operators because they involve reductions.
|
# numerical errors than other operators because they involve reductions.
|
||||||
# Therefore, we use a larger tolerance.
|
# Therefore, we use a larger tolerance.
|
||||||
if add_residual:
|
if add_residual:
|
||||||
assert torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
|
torch.testing.assert_close(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
|
||||||
assert torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
|
torch.testing.assert_close(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
|
||||||
else:
|
else:
|
||||||
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2)
|
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
|
||||||
|
@ -122,7 +122,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
|||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
assert torch.allclose(marlin_q_w_1, marlin_q_w_2)
|
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||||
@ -174,7 +174,7 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size,
|
|||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
assert torch.allclose(marlin_q_w_1, marlin_q_w_2)
|
torch.testing.assert_close(marlin_q_w_1, marlin_q_w_2)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
|
||||||
|
@ -50,7 +50,7 @@ def test_fused_moe(
|
|||||||
score = torch.randn((m, e), device='cuda', dtype=dtype)
|
score = torch.randn((m, e), device='cuda', dtype=dtype)
|
||||||
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False)
|
||||||
torch_output = torch_moe(a, w1, w2, score, topk)
|
torch_output = torch_moe(a, w1, w2, score, topk)
|
||||||
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
|
torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dtype",
|
@pytest.mark.parametrize("dtype",
|
||||||
@ -95,7 +95,7 @@ def test_mixtral_moe(dtype: torch.dtype):
|
|||||||
torch.bfloat16: 1e-2,
|
torch.bfloat16: 1e-2,
|
||||||
}
|
}
|
||||||
|
|
||||||
assert torch.allclose(hf_states.flatten(0, 1),
|
torch.testing.assert_close(hf_states.flatten(0, 1),
|
||||||
vllm_states,
|
vllm_states,
|
||||||
rtol=mixtral_moe_tol[dtype],
|
rtol=mixtral_moe_tol[dtype],
|
||||||
atol=mixtral_moe_tol[dtype])
|
atol=mixtral_moe_tol[dtype])
|
||||||
|
@ -67,11 +67,11 @@ def test_rotary_embedding(
|
|||||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||||
out_query, out_key = rope.forward(positions, query, key)
|
out_query, out_key = rope.forward(positions, query, key)
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
assert torch.allclose(out_query,
|
torch.testing.assert_close(out_query,
|
||||||
ref_query,
|
ref_query,
|
||||||
atol=get_default_atol(out_query),
|
atol=get_default_atol(out_query),
|
||||||
rtol=get_default_rtol(out_query))
|
rtol=get_default_rtol(out_query))
|
||||||
assert torch.allclose(out_key,
|
torch.testing.assert_close(out_key,
|
||||||
ref_key,
|
ref_key,
|
||||||
atol=get_default_atol(out_key),
|
atol=get_default_atol(out_key),
|
||||||
rtol=get_default_rtol(out_key))
|
rtol=get_default_rtol(out_key))
|
||||||
@ -129,11 +129,11 @@ def test_batched_rotary_embedding(
|
|||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
device=device))
|
device=device))
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
assert torch.allclose(out_query,
|
torch.testing.assert_close(out_query,
|
||||||
ref_query,
|
ref_query,
|
||||||
atol=get_default_atol(out_query),
|
atol=get_default_atol(out_query),
|
||||||
rtol=get_default_rtol(out_query))
|
rtol=get_default_rtol(out_query))
|
||||||
assert torch.allclose(out_key,
|
torch.testing.assert_close(out_key,
|
||||||
ref_key,
|
ref_key,
|
||||||
atol=get_default_atol(out_key),
|
atol=get_default_atol(out_key),
|
||||||
rtol=get_default_rtol(out_key))
|
rtol=get_default_rtol(out_key))
|
||||||
@ -200,11 +200,11 @@ def test_batched_rotary_embedding_multi_lora(
|
|||||||
out_query, out_key = rope.forward(positions, query, key,
|
out_query, out_key = rope.forward(positions, query, key,
|
||||||
query_offsets.flatten())
|
query_offsets.flatten())
|
||||||
# Compare the results.
|
# Compare the results.
|
||||||
assert torch.allclose(out_query,
|
torch.testing.assert_close(out_query,
|
||||||
ref_query,
|
ref_query,
|
||||||
atol=get_default_atol(out_query),
|
atol=get_default_atol(out_query),
|
||||||
rtol=get_default_rtol(out_query))
|
rtol=get_default_rtol(out_query))
|
||||||
assert torch.allclose(out_key,
|
torch.testing.assert_close(out_key,
|
||||||
ref_key,
|
ref_key,
|
||||||
atol=get_default_atol(out_key),
|
atol=get_default_atol(out_key),
|
||||||
rtol=get_default_rtol(out_key))
|
rtol=get_default_rtol(out_key))
|
||||||
|
@ -100,11 +100,11 @@ def test_sample_decoding_only(random_sampling, max_best_of,
|
|||||||
if modify_greedy_probs and not request_uses_random_sampling:
|
if modify_greedy_probs and not request_uses_random_sampling:
|
||||||
# If we are modifying greedy probs and the request is greedy,
|
# If we are modifying greedy probs and the request is greedy,
|
||||||
# we want to make sure the probs tensor is modified in place
|
# we want to make sure the probs tensor is modified in place
|
||||||
assert torch.allclose(
|
torch.testing.assert_close(
|
||||||
probs[i][sampled_tokens[i]],
|
probs[i][sampled_tokens[i]],
|
||||||
torch.full_like(probs[i][sampled_tokens[i]], 1.0))
|
torch.full_like(probs[i][sampled_tokens[i]], 1.0))
|
||||||
assert torch.sum(probs[i]) == 1.0
|
assert torch.sum(probs[i]) == 1.0
|
||||||
assert torch.allclose(
|
torch.testing.assert_close(
|
||||||
sampled_modified_probs[i][0],
|
sampled_modified_probs[i][0],
|
||||||
torch.full_like(sampled_modified_probs[i][0], 1.0))
|
torch.full_like(sampled_modified_probs[i][0], 1.0))
|
||||||
elif request_uses_random_sampling:
|
elif request_uses_random_sampling:
|
||||||
@ -117,7 +117,7 @@ def test_sample_decoding_only(random_sampling, max_best_of,
|
|||||||
# If the request is greedy and we are not modifying greedy probs,
|
# If the request is greedy and we are not modifying greedy probs,
|
||||||
# we want to make sure sampled_modified_probs tensor is the same as
|
# we want to make sure sampled_modified_probs tensor is the same as
|
||||||
# the probs tensor.
|
# the probs tensor.
|
||||||
assert torch.allclose(sampled_modified_probs[i][0],
|
torch.testing.assert_close(sampled_modified_probs[i],
|
||||||
probs[i][sampled_tokens[i]])
|
probs[i][sampled_tokens[i]])
|
||||||
|
|
||||||
if save_logprobs:
|
if save_logprobs:
|
||||||
|
@ -924,5 +924,5 @@ def assert_actual_matches_ideal(test_params: PhaseTestParameters,
|
|||||||
* output_under_test: actually observed output value
|
* output_under_test: actually observed output value
|
||||||
'''
|
'''
|
||||||
ideal_output = test_params.packed_qkvo.ideal_output
|
ideal_output = test_params.packed_qkvo.ideal_output
|
||||||
assert torch.allclose(ideal_output,
|
torch.testing.assert_close(ideal_output,
|
||||||
output_under_test.view_as(ideal_output))
|
output_under_test.view_as(ideal_output))
|
||||||
|
@ -247,7 +247,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
|||||||
expected_result = torch.cat(expected_results)
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
assert torch.allclose(lora_result,
|
torch.testing.assert_close(lora_result,
|
||||||
expected_result,
|
expected_result,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
@ -274,7 +274,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
|||||||
expected_result = embedding(torch.cat(inputs))
|
expected_result = embedding(torch.cat(inputs))
|
||||||
|
|
||||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
assert torch.allclose(lora_result,
|
torch.testing.assert_close(lora_result,
|
||||||
expected_result,
|
expected_result,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
@ -384,7 +384,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
|||||||
expected_result = torch.cat(expected_results)
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
assert torch.allclose(lora_result,
|
torch.testing.assert_close(lora_result,
|
||||||
expected_result,
|
expected_result,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
@ -411,7 +411,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
|||||||
expected_result = expanded_embedding(torch.cat(inputs))
|
expected_result = expanded_embedding(torch.cat(inputs))
|
||||||
|
|
||||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
assert torch.allclose(lora_result,
|
torch.testing.assert_close(lora_result,
|
||||||
expected_result,
|
expected_result,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
@ -541,7 +541,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
|||||||
embedding_bias=None)
|
embedding_bias=None)
|
||||||
|
|
||||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
assert torch.allclose(lora_result,
|
torch.testing.assert_close(lora_result,
|
||||||
expected_result,
|
expected_result,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
@ -614,7 +614,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
|||||||
expected_result = torch.cat(expected_results)
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
assert torch.allclose(lora_result,
|
torch.testing.assert_close(lora_result,
|
||||||
expected_result,
|
expected_result,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
@ -642,7 +642,7 @@ def test_linear_replicated(dist_init, num_loras, device, stage) -> None:
|
|||||||
expected_result = linear(torch.cat(inputs))[0]
|
expected_result = linear(torch.cat(inputs))[0]
|
||||||
|
|
||||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
assert torch.allclose(lora_result,
|
torch.testing.assert_close(lora_result,
|
||||||
expected_result,
|
expected_result,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
@ -728,7 +728,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||||||
expected_result = torch.cat(expected_results)
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
assert torch.allclose(lora_result,
|
torch.testing.assert_close(lora_result,
|
||||||
expected_result,
|
expected_result,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
@ -756,7 +756,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||||||
expected_result = linear(torch.cat(inputs))[0]
|
expected_result = linear(torch.cat(inputs))[0]
|
||||||
|
|
||||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
assert torch.allclose(lora_result,
|
torch.testing.assert_close(lora_result,
|
||||||
expected_result,
|
expected_result,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
@ -868,7 +868,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|||||||
expected_result = torch.cat(expected_results)
|
expected_result = torch.cat(expected_results)
|
||||||
|
|
||||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
assert torch.allclose(lora_result,
|
torch.testing.assert_close(lora_result,
|
||||||
expected_result,
|
expected_result,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
@ -900,7 +900,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|||||||
expected_result = linear(torch.cat(inputs))[0]
|
expected_result = linear(torch.cat(inputs))[0]
|
||||||
|
|
||||||
rtol, atol = TOLERANCES[lora_result.dtype]
|
rtol, atol = TOLERANCES[lora_result.dtype]
|
||||||
assert torch.allclose(lora_result,
|
torch.testing.assert_close(lora_result,
|
||||||
expected_result,
|
expected_result,
|
||||||
rtol=rtol,
|
rtol=rtol,
|
||||||
atol=atol)
|
atol=atol)
|
||||||
|
@ -533,13 +533,13 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
|
|||||||
packed_lora = model_lora.get_lora("gate_up_proj")
|
packed_lora = model_lora.get_lora("gate_up_proj")
|
||||||
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
|
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
|
||||||
|
|
||||||
assert torch.allclose(packed_lora.lora_a[0],
|
torch.testing.assert_close(packed_lora.lora_a[0],
|
||||||
model_lora.get_lora("gate_proj").lora_a)
|
model_lora.get_lora("gate_proj").lora_a)
|
||||||
assert torch.allclose(packed_lora.lora_b[0],
|
torch.testing.assert_close(packed_lora.lora_b[0],
|
||||||
model_lora.get_lora("gate_proj").lora_b)
|
model_lora.get_lora("gate_proj").lora_b)
|
||||||
assert torch.allclose(packed_lora.lora_a[1],
|
torch.testing.assert_close(packed_lora.lora_a[1],
|
||||||
model_lora.get_lora("up_proj").lora_a)
|
model_lora.get_lora("up_proj").lora_a)
|
||||||
assert torch.allclose(packed_lora.lora_b[1],
|
torch.testing.assert_close(packed_lora.lora_b[1],
|
||||||
model_lora.get_lora("up_proj").lora_b)
|
model_lora.get_lora("up_proj").lora_b)
|
||||||
|
|
||||||
packed_lora1 = model_lora1.get_lora("gate_up_proj")
|
packed_lora1 = model_lora1.get_lora("gate_up_proj")
|
||||||
@ -547,7 +547,7 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
|
|||||||
|
|
||||||
assert packed_lora1.lora_a[0] is None
|
assert packed_lora1.lora_a[0] is None
|
||||||
assert packed_lora1.lora_b[0] is None
|
assert packed_lora1.lora_b[0] is None
|
||||||
assert torch.allclose(packed_lora1.lora_a[1],
|
torch.testing.assert_close(packed_lora1.lora_a[1],
|
||||||
model_lora1.get_lora("up_proj").lora_a)
|
model_lora1.get_lora("up_proj").lora_a)
|
||||||
assert torch.allclose(packed_lora1.lora_b[1],
|
torch.testing.assert_close(packed_lora1.lora_b[1],
|
||||||
model_lora1.get_lora("up_proj").lora_b)
|
model_lora1.get_lora("up_proj").lora_b)
|
||||||
|
@ -127,16 +127,18 @@ def test_scaled_fp8_quant(dtype) -> None:
|
|||||||
|
|
||||||
# Reference dynamic quantizaton
|
# Reference dynamic quantizaton
|
||||||
y = quantize_ref(x, inv_scale)
|
y = quantize_ref(x, inv_scale)
|
||||||
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
torch.testing.assert_close(ref_y,
|
||||||
|
per_tensor_dequantize(y, inv_scale, dtype))
|
||||||
|
|
||||||
# Static quantization
|
# Static quantization
|
||||||
y, _ = ops.scaled_fp8_quant(x, inv_scale)
|
y, _ = ops.scaled_fp8_quant(x, inv_scale)
|
||||||
assert torch.allclose(ref_y, per_tensor_dequantize(y, inv_scale, dtype))
|
torch.testing.assert_close(ref_y,
|
||||||
|
per_tensor_dequantize(y, inv_scale, dtype))
|
||||||
|
|
||||||
# Padding
|
# Padding
|
||||||
y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
|
y, _ = ops.scaled_fp8_quant(x, inv_scale, num_token_padding=17)
|
||||||
assert y.shape[0] == 17
|
assert y.shape[0] == 17
|
||||||
assert torch.allclose(
|
torch.testing.assert_close(
|
||||||
ref_y,
|
ref_y,
|
||||||
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
|
per_tensor_dequantize(torch.narrow(y, 0, 0, x.shape[0]), inv_scale,
|
||||||
dtype))
|
dtype))
|
||||||
|
@ -632,7 +632,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
|||||||
|
|
||||||
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
|
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
|
||||||
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
|
||||||
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
|
torch.testing.assert_close(hf_probs, sample_probs, rtol=0.0, atol=1e-5)
|
||||||
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
|
||||||
|
|
||||||
|
|
||||||
|
@ -161,7 +161,7 @@ def assert_logprobs_dict_allclose(
|
|||||||
single_step_actual_logprobs[token_id].logprob)
|
single_step_actual_logprobs[token_id].logprob)
|
||||||
expected = torch.tensor(
|
expected = torch.tensor(
|
||||||
single_step_expected_logprobs[token_id].logprob)
|
single_step_expected_logprobs[token_id].logprob)
|
||||||
assert torch.allclose(actual, expected)
|
torch.testing.assert_close(actual, expected)
|
||||||
|
|
||||||
|
|
||||||
def create_sampler_output_list(
|
def create_sampler_output_list(
|
||||||
|
@ -90,5 +90,7 @@ def test_logits_processors(seed: int, device: str):
|
|||||||
assert torch.isinf(logits_processor_output[:, 0]).all()
|
assert torch.isinf(logits_processor_output[:, 0]).all()
|
||||||
|
|
||||||
fake_logits *= logits_processor.scale
|
fake_logits *= logits_processor.scale
|
||||||
assert torch.allclose(logits_processor_output[:, 1], fake_logits[:, 1],
|
torch.testing.assert_close(logits_processor_output[:, 1],
|
||||||
1e-4)
|
fake_logits[:, 1],
|
||||||
|
rtol=1e-4,
|
||||||
|
atol=0.0)
|
||||||
|
@ -77,7 +77,7 @@ def test_prepare_prompt(batch_size):
|
|||||||
device = model_runner.device
|
device = model_runner.device
|
||||||
assert attn_metadata.num_prefills > 0
|
assert attn_metadata.num_prefills > 0
|
||||||
assert attn_metadata.num_decode_tokens == 0
|
assert attn_metadata.num_decode_tokens == 0
|
||||||
assert torch.allclose(
|
torch.testing.assert_close(
|
||||||
attn_metadata.seq_lens_tensor,
|
attn_metadata.seq_lens_tensor,
|
||||||
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
torch.tensor(seq_lens, device=device, dtype=torch.int))
|
||||||
assert attn_metadata.seq_lens == seq_lens
|
assert attn_metadata.seq_lens == seq_lens
|
||||||
@ -90,7 +90,7 @@ def test_prepare_prompt(batch_size):
|
|||||||
for seq_len in seq_lens:
|
for seq_len in seq_lens:
|
||||||
start_idx += seq_len
|
start_idx += seq_len
|
||||||
start_loc.append(start_idx)
|
start_loc.append(start_idx)
|
||||||
assert torch.allclose(
|
torch.testing.assert_close(
|
||||||
attn_metadata.query_start_loc,
|
attn_metadata.query_start_loc,
|
||||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||||
|
|
||||||
@ -102,10 +102,10 @@ def test_prepare_prompt(batch_size):
|
|||||||
start_idx += seq_len
|
start_idx += seq_len
|
||||||
seq_start_loc.append(start_idx)
|
seq_start_loc.append(start_idx)
|
||||||
|
|
||||||
assert torch.allclose(
|
torch.testing.assert_close(
|
||||||
attn_metadata.seq_start_loc,
|
attn_metadata.seq_start_loc,
|
||||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||||
assert torch.allclose(
|
torch.testing.assert_close(
|
||||||
attn_metadata.context_lens_tensor,
|
attn_metadata.context_lens_tensor,
|
||||||
torch.zeros(attn_metadata.context_lens_tensor.shape[0],
|
torch.zeros(attn_metadata.context_lens_tensor.shape[0],
|
||||||
dtype=torch.int,
|
dtype=torch.int,
|
||||||
@ -114,7 +114,7 @@ def test_prepare_prompt(batch_size):
|
|||||||
expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
|
expected = torch.tensor([[] for _ in range(len(seq_group_metadata_list))],
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=model_runner.device)
|
device=model_runner.device)
|
||||||
assert torch.allclose(attn_metadata.block_tables, expected)
|
torch.testing.assert_close(attn_metadata.block_tables, expected)
|
||||||
# Cuda graph should not be used for prerill.
|
# Cuda graph should not be used for prerill.
|
||||||
assert attn_metadata.use_cuda_graph is False
|
assert attn_metadata.use_cuda_graph is False
|
||||||
|
|
||||||
@ -201,7 +201,7 @@ def test_prepare_decode_cuda_graph(batch_size):
|
|||||||
# decode has only 1 token for query.
|
# decode has only 1 token for query.
|
||||||
start_idx += 1
|
start_idx += 1
|
||||||
start_loc.append(start_idx)
|
start_loc.append(start_idx)
|
||||||
assert torch.allclose(
|
torch.testing.assert_close(
|
||||||
attn_metadata.query_start_loc,
|
attn_metadata.query_start_loc,
|
||||||
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
torch.tensor(start_loc, dtype=torch.int32, device=device))
|
||||||
|
|
||||||
@ -210,15 +210,15 @@ def test_prepare_decode_cuda_graph(batch_size):
|
|||||||
for seq_len in seq_lens:
|
for seq_len in seq_lens:
|
||||||
start_idx += seq_len
|
start_idx += seq_len
|
||||||
seq_start_loc.append(start_idx)
|
seq_start_loc.append(start_idx)
|
||||||
assert torch.allclose(
|
torch.testing.assert_close(
|
||||||
attn_metadata.seq_start_loc,
|
attn_metadata.seq_start_loc,
|
||||||
torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
|
torch.tensor(seq_start_loc, dtype=torch.int32, device=device))
|
||||||
|
|
||||||
assert torch.allclose(
|
torch.testing.assert_close(
|
||||||
attn_metadata.context_lens_tensor,
|
attn_metadata.context_lens_tensor,
|
||||||
torch.tensor(context_lens, dtype=torch.int, device=device))
|
torch.tensor(context_lens, dtype=torch.int, device=device))
|
||||||
assert attn_metadata.max_decode_seq_len == max(seq_lens)
|
assert attn_metadata.max_decode_seq_len == max(seq_lens)
|
||||||
assert torch.allclose(
|
torch.testing.assert_close(
|
||||||
attn_metadata.seq_lens_tensor[:len(seq_lens)],
|
attn_metadata.seq_lens_tensor[:len(seq_lens)],
|
||||||
torch.tensor(seq_lens, dtype=torch.int, device=device))
|
torch.tensor(seq_lens, dtype=torch.int, device=device))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user