import streamlit as st import torch from huggingface_hub import model_info from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler def inference(prompt, model, n_images, seed): # Load the model info = model_info(model) model_base = info.cardData["base_model"] pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float32) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe.unet.load_attn_procs(model) # Load the UI components for progress bar and image grid progress_bar_ui = st.empty() with progress_bar_ui.container(): progress_bar = st.progress(0, text=f"Performing inference on {n_images} images...") image_grid_ui = st.empty() # Run inference result_images = [] generators = [torch.Generator().manual_seed(i) for i in range(seed, n_images+seed)] print(f"Inferencing '{prompt}' for {n_images} images.") for i in range(n_images): result = pipe(prompt, generator=generators[i], num_inference_steps=25).images[0] result_images.append(result) # Start with empty UI elements progress_bar_ui.empty() image_grid_ui.empty() # Update the progress bar with progress_bar_ui.container(): value = ((i+1)/(len(dataset))) progress_bar.progress(value, text=f"{i+1} out of {len(dataset)} images processed.") # Update the image grid with image_grid_ui.container(): col1, col2, col3 = st.columns(3) with col1: for i in range(0, len(result_images), 3): st.image(result_images[i], caption=f"Image - {i+1}") with col2: for i in range(1, len(result_images), 3): st.image(result_images[i], caption=f"Image - {i+2}") with col3: for i in range(2, len(result_images), 3): st.image(result_images[i], caption=f"Image - {i+3}") def main(): pass if __name__ == "__main__": # --- START UI --- st.title("Finetune LoRA inference") with st.form(key='form_parameters'): prompt = st.text_input("Enter the prompt: ") model_options = ["asrimanth/person-thumbs-up-plain-lora", "asrimanth/person-thumbs-up-lora", "asrimanth/person-thumbs-up-lora-no-cap"] current_model = st.selectbox("Choose a model", options=model_options) col1_inp, col2_inp = st.columns(2) with col1_inp: n_images = int(st.number_input("Enter the number of images", min_value=0, max_value=50)) with col2_inp: seed_input = int(st.number_input("Enter the seed (default=25)", value=25, min_value=0)) submitted = st.form_submit_button("Predict") if submitted: # The form is submitted inference(prompt, current_model, n_images, seed_input)