bgaspra commited on
Commit
3aa52df
1 Parent(s): 0f9a2bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -155
app.py CHANGED
@@ -1,164 +1,205 @@
1
- import torch
2
- import torch.nn as nn
3
- from torchvision import transforms, models
 
 
 
 
 
 
 
 
 
 
 
 
4
  import gradio as gr
5
- from PIL import Image
6
 
7
- # Model Architecture (sama seperti sebelumnya)
8
- class ModelRecommender(nn.Module):
9
- def __init__(self, num_models, text_embedding_dim=768):
10
- super(ModelRecommender, self).__init__()
11
-
12
- # CNN for image processing
13
- self.cnn = models.resnet18(pretrained=True)
14
- self.cnn.fc = nn.Linear(512, 256)
15
-
16
- # MLP for text processing
17
- self.text_mlp = nn.Sequential(
18
- nn.Linear(text_embedding_dim, 512),
19
- nn.ReLU(),
20
- nn.Linear(512, 256),
21
- nn.ReLU()
22
- )
23
-
24
- # Combined layers
25
- self.combined = nn.Sequential(
26
- nn.Linear(512, 256),
27
- nn.ReLU(),
28
- nn.Dropout(0.5),
29
- nn.Linear(256, num_models)
30
- )
31
-
32
- def forward(self, image, text_features):
33
- # Process image
34
- img_features = self.cnn(image)
35
-
36
- # Process text
37
- text_features = self.text_mlp(text_features)
38
-
39
- # Combine features
40
- combined = torch.cat((img_features, text_features), dim=1)
41
-
42
- # Final prediction
43
- output = self.combined(combined)
44
- return output
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- # Load model dan dataset info
49
- def load_model():
50
- # Load dataset info
51
- dataset_info = torch.load('dataset_info.pth')
52
- model_names = dataset_info['model_names']
53
-
54
- # Initialize model
55
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
- model = ModelRecommender(len(model_names))
57
-
58
- # Load model weights
59
- checkpoint = torch.load('sd_recommender_model.pth', map_location=device)
60
- model.load_state_dict(checkpoint['model_state_dict'])
61
- model.to(device)
62
- model.eval()
63
-
64
- return model, model_names, device
 
 
 
 
 
65
 
66
- def calculate_euclidean_distance(features1, features2):
67
- return np.linalg.norm(features1 - features2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- def predict_image(image):
70
- if not hasattr(predict_image, "model"):
71
- predict_image.model, predict_image.model_names, predict_image.device = load_model()
72
-
73
- transform = transforms.Compose([
74
- transforms.Resize((224, 224)),
75
- transforms.ToTensor(),
76
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
77
- std=[0.229, 0.224, 0.225])
78
- ])
79
-
80
- image_tensor = transform(image).unsqueeze(0).to(predict_image.device)
81
- dummy_text_features = torch.zeros(1, 768).to(predict_image.device)
82
-
83
- # Get image features
84
- with torch.no_grad():
85
- img_features = predict_image.model.cnn(image_tensor).cpu().numpy()
86
- outputs = predict_image.model(image_tensor, dummy_text_features)
87
- top5_prob, top5_indices = torch.topk(outputs, 5)
88
-
89
- # Create HTML gallery
90
- html_output = """
91
- <style>
92
- .model-gallery {
93
- display: grid;
94
- grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
95
- gap: 20px;
96
- padding: 20px;
97
- }
98
- .model-card {
99
- border: 1px solid #ddd;
100
- border-radius: 8px;
101
- overflow: hidden;
102
- background: white;
103
- box-shadow: 0 2px 5px rgba(0,0,0,0.1);
104
- }
105
- .model-img {
106
- width: 100%;
107
- height: 200px;
108
- object-fit: cover;
109
- }
110
- .model-info {
111
- padding: 15px;
112
- }
113
- .model-name {
114
- color: #2563eb;
115
- text-decoration: none;
116
- font-weight: bold;
117
- font-size: 1.1em;
118
- margin-bottom: 8px;
119
- display: block;
120
- }
121
- .model-name:hover {
122
- text-decoration: underline;
123
- }
124
- .distance {
125
- color: #666;
126
- font-size: 0.9em;
127
- }
128
- </style>
129
- <div class="model-gallery">
130
- """
131
-
132
- # Generate cards for each model
133
- for idx, (score, model_idx) in enumerate(zip(top5_prob[0], top5_indices[0])):
134
- model_name = predict_image.model_names[model_idx.item()]
135
- distance = calculate_euclidean_distance(img_features[0],
136
- torch.randn(512).numpy()) # Placeholder for actual features
137
-
138
- civitai_url = f"https://civitai.com/search/models?sortBy=models_v9&query={model_name}"
139
-
140
- html_output += f"""
141
- <div class="model-card">
142
- <img class="model-img" src="data:image/svg+xml,<svg xmlns='http://www.w3.org/2000/svg' width='250' height='200' viewBox='0 0 250 200'><rect width='100%' height='100%' fill='%23f0f0f0'/><text x='50%' y='50%' dominant-baseline='middle' text-anchor='middle' font-family='Arial' font-size='16' fill='%23666'>Model Preview</text></svg>" alt="{model_name}">
143
- <div class="model-info">
144
- <a href="{civitai_url}" target="_blank" class="model-name">{model_name}</a>
145
- <div class="distance">Euclidean Distance: {distance:.4f}</div>
146
- </div>
147
- </div>
148
- """
149
-
150
- html_output += "</div>"
151
-
152
- return html_output
153
 
154
- # Gradio Interface
155
- demo = gr.Interface(
156
- fn=predict_image,
157
- inputs=gr.Image(type="pil"),
158
- outputs=gr.HTML(),
159
- title="Stable Diffusion Model Recommender",
160
- description="Upload an AI-generated image to get model recommendations",
161
- examples=[["example1.jpg"], ["example2.jpg"]]
162
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- demo.launch()
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ from tqdm import tqdm
4
+ from datasets import load_dataset
5
+ import numpy as np
6
+ import tensorflow as tf
7
+ from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
8
+ from tensorflow.keras.preprocessing import image
9
+ from tensorflow.keras.layers import Dense, Input, Concatenate, Embedding, Flatten
10
+ from tensorflow.keras.models import Model
11
+ from tensorflow.keras.preprocessing.text import Tokenizer
12
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
13
+ from sklearn.preprocessing import LabelEncoder
14
+ import joblib
15
+ from PIL import UnidentifiedImageError, Image
16
  import gradio as gr
 
17
 
18
+ # Constants
19
+ MAX_TEXT_LENGTH = 200
20
+ EMBEDDING_DIM = 100
21
+ IMAGE_SIZE = 224
22
+ BATCH_SIZE = 32
23
+
24
+ def load_and_preprocess_data(subset_size=2700):
25
+ # Load dataset
26
+ dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
27
+ dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
28
+
29
+ # Filter out NSFW content
30
+ dataset_subset = dataset_subset.filter(lambda x: not x['nsfw'])
31
+
32
+ return dataset_subset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ def process_text_data(dataset_subset):
35
+ # Combine prompt and negative prompt
36
+ text_data = [f"{sample['prompt']} {sample['negativePrompt']}" for sample in dataset_subset]
37
+
38
+ # Tokenize text
39
+ tokenizer = Tokenizer()
40
+ tokenizer.fit_on_texts(text_data)
41
+ sequences = tokenizer.texts_to_sequences(text_data)
42
+ text_data_padded = pad_sequences(sequences, maxlen=MAX_TEXT_LENGTH)
43
+
44
+ return text_data_padded, tokenizer
45
+
46
+ def process_image_data(dataset_subset):
47
+ image_dir = 'civitai_images'
48
+ os.makedirs(image_dir, exist_ok=True)
49
+
50
+ image_data = []
51
+ valid_indices = []
52
+
53
+ for idx, sample in enumerate(tqdm(dataset_subset)):
54
+ img_url = sample['url']
55
+ img_path = os.path.join(image_dir, os.path.basename(img_url))
56
 
57
+ try:
58
+ # Download and save image
59
+ response = requests.get(img_url)
60
+ response.raise_for_status()
61
+
62
+ if 'image' not in response.headers['Content-Type']:
63
+ continue
64
+
65
+ with open(img_path, 'wb') as f:
66
+ f.write(response.content)
67
+
68
+ # Load and preprocess image
69
+ img = image.load_img(img_path, target_size=(IMAGE_SIZE, IMAGE_SIZE))
70
+ img_array = image.img_to_array(img)
71
+ img_array = preprocess_input(img_array)
72
+
73
+ image_data.append(img_array)
74
+ valid_indices.append(idx)
75
+
76
+ except Exception as e:
77
+ print(f"Error processing image {img_url}: {e}")
78
+ continue
79
+
80
+ return np.array(image_data), valid_indices
81
 
82
+ def create_multimodal_model(num_words, num_classes):
83
+ # Image input branch (CNN)
84
+ image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
85
+ cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
86
+ cnn_features = cnn_base(image_input)
87
+
88
+ # Text input branch (MLP)
89
+ text_input = Input(shape=(MAX_TEXT_LENGTH,))
90
+ embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
91
+ flatten_text = Flatten()(embedding_layer)
92
+ text_features = Dense(256, activation='relu')(flatten_text)
93
+
94
+ # Combine features
95
+ combined = Concatenate()([cnn_features, text_features])
96
+
97
+ # Fully connected layers
98
+ x = Dense(512, activation='relu')(combined)
99
+ x = Dense(256, activation='relu')(x)
100
+ output = Dense(num_classes, activation='softmax')(x)
101
+
102
+ model = Model(inputs=[image_input, text_input], outputs=output)
103
+ return model
104
 
105
+ def train_model():
106
+ # Load and preprocess data
107
+ dataset_subset = load_and_preprocess_data()
108
+
109
+ # Process text data
110
+ text_data_padded, tokenizer = process_text_data(dataset_subset)
111
+
112
+ # Process image data
113
+ image_data, valid_indices = process_image_data(dataset_subset)
114
+
115
+ # Get valid text data and labels
116
+ text_data_padded = text_data_padded[valid_indices]
117
+ model_names = [dataset_subset[i]['Model'] for i in valid_indices]
118
+
119
+ # Encode labels
120
+ label_encoder = LabelEncoder()
121
+ encoded_labels = label_encoder.fit_transform(model_names)
122
+
123
+ # Create and compile model
124
+ model = create_multimodal_model(
125
+ num_words=len(tokenizer.word_index) + 1,
126
+ num_classes=len(label_encoder.classes_)
127
+ )
128
+
129
+ model.compile(
130
+ optimizer='adam',
131
+ loss='sparse_categorical_crossentropy',
132
+ metrics=['accuracy']
133
+ )
134
+
135
+ # Train model
136
+ history = model.fit(
137
+ [image_data, text_data_padded],
138
+ encoded_labels,
139
+ batch_size=BATCH_SIZE,
140
+ epochs=10,
141
+ validation_split=0.2
142
+ )
143
+
144
+ # Save models and encoders
145
+ model.save('multimodal_model')
146
+ joblib.dump(tokenizer, 'tokenizer.pkl')
147
+ joblib.dump(label_encoder, 'label_encoder.pkl')
148
+
149
+ return model, tokenizer, label_encoder
150
 
151
+ def get_recommendations(image_input, text_input, model, tokenizer, label_encoder, top_k=5):
152
+ # Preprocess image
153
+ img_array = image.img_to_array(image_input)
154
+ img_array = tf.image.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE))
155
+ img_array = preprocess_input(img_array)
156
+ img_array = np.expand_dims(img_array, axis=0)
157
+
158
+ # Preprocess text
159
+ text_sequence = tokenizer.texts_to_sequences([text_input])
160
+ text_padded = pad_sequences(text_sequence, maxlen=MAX_TEXT_LENGTH)
161
+
162
+ # Get predictions
163
+ predictions = model.predict([img_array, text_padded])
164
+ top_indices = np.argsort(predictions[0])[-top_k:][::-1]
165
+
166
+ # Get recommended model names and confidence scores
167
+ recommendations = [
168
+ (label_encoder.inverse_transform([idx])[0], predictions[0][idx])
169
+ for idx in top_indices
170
+ ]
171
+
172
+ return recommendations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ # Gradio interface
175
+ def create_gradio_interface():
176
+ # Load saved models
177
+ model = tf.keras.models.load_model('multimodal_model')
178
+ tokenizer = joblib.load('tokenizer.pkl')
179
+ label_encoder = joblib.load('label_encoder.pkl')
180
+
181
+ def predict(img, text):
182
+ recommendations = get_recommendations(img, text, model, tokenizer, label_encoder)
183
+ return "\n".join([f"Model: {name}, Confidence: {conf:.2f}" for name, conf in recommendations])
184
+
185
+ interface = gr.Interface(
186
+ fn=predict,
187
+ inputs=[
188
+ gr.Image(type="pil", label="Upload Image"),
189
+ gr.Textbox(label="Enter Prompt")
190
+ ],
191
+ outputs=gr.Textbox(label="Recommended Models"),
192
+ title="Multimodal Model Recommendation System",
193
+ description="Upload an image and enter a prompt to get model recommendations"
194
+ )
195
+
196
+ return interface
197
 
198
+ if __name__ == "__main__":
199
+ # Train model if not already trained
200
+ if not os.path.exists('multimodal_model'):
201
+ model, tokenizer, label_encoder = train_model()
202
+
203
+ # Launch Gradio interface
204
+ interface = create_gradio_interface()
205
+ interface.launch()