File size: 1,060 Bytes
bc05b03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from models import segmentation, inpainting
from PIL import Image

class ProductBackgroundModifier:
    def __init__(
        self,
        segmentation_model: segmentation.SegmentationModel,
        inpainting_model: inpainting.InpaintingModel,
        device = torch.device("cpu"),
    ) -> None:
        self.segmentation_model = segmentation_model
        self.inpainting_model = inpainting_model
        self.device = device
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(1024),
            transforms.CenterCrop((1024, 1024))
        ])

    def generate(self, image: Image.Image, prompt: str) -> Image.Image:
        image_tensor = self.transform(image).to(self.device)
        mask_image = self.segmentation_model.generate(image_tensor)
        generated_image = self.inpainting_model.generate(image=image, mask_image=mask_image, prompt=prompt)
        return generated_image