Spaces:
Runtime error
Runtime error
import websocket # websocket-client | |
import uuid | |
import json | |
import urllib.request | |
import urllib.parse | |
import random | |
from PIL import Image | |
import io | |
import base64 | |
import io | |
import os | |
import gradio as gr | |
server_address = os.environ.get("URL_API") | |
json_data=os.environ.get("JSON_API") | |
client_id = str(uuid.uuid4()) | |
def queue_prompt(prompt): | |
p = {"prompt": prompt, "client_id": client_id} | |
data = json.dumps(p, indent=4).encode('utf-8') # Prettify JSON for print | |
req = urllib.request.Request(f"http://{server_address}/prompt", data=data) | |
return json.loads(urllib.request.urlopen(req).read()) | |
def get_image(filename, subfolder, folder_type): | |
data = {"filename": filename, "subfolder": subfolder, "type": folder_type} | |
url_values = urllib.parse.urlencode(data) | |
with urllib.request.urlopen(f"http://{server_address}/view?{url_values}") as response: | |
return response.read() | |
def get_history(prompt_id): | |
with urllib.request.urlopen(f"http://{server_address}/history/{prompt_id}") as response: | |
return json.loads(response.read()) | |
def get_images(ws, prompt): | |
progress=gr.Progress(track_tqdm=True) | |
prompt_id = queue_prompt(prompt)['prompt_id'] | |
output_images = {} | |
last_reported_percentage = 0 | |
while True: | |
out = ws.recv() | |
if isinstance(out, str): | |
message = json.loads(out) | |
if message['type'] == 'progress': | |
data = message['data'] | |
current_progress = data['value'] | |
max_progress = data['max'] | |
percentage = int((current_progress / max_progress) * 100) | |
if percentage >= last_reported_percentage + 10: | |
last_reported_percentage = percentage | |
elif message['type'] == 'executing': | |
data = message['data'] | |
if data['node'] is None and data['prompt_id'] == prompt_id: | |
break # Execution is done | |
else: | |
continue # Previews are binary data | |
history = get_history(prompt_id)[prompt_id] | |
for o in history['outputs']: | |
for node_id in history['outputs']: | |
node_output = history['outputs'][node_id] | |
if 'images' in node_output: | |
images_output = [] | |
for image in node_output['images']: | |
image_data = get_image(image['filename'], image['subfolder'], image['type']) | |
images_output.append(image_data) | |
output_images[node_id] = images_output | |
return output_images | |
def pil_to_base64(image): | |
buffer = io.BytesIO() | |
image.save(buffer, format="PNG") | |
base64_string=base64.b64encode(buffer.getvalue()).decode("utf-8") | |
return f"data:image/png;base64,{base64_string}" | |
def generate_images(positive_prompt, image): | |
ws = websocket.WebSocket() | |
ws_url = f"ws://{server_address}/ws?clientId={client_id}" | |
ws.connect(ws_url) | |
data = json.loads(json_data) | |
data["49"]["inputs"]["text"] = positive_prompt | |
if image: | |
data["90"]["inputs"]["images"]["base64"] = [pil_to_base64(image)] | |
else: | |
data.pop("90", None) | |
data.pop("68", None) | |
data["62"]["inputs"]["images"] = ["61",0] | |
seed = random.randint(1, 1000000000) | |
data["47"]["inputs"]["noise_seed"] = seed | |
images = get_images(ws, data) | |
ws.close() | |
for node_id in images: | |
for image_data in images[node_id]: | |
image = Image.open(io.BytesIO(image_data)) | |
return image | |