File size: 2,630 Bytes
b8a701a
dddb041
 
 
 
 
 
b8a701a
 
 
 
dddb041
 
 
 
 
0347515
223f8b4
 
ccde57b
 
 
 
 
 
 
 
 
 
b8a701a
 
 
0347515
b8a701a
0347515
b8a701a
0347515
 
 
 
 
 
 
 
dddb041
 
c49f67b
0347515
c49f67b
0347515
c49f67b
06fd0d1
bf71179
8669462
0347515
538d554
057bc07
 
 
 
 
 
 
 
a200bb2
 
1b088a5
057bc07
c133494
057bc07
c49f67b
 
057bc07
a200bb2
057bc07
c49f67b
057bc07
 
c03b3ba
057bc07
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Existing imports
import gradio as gr
import requests
import io
from PIL import Image
import json
import os
import logging

# Initialize logging
logging.basicConfig(level=logging.DEBUG)

# Load LoRAs from JSON
with open('loras.json', 'r') as f:
    loras = json.load(f)

# Define the function to run when the button is clicked
def update_selection(selected_state: gr.SelectData):
    logging.debug(f"Inside update_selection, selected_state: {selected_state}")
    # ... your existing code
    logging.debug(f"Updated selected_state: {selected_state}")
    return (
        updated_text,
        instance_prompt,
        gr.update(placeholder=new_placeholder),
        selected_state,
        use_with_diffusers,
        use_with_uis,
    )

def run_lora(prompt, selected_state, progress=gr.Progress(track_tqdm=True)):
    logging.debug(f"Inside run_lora, selected_state: {selected_state}")
    if not selected_state:
        logging.error("selected_state is None or empty.")
        raise gr.Error("You must select a LoRA")

    selected_lora_index = selected_state['index']
    selected_lora = loras[selected_lora_index]
    api_url = f"https://api-inference.huggingface.co/models/{selected_lora['repo']}"
    trigger_word = selected_lora["trigger_word"]
    token = os.getenv("API_TOKEN")
    payload = {"inputs": f"{prompt} {trigger_word}"}
    
    # API call
    headers = {"Authorization": f"Bearer {token}"}
    response = requests.post(api_url, headers=headers, json=payload)
    if response.status_code == 200:
        return Image.open(io.BytesIO(response.content))
    else:
        return "API Error"

# Gradio UI
with gr.Blocks(css="custom.css") as app:
    title = gr.HTML("<h1>LoRA the Explorer</h1>")
    selected_state = gr.State()  # Initialize with empty state
    with gr.Row():
        gallery = gr.Gallery(
            [(item["image"], item["title"]) for item in loras],
            label="LoRA Gallery",
            allow_preview=False,
            columns=3
        )
        with gr.Column():
            prompt_title = gr.Markdown("### Click on a LoRA in the gallery to select it")
            with gr.Row():
                prompt = gr.Textbox(label="Prompt", show_label=False, lines=1, max_lines=1, placeholder="Type a prompt after selecting a LoRA")
                button = gr.Button("Run")
            result = gr.Image(interactive=False, label="Generated Image")

    gallery.select(
        update_selection,
        outputs=[selected_state]
    )
    button.click(
        fn=run_lora,
        inputs=[prompt, selected_state],
        outputs=[result]
    )

app.queue(max_size=20)
app.launch()