Spaces:
Runtime error
Runtime error
File size: 10,645 Bytes
2b799f2 3a88d20 2b799f2 1479d8c e6f2d3f 3a88d20 74c4534 3a88d20 79b7222 3a88d20 79b7222 3a88d20 86bb1a1 3a88d20 2b799f2 3a88d20 b6b09f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import json
import base64
import time
import os
import random
import io
from dotenv import load_dotenv
import replicate
from PIL import Image, ImageOps
from io import BytesIO
# Load environment variables
load_dotenv()
# Constants
REPLICATE_API_TOKEN = os.getenv("REPLICATE_API_TOKEN")
# Create the tab for the image analyzer
def image_analyzer_tab():
# Function to analyze the image
def analyze_image(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
analysis = replicate.run(
"andreasjansson/blip-2:4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608",
input={"image": "data:image/png;base64," + img_str, "prompt": "what's in this picture?"}
)
return analysis
class Config:
REPLICATE_API_TOKEN = REPLICATE_API_TOKEN
class ImageUtils:
@staticmethod
def image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="JPEG")
return base64.b64encode(buffered.getvalue()).decode('utf-8')
@staticmethod
def convert_image_mode(image, mode="RGB"):
if image.mode != mode:
return image.convert(mode)
return image
def pad_image(image, padding_color=(255, 255, 255)):
width, height = image.size
new_width = width + 20
new_height = height + 20
result = Image.new(image.mode, (new_width, new_height), padding_color)
result.paste(image, (10, 10))
return result
def resize_and_pad_image(image, target_width, target_height, padding_color=(255, 255, 255)):
original_width, original_height = image.size
aspect_ratio = original_width / original_height
target_aspect_ratio = target_width / target_height
if aspect_ratio > target_aspect_ratio:
new_width = target_width
new_height = int(target_width / aspect_ratio)
else:
new_width = int(target_height * aspect_ratio)
new_height = target_height
resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
padded_image = Image.new(image.mode, (target_width, target_height), padding_color)
padded_image.paste(resized_image, ((target_width - new_width) // 2, (target_height - new_height) // 2))
return padded_image
def image_prompt(prompt, cn_img1, cn_img2, cn_img3, cn_img4, weight1, weight2, weight3, weight4):
cn_img1 = pad_image(cn_img1)
buffered1 = BytesIO()
cn_img1.save(buffered1, format="PNG")
cn_img1_base64 = base64.b64encode(buffered1.getvalue()).decode('utf-8')
buffered2 = BytesIO()
cn_img2.save(buffered2, format="PNG")
cn_img2_base64 = base64.b64encode(buffered2.getvalue()).decode('utf-8')
buffered3 = BytesIO()
cn_img3.save(buffered3, format="PNG")
cn_img3_base64 = base64.b64encode(buffered3.getvalue()).decode('utf-8')
buffered4 = BytesIO()
cn_img4.save(buffered4, format="PNG")
cn_img4_base64 = base64.b64encode(buffered4.getvalue()).decode('utf-8')
# Resize and pad the sketch input image to match the aspect ratio selection
aspect_ratio_width, aspect_ratio_height = 1280, 768
uov_input_image = resize_and_pad_image(cn_img1, aspect_ratio_width, aspect_ratio_height)
buffered_uov = BytesIO()
uov_input_image.save(buffered_uov, format="PNG")
uov_input_image_base64 = base64.b64encode(buffered_uov.getvalue()).decode('utf-8')
# Call the Replicate API to generate the image
fooocus_model = replicate.models.get("vetkastar/fooocus").versions.get("d555a800025fe1c171e386d299b1de635f8d8fc3f1ade06a14faf5154eba50f3")
image = replicate.predictions.create(version=fooocus_model, input={
"prompt": prompt,
"cn_type1": "PyraCanny",
"cn_type2": "ImagePrompt",
"cn_type3": "ImagePrompt",
"cn_type4": "ImagePrompt",
"cn_weight1": weight1,
"cn_weight2": weight2,
"cn_weight3": weight3,
"cn_weight4": weight4,
"cn_img1": "data:image/png;base64," + cn_img1_base64,
"cn_img2": "data:image/png;base64," + cn_img2_base64,
"cn_img3": "data:image/png;base64," + cn_img3_base64,
"cn_img4": "data:image/png;base64," + cn_img4_base64,
"uov_input_image": "data:image/png;base64," + uov_input_image_base64,
"sharpness": 2,
"image_seed": -1,
"image_number": 1,
"guidance_scale": 7,
"refiner_switch": 0.5,
"negative_prompt": "",
"inpaint_strength": 0.5,
"style_selections": "Fooocus V2,Fooocus Enhance,Fooocus Sharp",
"loras_custom_urls": "",
"uov_upscale_value": 0,
"use_default_loras": True,
"outpaint_selections": "",
"outpaint_distance_top": 0,
"performance_selection": "Lightning",
"outpaint_distance_left": 0,
"aspect_ratios_selection": "1280*768",
"outpaint_distance_right": 0,
"outpaint_distance_bottom": 0,
"inpaint_additional_prompt": "",
"uov_method": "Vary (Subtle)"
})
image.wait()
# Fetch the generated image from the output URL
response = requests.get(image.output["paths"][0])
img = Image.open(BytesIO(response.content))
with open("output.png", "wb") as f:
f.write(response.content)
return "output.png", "Job completed successfully using Replicate API."
def create_status_image():
if os.path.exists("output.png"):
return "output.png"
else:
return None
def preload_images(cn_img2, cn_img3, cn_img4):
cn_img2 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400"
cn_img3 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400"
cn_img4 = f"https://picsum.photos/seed/{random.randint(0, 1000)}/400/400"
return cn_img2, cn_img3, cn_img4
def shuffle_and_load_images(files):
if not files:
return generate_placeholder_image(), generate_placeholder_image(), generate_placeholder_image()
else:
random.shuffle(files)
return files[0], files[1], files[2]
def analyze_image(image: Image.Image) -> dict:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
analysis = replicate.run(
"andreasjansson/blip-2:4b32258c42e9efd4288bb9910bc532a69727f9acd26aa08e175713a0a857a608",
input={"image": "data:image/png;base64," + img_str, "prompt": "what's in this picture?"}
)
return analysis
def get_prompt_from_image(image: Image.Image) -> str:
analysis = analyze_image(image)
return analysis.get("describe", "")
def generate_prompt(image: Image.Image, current_prompt: str) -> str:
return get_prompt_from_image(image)
import gradio as gr
def create_gradio_interface():
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=0):
with gr.Tab(label="Sketch"):
image_input = cn_img1_input = gr.Image(label="Sketch", type="pil")
weight1 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.75)
copy_to_sketch_button = gr.Button("Grab Last Output")
with gr.Accordion("Upload Project Files", open=False):
with gr.Accordion("π", open=False):
file_upload = gr.File(file_count="multiple", elem_classes="gradio-column")
image_gallery = gr.Gallery(label="Image Gallery", elem_classes="gradio-column")
file_upload.change(shuffle_and_load_images, inputs=[file_upload], outputs=[image_gallery])
with gr.Column(scale=2):
with gr.Tab(label="Node"):
with gr.Accordion("Output"):
with gr.Column():
status = gr.Textbox(label="Status")
status_image = gr.Image(label="Queue Status", interactive=False)
with gr.Row():
with gr.Column(scale=1):
analysis_output = gr.Textbox(label="Prompt", placeholder="Enter your text prompt here")
with gr.Column(scale=0):
analyze_button = gr.Button("Analyze Image")
analyze_button.click(fn=analyze_image, inputs=image_input, outputs=analysis_output)
with gr.Row():
preload_button = gr.Button("πΈ")
shuffle_and_load_button = gr.Button("π")
generate_button = gr.Button("π Generate π")
with gr.Row():
with gr.Column():
cn_img2_input = gr.Image(label="Image Prompt 2", type="pil", height=256)
weight2 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5)
with gr.Column():
cn_img3_input = gr.Image(label="Image Prompt 3", type="pil", height=256)
weight3 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5)
with gr.Column():
cn_img4_input = gr.Image(label="Image Prompt 4", type="pil", height=256)
weight4 = gr.Slider(minimum=0, maximum=1, step=0.1, value=0.5)
with gr.Row():
preload_button.click(preload_images, inputs=[cn_img2_input, cn_img3_input, cn_img4_input], outputs=[cn_img2_input, cn_img3_input, cn_img4_input])
shuffle_and_load_button.click(shuffle_and_load_images, inputs=[file_upload], outputs=[cn_img2_input, cn_img3_input, cn_img4_input])
generate_button.click(
fn=image_prompt,
inputs=[analysis_output, cn_img1_input, cn_img2_input, cn_img3_input, cn_img4_input, weight1, weight2, weight3, weight4],
outputs=[status_image, status]
)
copy_to_sketch_button.click(
fn=lambda: Image.open("output.png") if os.path.exists("output.png") else None,
inputs=[],
outputs=[cn_img1_input]
)
# β²οΈ Update the image every 5 seconds
demo.load(create_status_image, every=5, outputs=status_image)
demo.launch(server_name="0.0.0.0", server_port=6644, share=True)
create_gradio_interface()
|