File size: 10,645 Bytes
2b799f2
 
3a88d20
2b799f2
 
1479d8c
e6f2d3f
3a88d20
 
 
 
 
 
74c4534
3a88d20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b7222
3a88d20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79b7222
3a88d20
 
 
 
 
 
 
 
 
 
86bb1a1
3a88d20
2b799f2
 
3a88d20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6b09f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import json
import base64
import time
import os
import random
import io
from dotenv import load_dotenv
import replicate
from PIL import Image, ImageOps
from io import BytesIO

# Load environment variables
load_dotenv()
# Constants
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")

# Create the tab for the image analyzer
def image_analyzer_tab():
    # Function to analyze the image
    def analyze_image(image):
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        analysis = replicate.run(
            "andreasjansson/blip-2:4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608",
            input={"image": "data:image/png;base64," + img_str, "prompt": "what's in this picture?"}
        )
        return analysis



class Config:
    REPLICATE_API_TOKEN = REPLICATE_API_TOKEN

class ImageUtils:
    @staticmethod
    def image_to_base64(image):
        buffered = io.BytesIO()
        image.save(buffered, format="JPEG")
        return base64.b64encode(buffered.getvalue()).decode('utf-8')

    @staticmethod
    def convert_image_mode(image, mode="RGB"):
        if image.mode != mode:
            return image.convert(mode)
        return image

def pad_image(image, padding_color=(255, 255, 255)):
    width, height = image.size
    new_width = width + 20
    new_height = height + 20
    result = Image.new(image.mode, (new_width, new_height), padding_color)
    result.paste(image, (10, 10))
    return result

def resize_and_pad_image(image, target_width, target_height, padding_color=(255, 255, 255)):
    original_width, original_height = image.size
    aspect_ratio = original_width / original_height
    target_aspect_ratio = target_width / target_height

    if aspect_ratio > target_aspect_ratio:
        new_width = target_width
        new_height = int(target_width / aspect_ratio)
    else:
        new_width = int(target_height * aspect_ratio)
        new_height = target_height

    resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
    padded_image = Image.new(image.mode, (target_width, target_height), padding_color)
    padded_image.paste(resized_image, ((target_width - new_width) // 2, (target_height - new_height) // 2))
    return padded_image

def image_prompt(prompt, cn_img1, cn_img2, cn_img3, cn_img4, weight1, weight2, weight3, weight4):
    cn_img1 = pad_image(cn_img1)
    buffered1 = BytesIO()
    cn_img1.save(buffered1, format="PNG")
    cn_img1_base64 = base64.b64encode(buffered1.getvalue()).decode('utf-8')

    buffered2 = BytesIO()
    cn_img2.save(buffered2, format="PNG")
    cn_img2_base64 = base64.b64encode(buffered2.getvalue()).decode('utf-8')

    buffered3 = BytesIO()
    cn_img3.save(buffered3, format="PNG")
    cn_img3_base64 = base64.b64encode(buffered3.getvalue()).decode('utf-8')

    buffered4 = BytesIO()
    cn_img4.save(buffered4, format="PNG")
    cn_img4_base64 = base64.b64encode(buffered4.getvalue()).decode('utf-8')
    
    # Resize and pad the sketch input image to match the aspect ratio selection
    aspect_ratio_width, aspect_ratio_height = 1280, 768
    uov_input_image = resize_and_pad_image(cn_img1, aspect_ratio_width, aspect_ratio_height)
    buffered_uov = BytesIO()
    uov_input_image.save(buffered_uov, format="PNG")
    uov_input_image_base64 = base64.b64encode(buffered_uov.getvalue()).decode('utf-8')

    # Call the Replicate API to generate the image
    fooocus_model = replicate.models.get("vetkastar/fooocus").versions.get("d555a800025fe1c171e386d299b1de635f8d8fc3f1ade06a14faf5154eba50f3")
    image = replicate.predictions.create(version=fooocus_model, input={
        "prompt": prompt,
        "cn_type1": "PyraCanny",
        "cn_type2": "ImagePrompt",
        "cn_type3": "ImagePrompt",
        "cn_type4": "ImagePrompt",
        "cn_weight1": weight1,
        "cn_weight2": weight2,
        "cn_weight3": weight3,
        "cn_weight4": weight4,
        "cn_img1": "data:image/png;base64," + cn_img1_base64,
        "cn_img2": "data:image/png;base64," + cn_img2_base64,
        "cn_img3": "data:image/png;base64," + cn_img3_base64,
        "cn_img4": "data:image/png;base64," + cn_img4_base64,
        "uov_input_image": "data:image/png;base64," + uov_input_image_base64,
        "sharpness": 2,
        "image_seed": -1,
        "image_number": 1,
        "guidance_scale": 7,
        "refiner_switch": 0.5,
        "negative_prompt": "",
        "inpaint_strength": 0.5,
        "style_selections": "Fooocus V2,Fooocus Enhance,Fooocus Sharp",
        "loras_custom_urls": "",
        "uov_upscale_value": 0,
        "use_default_loras": True,
        "outpaint_selections": "",
        "outpaint_distance_top": 0,
        "performance_selection": "Lightning",
        "outpaint_distance_left": 0,
        "aspect_ratios_selection": "1280*768",
        "outpaint_distance_right": 0,
        "outpaint_distance_bottom": 0,
        "inpaint_additional_prompt": "",
        "uov_method": "Vary (Subtle)"
    })
    image.wait()
    # Fetch the generated image from the output URL
    response = requests.get(image.output["paths"][0])
    img = Image.open(BytesIO(response.content))
    
    with open("output.png", "wb") as f:
        f.write(response.content)
    return "output.png", "Job completed successfully using Replicate API."

def create_status_image():
    if os.path.exists("output.png"):
        return "output.png"
    else:
        return None

def preload_images(cn_img2, cn_img3, cn_img4):
    cn_img2 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400"
    cn_img3 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400"
    cn_img4 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400"
    return cn_img2, cn_img3, cn_img4

def shuffle_and_load_images(files):
    if not files:
        return generate_placeholder_image(), generate_placeholder_image(), generate_placeholder_image()
    else:
        random.shuffle(files)
        return files[0], files[1], files[2]

def analyze_image(image: Image.Image) -> dict:
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
    analysis = replicate.run(
        "andreasjansson/blip-2:4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608",
        input={"image": "data:image/png;base64," + img_str, "prompt": "what's in this picture?"}
    )
    return analysis

def get_prompt_from_image(image: Image.Image) -> str:
    analysis = analyze_image(image)
    return analysis.get("describe", "")

def generate_prompt(image: Image.Image, current_prompt: str) -> str:
    return get_prompt_from_image(image)

import gradio as gr

def create_gradio_interface():
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column(scale=0):
                with gr.Tab(label="Sketch"):
                    image_input = cn_img1_input = gr.Image(label="Sketch", type="pil")
                    weight1 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.75)
                    copy_to_sketch_button = gr.Button("Grab Last Output")

                
                with gr.Accordion("Upload Project Files", open=False):
                    with gr.Accordion("πŸ“", open=False):
                        file_upload = gr.File(file_count="multiple", elem_classes="gradio-column")
                        image_gallery = gr.Gallery(label="Image Gallery", elem_classes="gradio-column")
                        file_upload.change(shuffle_and_load_images, inputs=[file_upload], outputs=[image_gallery])
            with gr.Column(scale=2):
                with gr.Tab(label="Node"):
                    with gr.Accordion("Output"):
                        with gr.Column():
                            status = gr.Textbox(label="Status")
                            status_image = gr.Image(label="Queue Status", interactive=False)
                            with gr.Row():
                                with gr.Column(scale=1):
                                    analysis_output = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here")
                                with gr.Column(scale=0):
                                    analyze_button = gr.Button("Analyze Image")
                                    analyze_button.click(fn=analyze_image, inputs=image_input, outputs=analysis_output)
                            with gr.Row():
                                preload_button = gr.Button("🌸")
                                shuffle_and_load_button = gr.Button("πŸ“‚")
                                generate_button = gr.Button("πŸš€ Generate πŸš€")

                with gr.Row():
                    with gr.Column():
                        cn_img2_input = gr.Image(label="Image Prompt 2", type="pil", height=256)
                        weight2 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5)
                    with gr.Column():
                        cn_img3_input = gr.Image(label="Image Prompt 3", type="pil", height=256)
                        weight3 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5)
                    with gr.Column():
                        cn_img4_input = gr.Image(label="Image Prompt 4", type="pil", height=256)
                        weight4 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5)

                with gr.Row():
                    preload_button.click(preload_images, inputs=[cn_img2_input, cn_img3_input, cn_img4_input], outputs=[cn_img2_input, cn_img3_input, cn_img4_input])
                    shuffle_and_load_button.click(shuffle_and_load_images, inputs=[file_upload], outputs=[cn_img2_input, cn_img3_input, cn_img4_input])

                    generate_button.click(
                        fn=image_prompt,
                        inputs=[analysis_output, cn_img1_input, cn_img2_input, cn_img3_input, cn_img4_input, weight1, weight2, weight3, weight4],
                        outputs=[status_image, status]
                    )

                    copy_to_sketch_button.click(
                        fn=lambda: Image.open("output.png") if os.path.exists("output.png") else None,
                        inputs=[],
                        outputs=[cn_img1_input]
                    )

                # ⏲️ Update the image every 5 seconds
                demo.load(create_status_image, every=5, outputs=status_image)

    demo.launch(server_name="0.0.0.0", server_port=6644, share=True)

create_gradio_interface()