Spaces:
Running
Running
# HF Spaces | |
import gradio as gr | |
import asyncio | |
import fal_client | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
import time | |
import base64 | |
import json | |
# Local Dev | |
import os | |
from dotenv import load_dotenv | |
load_dotenv() | |
FAL_KEY = os.getenv("FAL_KEY") | |
with open("examples/examples.json") as f: | |
examples = json.load(f) | |
# IC Light, Replace Background | |
async def submit_ic_light_bria(image_data, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color): | |
if not lightsource_start_color.startswith("#"): | |
lightsource_start_color = f"#{lightsource_start_color}" | |
if not lightsource_end_color.startswith("#"): | |
lightsource_end_color = f"#{lightsource_end_color}" | |
retries = 3 | |
for attempt in range(retries): | |
try: | |
handler = await fal_client.submit_async( | |
"comfy/martintmv-git/ic-light-bria", | |
arguments={ | |
"loadimage_1": image_data, | |
"Positive Prompt": positive_prompt, | |
"Negative Prompt": negative_prompt, | |
"lightsource_start_color": lightsource_start_color, | |
"lightsource_end_color": lightsource_end_color | |
}, | |
credentials={"fal_key": FAL_KEY} | |
) | |
log_index = 0 | |
output_logs = [] | |
async for event in handler.iter_events(with_logs=True): | |
if isinstance(event, fal_client.InProgress): | |
if event.logs: | |
new_logs = event.logs[log_index:] | |
for log in new_logs: | |
output_logs.append(log["message"]) | |
log_index = len(event.logs) | |
result = await handler.get() | |
output_logs.append("Processing completed") | |
# Debug log | |
print("API Result:", result) | |
# Extract the image URL | |
image_url = result["outputs"]["9"]["images"][0]["url"] | |
response = requests.get(image_url) | |
image = Image.open(BytesIO(response.content)) | |
return output_logs, image | |
except Exception as e: | |
print(f"Attempt {attempt + 1} failed: {e}") | |
if attempt < retries - 1: | |
time.sleep(2) # HTTP req retry mechanism | |
else: | |
return [f"Error: {str(e)}"], None | |
# SDXL, Depth Anything, Replace Background | |
async def submit_sdxl_rembg(image_data, positive_prompt, negative_prompt): | |
retries = 3 | |
for attempt in range(retries): | |
try: | |
handler = await fal_client.submit_async( | |
"comfy/martintmv-git/sdxl-depthanything-rembg", | |
arguments={ | |
"loadimage_1": image_data, | |
"Positive prompt": positive_prompt, | |
"Negative prompt": negative_prompt | |
}, | |
credentials={"fal_key": FAL_KEY} | |
) | |
log_index = 0 | |
output_logs = [] | |
async for event in handler.iter_events(with_logs=True): | |
if isinstance(event, fal_client.InProgress): | |
if event.logs: | |
new_logs = event.logs[log_index:] | |
for log in new_logs: | |
output_logs.append(log["message"]) | |
log_index = len(event.logs) | |
result = await handler.get() | |
output_logs.append("Processing completed") | |
# Debug log | |
print("API Result:", result) | |
# Extract the image URL | |
image_url = result["outputs"]["9"]["images"][0]["url"] | |
response = requests.get(image_url) | |
image = Image.open(BytesIO(response.content)) | |
return output_logs, image | |
except Exception as e: | |
print(f"Attempt {attempt + 1} failed: {e}") | |
if attempt < retries - 1: | |
time.sleep(2) # HTTP req retry mechanism | |
else: | |
return [f"Error: {str(e)}"], None | |
# SV3D, AnimateDiff | |
async def submit_sv3d(image_data, fps, loop_frames_count, gif_loop): | |
retries = 3 | |
for attempt in range(retries): | |
try: | |
handler = await fal_client.submit_async( | |
"comfy/martintmv-git/sv3d", | |
arguments={ | |
"loadimage_1": image_data, | |
"FPS (bigger number = more speed)": fps, | |
"Loop Frames Count": loop_frames_count, | |
"GIF Loop": gif_loop | |
}, | |
credentials={"fal_key": FAL_KEY} | |
) | |
log_index = 0 | |
output_logs = [] | |
async for event in handler.iter_events(with_logs=True): | |
if isinstance(event, fal_client.InProgress): | |
if event.logs: | |
new_logs = event.logs[log_index:] | |
for log in new_logs: | |
output_logs.append(log["message"]) | |
log_index = len(event.logs) | |
result = await handler.get() | |
output_logs.append("Processing completed") | |
print("API Result:", result) | |
gif_url = result["outputs"]["15"]["gifs"][0]["url"] | |
return output_logs, gif_url | |
except Exception as e: | |
print(f"Attempt {attempt + 1} failed: {e}") | |
if attempt < retries - 1: | |
time.sleep(2) | |
else: | |
return [f"Error: {str(e)}"], None | |
def convert_image_to_base64(image): | |
buffered = BytesIO() | |
image.save(buffered, format="PNG") | |
return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode() | |
def submit_sync_ic_light_bria(image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color): | |
image_data = convert_image_to_base64(Image.open(image_upload)) | |
return asyncio.run(submit_ic_light_bria(image_data, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color)) | |
def submit_sync_sdxl_rembg(image_upload, positive_prompt, negative_prompt): | |
image_data = convert_image_to_base64(Image.open(image_upload)) | |
return asyncio.run(submit_sdxl_rembg(image_data, positive_prompt, negative_prompt)) | |
def submit_sync_sv3d(image_upload, fps, loop_frames_count, gif_loop): | |
image_data = convert_image_to_base64(Image.open(image_upload)) | |
return asyncio.run(submit_sv3d(image_data, fps, loop_frames_count, gif_loop)) | |
def run_gradio_app(): | |
with gr.Blocks() as demo: | |
gr.Markdown("# Comfy Anything π") | |
gr.Markdown("### Community ComfyUI workflows running on [fal.ai](https://fal.ai)") | |
gr.Markdown("#### Comfy Anything on [GitHub](https://github.com/martintomov/comfy-anything)") | |
gr.Markdown("#### Support the project:") | |
gr.Markdown("𧑠Bitcoin address - bc1qs3q0rjpr9fvn9knjy5aktfr8w5duvvjpezkgt9") | |
gr.Markdown("π Want to run your own workflow? Import it into [fal.ai](https://fal.ai)'s ComfyUI and get a Python API endpoint.") | |
# API Key input section | |
with gr.Row(): | |
api_key_input = gr.Textbox(label="Enter your fal.ai API Key", type="password") | |
api_key_submit = gr.Button("Submit API Key") | |
# Main app content (initially hidden) | |
with gr.Row(visible=False) as main_content: | |
with gr.Column(scale=1): | |
workflow = gr.Dropdown(label="Select Workflow", choices=["IC Light, Replace Background", "SDXL, Depth Anything, Replace Background", "SV3D"], value="IC Light, Replace Background") | |
image_upload = gr.Image(label="Upload Image", type="filepath") | |
positive_prompt = gr.Textbox(label="Positive Prompt", visible=True) | |
negative_prompt = gr.Textbox(label="Negative Prompt", value="Watermark", visible=True) | |
lightsource_start_color = gr.ColorPicker(label="Start Color", value="#FFFFFF", visible=True) | |
lightsource_end_color = gr.ColorPicker(label="End Color", value="#000000", visible=True) | |
fps = gr.Slider(label="FPS (bigger number = more speed)", minimum=1, maximum=60, step=1, value=8, visible=False) | |
loop_frames_count = gr.Slider(label="Loop Frames Count", minimum=1, maximum=100, step=1, value=30, visible=False) | |
gif_loop = gr.Checkbox(label="GIF Loop", value=True, visible=False) | |
submit_btn = gr.Button("Submit") | |
with gr.Column(scale=2): | |
output_logs = gr.Textbox(label="Logs") | |
output_result = gr.Image(label="Result") | |
def validate_api_key(api_key): | |
global FAL_KEY | |
FAL_KEY = api_key | |
return gr.Row(visible=True) | |
api_key_submit.click( | |
fn=validate_api_key, | |
inputs=api_key_input, | |
outputs=main_content | |
) | |
def update_ui(workflow): | |
if workflow == "IC Light, Replace Background": | |
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)] | |
elif workflow == "SDXL, Depth Anything, Replace Background": | |
return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)] | |
elif workflow == "SV3D": | |
return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)] | |
workflow.change(fn=update_ui, inputs=workflow, outputs=[positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color, fps, loop_frames_count, gif_loop]) | |
def on_submit(image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color, fps, loop_frames_count, gif_loop, workflow): | |
if workflow == "IC Light, Replace Background": | |
logs, image = submit_sync_ic_light_bria(image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color) | |
return logs, image | |
elif workflow == "SDXL, Depth Anything, Replace Background": | |
logs, image = submit_sync_sdxl_rembg(image_upload, positive_prompt, negative_prompt) | |
return logs, image | |
elif workflow == "SV3D": | |
logs, gif_url = submit_sync_sv3d(image_upload, fps, loop_frames_count, gif_loop) | |
return logs, gif_url | |
submit_btn.click( | |
fn=on_submit, | |
inputs=[image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color, fps, loop_frames_count, gif_loop, workflow], | |
outputs=[output_logs, output_result] | |
) | |
gr.Examples( | |
examples=[ | |
[example['input_image'], example['positive_prompt'], example['negative_prompt'], example.get('lightsource_start_color', "#FFFFFF"), example.get('lightsource_end_color', "#000000"), example.get('fps', 8), example.get('loop_frames_count', 30), example.get('gif_loop', True), example['workflow']] | |
for example in examples | |
], | |
inputs=[image_upload, positive_prompt, negative_prompt, lightsource_start_color, lightsource_end_color, fps, loop_frames_count, gif_loop, workflow], | |
outputs=[output_logs, output_result], | |
fn=on_submit, | |
cache_examples=True | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
run_gradio_app() |