import streamlit as st from datasets import load_dataset import numpy as np import matplotlib.pyplot as plt import pandas as pd import seaborn as sns # Load dataset from Hugging Face with caching def load_hf_dataset(): repo_id = "louiecerv/cats_dogs_dataset" return load_dataset(repo_id) dataset = load_hf_dataset() split = "train" data = dataset[split] label_names = dataset[split].features["label"].names @st.cache_data def get_label_counts(): return pd.Series(data["label"]).value_counts().sort_index() def display_images(images, labels, label_names, cols=5): """Display images in a grid.""" rows = (len(images) + cols - 1) // cols fig, axes = plt.subplots(rows, cols, figsize=(2 * cols, 2 * rows)) axes = axes.flatten() if isinstance(axes, np.ndarray) else [axes] for i, ax in enumerate(axes): if i < len(images): ax.imshow(images[i]) ax.set_title(label_names[labels[i]]) ax.axis("off") else: ax.axis("off") st.pyplot(fig) def main(): st.title("Image Dataset Explorer") st.subheader(f"Displaying images from the {split} set") # Show Initial Images if st.button("Show First 25 Images"): with st.spinner("Loading images..."): images = [data[i]["image"] for i in range(25)] labels = [data[i]["label"] for i in range(25)] display_images(images, labels, label_names) st.sidebar.title("Explore the Dataset") # Random Image Viewer if st.sidebar.button("Show Random Images"): with st.spinner("Loading images..."): rand_indices = [int(i) for i in np.random.choice(len(data), 25, replace=False)] images = [data[i]["image"] for i in rand_indices] labels = [data[i]["label"] for i in rand_indices] display_images(images, labels, label_names) # Class Distribution if st.sidebar.button("Show Class Distribution"): label_counts = get_label_counts() label_names_map = {i: name for i, name in enumerate(label_names)} fig, ax = plt.subplots(figsize=(8, 4)) sns.barplot(x=[label_names_map[i] for i in label_counts.index], y=label_counts.values, ax=ax) ax.set_title("Class Distribution") ax.set_ylabel("Count") ax.set_xlabel("Class") st.pyplot(fig) # Filter by class label selected_label = st.sidebar.selectbox("Filter by Label", label_names) if st.sidebar.button("Show Filtered Images"): with st.spinner("Loading images..."): filtered_indices = [int(i) for i in np.where(np.array(data["label"]) == label_names.index(selected_label))[0]] if len(filtered_indices) > 0: images = [data[i]["image"] for i in filtered_indices] labels = [data[i]["label"] for i in filtered_indices] display_images(images, labels, label_names) else: st.write(f"No images found for label: {selected_label}") if __name__ == "__main__": main()