Fix example, load weights safely and remove extra whitespace
#2
by
PartyParrot
- opened
- README.md +11 -11
- image2.png +0 -0
- model.py +27 -27
README.md
CHANGED
@@ -13,34 +13,34 @@ tags:
|
|
13 |
|
14 |
# BEN - Background Erase Network (Beta Base Model)
|
15 |
|
16 |
-
BEN is a deep learning model designed to automatically remove backgrounds from images, producing both a mask and a foreground image.
|
17 |
|
18 |
- MADE IN AMERICA
|
19 |
|
20 |
## Quick Start Code
|
|
|
21 |
```python
|
22 |
-
|
23 |
from PIL import Image
|
24 |
import torch
|
25 |
|
26 |
|
27 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
28 |
|
29 |
-
file = "./image2.
|
30 |
|
31 |
-
model = model.BEN_Base().to(device).eval() #init pipeline
|
32 |
|
33 |
-
model.loadcheckpoints("./
|
34 |
image = Image.open(file)
|
35 |
-
|
|
|
36 |
|
37 |
mask.save("./mask.png")
|
38 |
foreground.save("./foreground.png")
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
```
|
43 |
-
|
|
|
44 |
|
45 |
![Demo Results](demo.jpg)
|
46 |
|
@@ -84,4 +84,4 @@ foreground.save("./foreground.png")
|
|
84 |
|
85 |
## Installation
|
86 |
1. Clone Repo
|
87 |
-
2. Install requirements.txt
|
|
|
13 |
|
14 |
# BEN - Background Erase Network (Beta Base Model)
|
15 |
|
16 |
+
BEN is a deep learning model designed to automatically remove backgrounds from images, producing both a mask and a foreground image.
|
17 |
|
18 |
- MADE IN AMERICA
|
19 |
|
20 |
## Quick Start Code
|
21 |
+
|
22 |
```python
|
23 |
+
import model
|
24 |
from PIL import Image
|
25 |
import torch
|
26 |
|
27 |
|
28 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
29 |
|
30 |
+
file = "./image2.png" # input image
|
31 |
|
32 |
+
model = model.BEN_Base().to(device).eval() #init pipeline
|
33 |
|
34 |
+
model.loadcheckpoints("./BEN_Base.pth")
|
35 |
image = Image.open(file)
|
36 |
+
with torch.no_grad():
|
37 |
+
mask, foreground = model.inference(image)
|
38 |
|
39 |
mask.save("./mask.png")
|
40 |
foreground.save("./foreground.png")
|
|
|
|
|
|
|
41 |
```
|
42 |
+
|
43 |
+
# BEN SOA Benchmarks on Disk 5k Eval
|
44 |
|
45 |
![Demo Results](demo.jpg)
|
46 |
|
|
|
84 |
|
85 |
## Installation
|
86 |
1. Clone Repo
|
87 |
+
2. Install requirements.txt
|
image2.png
ADDED
model.py
CHANGED
@@ -560,7 +560,7 @@ class SwinTransformer(nn.Module):
|
|
560 |
# interpolate the position embedding to the corresponding size
|
561 |
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
562 |
x = (x + absolute_pos_embed) # B Wh*Ww C
|
563 |
-
|
564 |
outs = [x.contiguous()]
|
565 |
x = x.flatten(2).transpose(1, 2)
|
566 |
x = self.pos_drop(x)
|
@@ -634,7 +634,7 @@ class PositionEmbeddingSine:
|
|
634 |
scale = 2 * math.pi
|
635 |
self.scale = scale
|
636 |
self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
|
637 |
-
|
638 |
def __call__(self, b, h, w):
|
639 |
device = self.dim_t.device
|
640 |
mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
|
@@ -646,18 +646,18 @@ class PositionEmbeddingSine:
|
|
646 |
eps = 1e-6
|
647 |
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
648 |
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
649 |
-
|
650 |
dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
|
651 |
pos_x = x_embed[:, :, :, None] / dim_t
|
652 |
pos_y = y_embed[:, :, :, None] / dim_t
|
653 |
-
|
654 |
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
655 |
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
656 |
-
|
657 |
-
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
658 |
|
|
|
659 |
|
660 |
-
|
|
|
661 |
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
662 |
super(MCLM, self).__init__()
|
663 |
self.attention = nn.ModuleList([
|
@@ -688,10 +688,10 @@ class MCLM(nn.Module):
|
|
688 |
l: 4,c,h,w
|
689 |
g: 1,c,h,w
|
690 |
"""
|
691 |
-
b, c, h, w = l.size()
|
692 |
# 4,c,h,w -> 1,c,2h,2w
|
693 |
concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
694 |
-
|
695 |
pools = []
|
696 |
for pool_ratio in self.pool_ratios:
|
697 |
# b,c,h,w
|
@@ -734,7 +734,7 @@ class MCLM(nn.Module):
|
|
734 |
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
735 |
l_hw_b_c = self.norm1(l_hw_b_c)
|
736 |
l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
737 |
-
l_hw_b_c = self.norm2(l_hw_b_c)
|
738 |
|
739 |
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
740 |
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
@@ -770,42 +770,42 @@ class MCRM(nn.Module):
|
|
770 |
|
771 |
def forward(self, x):
|
772 |
device = x.device
|
773 |
-
b, c, h, w = x.size()
|
774 |
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
775 |
-
|
776 |
patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
777 |
-
|
778 |
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
779 |
token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
|
780 |
loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
781 |
-
|
782 |
pools = []
|
783 |
for pool_ratio in self.pool_ratios:
|
784 |
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
785 |
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
786 |
pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
787 |
-
|
788 |
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
789 |
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
790 |
-
|
791 |
outputs = []
|
792 |
for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
|
793 |
v = pools[i]
|
794 |
k = v
|
795 |
outputs.append(self.attention[i](q, k, v)[0])
|
796 |
-
|
797 |
-
outputs = torch.cat(outputs, 1)
|
798 |
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
799 |
src = self.norm1(src)
|
800 |
src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
|
801 |
src = self.norm2(src)
|
802 |
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
803 |
glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
|
804 |
-
|
805 |
return torch.cat((src, glb), 0), token_attention_map
|
806 |
|
807 |
|
808 |
-
class BEN_Base(nn.Module):
|
809 |
def __init__(self):
|
810 |
super().__init__()
|
811 |
|
@@ -868,7 +868,7 @@ class BEN_Base(nn.Module):
|
|
868 |
e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
|
869 |
|
870 |
e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
|
871 |
-
e4 = self.conv4(e4)
|
872 |
e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
|
873 |
e3 = self.conv3(e3)
|
874 |
e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
|
@@ -909,11 +909,11 @@ class BEN_Base(nn.Module):
|
|
909 |
return blurred_mask, foreground
|
910 |
|
911 |
def loadcheckpoints(self,model_path):
|
912 |
-
model_dict = torch.load(model_path,map_location="cpu")
|
913 |
self.load_state_dict(model_dict['model_state_dict'], strict=True)
|
914 |
del model_path
|
915 |
|
916 |
-
|
917 |
|
918 |
|
919 |
def rgb_loader_refiner( original_image):
|
@@ -923,16 +923,16 @@ def rgb_loader_refiner( original_image):
|
|
923 |
# Convert to RGB if necessary
|
924 |
if image.mode != 'RGB':
|
925 |
image = image.convert('RGB')
|
926 |
-
|
927 |
# Resize the image
|
928 |
image = image.resize((1024, 1024), resample=Image.LANCZOS)
|
929 |
|
930 |
-
return image.convert('RGB'), h, w,original_image
|
931 |
-
|
932 |
# Define the image transformation
|
933 |
img_transform = transforms.Compose([
|
934 |
transforms.ToTensor(),
|
935 |
-
transforms.ConvertImageDtype(torch.float32),
|
936 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
937 |
])
|
938 |
|
|
|
560 |
# interpolate the position embedding to the corresponding size
|
561 |
absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
|
562 |
x = (x + absolute_pos_embed) # B Wh*Ww C
|
563 |
+
|
564 |
outs = [x.contiguous()]
|
565 |
x = x.flatten(2).transpose(1, 2)
|
566 |
x = self.pos_drop(x)
|
|
|
634 |
scale = 2 * math.pi
|
635 |
self.scale = scale
|
636 |
self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32)
|
637 |
+
|
638 |
def __call__(self, b, h, w):
|
639 |
device = self.dim_t.device
|
640 |
mask = torch.zeros([b, h, w], dtype=torch.bool, device=device)
|
|
|
646 |
eps = 1e-6
|
647 |
y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
|
648 |
x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
|
649 |
+
|
650 |
dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats)
|
651 |
pos_x = x_embed[:, :, :, None] / dim_t
|
652 |
pos_y = y_embed[:, :, :, None] / dim_t
|
653 |
+
|
654 |
pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
655 |
pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
|
|
|
656 |
|
657 |
+
return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
658 |
|
659 |
+
|
660 |
+
class MCLM(nn.Module):
|
661 |
def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]):
|
662 |
super(MCLM, self).__init__()
|
663 |
self.attention = nn.ModuleList([
|
|
|
688 |
l: 4,c,h,w
|
689 |
g: 1,c,h,w
|
690 |
"""
|
691 |
+
b, c, h, w = l.size()
|
692 |
# 4,c,h,w -> 1,c,2h,2w
|
693 |
concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2)
|
694 |
+
|
695 |
pools = []
|
696 |
for pool_ratio in self.pool_ratios:
|
697 |
# b,c,h,w
|
|
|
734 |
l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re)
|
735 |
l_hw_b_c = self.norm1(l_hw_b_c)
|
736 |
l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone())))
|
737 |
+
l_hw_b_c = self.norm2(l_hw_b_c)
|
738 |
|
739 |
l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c
|
740 |
return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w)
|
|
|
770 |
|
771 |
def forward(self, x):
|
772 |
device = x.device
|
773 |
+
b, c, h, w = x.size()
|
774 |
loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w
|
775 |
+
|
776 |
patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
777 |
+
|
778 |
token_attention_map = self.sigmoid(self.sal_conv(glb))
|
779 |
token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest')
|
780 |
loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2)
|
781 |
+
|
782 |
pools = []
|
783 |
for pool_ratio in self.pool_ratios:
|
784 |
tgt_hw = (round(h / pool_ratio), round(w / pool_ratio))
|
785 |
pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw)
|
786 |
pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw
|
787 |
+
|
788 |
pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c")
|
789 |
loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c')
|
790 |
+
|
791 |
outputs = []
|
792 |
for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches
|
793 |
v = pools[i]
|
794 |
k = v
|
795 |
outputs.append(self.attention[i](q, k, v)[0])
|
796 |
+
|
797 |
+
outputs = torch.cat(outputs, 1)
|
798 |
src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs)
|
799 |
src = self.norm1(src)
|
800 |
src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone())))
|
801 |
src = self.norm2(src)
|
802 |
src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc
|
803 |
glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb
|
804 |
+
|
805 |
return torch.cat((src, glb), 0), token_attention_map
|
806 |
|
807 |
|
808 |
+
class BEN_Base(nn.Module):
|
809 |
def __init__(self):
|
810 |
super().__init__()
|
811 |
|
|
|
868 |
e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16)
|
869 |
|
870 |
e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4))
|
871 |
+
e4 = self.conv4(e4)
|
872 |
e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3))
|
873 |
e3 = self.conv3(e3)
|
874 |
e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2))
|
|
|
909 |
return blurred_mask, foreground
|
910 |
|
911 |
def loadcheckpoints(self,model_path):
|
912 |
+
model_dict = torch.load(model_path, map_location="cpu", weights_only=True)
|
913 |
self.load_state_dict(model_dict['model_state_dict'], strict=True)
|
914 |
del model_path
|
915 |
|
916 |
+
|
917 |
|
918 |
|
919 |
def rgb_loader_refiner( original_image):
|
|
|
923 |
# Convert to RGB if necessary
|
924 |
if image.mode != 'RGB':
|
925 |
image = image.convert('RGB')
|
926 |
+
|
927 |
# Resize the image
|
928 |
image = image.resize((1024, 1024), resample=Image.LANCZOS)
|
929 |
|
930 |
+
return image.convert('RGB'), h, w,original_image
|
931 |
+
|
932 |
# Define the image transformation
|
933 |
img_transform = transforms.Compose([
|
934 |
transforms.ToTensor(),
|
935 |
+
transforms.ConvertImageDtype(torch.float32),
|
936 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
937 |
])
|
938 |
|