ZhengPeng7 commited on
Commit
f6b7155
1 Parent(s): 937a7a8

Dtype adaptability between FP32 and FP16 in inference.

Browse files
Files changed (1) hide show
  1. birefnet.py +1 -1
birefnet.py CHANGED
@@ -992,7 +992,7 @@ class BasicLayer(nn.Module):
992
  mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
993
  mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
994
  attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
995
- attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
996
 
997
  for blk in self.blocks:
998
  blk.H, blk.W = H, W
 
992
  mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
993
  mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
994
  attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
995
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)).to(x.dtype)
996
 
997
  for blk in self.blocks:
998
  blk.H, blk.W = H, W