# SPDX-License-Identifier: Apache-2.0 import torch from vllm.model_executor.models.utils import AutoWeightsLoader class ModuleWithBatchNorm(torch.nn.Module): def __init__(self): super().__init__() self.bn = torch.nn.BatchNorm1d(2) def forward(self, x): return self.bn(x) class ModuleWithNestedBatchNorm(torch.nn.Module): def __init__(self): super().__init__() self.nested_mod = ModuleWithBatchNorm() def forward(self, x): return self.nested_mod(x) def test_module_with_batchnorm_can_load(): """Ensure the auto weight loader can load batchnorm stats.""" mod = ModuleWithBatchNorm() # Run some data through the module with batchnorm mod(torch.Tensor([[1, 2], [3, 4]])) # Try to load the weights to a new instance def weight_generator(): yield from mod.state_dict().items() new_mod = ModuleWithBatchNorm() assert not torch.all(new_mod.bn.running_mean == mod.bn.running_mean) assert not torch.all(new_mod.bn.running_var == mod.bn.running_var) assert new_mod.bn.num_batches_tracked.item() == 0 loader = AutoWeightsLoader(new_mod) loader.load_weights(weight_generator()) # Ensure the stats are updated assert torch.all(new_mod.bn.running_mean == mod.bn.running_mean) assert torch.all(new_mod.bn.running_var == mod.bn.running_var) assert new_mod.bn.num_batches_tracked.item() == 1 def test_module_with_child_containing_batchnorm_can_autoload(): """Ensure the auto weight loader can load nested modules batchnorm stats.""" mod = ModuleWithNestedBatchNorm() # Run some data through the module with batchnorm mod(torch.Tensor([[1, 2], [3, 4]])) # Try to load the weights to a new instance def weight_generator(): yield from mod.state_dict().items() new_mod = ModuleWithNestedBatchNorm() assert not torch.all( new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) assert not torch.all( new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 0 loader = AutoWeightsLoader(new_mod) loader.load_weights(weight_generator()) # Ensure the stats are updated assert torch.all( new_mod.nested_mod.bn.running_mean == mod.nested_mod.bn.running_mean) assert torch.all( new_mod.nested_mod.bn.running_var == mod.nested_mod.bn.running_var) assert new_mod.nested_mod.bn.num_batches_tracked.item() == 1