File size: 7,332 Bytes
eca813c
 
 
 
 
c0b640d
eca813c
bf6948f
eca813c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0b640d
 
 
 
 
 
eca813c
 
 
726ebdf
eca813c
726ebdf
eca813c
 
 
 
 
 
 
 
 
 
 
 
 
 
c0b640d
 
 
 
 
 
 
 
 
 
 
 
eca813c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0b640d
eca813c
 
 
 
 
 
 
 
f97637f
eca813c
726ebdf
 
5655f12
 
eca813c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
726ebdf
 
eca813c
 
 
726ebdf
eca813c
 
c0b640d
eca813c
726ebdf
eca813c
 
c0b640d
726ebdf
30e735a
f97637f
726ebdf
 
eca813c
 
30e735a
 
4b406e7
30e735a
4b406e7
726ebdf
 
 
 
 
 
eca813c
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib import cm

import cv2
from PIL import Image
import numpy as np

import torch
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
from diffusers import StableDiffusionInpaintPipeline


class VirtualStagingToolV2():

    def __init__(self,
                 segmentation_version='openmmlab/upernet-convnext-tiny',
                 diffusion_version="stabilityai/stable-diffusion-2-inpainting"
                 ):

        self.segmentation_version = segmentation_version
        self.diffusion_version = diffusion_version

        if segmentation_version == "openmmlab/upernet-convnext-tiny":
            self.feature_extractor = AutoImageProcessor.from_pretrained(self.segmentation_version)
            self.segmentation_model = UperNetForSemanticSegmentation.from_pretrained(self.segmentation_version)
        elif segmentation_version == "nvidia/segformer-b5-finetuned-ade-640-640":
            self.feature_extractor = SegformerFeatureExtractor.from_pretrained(self.segmentation_version)
            self.segmentation_model = SegformerForSemanticSegmentation.from_pretrained(self.segmentation_version)

        self.diffution_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
            self.diffusion_version,
            torch_dtype=torch.float16,
        )
        self.diffution_pipeline = self.diffution_pipeline.to("cuda")

    def _predict(self, image):
        inputs = self.feature_extractor(images=image, return_tensors="pt")
        outputs = self.segmentation_model(**inputs)
        prediction = \
        self.feature_extractor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
        return prediction

    def _save_mask(self, img, prediction_array, mask_items=[]):
        mask = np.zeros_like(prediction_array, dtype=np.uint8)

        mask[np.isin(prediction_array, mask_items)] = 0
        mask[~np.isin(prediction_array, mask_items)] = 255

        buffer_size = 10

        # Dilate the binary image
        kernel = np.ones((buffer_size, buffer_size), np.uint8)
        dilated_image = cv2.dilate(mask, kernel, iterations=1)

        # Subtract the original binary image
        buffer_area = dilated_image - mask

        # Apply buffer area to the original image
        mask = cv2.bitwise_or(mask, buffer_area)

        #     # # Create a PIL Image object from the mask
        mask_image = Image.fromarray(mask, mode='L')
        # display(mask_image)

        # mask_image = mask_image.resize((512, 512))
        # mask_image.save(".tmp/mask_1.png", "PNG")
        # img = img.resize((512, 512))
        # img.save(".tmp/input_1.png", "PNG")
        return mask_image

    def _save_transparent_mask(self, img, prediction_array, mask_items=[]):
        mask = np.array(img)
        mask[~np.isin(prediction_array, mask_items), :] = 255
        mask_image = Image.fromarray(mask).convert('RGBA')

        # Set the transparency of the pixels corresponding to object 1 to 0 (fully transparent)
        mask_data = mask_image.getdata()
        mask_data = [(r, g, b, 0) if r == 255 else (r, g, b, 255) for (r, g, b, a) in mask_data]
        mask_image.putdata(mask_data)

        return mask_image

    def get_mask(self, image_path=None, image=None):
        if image_path:
            image = Image.open(image_path)
        else:
            if not image:
                raise ValueError("no image provided")

        # display(image)
        prediction = self._predict(image)

        label_ids = np.unique(prediction)

        mask_items = [0, 3, 5, 8, 14]

        if 1 in label_ids or 25 in label_ids:
            mask_items = [1, 2, 4, 25, 32]
            room = 'backyard'
        elif 73 in label_ids or 50 in label_ids or 61 in label_ids:
            mask_items = [0, 3, 5, 8, 14, 50, 61, 71, 73, 118, 124, 129
                          ]
            room = 'kitchen'
        elif 37 in label_ids or 65 in label_ids or (27 in label_ids and 47 in label_ids and 70 in label_ids):
            mask_items = [0, 3, 5, 8, 14, 27, 65]
            room = 'bathroom'
        elif 7 in label_ids:
            room = 'bedroom'
        elif 23 in label_ids or 49 in label_ids:
            mask_items = [0, 3, 5, 8, 14, 49]
            room = 'living room'
        elif 15 in label_ids and 19 in label_ids:
            room = 'dining room'
        else:
            room ='room'
        label_ids_without_mask = [i for i in label_ids if i not in mask_items]

        items = [self.segmentation_model.config.id2label[i] for i in label_ids_without_mask]

        mask_image = self._save_mask(image, prediction, mask_items)
        transparent_mask_image = self._save_transparent_mask(image, prediction, mask_items)
        return mask_image, transparent_mask_image, image, items, room

    def _edit_image(self, init_image, mask_image, prompt,  # height, width,
                    number_images=1):

        init_image = init_image.resize((512, 512)).convert("RGB")
        mask_image = mask_image.resize((512, 512)).convert("RGB")

        output_images = self.diffution_pipeline(
            prompt=prompt, image=init_image, mask_image=mask_image,
            # width=width, height=height,
            num_images_per_prompt=number_images).images
        # display(output_image)
        return output_images

    def virtual_stage(self, image_path=None, image=None, style=None,
                      color_preference=None, additional_info=None, number_images=1):
        mask_image, transparent_mask_image, init_image, items, room = self.get_mask(image_path, image)
        if not style:
            raise ValueError('style not provided.')


        if room == 'kitchen':
            items = [i for i in items if i in ['cabinet', 'shelf', 'counter', 'countertop', 'stool']]
        elif room == 'bedroom':
            items = [i for i in items if i in ['bed ', 'table', 'chest of drawers', 'desk', 'armchair', 'wardrobe']]
        elif room == 'bathroom':
            items = [i for i in items if
                     i in ['shower', 'bathtub', 'screen door', 'cabinet']]
        elif room == 'living room':
            items = [i for i in items if
                     i in ['table', 'sofa', 'chest of drawers', 'armchair', 'cabinet', 'coffee table']]
        elif room == 'dining room':
            items = [i for i in items if i in ['table', 'chair', 'cabinet']]

        items = ', '.join(items)

        if room == 'backyard':
            prompt = f'Realistic, high resolution, {room} with {style}'
        else:
            prompt = f'Realistic {items}, high resolution, in the {style} style {room}'

        if color_preference:
            prompt = f"{prompt} in {color_preference}"

        if additional_info:
            prompt = f'{prompt}. {additional_info}'
        print(prompt)

        output_images = self._edit_image(init_image, mask_image, prompt, number_images)

        final_output_images = []
        for output_image in output_images:
            output_image = output_image.resize(init_image.size)
            final_output_images.append(output_image)
        return final_output_images, transparent_mask_image