Srimanth Agastyaraju
Initial commit
5372b88
raw
history blame
2.89 kB
import streamlit as st
import torch
from huggingface_hub import model_info
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
def inference(prompt, model, n_images, seed):
# Load the model
info = model_info(model)
model_base = info.cardData["base_model"]
pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float32)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.unet.load_attn_procs(model)
# Load the UI components for progress bar and image grid
progress_bar_ui = st.empty()
with progress_bar_ui.container():
progress_bar = st.progress(0, text=f"Performing inference on {n_images} images...")
image_grid_ui = st.empty()
# Run inference
result_images = []
generators = [torch.Generator().manual_seed(i) for i in range(seed, n_images+seed)]
print(f"Inferencing '{prompt}' for {n_images} images.")
for i in range(n_images):
result = pipe(prompt, generator=generators[i], num_inference_steps=25).images[0]
result_images.append(result)
# Start with empty UI elements
progress_bar_ui.empty()
image_grid_ui.empty()
# Update the progress bar
with progress_bar_ui.container():
value = ((i+1)/(len(dataset)))
progress_bar.progress(value, text=f"{i+1} out of {len(dataset)} images processed.")
# Update the image grid
with image_grid_ui.container():
col1, col2, col3 = st.columns(3)
with col1:
for i in range(0, len(result_images), 3):
st.image(result_images[i], caption=f"Image - {i+1}")
with col2:
for i in range(1, len(result_images), 3):
st.image(result_images[i], caption=f"Image - {i+2}")
with col3:
for i in range(2, len(result_images), 3):
st.image(result_images[i], caption=f"Image - {i+3}")
def main():
pass
if __name__ == "__main__":
# --- START UI ---
st.title("Finetune LoRA inference")
with st.form(key='form_parameters'):
prompt = st.text_input("Enter the prompt: ")
model_options = ["asrimanth/person-thumbs-up-plain-lora", "asrimanth/person-thumbs-up-lora", "asrimanth/person-thumbs-up-lora-no-cap"]
current_model = st.selectbox("Choose a model", options=model_options)
col1_inp, col2_inp = st.columns(2)
with col1_inp:
n_images = int(st.number_input("Enter the number of images", min_value=0, max_value=50))
with col2_inp:
seed_input = int(st.number_input("Enter the seed (default=25)", value=25, min_value=0))
submitted = st.form_submit_button("Predict")
if submitted: # The form is submitted
inference(prompt, current_model, n_images, seed_input)