Spaces:
Sleeping
Sleeping
File size: 4,617 Bytes
f2ca68f 9587045 f2ca68f 935a747 9587045 f2ca68f 9587045 f2ca68f 9587045 26d55ba 9587045 107b2a4 9587045 26d55ba 9587045 935a747 9587045 935a747 9587045 935a747 9587045 935a747 9587045 935a747 9587045 ec1fd1e 9587045 26d55ba 9587045 ec1fd1e 9587045 ec1fd1e 9587045 ec1fd1e 9587045 f2ca68f 9587045 f2ca68f 935a747 9587045 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
import pandas as pd
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder
# Load dataset
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
# Text preprocessing function
def preprocess_text(text, max_length=100):
# Convert text to lowercase and split into words
words = text.lower().split()
# Truncate or pad to max_length
if len(words) > max_length:
words = words[:max_length]
else:
words.extend([''] * (max_length - len(words)))
return words
class CustomDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
self.label_encoder = LabelEncoder()
self.labels = self.label_encoder.fit_transform(dataset['Model'])
# Create vocabulary from all prompts
self.vocab = set()
for item in dataset['prompt']:
self.vocab.update(preprocess_text(item))
self.vocab = list(self.vocab)
self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
def __len__(self):
return len(self.dataset)
def text_to_vector(self, text):
words = preprocess_text(text)
vector = torch.zeros(len(self.vocab))
for word in words:
if word in self.word_to_idx:
vector[self.word_to_idx[word]] += 1
return vector
def __getitem__(self, idx):
image = self.transform(self.dataset[idx]['image'])
text_vector = self.text_to_vector(self.dataset[idx]['prompt'])
label = self.labels[idx]
return image, text_vector, label
# Define CNN for image processing
class ImageModel(nn.Module):
def __init__(self):
super(ImageModel, self).__init__()
self.model = models.resnet18(pretrained=True)
self.model.fc = nn.Linear(self.model.fc.in_features, 512)
def forward(self, x):
return self.model(x)
# Define MLP for text processing
class TextMLP(nn.Module):
def __init__(self, vocab_size):
super(TextMLP, self).__init__()
self.layers = nn.Sequential(
nn.Linear(vocab_size, 1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(512, 512)
)
def forward(self, x):
return self.layers(x)
# Combined model
class CombinedModel(nn.Module):
def __init__(self, vocab_size):
super(CombinedModel, self).__init__()
self.image_model = ImageModel()
self.text_model = TextMLP(vocab_size)
self.fc = nn.Linear(1024, len(dataset['Model'].unique()))
def forward(self, image, text):
image_features = self.image_model(image)
text_features = self.text_model(text)
combined = torch.cat((image_features, text_features), dim=1)
return self.fc(combined)
# Create dataset instance and model
custom_dataset = CustomDataset(dataset)
model = CombinedModel(len(custom_dataset.vocab))
def get_recommendations(image):
model.eval()
with torch.no_grad():
# Process input image
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
image_tensor = transform(image).unsqueeze(0)
# Create dummy text vector (since we're only doing image-based recommendations)
dummy_text = torch.zeros((1, len(custom_dataset.vocab)))
# Get model output
output = model(image_tensor, dummy_text)
_, indices = torch.topk(output, 5)
# Get recommended images and their information
recommendations = []
for idx in indices[0]:
recommended_image = dataset[idx.item()]['image']
model_name = dataset[idx.item()]['Model']
recommendations.append((recommended_image, f"{model_name}"))
return recommendations
# Set up Gradio interface
interface = gr.Interface(
fn=get_recommendations,
inputs=gr.Image(type="pil"),
outputs=gr.Gallery(label="Recommended Images"),
title="Image Recommendation System",
description="Upload an image and get similar images with their model names."
)
# Launch the app
interface.launch() |