Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
from torchvision import models | |
import pandas as pd | |
from datasets import load_dataset | |
from torch.utils.data import DataLoader, Dataset | |
from sklearn.preprocessing import LabelEncoder | |
# Load dataset | |
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]') | |
# Text preprocessing function with None handling | |
def preprocess_text(text, max_length=100): | |
# Handle None or empty text | |
if text is None or not isinstance(text, str): | |
text = "" | |
# Convert text to lowercase and split into words | |
words = text.lower().split() | |
# Truncate or pad to max_length | |
if len(words) > max_length: | |
words = words[:max_length] | |
else: | |
words.extend([''] * (max_length - len(words))) | |
return words | |
class CustomDataset(Dataset): | |
def __init__(self, dataset): | |
self.dataset = dataset | |
self.transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
]) | |
# Filter out None values from Model column | |
valid_indices = [i for i, model in enumerate(dataset['Model']) if model is not None] | |
self.valid_dataset = dataset.select(valid_indices) | |
self.label_encoder = LabelEncoder() | |
self.labels = self.label_encoder.fit_transform(self.valid_dataset['Model']) | |
# Create vocabulary from all prompts | |
self.vocab = set() | |
for item in self.valid_dataset['prompt']: | |
try: | |
self.vocab.update(preprocess_text(item)) | |
except Exception as e: | |
print(f"Error processing prompt: {e}") | |
continue | |
# Remove empty string from vocabulary if present | |
self.vocab.discard('') | |
self.vocab = list(self.vocab) | |
self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)} | |
def __len__(self): | |
return len(self.valid_dataset) | |
def text_to_vector(self, text): | |
try: | |
words = preprocess_text(text) | |
vector = torch.zeros(len(self.vocab)) | |
for word in words: | |
if word in self.word_to_idx: | |
vector[self.word_to_idx[word]] += 1 | |
return vector | |
except Exception as e: | |
print(f"Error converting text to vector: {e}") | |
return torch.zeros(len(self.vocab)) | |
def __getitem__(self, idx): | |
try: | |
image = self.transform(self.valid_dataset[idx]['image']) | |
text_vector = self.text_to_vector(self.valid_dataset[idx]['prompt']) | |
label = self.labels[idx] | |
return image, text_vector, label | |
except Exception as e: | |
print(f"Error getting item at index {idx}: {e}") | |
# Return zero tensors as fallback | |
return (torch.zeros((3, 224, 224)), | |
torch.zeros(len(self.vocab)), | |
0) | |
# Define CNN for image processing | |
class ImageModel(nn.Module): | |
def __init__(self): | |
super(ImageModel, self).__init__() | |
self.model = models.resnet18(pretrained=True) | |
self.model.fc = nn.Linear(self.model.fc.in_features, 512) | |
def forward(self, x): | |
return self.model(x) | |
# Define MLP for text processing | |
class TextMLP(nn.Module): | |
def __init__(self, vocab_size): | |
super(TextMLP, self).__init__() | |
self.layers = nn.Sequential( | |
nn.Linear(vocab_size, 1024), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
nn.Linear(1024, 512), | |
nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(512, 512) | |
) | |
def forward(self, x): | |
return self.layers(x) | |
# Combined model | |
class CombinedModel(nn.Module): | |
def __init__(self, vocab_size, num_classes): | |
super(CombinedModel, self).__init__() | |
self.image_model = ImageModel() | |
self.text_model = TextMLP(vocab_size) | |
self.fc = nn.Linear(1024, num_classes) | |
def forward(self, image, text): | |
image_features = self.image_model(image) | |
text_features = self.text_model(text) | |
combined = torch.cat((image_features, text_features), dim=1) | |
return self.fc(combined) | |
# Create dataset instance | |
print("Creating dataset...") | |
custom_dataset = CustomDataset(dataset) | |
print(f"Vocabulary size: {len(custom_dataset.vocab)}") | |
print(f"Number of valid samples: {len(custom_dataset)}") | |
# Create model | |
num_classes = len(custom_dataset.label_encoder.classes_) | |
model = CombinedModel(len(custom_dataset.vocab), num_classes) | |
def get_recommendations(image): | |
model.eval() | |
with torch.no_grad(): | |
# Process input image | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor() | |
]) | |
image_tensor = transform(image).unsqueeze(0) | |
# Create dummy text vector | |
dummy_text = torch.zeros((1, len(custom_dataset.vocab))) | |
# Get model output | |
output = model(image_tensor, dummy_text) | |
_, indices = torch.topk(output, 5) | |
# Get recommended images and their information | |
recommendations = [] | |
for idx in indices[0]: | |
try: | |
recommended_image = custom_dataset.valid_dataset[idx.item()]['image'] | |
model_name = custom_dataset.valid_dataset[idx.item()]['Model'] | |
recommendations.append((recommended_image, f"{model_name}")) | |
except Exception as e: | |
print(f"Error getting recommendation for index {idx}: {e}") | |
continue | |
return recommendations | |
# Set up Gradio interface | |
interface = gr.Interface( | |
fn=get_recommendations, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Gallery(label="Recommended Images"), | |
title="Image Recommendation System", | |
description="Upload an image and get similar images with their model names." | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
interface.launch() |