import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from torchvision import models import pandas as pd from datasets import load_dataset from torch.utils.data import DataLoader, Dataset from sklearn.preprocessing import LabelEncoder # Load dataset dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]') # Text preprocessing function with None handling def preprocess_text(text, max_length=100): # Handle None or empty text if text is None or not isinstance(text, str): text = "" # Convert text to lowercase and split into words words = text.lower().split() # Truncate or pad to max_length if len(words) > max_length: words = words[:max_length] else: words.extend([''] * (max_length - len(words))) return words class CustomDataset(Dataset): def __init__(self, dataset): self.dataset = dataset self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # Filter out None values from Model column valid_indices = [i for i, model in enumerate(dataset['Model']) if model is not None] self.valid_dataset = dataset.select(valid_indices) self.label_encoder = LabelEncoder() self.labels = self.label_encoder.fit_transform(self.valid_dataset['Model']) # Create vocabulary from all prompts self.vocab = set() for item in self.valid_dataset['prompt']: try: self.vocab.update(preprocess_text(item)) except Exception as e: print(f"Error processing prompt: {e}") continue # Remove empty string from vocabulary if present self.vocab.discard('') self.vocab = list(self.vocab) self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)} def __len__(self): return len(self.valid_dataset) def text_to_vector(self, text): try: words = preprocess_text(text) vector = torch.zeros(len(self.vocab)) for word in words: if word in self.word_to_idx: vector[self.word_to_idx[word]] += 1 return vector except Exception as e: print(f"Error converting text to vector: {e}") return torch.zeros(len(self.vocab)) def __getitem__(self, idx): try: image = self.transform(self.valid_dataset[idx]['image']) text_vector = self.text_to_vector(self.valid_dataset[idx]['prompt']) label = self.labels[idx] return image, text_vector, label except Exception as e: print(f"Error getting item at index {idx}: {e}") # Return zero tensors as fallback return (torch.zeros((3, 224, 224)), torch.zeros(len(self.vocab)), 0) # Define CNN for image processing class ImageModel(nn.Module): def __init__(self): super(ImageModel, self).__init__() self.model = models.resnet18(pretrained=True) self.model.fc = nn.Linear(self.model.fc.in_features, 512) def forward(self, x): return self.model(x) # Define MLP for text processing class TextMLP(nn.Module): def __init__(self, vocab_size): super(TextMLP, self).__init__() self.layers = nn.Sequential( nn.Linear(vocab_size, 1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.2), nn.Linear(512, 512) ) def forward(self, x): return self.layers(x) # Combined model class CombinedModel(nn.Module): def __init__(self, vocab_size, num_classes): super(CombinedModel, self).__init__() self.image_model = ImageModel() self.text_model = TextMLP(vocab_size) self.fc = nn.Linear(1024, num_classes) def forward(self, image, text): image_features = self.image_model(image) text_features = self.text_model(text) combined = torch.cat((image_features, text_features), dim=1) return self.fc(combined) # Create dataset instance print("Creating dataset...") custom_dataset = CustomDataset(dataset) print(f"Vocabulary size: {len(custom_dataset.vocab)}") print(f"Number of valid samples: {len(custom_dataset)}") # Create model num_classes = len(custom_dataset.label_encoder.classes_) model = CombinedModel(len(custom_dataset.vocab), num_classes) def get_recommendations(image): model.eval() with torch.no_grad(): # Process input image transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) image_tensor = transform(image).unsqueeze(0) # Create dummy text vector dummy_text = torch.zeros((1, len(custom_dataset.vocab))) # Get model output output = model(image_tensor, dummy_text) _, indices = torch.topk(output, 5) # Get recommended images and their information recommendations = [] for idx in indices[0]: try: recommended_image = custom_dataset.valid_dataset[idx.item()]['image'] model_name = custom_dataset.valid_dataset[idx.item()]['Model'] recommendations.append((recommended_image, f"{model_name}")) except Exception as e: print(f"Error getting recommendation for index {idx}: {e}") continue return recommendations # Set up Gradio interface interface = gr.Interface( fn=get_recommendations, inputs=gr.Image(type="pil"), outputs=gr.Gallery(label="Recommended Images"), title="Image Recommendation System", description="Upload an image and get similar images with their model names." ) # Launch the app if __name__ == "__main__": interface.launch()