[mis][ci/test] fix flaky test in test_sharded_state_loader.py (#5361)
[mis][ci/test] fix flaky test in tests/test_sharded_state_loader.py (#5361)
This commit is contained in:
parent
0373e1837e
commit
5d7e3d0176
@ -39,7 +39,8 @@ def test_filter_subtensors():
|
|||||||
filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict)
|
filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict)
|
||||||
assert tuple(filtered_state_dict.keys()) == ("a", "b", "c")
|
assert tuple(filtered_state_dict.keys()) == ("a", "b", "c")
|
||||||
for key, tensor in filtered_state_dict.items():
|
for key, tensor in filtered_state_dict.items():
|
||||||
assert tensor.equal(state_dict[key])
|
# NOTE: don't use `euqal` here, as the tensor might contain NaNs
|
||||||
|
assert tensor is state_dict[key]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
@pytest.fixture(scope="module")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user