import requests from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry import json import base64 import time import os import random import io from dotenv import load_dotenv import replicate from PIL import Image, ImageOps from io import BytesIO # Load environment variables load_dotenv() # Constants REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN") # Create the tab for the image analyzer def image_analyzer_tab(): # Function to analyze the image def analyze_image(image): buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") analysis = replicate.run( "andreasjansson/blip-2:4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608", input={"image": "data:image/png;base64," + img_str, "prompt": "what's in this picture?"} ) return analysis class Config: REPLICATE_API_TOKEN = REPLICATE_API_TOKEN class ImageUtils: @staticmethod def image_to_base64(image): buffered = io.BytesIO() image.save(buffered, format="JPEG") return base64.b64encode(buffered.getvalue()).decode('utf-8') @staticmethod def convert_image_mode(image, mode="RGB"): if image.mode != mode: return image.convert(mode) return image def pad_image(image, padding_color=(255, 255, 255)): width, height = image.size new_width = width + 20 new_height = height + 20 result = Image.new(image.mode, (new_width, new_height), padding_color) result.paste(image, (10, 10)) return result def resize_and_pad_image(image, target_width, target_height, padding_color=(255, 255, 255)): original_width, original_height = image.size aspect_ratio = original_width / original_height target_aspect_ratio = target_width / target_height if aspect_ratio > target_aspect_ratio: new_width = target_width new_height = int(target_width / aspect_ratio) else: new_width = int(target_height * aspect_ratio) new_height = target_height resized_image = image.resize((new_width, new_height), Image.ANTIALIAS) padded_image = Image.new(image.mode, (target_width, target_height), padding_color) padded_image.paste(resized_image, ((target_width - new_width) // 2, (target_height - new_height) // 2)) return padded_image def image_prompt(prompt, cn_img1, cn_img2, cn_img3, cn_img4, weight1, weight2, weight3, weight4): cn_img1 = pad_image(cn_img1) buffered1 = BytesIO() cn_img1.save(buffered1, format="PNG") cn_img1_base64 = base64.b64encode(buffered1.getvalue()).decode('utf-8') buffered2 = BytesIO() cn_img2.save(buffered2, format="PNG") cn_img2_base64 = base64.b64encode(buffered2.getvalue()).decode('utf-8') buffered3 = BytesIO() cn_img3.save(buffered3, format="PNG") cn_img3_base64 = base64.b64encode(buffered3.getvalue()).decode('utf-8') buffered4 = BytesIO() cn_img4.save(buffered4, format="PNG") cn_img4_base64 = base64.b64encode(buffered4.getvalue()).decode('utf-8') # Resize and pad the sketch input image to match the aspect ratio selection aspect_ratio_width, aspect_ratio_height = 1280, 768 uov_input_image = resize_and_pad_image(cn_img1, aspect_ratio_width, aspect_ratio_height) buffered_uov = BytesIO() uov_input_image.save(buffered_uov, format="PNG") uov_input_image_base64 = base64.b64encode(buffered_uov.getvalue()).decode('utf-8') # Call the Replicate API to generate the image fooocus_model = replicate.models.get("vetkastar/fooocus").versions.get("d555a800025fe1c171e386d299b1de635f8d8fc3f1ade06a14faf5154eba50f3") image = replicate.predictions.create(version=fooocus_model, input={ "prompt": prompt, "cn_type1": "PyraCanny", "cn_type2": "ImagePrompt", "cn_type3": "ImagePrompt", "cn_type4": "ImagePrompt", "cn_weight1": weight1, "cn_weight2": weight2, "cn_weight3": weight3, "cn_weight4": weight4, "cn_img1": "data:image/png;base64," + cn_img1_base64, "cn_img2": "data:image/png;base64," + cn_img2_base64, "cn_img3": "data:image/png;base64," + cn_img3_base64, "cn_img4": "data:image/png;base64," + cn_img4_base64, "uov_input_image": "data:image/png;base64," + uov_input_image_base64, "sharpness": 2, "image_seed": -1, "image_number": 1, "guidance_scale": 7, "refiner_switch": 0.5, "negative_prompt": "", "inpaint_strength": 0.5, "style_selections": "Fooocus V2,Fooocus Enhance,Fooocus Sharp", "loras_custom_urls": "", "uov_upscale_value": 0, "use_default_loras": True, "outpaint_selections": "", "outpaint_distance_top": 0, "performance_selection": "Lightning", "outpaint_distance_left": 0, "aspect_ratios_selection": "1280*768", "outpaint_distance_right": 0, "outpaint_distance_bottom": 0, "inpaint_additional_prompt": "", "uov_method": "Vary (Subtle)" }) image.wait() # Fetch the generated image from the output URL response = requests.get(image.output["paths"][0]) img = Image.open(BytesIO(response.content)) with open("output.png", "wb") as f: f.write(response.content) return "output.png", "Job completed successfully using Replicate API." def create_status_image(): if os.path.exists("output.png"): return "output.png" else: return None def preload_images(cn_img2, cn_img3, cn_img4): cn_img2 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400" cn_img3 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400" cn_img4 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400" return cn_img2, cn_img3, cn_img4 def shuffle_and_load_images(files): if not files: return generate_placeholder_image(), generate_placeholder_image(), generate_placeholder_image() else: random.shuffle(files) return files[0], files[1], files[2] def analyze_image(image: Image.Image) -> dict: buffered = BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") analysis = replicate.run( "andreasjansson/blip-2:4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608", input={"image": "data:image/png;base64," + img_str, "prompt": "what's in this picture?"} ) return analysis def get_prompt_from_image(image: Image.Image) -> str: analysis = analyze_image(image) return analysis.get("describe", "") def generate_prompt(image: Image.Image, current_prompt: str) -> str: return get_prompt_from_image(image) import gradio as gr def create_gradio_interface(): with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=0): with gr.Tab(label="Sketch"): image_input = cn_img1_input = gr.Image(label="Sketch", type="pil") weight1 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.75) copy_to_sketch_button = gr.Button("Grab Last Output") with gr.Accordion("Upload Project Files", open=False): with gr.Accordion("📁", open=False): file_upload = gr.File(file_count="multiple", elem_classes="gradio-column") image_gallery = gr.Gallery(label="Image Gallery", elem_classes="gradio-column") file_upload.change(shuffle_and_load_images, inputs=[file_upload], outputs=[image_gallery]) with gr.Column(scale=2): with gr.Tab(label="Node"): with gr.Accordion("Output"): with gr.Column(): status = gr.Textbox(label="Status") status_image = gr.Image(label="Queue Status", interactive=False) with gr.Row(): with gr.Column(scale=1): analysis_output = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here") with gr.Column(scale=0): analyze_button = gr.Button("Analyze Image") analyze_button.click(fn=analyze_image, inputs=image_input, outputs=analysis_output) with gr.Row(): preload_button = gr.Button("🌸") shuffle_and_load_button = gr.Button("📂") generate_button = gr.Button("🚀 Generate 🚀") with gr.Row(): with gr.Column(): cn_img2_input = gr.Image(label="Image Prompt 2", type="pil", height=256) weight2 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5) with gr.Column(): cn_img3_input = gr.Image(label="Image Prompt 3", type="pil", height=256) weight3 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5) with gr.Column(): cn_img4_input = gr.Image(label="Image Prompt 4", type="pil", height=256) weight4 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5) with gr.Row(): preload_button.click(preload_images, inputs=[cn_img2_input, cn_img3_input, cn_img4_input], outputs=[cn_img2_input, cn_img3_input, cn_img4_input]) shuffle_and_load_button.click(shuffle_and_load_images, inputs=[file_upload], outputs=[cn_img2_input, cn_img3_input, cn_img4_input]) generate_button.click( fn=image_prompt, inputs=[analysis_output, cn_img1_input, cn_img2_input, cn_img3_input, cn_img4_input, weight1, weight2, weight3, weight4], outputs=[status_image, status] ) copy_to_sketch_button.click( fn=lambda: Image.open("output.png") if os.path.exists("output.png") else None, inputs=[], outputs=[cn_img1_input] ) # ⏲️ Update the image every 5 seconds demo.load(create_status_image, every=5, outputs=status_image) demo.launch(server_name="0.0.0.0", server_port=6644, share=True) create_gradio_interface()