Bug when num_classes=0
Here is my code snippet
from urllib.request import urlopen
from PIL import Image
import timm
img = Image.open(urlopen(
'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png'
))
model = timm.create_model('mobilenetv4_hybrid_medium.e200_r256_in12k_ft_in1k', pretrained=True, num_classes=0)
model = model.eval()
# get model specific transforms (normalization, resize)
data_config = timm.data.resolve_model_data_config(model)
transforms = timm.data.create_transform(**data_config, is_training=False)
output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1
This causes the error shown below:
```
ValueError Traceback (most recent call last)
Cell In[24], line 1
----> 1 output = model(transforms(img).unsqueeze(0)) # unsqueeze single image into batch of 1
File /opt/homebrew/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /opt/homebrew/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /opt/homebrew/lib/python3.12/site-packages/timm/models/mobilenetv3.py:273, in MobileNetV3.forward(self, x)
271 def forward(self, x: torch.Tensor) -> torch.Tensor:
272 x = self.forward_features(x)
--> 273 x = self.forward_head(x)
274 return x
File /opt/homebrew/lib/python3.12/site-packages/timm/models/mobilenetv3.py:262, in MobileNetV3.forward_head(self, x, pre_logits)
260 x = self.global_pool(x)
261 x = self.conv_head(x)
--> 262 x = self.norm_head(x)
263 x = self.act2(x)
264 x = self.flatten(x)
File /opt/homebrew/lib/python3.12/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /opt/homebrew/lib/python3.12/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /opt/homebrew/lib/python3.12/site-packages/timm/layers/norm_act.py:115, in BatchNormAct2d.forward(self, x)
108 bn_training = (self.running_mean is None) and (self.running_var is None)
110 r"""
111 Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
112 passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
113 used for normalization (i.e. in eval mode when buffers are not None).
114 """
--> 115 x = F.batch_norm(
116 x,
117 # If buffers are not to be tracked, ensure that they won't be updated
118 self.running_mean if not self.training or self.track_running_stats else None,
119 self.running_var if not self.training or self.track_running_stats else None,
120 self.weight,
121 self.bias,
122 bn_training,
123 exponential_average_factor,
124 self.eps,
125 )
126 x = self.drop(x)
127 x = self.act(x)
File /opt/homebrew/lib/python3.12/site-packages/torch/nn/functional.py:2507, in batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)
2494 return handle_torch_function(
2495 batch_norm,
2496 (input, running_mean, running_var, weight, bias),
(...)
2504 eps=eps,
2505 )
2506 if training:
-> 2507 _verify_batch_size(input.size())
2509 return torch.batch_norm(
2510 input, weight, bias, running_mean, running_var, training, momentum, eps, torch.backends.cudnn.enabled
2511 )
File /opt/homebrew/lib/python3.12/site-packages/torch/nn/functional.py:2475, in _verify_batch_size(size)
2473 size_prods *= size[i + 2]
2474 if size_prods == 1:
-> 2475 raise ValueError(f"Expected more than 1 value per channel when training, got input size {size}")
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 1280, 1, 1])
```
Use a batch size > 1 when you're in train mode, or switch to eval. I don't think you used that snippet verbatim.
wow, sorry yes. I hadn't done model.eval() locally before and on train mode, and it was the following snippet that caused the error:
def _get_output_dimension(input_size: tuple, model: nn.Module) -> int:
dummy_input = torch.randn(1, *input_size)
with torch.inference_mode():
dim = model(dummy_input).shape[-1]
return dim
Would this behavious get "fixed" so to speak, so that it accepts batch size of 1, or is this something that will simply not work with mobilenetv4 for some reason
nothing to do with mobilenetv4 specifically, it's a restriction of the batchnorm layer, it doesn't work with any model that has BatchNorm layers, batch size must be > 1 in train mode.