[mis][ci/test] fix flaky test in test_sharded_state_loader.py (#5361)

[mis][ci/test] fix flaky test in tests/test_sharded_state_loader.py (#5361)
This commit is contained in:
youkaichao 2024-06-08 20:50:14 -07:00 committed by GitHub
parent 0373e1837e
commit 5d7e3d0176
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -39,7 +39,8 @@ def test_filter_subtensors():
filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict)
assert tuple(filtered_state_dict.keys()) == ("a", "b", "c")
for key, tensor in filtered_state_dict.items():
assert tensor.equal(state_dict[key])
# NOTE: don't use `euqal` here, as the tensor might contain NaNs
assert tensor is state_dict[key]
@pytest.fixture(scope="module")