80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
![]() |
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||
|
|
||
|
|
||
|
class ModuleWithBatchNorm(torch.nn.Module):
|
||
|
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.bn = torch.nn.BatchNorm1d(2)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.bn(x)
|
||
|
|
||
|
|
||
|
class ModuleWithNestedBatchNorm(torch.nn.Module):
|
||
|
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
self.nested_mod = ModuleWithBatchNorm()
|
||
|
|
||
|
def forward(self, x):
|
||
|
return self.nested_mod(x)
|
||
|
|
||
|
|
||
|
def test_module_with_batchnorm_can_load():
|
||
|
"""Ensure the auto weight loader can load batchnorm stats."""
|
||
|
mod = ModuleWithBatchNorm()
|
||
|
# Run some data through the module with batchnorm
|
||
|
mod(torch.Tensor([[1, 2], [3, 4]]))
|
||
|
|
||
|
# Try to load the weights to a new instance
|
||
|
def weight_generator():
|
||
|
yield from mod.state_dict().items()
|
||
|
|
||
|
new_mod = ModuleWithBatchNorm()
|
||
|
|
||
|
assert not torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
|
||
|
assert not torch.all(new_mod.bn.running_var == mod.bn.running_var)
|
||
|
assert new_mod.bn.num_batches_tracked.item() == 0
|
||
|
|
||
|
loader = AutoWeightsLoader(new_mod)
|
||
|
loader.load_weights(weight_generator())
|
||
|
|
||
|
# Ensure the stats are updated
|
||
|
assert torch.all(new_mod.bn.running_mean == mod.bn.running_mean)
|
||
|
assert torch.all(new_mod.bn.running_var == mod.bn.running_var)
|
||
|
assert new_mod.bn.num_batches_tracked.item() == 1
|
||
|
|
||
|
|
||
|
def test_module_with_child_containing_batchnorm_can_autoload():
|
||
|
"""Ensure the auto weight loader can load nested modules batchnorm stats."""
|
||
|
mod = ModuleWithNestedBatchNorm()
|
||
|
# Run some data through the module with batchnorm
|
||
|
mod(torch.Tensor([[1, 2], [3, 4]]))
|
||
|
|
||
|
# Try to load the weights to a new instance
|
||
|
def weight_generator():
|
||
|
yield from mod.state_dict().items()
|
||
|
|
||
|
new_mod = ModuleWithNestedBatchNorm()
|
||
|
|
||
|
assert not torch.all(
|
||
|
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||
|
assert not torch.all(
|
||
|
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||
|
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0
|
||
|
|
||
|
loader = AutoWeightsLoader(new_mod)
|
||
|
loader.load_weights(weight_generator())
|
||
|
|
||
|
# Ensure the stats are updated
|
||
|
assert torch.all(
|
||
|
new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean)
|
||
|
assert torch.all(
|
||
|
new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var)
|
||
|
assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1
|