gmastrapas commited on
Commit
952897b
1 Parent(s): b845577

feat: add autocasting in vision.patch_embed

Browse files
Files changed (1) hide show
  1. eva_model.py +2 -1
eva_model.py CHANGED
@@ -462,13 +462,14 @@ class PatchEmbed(nn.Module):
462
  )
463
 
464
  def forward(self, x, **kwargs):
 
465
  B, C, H, W = x.shape
466
  # FIXME look at relaxing size constraints
467
  assert H == self.img_size[0] and W == self.img_size[1], (
468
  f"Input image size ({H}*{W}) doesn't match model "
469
  f'({self.img_size[0]}*{self.img_size[1]}).'
470
  )
471
- x = self.proj(x).flatten(2).transpose(1, 2)
472
  return x
473
 
474
 
 
462
  )
463
 
464
  def forward(self, x, **kwargs):
465
+ target_dtype = self.proj.weight.dtype
466
  B, C, H, W = x.shape
467
  # FIXME look at relaxing size constraints
468
  assert H == self.img_size[0] and W == self.img_size[1], (
469
  f"Input image size ({H}*{W}) doesn't match model "
470
  f'({self.img_size[0]}*{self.img_size[1]}).'
471
  )
472
+ x = self.proj(x.to(dtype=target_dtype)).flatten(2).transpose(1, 2)
473
  return x
474
 
475