File size: 4,617 Bytes
f2ca68f
9587045
 
 
 
f2ca68f
935a747
9587045
 
f2ca68f
9587045
 
f2ca68f
9587045
 
 
 
 
 
 
 
 
 
26d55ba
9587045
 
 
 
 
 
 
 
 
107b2a4
9587045
 
 
 
 
 
26d55ba
9587045
 
935a747
9587045
 
 
 
 
 
 
935a747
9587045
 
 
 
 
 
 
 
 
 
 
 
935a747
9587045
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935a747
9587045
 
935a747
9587045
 
 
 
 
 
 
ec1fd1e
9587045
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26d55ba
9587045
 
ec1fd1e
9587045
 
 
ec1fd1e
9587045
 
 
 
 
 
ec1fd1e
9587045
f2ca68f
9587045
 
 
 
 
 
 
f2ca68f
935a747
9587045
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import models
import pandas as pd
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import LabelEncoder

# Load dataset
dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')

# Text preprocessing function
def preprocess_text(text, max_length=100):
    # Convert text to lowercase and split into words
    words = text.lower().split()
    # Truncate or pad to max_length
    if len(words) > max_length:
        words = words[:max_length]
    else:
        words.extend([''] * (max_length - len(words)))
    return words

class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])
        self.label_encoder = LabelEncoder()
        self.labels = self.label_encoder.fit_transform(dataset['Model'])
        
        # Create vocabulary from all prompts
        self.vocab = set()
        for item in dataset['prompt']:
            self.vocab.update(preprocess_text(item))
        self.vocab = list(self.vocab)
        self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
        
    def __len__(self):
        return len(self.dataset)
    
    def text_to_vector(self, text):
        words = preprocess_text(text)
        vector = torch.zeros(len(self.vocab))
        for word in words:
            if word in self.word_to_idx:
                vector[self.word_to_idx[word]] += 1
        return vector
    
    def __getitem__(self, idx):
        image = self.transform(self.dataset[idx]['image'])
        text_vector = self.text_to_vector(self.dataset[idx]['prompt'])
        label = self.labels[idx]
        return image, text_vector, label

# Define CNN for image processing
class ImageModel(nn.Module):
    def __init__(self):
        super(ImageModel, self).__init__()
        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, 512)
        
    def forward(self, x):
        return self.model(x)

# Define MLP for text processing
class TextMLP(nn.Module):
    def __init__(self, vocab_size):
        super(TextMLP, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(vocab_size, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 512)
        )
        
    def forward(self, x):
        return self.layers(x)

# Combined model
class CombinedModel(nn.Module):
    def __init__(self, vocab_size):
        super(CombinedModel, self).__init__()
        self.image_model = ImageModel()
        self.text_model = TextMLP(vocab_size)
        self.fc = nn.Linear(1024, len(dataset['Model'].unique()))
        
    def forward(self, image, text):
        image_features = self.image_model(image)
        text_features = self.text_model(text)
        combined = torch.cat((image_features, text_features), dim=1)
        return self.fc(combined)

# Create dataset instance and model
custom_dataset = CustomDataset(dataset)
model = CombinedModel(len(custom_dataset.vocab))

def get_recommendations(image):
    model.eval()
    with torch.no_grad():
        # Process input image
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ])
        image_tensor = transform(image).unsqueeze(0)
        
        # Create dummy text vector (since we're only doing image-based recommendations)
        dummy_text = torch.zeros((1, len(custom_dataset.vocab)))
        
        # Get model output
        output = model(image_tensor, dummy_text)
        _, indices = torch.topk(output, 5)
        
        # Get recommended images and their information
        recommendations = []
        for idx in indices[0]:
            recommended_image = dataset[idx.item()]['image']
            model_name = dataset[idx.item()]['Model']
            recommendations.append((recommended_image, f"{model_name}"))
        
    return recommendations

# Set up Gradio interface
interface = gr.Interface(
    fn=get_recommendations,
    inputs=gr.Image(type="pil"),
    outputs=gr.Gallery(label="Recommended Images"),
    title="Image Recommendation System",
    description="Upload an image and get similar images with their model names."
)

# Launch the app
interface.launch()