import gradio as gr import requests import io from PIL import Image import json import os import shutil import logging import math from tqdm import tqdm import time from diffusers import DiffusionPipeline def run_lora(lora, prompt, neg_prompt, progress=gr.Progress(track_tqdm=True)): print(f"Inside run_lora, lora: {lora}, prompt: {prompt}, neg_prompt: {neg_prompt}") api_url = f"https://api-inference.huggingface.co/models/{lora}" payload = { "inputs": f"{prompt}", "parameters":{"negative_prompt": "bad art, ugly, watermark, deformed"}, } # Add a print statement to display the API request print(f"API Request: {api_url}") print(f"API Payload: {payload}") error_count = 0 pbar = tqdm(total=None, desc="Loading model") while(True): response = requests.post(api_url, json=payload) if response.status_code == 200: return Image.open(io.BytesIO(response.content)) elif response.status_code == 503: #503 is triggered when the model is doing cold boot. It also gives you a time estimate from when the model is loaded but it is not super precise time.sleep(1) pbar.update(1) elif response.status_code == 500 and error_count < 5: print(response.content) time.sleep(1) error_count += 1 continue else: logging.error(f"API Error: {response.status_code}") raise gr.Error("API Error: Unable to fetch the image.") # Raise a Gradio error here app = gr.Interface( run_lora, [ gr.Textbox(label="LoRA model card", show_label=False, lines=1, max_lines=1, placeholder="Type the LoRA model card here."), gr.Textbox(label="Prompt", show_label=False, placeholder="Type a prompt after selecting a LoRA."), gr.Textbox(label="Negative Prompt", show_label=False, placeholder="Type negative prompt here."), # gr.Button("Run") ], "image", # examples=[ # [2, "cat", ["Japan", "Pakistan"], "park", ["ate", "swam"], True], # [4, "dog", ["Japan"], "zoo", ["ate", "swam"], False], # [10, "bird", ["USA", "Pakistan"], "road", ["ran"], False], # [8, "cat", ["Pakistan"], "zoo", ["ate"], True], # ] ) app.launch()