Spaces:
Running
Running
import gradio as gr | |
import torch | |
import torch.nn as nn | |
import torchvision.transforms as transforms | |
from torchvision import models | |
import pandas as pd | |
from datasets import load_dataset | |
from torch.utils.data import DataLoader, Dataset | |
from sklearn.preprocessing import LabelEncoder | |
# Load dataset | |
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]') | |
# Text preprocessing function | |
def preprocess_text(text, max_length=100): | |
# Convert text to lowercase and split into words | |
words = text.lower().split() | |
# Truncate or pad to max_length | |
if len(words) > max_length: | |
words = words[:max_length] | |
else: | |
words.extend([''] * (max_length - len(words))) | |
return words | |
class CustomDataset(Dataset): | |
def __init__(self, dataset): | |
self.dataset = dataset | |
self.transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor(), | |
]) | |
self.label_encoder = LabelEncoder() | |
self.labels = self.label_encoder.fit_transform(dataset['Model']) | |
# Create vocabulary from all prompts | |
self.vocab = set() | |
for item in dataset['prompt']: | |
self.vocab.update(preprocess_text(item)) | |
self.vocab = list(self.vocab) | |
self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)} | |
def __len__(self): | |
return len(self.dataset) | |
def text_to_vector(self, text): | |
words = preprocess_text(text) | |
vector = torch.zeros(len(self.vocab)) | |
for word in words: | |
if word in self.word_to_idx: | |
vector[self.word_to_idx[word]] += 1 | |
return vector | |
def __getitem__(self, idx): | |
image = self.transform(self.dataset[idx]['image']) | |
text_vector = self.text_to_vector(self.dataset[idx]['prompt']) | |
label = self.labels[idx] | |
return image, text_vector, label | |
# Define CNN for image processing | |
class ImageModel(nn.Module): | |
def __init__(self): | |
super(ImageModel, self).__init__() | |
self.model = models.resnet18(pretrained=True) | |
self.model.fc = nn.Linear(self.model.fc.in_features, 512) | |
def forward(self, x): | |
return self.model(x) | |
# Define MLP for text processing | |
class TextMLP(nn.Module): | |
def __init__(self, vocab_size): | |
super(TextMLP, self).__init__() | |
self.layers = nn.Sequential( | |
nn.Linear(vocab_size, 1024), | |
nn.ReLU(), | |
nn.Dropout(0.3), | |
nn.Linear(1024, 512), | |
nn.ReLU(), | |
nn.Dropout(0.2), | |
nn.Linear(512, 512) | |
) | |
def forward(self, x): | |
return self.layers(x) | |
# Combined model | |
class CombinedModel(nn.Module): | |
def __init__(self, vocab_size): | |
super(CombinedModel, self).__init__() | |
self.image_model = ImageModel() | |
self.text_model = TextMLP(vocab_size) | |
self.fc = nn.Linear(1024, len(dataset['Model'].unique())) | |
def forward(self, image, text): | |
image_features = self.image_model(image) | |
text_features = self.text_model(text) | |
combined = torch.cat((image_features, text_features), dim=1) | |
return self.fc(combined) | |
# Create dataset instance and model | |
custom_dataset = CustomDataset(dataset) | |
model = CombinedModel(len(custom_dataset.vocab)) | |
def get_recommendations(image): | |
model.eval() | |
with torch.no_grad(): | |
# Process input image | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), | |
transforms.ToTensor() | |
]) | |
image_tensor = transform(image).unsqueeze(0) | |
# Create dummy text vector (since we're only doing image-based recommendations) | |
dummy_text = torch.zeros((1, len(custom_dataset.vocab))) | |
# Get model output | |
output = model(image_tensor, dummy_text) | |
_, indices = torch.topk(output, 5) | |
# Get recommended images and their information | |
recommendations = [] | |
for idx in indices[0]: | |
recommended_image = dataset[idx.item()]['image'] | |
model_name = dataset[idx.item()]['Model'] | |
recommendations.append((recommended_image, f"{model_name}")) | |
return recommendations | |
# Set up Gradio interface | |
interface = gr.Interface( | |
fn=get_recommendations, | |
inputs=gr.Image(type="pil"), | |
outputs=gr.Gallery(label="Recommended Images"), | |
title="Image Recommendation System", | |
description="Upload an image and get similar images with their model names." | |
) | |
# Launch the app | |
interface.launch() |