bgaspra commited on
Commit
c1bc1fb
1 Parent(s): 9ec6191

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -44
app.py CHANGED
@@ -21,6 +21,9 @@ EMBEDDING_DIM = 50
21
  IMAGE_SIZE = 160
22
  BATCH_SIZE = 64
23
 
 
 
 
24
  def load_and_preprocess_data(subset_size=500):
25
  # Load dataset
26
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
@@ -29,13 +32,17 @@ def load_and_preprocess_data(subset_size=500):
29
  # Filter out NSFW content
30
  dataset_subset = dataset_subset.filter(lambda x: not x['nsfw'])
31
 
 
 
 
 
 
32
  return dataset_subset
33
 
34
  def process_text_data(dataset_subset):
35
- # Combine prompt and negative prompt
36
- text_data = [f"{sample['prompt']} {sample['negativePrompt']}" for sample in dataset_subset]
37
 
38
- # Tokenize text
39
  tokenizer = Tokenizer(num_words=10000)
40
  tokenizer.fit_on_texts(text_data)
41
  sequences = tokenizer.texts_to_sequences(text_data)
@@ -43,6 +50,14 @@ def process_text_data(dataset_subset):
43
 
44
  return text_data_padded, tokenizer
45
 
 
 
 
 
 
 
 
 
46
  def process_image_data(dataset_subset):
47
  image_dir = 'civitai_images'
48
  os.makedirs(image_dir, exist_ok=True)
@@ -55,7 +70,6 @@ def process_image_data(dataset_subset):
55
  img_path = os.path.join(image_dir, os.path.basename(img_url))
56
 
57
  try:
58
- # Download and save image
59
  response = requests.get(img_url, timeout=5)
60
  response.raise_for_status()
61
 
@@ -65,7 +79,6 @@ def process_image_data(dataset_subset):
65
  with open(img_path, 'wb') as f:
66
  f.write(response.content)
67
 
68
- # Load and preprocess image
69
  img = image.load_img(img_path, target_size=(IMAGE_SIZE, IMAGE_SIZE))
70
  img_array = image.img_to_array(img)
71
  img_array = preprocess_input(img_array)
@@ -79,26 +92,21 @@ def process_image_data(dataset_subset):
79
  return np.array(image_data), valid_indices
80
 
81
  def create_multimodal_model(num_words, num_classes):
82
- # Image input branch (CNN)
83
  image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
84
  cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
85
 
86
- # Freeze most of the ResNet50 layers
87
  for layer in cnn_base.layers[:-10]:
88
  layer.trainable = False
89
 
90
  cnn_features = cnn_base(image_input)
91
 
92
- # Text input branch
93
  text_input = Input(shape=(MAX_TEXT_LENGTH,))
94
  embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
95
  flatten_text = Flatten()(embedding_layer)
96
  text_features = Dense(128, activation='relu')(flatten_text)
97
 
98
- # Combine features
99
  combined = Concatenate()([cnn_features, text_features])
100
 
101
- # Simplified fully connected layers
102
  x = Dense(256, activation='relu')(combined)
103
  output = Dense(num_classes, activation='softmax')(x)
104
 
@@ -106,24 +114,18 @@ def create_multimodal_model(num_words, num_classes):
106
  return model
107
 
108
  def train_model():
109
- # Load and preprocess data
110
  dataset_subset = load_and_preprocess_data()
111
 
112
- # Process text data
113
  text_data_padded, tokenizer = process_text_data(dataset_subset)
114
 
115
- # Process image data
116
  image_data, valid_indices = process_image_data(dataset_subset)
117
 
118
- # Get valid text data and labels
119
  text_data_padded = text_data_padded[valid_indices]
120
  model_names = [dataset_subset[i]['Model'] for i in valid_indices]
121
 
122
- # Encode labels
123
  label_encoder = LabelEncoder()
124
  encoded_labels = label_encoder.fit_transform(model_names)
125
 
126
- # Create and compile model
127
  model = create_multimodal_model(
128
  num_words=10000,
129
  num_classes=len(label_encoder.classes_)
@@ -135,7 +137,6 @@ def train_model():
135
  metrics=['accuracy']
136
  )
137
 
138
- # Train model
139
  history = model.fit(
140
  [image_data, text_data_padded],
141
  encoded_labels,
@@ -144,68 +145,75 @@ def train_model():
144
  validation_split=0.2
145
  )
146
 
147
- # Save models and encoders with correct extensions
148
- model.save('multimodal_model.keras') # Changed from 'multimodal_model'
149
  joblib.dump(tokenizer, 'tokenizer.pkl')
150
  joblib.dump(label_encoder, 'label_encoder.pkl')
151
 
 
 
 
152
  return model, tokenizer, label_encoder
153
 
154
- def get_recommendations(image_input, text_input, model, tokenizer, label_encoder, top_k=5):
155
- # Preprocess image
156
  img_array = image.img_to_array(image_input)
157
  img_array = tf.image.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE))
158
  img_array = preprocess_input(img_array)
159
  img_array = np.expand_dims(img_array, axis=0)
160
 
161
- # Preprocess text
162
- text_sequence = tokenizer.texts_to_sequences([text_input])
163
  text_padded = pad_sequences(text_sequence, maxlen=MAX_TEXT_LENGTH)
164
 
165
- # Get predictions
166
  predictions = model.predict([img_array, text_padded])
167
  top_indices = np.argsort(predictions[0])[-top_k:][::-1]
168
 
169
- # Get recommended model names and confidence scores
170
- recommendations = [
171
- (label_encoder.inverse_transform([idx])[0], predictions[0][idx])
172
- for idx in top_indices
173
- ]
 
 
 
174
 
175
  return recommendations
176
 
177
  def create_gradio_interface():
178
- # Load saved models with correct path
179
- model = tf.keras.models.load_model('multimodal_model.keras') # Changed from 'multimodal_model'
180
  tokenizer = joblib.load('tokenizer.pkl')
181
  label_encoder = joblib.load('label_encoder.pkl')
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
- def predict(img, text):
184
- recommendations = get_recommendations(img, text, model, tokenizer, label_encoder)
185
- return "\n".join([f"Model: {name}, Confidence: {conf:.2f}" for name, conf in recommendations])
186
 
187
  interface = gr.Interface(
188
  fn=predict,
189
- inputs=[
190
- gr.Image(type="pil", label="Upload Image"),
191
- gr.Textbox(label="Enter Prompt")
192
- ],
193
- outputs=gr.Textbox(label="Recommended Models"),
194
- title="Multimodal Model Recommendation System",
195
- description="Upload an image and enter a prompt to get model recommendations"
196
  )
197
 
198
  return interface
199
 
200
  if __name__ == "__main__":
201
- # Train model if not already trained
202
- if not os.path.exists('multimodal_model.keras'): # Changed from 'multimodal_model'
203
  print("Training new model...")
204
  model, tokenizer, label_encoder = train_model()
205
  print("Training completed!")
206
  else:
207
  print("Loading existing model...")
208
 
209
- # Launch Gradio interface
210
  interface = create_gradio_interface()
211
  interface.launch()
 
21
  IMAGE_SIZE = 160
22
  BATCH_SIZE = 64
23
 
24
+ # Store model examples
25
+ model_examples = {}
26
+
27
  def load_and_preprocess_data(subset_size=500):
28
  # Load dataset
29
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
 
32
  # Filter out NSFW content
33
  dataset_subset = dataset_subset.filter(lambda x: not x['nsfw'])
34
 
35
+ # Store example images for each model
36
+ for item in dataset_subset:
37
+ if item['Model'] not in model_examples:
38
+ model_examples[item['Model']] = item['url']
39
+
40
  return dataset_subset
41
 
42
  def process_text_data(dataset_subset):
43
+ # Combine prompt and negative prompt without user input
44
+ text_data = ["default prompt" for _ in dataset_subset]
45
 
 
46
  tokenizer = Tokenizer(num_words=10000)
47
  tokenizer.fit_on_texts(text_data)
48
  sequences = tokenizer.texts_to_sequences(text_data)
 
50
 
51
  return text_data_padded, tokenizer
52
 
53
+ def download_image(url):
54
+ try:
55
+ response = requests.get(url, timeout=5)
56
+ response.raise_for_status()
57
+ return Image.open(requests.get(url, stream=True).raw)
58
+ except:
59
+ return None
60
+
61
  def process_image_data(dataset_subset):
62
  image_dir = 'civitai_images'
63
  os.makedirs(image_dir, exist_ok=True)
 
70
  img_path = os.path.join(image_dir, os.path.basename(img_url))
71
 
72
  try:
 
73
  response = requests.get(img_url, timeout=5)
74
  response.raise_for_status()
75
 
 
79
  with open(img_path, 'wb') as f:
80
  f.write(response.content)
81
 
 
82
  img = image.load_img(img_path, target_size=(IMAGE_SIZE, IMAGE_SIZE))
83
  img_array = image.img_to_array(img)
84
  img_array = preprocess_input(img_array)
 
92
  return np.array(image_data), valid_indices
93
 
94
  def create_multimodal_model(num_words, num_classes):
 
95
  image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
96
  cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
97
 
 
98
  for layer in cnn_base.layers[:-10]:
99
  layer.trainable = False
100
 
101
  cnn_features = cnn_base(image_input)
102
 
 
103
  text_input = Input(shape=(MAX_TEXT_LENGTH,))
104
  embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
105
  flatten_text = Flatten()(embedding_layer)
106
  text_features = Dense(128, activation='relu')(flatten_text)
107
 
 
108
  combined = Concatenate()([cnn_features, text_features])
109
 
 
110
  x = Dense(256, activation='relu')(combined)
111
  output = Dense(num_classes, activation='softmax')(x)
112
 
 
114
  return model
115
 
116
  def train_model():
 
117
  dataset_subset = load_and_preprocess_data()
118
 
 
119
  text_data_padded, tokenizer = process_text_data(dataset_subset)
120
 
 
121
  image_data, valid_indices = process_image_data(dataset_subset)
122
 
 
123
  text_data_padded = text_data_padded[valid_indices]
124
  model_names = [dataset_subset[i]['Model'] for i in valid_indices]
125
 
 
126
  label_encoder = LabelEncoder()
127
  encoded_labels = label_encoder.fit_transform(model_names)
128
 
 
129
  model = create_multimodal_model(
130
  num_words=10000,
131
  num_classes=len(label_encoder.classes_)
 
137
  metrics=['accuracy']
138
  )
139
 
 
140
  history = model.fit(
141
  [image_data, text_data_padded],
142
  encoded_labels,
 
145
  validation_split=0.2
146
  )
147
 
148
+ model.save('multimodal_model.keras')
 
149
  joblib.dump(tokenizer, 'tokenizer.pkl')
150
  joblib.dump(label_encoder, 'label_encoder.pkl')
151
 
152
+ # Save model examples
153
+ joblib.dump(model_examples, 'model_examples.pkl')
154
+
155
  return model, tokenizer, label_encoder
156
 
157
+ def get_recommendations(image_input, model, tokenizer, label_encoder, top_k=5):
 
158
  img_array = image.img_to_array(image_input)
159
  img_array = tf.image.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE))
160
  img_array = preprocess_input(img_array)
161
  img_array = np.expand_dims(img_array, axis=0)
162
 
163
+ # Use default text input
164
+ text_sequence = tokenizer.texts_to_sequences(["default prompt"])
165
  text_padded = pad_sequences(text_sequence, maxlen=MAX_TEXT_LENGTH)
166
 
 
167
  predictions = model.predict([img_array, text_padded])
168
  top_indices = np.argsort(predictions[0])[-top_k:][::-1]
169
 
170
+ recommendations = []
171
+ for idx in top_indices:
172
+ model_name = label_encoder.inverse_transform([idx])[0]
173
+ confidence = predictions[0][idx]
174
+ if model_name in model_examples:
175
+ example_image = download_image(model_examples[model_name])
176
+ if example_image:
177
+ recommendations.append((model_name, confidence, example_image))
178
 
179
  return recommendations
180
 
181
  def create_gradio_interface():
182
+ model = tf.keras.models.load_model('multimodal_model.keras')
 
183
  tokenizer = joblib.load('tokenizer.pkl')
184
  label_encoder = joblib.load('label_encoder.pkl')
185
+ model_examples_data = joblib.load('model_examples.pkl')
186
+
187
+ def predict(img):
188
+ recommendations = get_recommendations(img, model, tokenizer, label_encoder)
189
+ result_text = ""
190
+ result_images = []
191
+
192
+ for model_name, conf, example_img in recommendations:
193
+ result_text += f"Model: {model_name}\n"
194
+ result_images.append(example_img)
195
+
196
+ return [result_text] + result_images
197
 
198
+ outputs = [gr.Textbox(label="Recommended Models")] + [gr.Image(label=f"Example {i+1}") for i in range(5)]
 
 
199
 
200
  interface = gr.Interface(
201
  fn=predict,
202
+ inputs=gr.Image(type="pil", label="Upload Image"),
203
+ outputs=outputs,
204
+ title="AI Model Recommendation System",
205
+ description="Upload an image to get model recommendations with examples"
 
 
 
206
  )
207
 
208
  return interface
209
 
210
  if __name__ == "__main__":
211
+ if not os.path.exists('multimodal_model.keras'):
 
212
  print("Training new model...")
213
  model, tokenizer, label_encoder = train_model()
214
  print("Training completed!")
215
  else:
216
  print("Loading existing model...")
217
 
 
218
  interface = create_gradio_interface()
219
  interface.launch()