Spaces:
Runtime error
Runtime error
import requests | |
from requests.adapters import HTTPAdapter | |
from urllib3.util.retry import Retry | |
import json | |
import base64 | |
import time | |
import os | |
import random | |
import io | |
from dotenv import load_dotenv | |
import replicate | |
from PIL import Image, ImageOps | |
from io import BytesIO | |
# Load environment variables | |
load_dotenv() | |
# Constants | |
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN") | |
# Create the tab for the image analyzer | |
def image_analyzer_tab(): | |
# Function to analyze the image | |
def analyze_image(image): | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
analysis = replicate.run( | |
"andreasjansson/blip-2:4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608", | |
input={"image": "data:image/png;base64," + img_str, "prompt": "what's in this picture?"} | |
) | |
return analysis | |
class Config: | |
REPLICATE_API_TOKEN = REPLICATE_API_TOKEN | |
class ImageUtils: | |
def image_to_base64(image): | |
buffered = io.BytesIO() | |
image.save(buffered, format="JPEG") | |
return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
def convert_image_mode(image, mode="RGB"): | |
if image.mode != mode: | |
return image.convert(mode) | |
return image | |
def pad_image(image, padding_color=(255, 255, 255)): | |
width, height = image.size | |
new_width = width + 20 | |
new_height = height + 20 | |
result = Image.new(image.mode, (new_width, new_height), padding_color) | |
result.paste(image, (10, 10)) | |
return result | |
def resize_and_pad_image(image, target_width, target_height, padding_color=(255, 255, 255)): | |
original_width, original_height = image.size | |
aspect_ratio = original_width / original_height | |
target_aspect_ratio = target_width / target_height | |
if aspect_ratio > target_aspect_ratio: | |
new_width = target_width | |
new_height = int(target_width / aspect_ratio) | |
else: | |
new_width = int(target_height * aspect_ratio) | |
new_height = target_height | |
resized_image = image.resize((new_width, new_height), Image.ANTIALIAS) | |
padded_image = Image.new(image.mode, (target_width, target_height), padding_color) | |
padded_image.paste(resized_image, ((target_width - new_width) // 2, (target_height - new_height) // 2)) | |
return padded_image | |
def image_prompt(prompt, cn_img1, cn_img2, cn_img3, cn_img4, weight1, weight2, weight3, weight4): | |
cn_img1 = pad_image(cn_img1) | |
buffered1 = BytesIO() | |
cn_img1.save(buffered1, format="PNG") | |
cn_img1_base64 = base64.b64encode(buffered1.getvalue()).decode('utf-8') | |
buffered2 = BytesIO() | |
cn_img2.save(buffered2, format="PNG") | |
cn_img2_base64 = base64.b64encode(buffered2.getvalue()).decode('utf-8') | |
buffered3 = BytesIO() | |
cn_img3.save(buffered3, format="PNG") | |
cn_img3_base64 = base64.b64encode(buffered3.getvalue()).decode('utf-8') | |
buffered4 = BytesIO() | |
cn_img4.save(buffered4, format="PNG") | |
cn_img4_base64 = base64.b64encode(buffered4.getvalue()).decode('utf-8') | |
# Resize and pad the sketch input image to match the aspect ratio selection | |
aspect_ratio_width, aspect_ratio_height = 1280, 768 | |
uov_input_image = resize_and_pad_image(cn_img1, aspect_ratio_width, aspect_ratio_height) | |
buffered_uov = BytesIO() | |
uov_input_image.save(buffered_uov, format="PNG") | |
uov_input_image_base64 = base64.b64encode(buffered_uov.getvalue()).decode('utf-8') | |
# Call the Replicate API to generate the image | |
fooocus_model = replicate.models.get("vetkastar/fooocus").versions.get("d555a800025fe1c171e386d299b1de635f8d8fc3f1ade06a14faf5154eba50f3") | |
image = replicate.predictions.create(version=fooocus_model, input={ | |
"prompt": prompt, | |
"cn_type1": "PyraCanny", | |
"cn_type2": "ImagePrompt", | |
"cn_type3": "ImagePrompt", | |
"cn_type4": "ImagePrompt", | |
"cn_weight1": weight1, | |
"cn_weight2": weight2, | |
"cn_weight3": weight3, | |
"cn_weight4": weight4, | |
"cn_img1": "data:image/png;base64," + cn_img1_base64, | |
"cn_img2": "data:image/png;base64," + cn_img2_base64, | |
"cn_img3": "data:image/png;base64," + cn_img3_base64, | |
"cn_img4": "data:image/png;base64," + cn_img4_base64, | |
"uov_input_image": "data:image/png;base64," + uov_input_image_base64, | |
"sharpness": 2, | |
"image_seed": -1, | |
"image_number": 1, | |
"guidance_scale": 7, | |
"refiner_switch": 0.5, | |
"negative_prompt": "", | |
"inpaint_strength": 0.5, | |
"style_selections": "Fooocus V2,Fooocus Enhance,Fooocus Sharp", | |
"loras_custom_urls": "", | |
"uov_upscale_value": 0, | |
"use_default_loras": True, | |
"outpaint_selections": "", | |
"outpaint_distance_top": 0, | |
"performance_selection": "Lightning", | |
"outpaint_distance_left": 0, | |
"aspect_ratios_selection": "1280*768", | |
"outpaint_distance_right": 0, | |
"outpaint_distance_bottom": 0, | |
"inpaint_additional_prompt": "", | |
"uov_method": "Vary (Subtle)" | |
}) | |
image.wait() | |
# Fetch the generated image from the output URL | |
response = requests.get(image.output["paths"][0]) | |
img = Image.open(BytesIO(response.content)) | |
with open("output.png", "wb") as f: | |
f.write(response.content) | |
return "output.png", "Job completed successfully using Replicate API." | |
def create_status_image(): | |
if os.path.exists("output.png"): | |
return "output.png" | |
else: | |
return None | |
def preload_images(cn_img2, cn_img3, cn_img4): | |
cn_img2 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400" | |
cn_img3 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400" | |
cn_img4 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400" | |
return cn_img2, cn_img3, cn_img4 | |
def shuffle_and_load_images(files): | |
if not files: | |
return generate_placeholder_image(), generate_placeholder_image(), generate_placeholder_image() | |
else: | |
random.shuffle(files) | |
return files[0], files[1], files[2] | |
def analyze_image(image: Image.Image) -> dict: | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
analysis = replicate.run( | |
"andreasjansson/blip-2:4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608", | |
input={"image": "data:image/png;base64," + img_str, "prompt": "what's in this picture?"} | |
) | |
return analysis | |
def get_prompt_from_image(image: Image.Image) -> str: | |
analysis = analyze_image(image) | |
return analysis.get("describe", "") | |
def generate_prompt(image: Image.Image, current_prompt: str) -> str: | |
return get_prompt_from_image(image) | |
import gradio as gr | |
def create_gradio_interface(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(scale=0): | |
with gr.Tab(label="Sketch"): | |
image_input = cn_img1_input = gr.Image(label="Sketch", type="pil") | |
weight1 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.75) | |
copy_to_sketch_button = gr.Button("Grab Last Output") | |
with gr.Accordion("Upload Project Files", open=False): | |
with gr.Accordion("π", open=False): | |
file_upload = gr.File(file_count="multiple", elem_classes="gradio-column") | |
image_gallery = gr.Gallery(label="Image Gallery", elem_classes="gradio-column") | |
file_upload.change(shuffle_and_load_images, inputs=[file_upload], outputs=[image_gallery]) | |
with gr.Column(scale=2): | |
with gr.Tab(label="Node"): | |
with gr.Accordion("Output"): | |
with gr.Column(): | |
status = gr.Textbox(label="Status") | |
status_image = gr.Image(label="Queue Status", interactive=False) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
analysis_output = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here") | |
with gr.Column(scale=0): | |
analyze_button = gr.Button("Analyze Image") | |
analyze_button.click(fn=analyze_image, inputs=image_input, outputs=analysis_output) | |
with gr.Row(): | |
preload_button = gr.Button("πΈ") | |
shuffle_and_load_button = gr.Button("π") | |
generate_button = gr.Button("π Generate π") | |
with gr.Row(): | |
with gr.Column(): | |
cn_img2_input = gr.Image(label="Image Prompt 2", type="pil", height=256) | |
weight2 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5) | |
with gr.Column(): | |
cn_img3_input = gr.Image(label="Image Prompt 3", type="pil", height=256) | |
weight3 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5) | |
with gr.Column(): | |
cn_img4_input = gr.Image(label="Image Prompt 4", type="pil", height=256) | |
weight4 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5) | |
with gr.Row(): | |
preload_button.click(preload_images, inputs=[cn_img2_input, cn_img3_input, cn_img4_input], outputs=[cn_img2_input, cn_img3_input, cn_img4_input]) | |
shuffle_and_load_button.click(shuffle_and_load_images, inputs=[file_upload], outputs=[cn_img2_input, cn_img3_input, cn_img4_input]) | |
generate_button.click( | |
fn=image_prompt, | |
inputs=[analysis_output, cn_img1_input, cn_img2_input, cn_img3_input, cn_img4_input, weight1, weight2, weight3, weight4], | |
outputs=[status_image, status] | |
) | |
copy_to_sketch_button.click( | |
fn=lambda: Image.open("output.png") if os.path.exists("output.png") else None, | |
inputs=[], | |
outputs=[cn_img1_input] | |
) | |
# β²οΈ Update the image every 5 seconds | |
demo.load(create_status_image, every=5, outputs=status_image) | |
demo.launch(server_name="0.0.0.0", server_port=6644, share=True) | |
create_gradio_interface() | |