import torch from diffusers import AutoPipelineForText2Image from PIL import Image import os import streamlit as st def load_images_from_folder(folder): images = [] for filename in os.listdir(folder): if filename.endswith(".jpg") or filename.endswith(".png"): images.append(os.path.join(folder, filename)) return images # Main function def BGIMAGES(): st.title("Background Images") st.header('Create a template', divider='orange') prompt = st.text_input("Prompt for a Background") if prompt: # Load the pipeline with st.spinner("Loading model..."): pipeline = AutoPipelineForText2Image.from_pretrained( "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16, variant="fp16" ).to("cuda") # Set the generator seed generator = torch.Generator("cuda").manual_seed(31) # Generate the image with st.spinner("Generating image..."): image_prompt = f"{prompt}, muted colors, detailed, 8k" image = pipeline(image_prompt, generator=generator, height=512, width=768).images[0] # Save the image output_dir = "Assets/bgImages" os.makedirs(output_dir, exist_ok=True) image_path = os.path.join(output_dir, f"{prompt}.png") image.save(image_path) # Display the image st.image(image, caption=f"Generated image for: {prompt}", width = 300) else: st.write("Please enter prompt for background to generate an image.") # Path to the folder containing images image_folder = "Assets/bgImages" # Load images from the folder images = load_images_from_folder(image_folder) # Display images and information in a grid layout with three images per row col_width = 350 # Adjust this value according to your preference num_images = len(images) images_per_row = 3 num_rows = (num_images + images_per_row - 1) // images_per_row st.header('Available Templates', divider='red') # Display images and information in a grid layout with three images per row for i in range(num_rows): cols = st.columns(images_per_row) for j in range(images_per_row): idx = i * images_per_row + j if idx < num_images: image_path = images[idx] image_name = os.path.splitext(os.path.basename(image_path))[0] # Get the file name without extension cols[j].image(image_path, width=col_width) cols[j].write(image_name) if __name__ == "__main__": BGIMAGES()