import gradio as gr import torch import torch.nn as nn import torchvision.transforms as transforms from torchvision import models import pandas as pd from datasets import load_dataset from torch.utils.data import DataLoader, Dataset from sklearn.preprocessing import LabelEncoder # Load dataset dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]') # Text preprocessing function def preprocess_text(text, max_length=100): # Convert text to lowercase and split into words words = text.lower().split() # Truncate or pad to max_length if len(words) > max_length: words = words[:max_length] else: words.extend([''] * (max_length - len(words))) return words class CustomDataset(Dataset): def __init__(self, dataset): self.dataset = dataset self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) self.label_encoder = LabelEncoder() self.labels = self.label_encoder.fit_transform(dataset['Model']) # Create vocabulary from all prompts self.vocab = set() for item in dataset['prompt']: self.vocab.update(preprocess_text(item)) self.vocab = list(self.vocab) self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)} def __len__(self): return len(self.dataset) def text_to_vector(self, text): words = preprocess_text(text) vector = torch.zeros(len(self.vocab)) for word in words: if word in self.word_to_idx: vector[self.word_to_idx[word]] += 1 return vector def __getitem__(self, idx): image = self.transform(self.dataset[idx]['image']) text_vector = self.text_to_vector(self.dataset[idx]['prompt']) label = self.labels[idx] return image, text_vector, label # Define CNN for image processing class ImageModel(nn.Module): def __init__(self): super(ImageModel, self).__init__() self.model = models.resnet18(pretrained=True) self.model.fc = nn.Linear(self.model.fc.in_features, 512) def forward(self, x): return self.model(x) # Define MLP for text processing class TextMLP(nn.Module): def __init__(self, vocab_size): super(TextMLP, self).__init__() self.layers = nn.Sequential( nn.Linear(vocab_size, 1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(0.2), nn.Linear(512, 512) ) def forward(self, x): return self.layers(x) # Combined model class CombinedModel(nn.Module): def __init__(self, vocab_size): super(CombinedModel, self).__init__() self.image_model = ImageModel() self.text_model = TextMLP(vocab_size) self.fc = nn.Linear(1024, len(dataset['Model'].unique())) def forward(self, image, text): image_features = self.image_model(image) text_features = self.text_model(text) combined = torch.cat((image_features, text_features), dim=1) return self.fc(combined) # Create dataset instance and model custom_dataset = CustomDataset(dataset) model = CombinedModel(len(custom_dataset.vocab)) def get_recommendations(image): model.eval() with torch.no_grad(): # Process input image transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor() ]) image_tensor = transform(image).unsqueeze(0) # Create dummy text vector (since we're only doing image-based recommendations) dummy_text = torch.zeros((1, len(custom_dataset.vocab))) # Get model output output = model(image_tensor, dummy_text) _, indices = torch.topk(output, 5) # Get recommended images and their information recommendations = [] for idx in indices[0]: recommended_image = dataset[idx.item()]['image'] model_name = dataset[idx.item()]['Model'] recommendations.append((recommended_image, f"{model_name}")) return recommendations # Set up Gradio interface interface = gr.Interface( fn=get_recommendations, inputs=gr.Image(type="pil"), outputs=gr.Gallery(label="Recommended Images"), title="Image Recommendation System", description="Upload an image and get similar images with their model names." ) # Launch the app interface.launch()