Flo-ImageBlend / app.py
rayochoajr's picture
Update app.py
b6b09f9 verified
raw
history blame
10.6 kB
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import json
import base64
import time
import os
import random
import io
from dotenv import load_dotenv
import replicate
from PIL import Image, ImageOps
from io import BytesIO
# Load environment variables
load_dotenv()
# Constants
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
# Create the tab for the image analyzer
def image_analyzer_tab():
# Function to analyze the image
def analyze_image(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
analysis = replicate.run(
"andreasjansson/blip-2:4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608",
input={"image": "data:image/png;base64," + img_str, "prompt": "what's in this picture?"}
)
return analysis
class Config:
REPLICATE_API_TOKEN = REPLICATE_API_TOKEN
class ImageUtils:
@staticmethod
def image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
@staticmethod
def convert_image_mode(image, mode="RGB"):
if image.mode != mode:
return image.convert(mode)
return image
def pad_image(image, padding_color=(255, 255, 255)):
width, height = image.size
new_width = width + 20
new_height = height + 20
result = Image.new(image.mode, (new_width, new_height), padding_color)
result.paste(image, (10, 10))
return result
def resize_and_pad_image(image, target_width, target_height, padding_color=(255, 255, 255)):
original_width, original_height = image.size
aspect_ratio = original_width / original_height
target_aspect_ratio = target_width / target_height
if aspect_ratio > target_aspect_ratio:
new_width = target_width
new_height = int(target_width / aspect_ratio)
else:
new_width = int(target_height * aspect_ratio)
new_height = target_height
resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
padded_image = Image.new(image.mode, (target_width, target_height), padding_color)
padded_image.paste(resized_image, ((target_width - new_width) // 2, (target_height - new_height) // 2))
return padded_image
def image_prompt(prompt, cn_img1, cn_img2, cn_img3, cn_img4, weight1, weight2, weight3, weight4):
cn_img1 = pad_image(cn_img1)
buffered1 = BytesIO()
cn_img1.save(buffered1, format="PNG")
cn_img1_base64 = base64.b64encode(buffered1.getvalue()).decode('utf-8')
buffered2 = BytesIO()
cn_img2.save(buffered2, format="PNG")
cn_img2_base64 = base64.b64encode(buffered2.getvalue()).decode('utf-8')
buffered3 = BytesIO()
cn_img3.save(buffered3, format="PNG")
cn_img3_base64 = base64.b64encode(buffered3.getvalue()).decode('utf-8')
buffered4 = BytesIO()
cn_img4.save(buffered4, format="PNG")
cn_img4_base64 = base64.b64encode(buffered4.getvalue()).decode('utf-8')
# Resize and pad the sketch input image to match the aspect ratio selection
aspect_ratio_width, aspect_ratio_height = 1280, 768
uov_input_image = resize_and_pad_image(cn_img1, aspect_ratio_width, aspect_ratio_height)
buffered_uov = BytesIO()
uov_input_image.save(buffered_uov, format="PNG")
uov_input_image_base64 = base64.b64encode(buffered_uov.getvalue()).decode('utf-8')
# Call the Replicate API to generate the image
fooocus_model = replicate.models.get("vetkastar/fooocus").versions.get("d555a800025fe1c171e386d299b1de635f8d8fc3f1ade06a14faf5154eba50f3")
image = replicate.predictions.create(version=fooocus_model, input={
"prompt": prompt,
"cn_type1": "PyraCanny",
"cn_type2": "ImagePrompt",
"cn_type3": "ImagePrompt",
"cn_type4": "ImagePrompt",
"cn_weight1": weight1,
"cn_weight2": weight2,
"cn_weight3": weight3,
"cn_weight4": weight4,
"cn_img1": "data:image/png;base64," + cn_img1_base64,
"cn_img2": "data:image/png;base64," + cn_img2_base64,
"cn_img3": "data:image/png;base64," + cn_img3_base64,
"cn_img4": "data:image/png;base64," + cn_img4_base64,
"uov_input_image": "data:image/png;base64," + uov_input_image_base64,
"sharpness": 2,
"image_seed": -1,
"image_number": 1,
"guidance_scale": 7,
"refiner_switch": 0.5,
"negative_prompt": "",
"inpaint_strength": 0.5,
"style_selections": "Fooocus V2,Fooocus Enhance,Fooocus Sharp",
"loras_custom_urls": "",
"uov_upscale_value": 0,
"use_default_loras": True,
"outpaint_selections": "",
"outpaint_distance_top": 0,
"performance_selection": "Lightning",
"outpaint_distance_left": 0,
"aspect_ratios_selection": "1280*768",
"outpaint_distance_right": 0,
"outpaint_distance_bottom": 0,
"inpaint_additional_prompt": "",
"uov_method": "Vary (Subtle)"
})
image.wait()
# Fetch the generated image from the output URL
response = requests.get(image.output["paths"][0])
img = Image.open(BytesIO(response.content))
with open("output.png", "wb") as f:
f.write(response.content)
return "output.png", "Job completed successfully using Replicate API."
def create_status_image():
if os.path.exists("output.png"):
return "output.png"
else:
return None
def preload_images(cn_img2, cn_img3, cn_img4):
cn_img2 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400"
cn_img3 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400"
cn_img4 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400"
return cn_img2, cn_img3, cn_img4
def shuffle_and_load_images(files):
if not files:
return generate_placeholder_image(), generate_placeholder_image(), generate_placeholder_image()
else:
random.shuffle(files)
return files[0], files[1], files[2]
def analyze_image(image: Image.Image) -> dict:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
analysis = replicate.run(
"andreasjansson/blip-2:4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608",
input={"image": "data:image/png;base64," + img_str, "prompt": "what's in this picture?"}
)
return analysis
def get_prompt_from_image(image: Image.Image) -> str:
analysis = analyze_image(image)
return analysis.get("describe", "")
def generate_prompt(image: Image.Image, current_prompt: str) -> str:
return get_prompt_from_image(image)
import gradio as gr
def create_gradio_interface():
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=0):
with gr.Tab(label="Sketch"):
image_input = cn_img1_input = gr.Image(label="Sketch", type="pil")
weight1 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.75)
copy_to_sketch_button = gr.Button("Grab Last Output")
with gr.Accordion("Upload Project Files", open=False):
with gr.Accordion("πŸ“", open=False):
file_upload = gr.File(file_count="multiple", elem_classes="gradio-column")
image_gallery = gr.Gallery(label="Image Gallery", elem_classes="gradio-column")
file_upload.change(shuffle_and_load_images, inputs=[file_upload], outputs=[image_gallery])
with gr.Column(scale=2):
with gr.Tab(label="Node"):
with gr.Accordion("Output"):
with gr.Column():
status = gr.Textbox(label="Status")
status_image = gr.Image(label="Queue Status", interactive=False)
with gr.Row():
with gr.Column(scale=1):
analysis_output = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here")
with gr.Column(scale=0):
analyze_button = gr.Button("Analyze Image")
analyze_button.click(fn=analyze_image, inputs=image_input, outputs=analysis_output)
with gr.Row():
preload_button = gr.Button("🌸")
shuffle_and_load_button = gr.Button("πŸ“‚")
generate_button = gr.Button("πŸš€ Generate πŸš€")
with gr.Row():
with gr.Column():
cn_img2_input = gr.Image(label="Image Prompt 2", type="pil", height=256)
weight2 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5)
with gr.Column():
cn_img3_input = gr.Image(label="Image Prompt 3", type="pil", height=256)
weight3 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5)
with gr.Column():
cn_img4_input = gr.Image(label="Image Prompt 4", type="pil", height=256)
weight4 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5)
with gr.Row():
preload_button.click(preload_images, inputs=[cn_img2_input, cn_img3_input, cn_img4_input], outputs=[cn_img2_input, cn_img3_input, cn_img4_input])
shuffle_and_load_button.click(shuffle_and_load_images, inputs=[file_upload], outputs=[cn_img2_input, cn_img3_input, cn_img4_input])
generate_button.click(
fn=image_prompt,
inputs=[analysis_output, cn_img1_input, cn_img2_input, cn_img3_input, cn_img4_input, weight1, weight2, weight3, weight4],
outputs=[status_image, status]
)
copy_to_sketch_button.click(
fn=lambda: Image.open("output.png") if os.path.exists("output.png") else None,
inputs=[],
outputs=[cn_img1_input]
)
# ⏲️ Update the image every 5 seconds
demo.load(create_status_image, every=5, outputs=status_image)
demo.launch(server_name="0.0.0.0", server_port=6644, share=True)
create_gradio_interface()