bgaspra commited on
Commit
9587045
·
verified ·
1 Parent(s): 650aead

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -171
app.py CHANGED
@@ -1,192 +1,140 @@
1
- import torch
2
  import gradio as gr
3
- from transformers import AutoProcessor, AutoModelForCausalLM
4
- from PIL import Image
 
 
5
  import pandas as pd
6
  from datasets import load_dataset
7
- from sklearn.metrics.pairwise import cosine_similarity
8
- import numpy as np
9
- import warnings
10
- warnings.filterwarnings('ignore')
11
-
12
- # Load Florence-2 model and processor
13
- model_name = "microsoft/Florence-2-base"
14
- device = "cuda" if torch.cuda.is_available() else "cpu"
15
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
 
17
- model = AutoModelForCausalLM.from_pretrained(
18
- model_name,
19
- torch_dtype=torch_dtype,
20
- trust_remote_code=True
21
- ).to(device)
22
- processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
23
 
24
- # Create a dummy image for text-only processing
25
- DUMMY_IMAGE = Image.new('RGB', (224, 224), color='white')
 
 
 
 
 
 
 
 
26
 
27
- # Load CivitAI dataset
28
- print("Loading dataset...")
29
- dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k", split="train[:1000]")
30
- df = pd.DataFrame(dataset)
31
- print("Dataset loaded successfully!")
32
-
33
- text_embedding_cache = {}
34
-
35
- def get_image_embedding(image):
36
- try:
37
- inputs = processor(
38
- images=image,
39
- text="Generate image description",
40
- return_tensors="pt",
41
- padding=True
42
- ).to(device, torch_dtype)
43
-
44
- # Generate decoder_input_ids with adjusted parameters
45
- decoder_input_ids = model.generate(
46
- **inputs,
47
- max_new_tokens=20, # Increased from max_length
48
- min_length=1,
49
- num_beams=1,
50
- do_sample=False,
51
- pad_token_id=processor.tokenizer.pad_token_id,
52
- return_dict_in_generate=True,
53
- ).sequences
54
-
55
- inputs['decoder_input_ids'] = decoder_input_ids
56
-
57
- with torch.no_grad():
58
- outputs = model(**inputs)
59
- image_embeddings = outputs.last_hidden_state.mean(dim=1)
60
- return image_embeddings.cpu().numpy()
61
- except Exception as e:
62
- print(f"Error in get_image_embedding: {str(e)}")
63
- return None
64
-
65
- def get_text_embedding(text):
66
- try:
67
- if text in text_embedding_cache:
68
- return text_embedding_cache[text]
69
-
70
- # Process text with dummy image
71
- inputs = processor(
72
- images=DUMMY_IMAGE,
73
- text=text,
74
- return_tensors="pt",
75
- padding=True
76
- ).to(device, torch_dtype)
77
-
78
- # Generate decoder_input_ids with adjusted parameters
79
- decoder_input_ids = model.generate(
80
- **inputs,
81
- max_new_tokens=20, # Using max_new_tokens instead of max_length
82
- min_length=1,
83
- num_beams=1,
84
- do_sample=False,
85
- pad_token_id=processor.tokenizer.pad_token_id,
86
- return_dict_in_generate=True,
87
- ).sequences
88
 
89
- inputs['decoder_input_ids'] = decoder_input_ids
 
 
 
 
 
90
 
91
- with torch.no_grad():
92
- outputs = model(**inputs)
93
- text_embeddings = outputs.last_hidden_state.mean(dim=1)
94
-
95
- embedding = text_embeddings.cpu().numpy()
96
- text_embedding_cache[text] = embedding
97
- return embedding
98
- except Exception as e:
99
- print(f"Error in get_text_embedding: {str(e)}")
100
- return None
101
-
102
- def precompute_embeddings():
103
- print("Pre-computing text embeddings...")
104
- for idx, row in df.iterrows():
105
- if row['prompt'] not in text_embedding_cache:
106
- _ = get_text_embedding(row['prompt'])
107
- if idx % 100 == 0:
108
- print(f"Processed {idx}/1000 embeddings")
109
- print("Finished pre-computing embeddings")
110
-
111
- def find_similar_images(uploaded_image, top_k=5):
112
- query_embedding = get_image_embedding(uploaded_image)
113
- if query_embedding is None:
114
- return [], []
115
 
116
- similarities = []
117
- for idx, row in df.iterrows():
118
- prompt_embedding = get_text_embedding(row['prompt'])
119
- if prompt_embedding is not None:
120
- similarity = cosine_similarity(query_embedding, prompt_embedding)[0][0]
121
- similarities.append({
122
- 'similarity': similarity,
123
- 'model': row['Model'],
124
- 'prompt': row['prompt']
125
- })
126
 
127
- sorted_results = sorted(similarities, key=lambda x: x['similarity'], reverse=True)
128
- top_models = []
129
- top_prompts = []
130
- seen_models = set()
131
- seen_prompts = set()
132
-
133
- for result in sorted_results:
134
- if len(top_models) < top_k and result['model'] not in seen_models:
135
- top_models.append(result['model'])
136
- seen_models.add(result['model'])
 
 
137
 
138
- if len(top_prompts) < top_k and result['prompt'] not in seen_prompts:
139
- top_prompts.append(result['prompt'])
140
- seen_prompts.add(result['prompt'])
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- if len(top_models) == top_k and len(top_prompts) == top_k:
143
- break
144
-
145
- return top_models, top_prompts
146
 
147
- def process_image(input_image):
148
- if input_image is None:
149
- return "Please upload an image.", "Please upload an image."
150
-
151
- try:
152
- if not isinstance(input_image, Image.Image):
153
- input_image = Image.fromarray(input_image)
154
 
155
- # Resize image to expected size
156
- input_image = input_image.resize((224, 224))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- recommended_models, recommended_prompts = find_similar_images(input_image)
 
159
 
160
- if not recommended_models or not recommended_prompts:
161
- return "Error processing image.", "Error processing image."
 
162
 
163
- models_text = "Recommended Models:\n" + "\n".join([f"{i+1}. {model}" for i, model in enumerate(recommended_models)])
164
- prompts_text = "Recommended Prompts:\n" + "\n".join([f"{i+1}. {prompt}" for i, prompt in enumerate(recommended_prompts)])
 
 
 
 
165
 
166
- return models_text, prompts_text
167
- except Exception as e:
168
- print(f"Error in process_image: {str(e)}")
169
- return "Error processing image.", "Error processing image."
170
-
171
- # Pre-compute embeddings when starting the application
172
- try:
173
- precompute_embeddings()
174
- except Exception as e:
175
- print(f"Error in precompute_embeddings: {str(e)}")
176
 
177
- # Create Gradio interface
178
- iface = gr.Interface(
179
- fn=process_image,
180
- inputs=gr.Image(type="pil", label="Upload AI-generated image"),
181
- outputs=[
182
- gr.Textbox(label="Recommended Models", lines=6),
183
- gr.Textbox(label="Recommended Prompts", lines=6)
184
- ],
185
- title="AI Image Model & Prompt Recommender",
186
- description="Upload an AI-generated image to get recommendations for Stable Diffusion models and prompts.",
187
- examples=[],
188
- cache_examples=False
189
  )
190
 
191
- # Launch the interface
192
- iface.launch()
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from torchvision import models
6
  import pandas as pd
7
  from datasets import load_dataset
8
+ from torch.utils.data import DataLoader, Dataset
9
+ from sklearn.preprocessing import LabelEncoder
 
 
 
 
 
 
 
10
 
11
+ # Load dataset
12
+ dataset = load_dataset('thefcraft/civitai-stable-diffusion-337k', split='train[:10000]')
 
 
 
 
13
 
14
+ # Text preprocessing function
15
+ def preprocess_text(text, max_length=100):
16
+ # Convert text to lowercase and split into words
17
+ words = text.lower().split()
18
+ # Truncate or pad to max_length
19
+ if len(words) > max_length:
20
+ words = words[:max_length]
21
+ else:
22
+ words.extend([''] * (max_length - len(words)))
23
+ return words
24
 
25
+ class CustomDataset(Dataset):
26
+ def __init__(self, dataset):
27
+ self.dataset = dataset
28
+ self.transform = transforms.Compose([
29
+ transforms.Resize((224, 224)),
30
+ transforms.ToTensor(),
31
+ ])
32
+ self.label_encoder = LabelEncoder()
33
+ self.labels = self.label_encoder.fit_transform(dataset['Model'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
+ # Create vocabulary from all prompts
36
+ self.vocab = set()
37
+ for item in dataset['prompt']:
38
+ self.vocab.update(preprocess_text(item))
39
+ self.vocab = list(self.vocab)
40
+ self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
41
 
42
+ def __len__(self):
43
+ return len(self.dataset)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def text_to_vector(self, text):
46
+ words = preprocess_text(text)
47
+ vector = torch.zeros(len(self.vocab))
48
+ for word in words:
49
+ if word in self.word_to_idx:
50
+ vector[self.word_to_idx[word]] += 1
51
+ return vector
 
 
 
52
 
53
+ def __getitem__(self, idx):
54
+ image = self.transform(self.dataset[idx]['image'])
55
+ text_vector = self.text_to_vector(self.dataset[idx]['prompt'])
56
+ label = self.labels[idx]
57
+ return image, text_vector, label
58
+
59
+ # Define CNN for image processing
60
+ class ImageModel(nn.Module):
61
+ def __init__(self):
62
+ super(ImageModel, self).__init__()
63
+ self.model = models.resnet18(pretrained=True)
64
+ self.model.fc = nn.Linear(self.model.fc.in_features, 512)
65
 
66
+ def forward(self, x):
67
+ return self.model(x)
68
+
69
+ # Define MLP for text processing
70
+ class TextMLP(nn.Module):
71
+ def __init__(self, vocab_size):
72
+ super(TextMLP, self).__init__()
73
+ self.layers = nn.Sequential(
74
+ nn.Linear(vocab_size, 1024),
75
+ nn.ReLU(),
76
+ nn.Dropout(0.3),
77
+ nn.Linear(1024, 512),
78
+ nn.ReLU(),
79
+ nn.Dropout(0.2),
80
+ nn.Linear(512, 512)
81
+ )
82
 
83
+ def forward(self, x):
84
+ return self.layers(x)
 
 
85
 
86
+ # Combined model
87
+ class CombinedModel(nn.Module):
88
+ def __init__(self, vocab_size):
89
+ super(CombinedModel, self).__init__()
90
+ self.image_model = ImageModel()
91
+ self.text_model = TextMLP(vocab_size)
92
+ self.fc = nn.Linear(1024, len(dataset['Model'].unique()))
93
 
94
+ def forward(self, image, text):
95
+ image_features = self.image_model(image)
96
+ text_features = self.text_model(text)
97
+ combined = torch.cat((image_features, text_features), dim=1)
98
+ return self.fc(combined)
99
+
100
+ # Create dataset instance and model
101
+ custom_dataset = CustomDataset(dataset)
102
+ model = CombinedModel(len(custom_dataset.vocab))
103
+
104
+ def get_recommendations(image):
105
+ model.eval()
106
+ with torch.no_grad():
107
+ # Process input image
108
+ transform = transforms.Compose([
109
+ transforms.Resize((224, 224)),
110
+ transforms.ToTensor()
111
+ ])
112
+ image_tensor = transform(image).unsqueeze(0)
113
 
114
+ # Create dummy text vector (since we're only doing image-based recommendations)
115
+ dummy_text = torch.zeros((1, len(custom_dataset.vocab)))
116
 
117
+ # Get model output
118
+ output = model(image_tensor, dummy_text)
119
+ _, indices = torch.topk(output, 5)
120
 
121
+ # Get recommended images and their information
122
+ recommendations = []
123
+ for idx in indices[0]:
124
+ recommended_image = dataset[idx.item()]['image']
125
+ model_name = dataset[idx.item()]['Model']
126
+ recommendations.append((recommended_image, f"{model_name}"))
127
 
128
+ return recommendations
 
 
 
 
 
 
 
 
 
129
 
130
+ # Set up Gradio interface
131
+ interface = gr.Interface(
132
+ fn=get_recommendations,
133
+ inputs=gr.Image(type="pil"),
134
+ outputs=gr.Gallery(label="Recommended Images"),
135
+ title="Image Recommendation System",
136
+ description="Upload an image and get similar images with their model names."
 
 
 
 
 
137
  )
138
 
139
+ # Launch the app
140
+ interface.launch()