[mis][ci/test] fix flaky test in test_sharded_state_loader.py (#5361)
[mis][ci/test] fix flaky test in tests/test_sharded_state_loader.py (#5361)
This commit is contained in:
parent
0373e1837e
commit
5d7e3d0176
@ -39,7 +39,8 @@ def test_filter_subtensors():
|
||||
filtered_state_dict = ShardedStateLoader._filter_subtensors(state_dict)
|
||||
assert tuple(filtered_state_dict.keys()) == ("a", "b", "c")
|
||||
for key, tensor in filtered_state_dict.items():
|
||||
assert tensor.equal(state_dict[key])
|
||||
# NOTE: don't use `euqal` here, as the tensor might contain NaNs
|
||||
assert tensor is state_dict[key]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
|
Loading…
x
Reference in New Issue
Block a user