File size: 3,381 Bytes
5372b88
00deb53
5372b88
 
 
7c6ffc8
5372b88
 
 
00deb53
5372b88
 
 
 
 
 
 
 
 
 
 
 
2aad33f
5372b88
 
 
7c6ffc8
5372b88
 
 
 
 
 
 
 
3827896
 
5372b88
 
 
 
 
 
 
 
 
7c6ffc8
5372b88
 
7c6ffc8
5372b88
 
 
 
 
 
 
7c6ffc8
 
 
 
 
 
5372b88
2c38794
 
7c6ffc8
 
5372b88
7c6ffc8
5372b88
42a9d1a
7c6ffc8
5372b88
 
 
 
7c6ffc8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import streamlit as st
import torch
from huggingface_hub import model_info
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

def inference(prompt, model, n_images, seed, n_inference_steps):
    # Load the model
    info = model_info(model)
    model_base = info.cardData["base_model"]
    pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float32)
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)

    pipe.unet.load_attn_procs(model)

    # Load the UI components for progress bar and image grid
    progress_bar_ui = st.empty()
    with progress_bar_ui.container():
        progress_bar = st.progress(0, text=f"Performing inference on {n_images} images...")
    image_grid_ui = st.empty()

    # Run inference
    result_images = []
    generators = [torch.Generator().manual_seed(i) for i in range(seed, n_images+seed)]
    print(f"Inferencing '{prompt}' for {n_images} images.")

    for i in range(n_images):
        result = pipe(prompt, generator=generators[i], num_inference_steps=n_inference_steps).images[0]
        result_images.append(result)

        # Start with empty UI elements
        progress_bar_ui.empty()
        image_grid_ui.empty()

        # Update the progress bar
        with progress_bar_ui.container():
            value = ((i+1)/n_images)
            progress_bar.progress(value, text=f"{i+1} out of {n_images} images processed.")

        # Update the image grid
        with image_grid_ui.container():
            col1, col2, col3 = st.columns(3)
            with col1:
                for i in range(0, len(result_images), 3):
                    st.image(result_images[i], caption=f"Image - {i+1}")
            with col2:
                for i in range(1, len(result_images), 3):
                    st.image(result_images[i], caption=f"Image - {i+1}")
            with col3:
                for i in range(2, len(result_images), 3):
                    st.image(result_images[i], caption=f"Image - {i+1}")


if __name__ == "__main__":
    # --- START UI ---
    st.title("Finetune LoRA inference")

    with st.form(key='form_parameters'):
        model_options = [
            "asrimanth/person-thumbs-up-plain-lora : Tom Cruise thumbs up",
            "asrimanth/srimanth-thumbs-up-lora-plain : srimanth thumbs up",
            "asrimanth/person-thumbs-up-lora : <tom_cruise> #thumbsup",
            "asrimanth/person-thumbs-up-lora-no-cap : <tom_cruise> #thumbsup",
        ]
        current_model = st.selectbox("Choose a model", options=model_options)
        model, _ = current_model.split(" : ")
        prompt = st.text_input("Enter the prompt: (sample prompts in dropdown)")
        current_model = current_model.split(" : ")[0]
        col1_inp, col2_inp, col_3_inp = st.columns(3)
        with col1_inp:
            n_images = int(st.number_input("Enter the number of images", value=3, min_value=0, max_value=50))
        with col2_inp:
            n_inference_steps = int(st.number_input("Enter the number of inference steps", value=5, min_value=0))
        with col_3_inp:
            seed_input = int(st.number_input("Enter the seed (default=25)", value=25, min_value=0))
        submitted = st.form_submit_button("Predict")

    if submitted: # The form is submitted
        inference(prompt, model, n_images, seed_input, n_inference_steps)