import json from collections import deque from dataclasses import dataclass import threading from typing import Optional import gradio as gr import websockets from gradio.processing_utils import decode_base64_to_image, encode_pil_to_base64 from PIL import Image from websockets.sync.client import connect from constants import DESCRIPTION, WS_ADDRESS, LOGO from utils import replace_background from gradio_examples import EXAMPLES MAX_QUEUE_SIZE = 4 @dataclass class GenerationState: prompts: deque responses: deque def get_initial_state() -> GenerationState: return GenerationState( prompts=deque(maxlen=MAX_QUEUE_SIZE), responses=deque(maxlen=MAX_QUEUE_SIZE), ) def load_initial_state(request: gr.Request) -> GenerationState: print("Loading initial state for", request.client.host) print("Total number of active threads", threading.active_count()) return get_initial_state() async def put_to_queue( image: Optional[Image.Image], prompt: str, seed: int, strength: float, state: GenerationState, ): prompts_queue = state.prompts if prompt and image is not None: prompts_queue.append((image, prompt, seed, strength)) return state def send_inference_request(state: GenerationState) -> Image.Image: prompts_queue = state.prompts response_queue = state.responses if len(prompts_queue) == 0: return state image, prompt, seed, strength = prompts_queue.popleft() original_image_size = image.size image = replace_background(image.resize((512, 512))) arguments = { "prompt": prompt, "image_url": encode_pil_to_base64(image), "strength": strength, "negative_prompt": "cartoon, illustration, animation. face. male, female", "seed": seed, "guidance_scale": 1, "num_inference_steps": 4, "sync_mode": 1, "num_images": 1, } connection = connect(WS_ADDRESS) connection.send(json.dumps(arguments)) try: response = json.loads(connection.recv()) except websockets.exceptions.ConnectionClosedOK: print("Connection closed, reconnecting...") # TODO: This is a hacky way to reconnect, but it works for now # Ideally, we should be able to reconnect to the same connection # and not have to create a new one connection = connect(WS_ADDRESS) try: response = json.loads(connection.recv()) except websockets.exceptions.ConnectionClosedOK: print("Connection closed again, aborting...") return state # TODO: If a new connection is created, the response do not contain the images. if "images" in response: response_queue.append((response, original_image_size)) return state def update_output_image(state: GenerationState): image_update = gr.update() inference_time_update = gr.update() response_queue = state.responses if len(response_queue) > 0: response, original_image_size = response_queue.popleft() generated_image = decode_base64_to_image(response["images"][0]["url"]) inference_time = response["timings"]["inference"] image_update = gr.update(value=generated_image.resize(original_image_size)) inference_time_update = gr.update(value=round(inference_time, 4)) return image_update, inference_time_update, state with gr.Blocks(css="style.css", title=f"Realtime Latent Consistency Model") as demo: generation_state = gr.State(get_initial_state()) gr.HTML(f'