initial - v0
Browse files- .gitignore +3 -0
- app.py +29 -0
- model_checkpoints/sam_vit.pth +3 -0
- models/inpainting.py +26 -0
- models/product.py +30 -0
- models/segmentation.py +43 -0
- requirements.txt +0 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
__pycache__/
|
3 |
+
flagged/
|
app.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from PIL import Image
|
3 |
+
from models.segmentation import SamSegmentationModel
|
4 |
+
from models.inpainting import KandingskyInpaintingModel
|
5 |
+
from models.product import ProductBackgroundModifier
|
6 |
+
import torch
|
7 |
+
|
8 |
+
def generate(image: Image.Image, prompt: str):
|
9 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
10 |
+
model = ProductBackgroundModifier(
|
11 |
+
segmentation_model=SamSegmentationModel(
|
12 |
+
model_type="vit_h",
|
13 |
+
checkpoint_path="model_checkpoints/sam_vit.pth",
|
14 |
+
device=device,
|
15 |
+
),
|
16 |
+
inpainting_model=KandingskyInpaintingModel(),
|
17 |
+
device=device
|
18 |
+
)
|
19 |
+
generated = model.generate(image=image, prompt=prompt)
|
20 |
+
return generated
|
21 |
+
|
22 |
+
gr.Interface(
|
23 |
+
fn=generate,
|
24 |
+
inputs=[
|
25 |
+
gr.Image(type="pil"),
|
26 |
+
gr.Text()
|
27 |
+
],
|
28 |
+
outputs=gr.Image(type="pil"),
|
29 |
+
).launch()
|
model_checkpoints/sam_vit.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
3 |
+
size 2564550879
|
models/inpainting.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers import AutoPipelineForInpainting
|
3 |
+
from torchvision.transforms.functional import to_pil_image
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
class InpaintingModel:
|
7 |
+
def __init__(self) -> None:
|
8 |
+
pass
|
9 |
+
def generate(self, image: torch.Tensor, mask_image: torch.Tensor, prompt: str) -> Image.Image:
|
10 |
+
pass
|
11 |
+
|
12 |
+
class KandingskyInpaintingModel(InpaintingModel):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
device = torch.device("cpu"),
|
16 |
+
) -> None:
|
17 |
+
super().__init__()
|
18 |
+
self.device = device
|
19 |
+
self.model = AutoPipelineForInpainting.from_pretrained("kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16)
|
20 |
+
self.model.enable_model_cpu_offload()
|
21 |
+
self.negative_prompt = "deformed, ugly, disfigured"
|
22 |
+
|
23 |
+
def generate(self, image: Image.Image, mask_image: Image.Image, prompt: str) -> Image.Image:
|
24 |
+
output = self.model(prompt=prompt, negative_prompt=self.negative_prompt, image=image, mask_image=mask_image)
|
25 |
+
return output.images[0]
|
26 |
+
|
models/product.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms
|
3 |
+
from torchvision.transforms.functional import to_pil_image
|
4 |
+
from models import segmentation, inpainting
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
class ProductBackgroundModifier:
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
segmentation_model: segmentation.SegmentationModel,
|
11 |
+
inpainting_model: inpainting.InpaintingModel,
|
12 |
+
device = torch.device("cpu"),
|
13 |
+
) -> None:
|
14 |
+
self.segmentation_model = segmentation_model
|
15 |
+
self.inpainting_model = inpainting_model
|
16 |
+
self.device = device
|
17 |
+
self.transform = transforms.Compose([
|
18 |
+
transforms.ToTensor(),
|
19 |
+
transforms.Resize(1024),
|
20 |
+
transforms.CenterCrop((1024, 1024))
|
21 |
+
])
|
22 |
+
|
23 |
+
def generate(self, image: Image.Image, prompt: str) -> Image.Image:
|
24 |
+
image_tensor = self.transform(image).to(self.device)
|
25 |
+
mask_image = self.segmentation_model.generate(image_tensor)
|
26 |
+
mask_image.show()
|
27 |
+
generated_image = self.inpainting_model.generate(image=image, mask_image=mask_image, prompt=prompt)
|
28 |
+
return generated_image
|
29 |
+
|
30 |
+
|
models/segmentation.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision.transforms.functional import to_pil_image
|
3 |
+
from segment_anything import SamPredictor, sam_model_registry
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
class SegmentationModel:
|
7 |
+
def __init__(self) -> None:
|
8 |
+
pass
|
9 |
+
def generate(self, image: torch.Tensor) -> Image.Image:
|
10 |
+
pass
|
11 |
+
|
12 |
+
class SamSegmentationModel(SegmentationModel):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
model_type: str,
|
16 |
+
checkpoint_path: str,
|
17 |
+
device = torch.device("cpu"),
|
18 |
+
) -> None:
|
19 |
+
super().__init__()
|
20 |
+
sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
|
21 |
+
sam.to(device)
|
22 |
+
self.device = device
|
23 |
+
self.model = SamPredictor(sam)
|
24 |
+
|
25 |
+
def generate(self, image: torch.Tensor) -> Image.Image:
|
26 |
+
_, H, W = image.size()
|
27 |
+
image = image.unsqueeze(0)
|
28 |
+
self.model.set_torch_image(image, original_image_size=(H, W))
|
29 |
+
center_point = [H / 2, W / 2]
|
30 |
+
input_point = torch.tensor([[center_point]]).to(self.device)
|
31 |
+
input_label = torch.tensor([[1]]).to(self.device)
|
32 |
+
masks, scores, logits = self.model.predict_torch(
|
33 |
+
point_coords=input_point,
|
34 |
+
point_labels=input_label,
|
35 |
+
boxes=None,
|
36 |
+
multimask_output=True
|
37 |
+
)
|
38 |
+
masks = masks.squeeze(0)
|
39 |
+
scores = scores.squeeze(0)
|
40 |
+
bmask = masks[torch.argmax(scores).item()]
|
41 |
+
mask_float = 1.0 - bmask.float()
|
42 |
+
final = torch.stack([mask_float, mask_float, mask_float])
|
43 |
+
return to_pil_image(final)
|
requirements.txt
ADDED
File without changes
|