bgaspra commited on
Commit
d90a660
1 Parent(s): 0102517

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -56
app.py CHANGED
@@ -14,7 +14,6 @@ from sklearn.preprocessing import LabelEncoder
14
  import joblib
15
  from PIL import UnidentifiedImageError, Image
16
  import gradio as gr
17
- import html
18
 
19
  # Optimized Constants
20
  MAX_TEXT_LENGTH = 100
@@ -22,36 +21,28 @@ EMBEDDING_DIM = 50
22
  IMAGE_SIZE = 160
23
  BATCH_SIZE = 64
24
 
25
- # Store model examples
26
- model_examples = {}
27
-
28
  def load_and_preprocess_data(subset_size=10000):
 
29
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
30
  dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
31
- dataset_subset = dataset_subset.filter(lambda x: not x['nsfw'])
32
 
33
- for item in dataset_subset:
34
- if item['Model'] not in model_examples:
35
- model_examples[item['Model']] = item['url']
36
 
37
  return dataset_subset
38
 
39
  def process_text_data(dataset_subset):
40
- text_data = ["default prompt" for _ in dataset_subset]
 
 
 
41
  tokenizer = Tokenizer(num_words=10000)
42
  tokenizer.fit_on_texts(text_data)
43
  sequences = tokenizer.texts_to_sequences(text_data)
44
  text_data_padded = pad_sequences(sequences, maxlen=MAX_TEXT_LENGTH)
 
45
  return text_data_padded, tokenizer
46
 
47
- def download_image(url):
48
- try:
49
- response = requests.get(url, timeout=5)
50
- response.raise_for_status()
51
- return Image.open(requests.get(url, stream=True).raw)
52
- except:
53
- return None
54
-
55
  def process_image_data(dataset_subset):
56
  image_dir = 'civitai_images'
57
  os.makedirs(image_dir, exist_ok=True)
@@ -64,6 +55,7 @@ def process_image_data(dataset_subset):
64
  img_path = os.path.join(image_dir, os.path.basename(img_url))
65
 
66
  try:
 
67
  response = requests.get(img_url, timeout=5)
68
  response.raise_for_status()
69
 
@@ -73,6 +65,7 @@ def process_image_data(dataset_subset):
73
  with open(img_path, 'wb') as f:
74
  f.write(response.content)
75
 
 
76
  img = image.load_img(img_path, target_size=(IMAGE_SIZE, IMAGE_SIZE))
77
  img_array = image.img_to_array(img)
78
  img_array = preprocess_input(img_array)
@@ -86,21 +79,26 @@ def process_image_data(dataset_subset):
86
  return np.array(image_data), valid_indices
87
 
88
  def create_multimodal_model(num_words, num_classes):
89
- image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3), name='image_input')
 
90
  cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
91
 
 
92
  for layer in cnn_base.layers[:-10]:
93
  layer.trainable = False
94
 
95
  cnn_features = cnn_base(image_input)
96
 
97
- text_input = Input(shape=(MAX_TEXT_LENGTH,), name='text_input')
 
98
  embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
99
  flatten_text = Flatten()(embedding_layer)
100
  text_features = Dense(128, activation='relu')(flatten_text)
101
 
 
102
  combined = Concatenate()([cnn_features, text_features])
103
 
 
104
  x = Dense(256, activation='relu')(combined)
105
  output = Dense(num_classes, activation='softmax')(x)
106
 
@@ -108,16 +106,24 @@ def create_multimodal_model(num_words, num_classes):
108
  return model
109
 
110
  def train_model():
 
111
  dataset_subset = load_and_preprocess_data()
 
 
112
  text_data_padded, tokenizer = process_text_data(dataset_subset)
 
 
113
  image_data, valid_indices = process_image_data(dataset_subset)
114
 
 
115
  text_data_padded = text_data_padded[valid_indices]
116
  model_names = [dataset_subset[i]['Model'] for i in valid_indices]
117
 
 
118
  label_encoder = LabelEncoder()
119
  encoded_labels = label_encoder.fit_transform(model_names)
120
 
 
121
  model = create_multimodal_model(
122
  num_words=10000,
123
  num_classes=len(label_encoder.classes_)
@@ -129,6 +135,7 @@ def train_model():
129
  metrics=['accuracy']
130
  )
131
 
 
132
  history = model.fit(
133
  [image_data, text_data_padded],
134
  encoded_labels,
@@ -137,77 +144,68 @@ def train_model():
137
  validation_split=0.2
138
  )
139
 
140
- model.save('multimodal_model.keras')
 
141
  joblib.dump(tokenizer, 'tokenizer.pkl')
142
  joblib.dump(label_encoder, 'label_encoder.pkl')
143
- joblib.dump(model_examples, 'model_examples.pkl')
144
 
145
  return model, tokenizer, label_encoder
146
 
147
- def get_recommendations(image_input, model, tokenizer, label_encoder, top_k=5):
 
148
  img_array = image.img_to_array(image_input)
149
  img_array = tf.image.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE))
150
  img_array = preprocess_input(img_array)
151
  img_array = np.expand_dims(img_array, axis=0)
152
 
153
- text_sequence = tokenizer.texts_to_sequences(["default prompt"])
 
154
  text_padded = pad_sequences(text_sequence, maxlen=MAX_TEXT_LENGTH)
155
 
156
- predictions = model.predict([img_array, text_padded], verbose=0)
 
157
  top_indices = np.argsort(predictions[0])[-top_k:][::-1]
158
 
159
- recommendations = []
160
- for idx in top_indices:
161
- model_name = label_encoder.inverse_transform([idx])[0]
162
- if model_name in model_examples:
163
- example_image = download_image(model_examples[model_name])
164
- if example_image:
165
- recommendations.append((model_name, example_image))
166
 
167
  return recommendations
168
 
169
- def create_gallery_html(recommendations):
170
- html_content = '<div style="display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 20px; padding: 20px;">'
171
-
172
- for model_name, image in recommendations:
173
- # Konversi gambar ke base64 jika diperlukan
174
- html_content += f'''
175
- <div style="text-align: center; border: 1px solid #ddd; padding: 10px; border-radius: 8px;">
176
- <img src="{html.escape(image.url) if hasattr(image, 'url') else ''}"
177
- style="width: 100%; height: auto; border-radius: 4px;" />
178
- <p style="margin-top: 10px; font-weight: bold;">{html.escape(model_name)}</p>
179
- </div>
180
- '''
181
-
182
- html_content += '</div>'
183
- return html_content
184
-
185
  def create_gradio_interface():
186
- model = tf.keras.models.load_model('multimodal_model.keras')
 
187
  tokenizer = joblib.load('tokenizer.pkl')
188
  label_encoder = joblib.load('label_encoder.pkl')
189
 
190
- def predict(img):
191
- recommendations = get_recommendations(img, model, tokenizer, label_encoder)
192
- return gr.HTML(create_gallery_html(recommendations))
193
 
194
  interface = gr.Interface(
195
  fn=predict,
196
- inputs=gr.Image(type="pil", label="Upload Image"),
197
- outputs=gr.HTML(),
198
- title="AI Model Recommendation System",
199
- description="Upload an image to get similar model recommendations"
 
 
 
200
  )
201
 
202
  return interface
203
 
204
  if __name__ == "__main__":
205
- if not os.path.exists('multimodal_model.keras'):
 
206
  print("Training new model...")
207
  model, tokenizer, label_encoder = train_model()
208
  print("Training completed!")
209
  else:
210
  print("Loading existing model...")
211
 
 
212
  interface = create_gradio_interface()
213
  interface.launch()
 
14
  import joblib
15
  from PIL import UnidentifiedImageError, Image
16
  import gradio as gr
 
17
 
18
  # Optimized Constants
19
  MAX_TEXT_LENGTH = 100
 
21
  IMAGE_SIZE = 160
22
  BATCH_SIZE = 64
23
 
 
 
 
24
  def load_and_preprocess_data(subset_size=10000):
25
+ # Load dataset
26
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
27
  dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
 
28
 
29
+ # Filter out NSFW content
30
+ dataset_subset = dataset_subset.filter(lambda x: not x['nsfw'])
 
31
 
32
  return dataset_subset
33
 
34
  def process_text_data(dataset_subset):
35
+ # Combine prompt and negative prompt
36
+ text_data = [f"{sample['prompt']} {sample['negativePrompt']}" for sample in dataset_subset]
37
+
38
+ # Tokenize text
39
  tokenizer = Tokenizer(num_words=10000)
40
  tokenizer.fit_on_texts(text_data)
41
  sequences = tokenizer.texts_to_sequences(text_data)
42
  text_data_padded = pad_sequences(sequences, maxlen=MAX_TEXT_LENGTH)
43
+
44
  return text_data_padded, tokenizer
45
 
 
 
 
 
 
 
 
 
46
  def process_image_data(dataset_subset):
47
  image_dir = 'civitai_images'
48
  os.makedirs(image_dir, exist_ok=True)
 
55
  img_path = os.path.join(image_dir, os.path.basename(img_url))
56
 
57
  try:
58
+ # Download and save image
59
  response = requests.get(img_url, timeout=5)
60
  response.raise_for_status()
61
 
 
65
  with open(img_path, 'wb') as f:
66
  f.write(response.content)
67
 
68
+ # Load and preprocess image
69
  img = image.load_img(img_path, target_size=(IMAGE_SIZE, IMAGE_SIZE))
70
  img_array = image.img_to_array(img)
71
  img_array = preprocess_input(img_array)
 
79
  return np.array(image_data), valid_indices
80
 
81
  def create_multimodal_model(num_words, num_classes):
82
+ # Image input branch (CNN)
83
+ image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
84
  cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
85
 
86
+ # Freeze most of the ResNet50 layers
87
  for layer in cnn_base.layers[:-10]:
88
  layer.trainable = False
89
 
90
  cnn_features = cnn_base(image_input)
91
 
92
+ # Text input branch
93
+ text_input = Input(shape=(MAX_TEXT_LENGTH,))
94
  embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
95
  flatten_text = Flatten()(embedding_layer)
96
  text_features = Dense(128, activation='relu')(flatten_text)
97
 
98
+ # Combine features
99
  combined = Concatenate()([cnn_features, text_features])
100
 
101
+ # Simplified fully connected layers
102
  x = Dense(256, activation='relu')(combined)
103
  output = Dense(num_classes, activation='softmax')(x)
104
 
 
106
  return model
107
 
108
  def train_model():
109
+ # Load and preprocess data
110
  dataset_subset = load_and_preprocess_data()
111
+
112
+ # Process text data
113
  text_data_padded, tokenizer = process_text_data(dataset_subset)
114
+
115
+ # Process image data
116
  image_data, valid_indices = process_image_data(dataset_subset)
117
 
118
+ # Get valid text data and labels
119
  text_data_padded = text_data_padded[valid_indices]
120
  model_names = [dataset_subset[i]['Model'] for i in valid_indices]
121
 
122
+ # Encode labels
123
  label_encoder = LabelEncoder()
124
  encoded_labels = label_encoder.fit_transform(model_names)
125
 
126
+ # Create and compile model
127
  model = create_multimodal_model(
128
  num_words=10000,
129
  num_classes=len(label_encoder.classes_)
 
135
  metrics=['accuracy']
136
  )
137
 
138
+ # Train model
139
  history = model.fit(
140
  [image_data, text_data_padded],
141
  encoded_labels,
 
144
  validation_split=0.2
145
  )
146
 
147
+ # Save models and encoders with correct extensions
148
+ model.save('multimodal_model.keras') # Changed from 'multimodal_model'
149
  joblib.dump(tokenizer, 'tokenizer.pkl')
150
  joblib.dump(label_encoder, 'label_encoder.pkl')
 
151
 
152
  return model, tokenizer, label_encoder
153
 
154
+ def get_recommendations(image_input, text_input, model, tokenizer, label_encoder, top_k=5):
155
+ # Preprocess image
156
  img_array = image.img_to_array(image_input)
157
  img_array = tf.image.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE))
158
  img_array = preprocess_input(img_array)
159
  img_array = np.expand_dims(img_array, axis=0)
160
 
161
+ # Preprocess text
162
+ text_sequence = tokenizer.texts_to_sequences([text_input])
163
  text_padded = pad_sequences(text_sequence, maxlen=MAX_TEXT_LENGTH)
164
 
165
+ # Get predictions
166
+ predictions = model.predict([img_array, text_padded])
167
  top_indices = np.argsort(predictions[0])[-top_k:][::-1]
168
 
169
+ # Get recommended model names and confidence scores
170
+ recommendations = [
171
+ (label_encoder.inverse_transform([idx])[0], predictions[0][idx])
172
+ for idx in top_indices
173
+ ]
 
 
174
 
175
  return recommendations
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  def create_gradio_interface():
178
+ # Load saved models with correct path
179
+ model = tf.keras.models.load_model('multimodal_model.keras') # Changed from 'multimodal_model'
180
  tokenizer = joblib.load('tokenizer.pkl')
181
  label_encoder = joblib.load('label_encoder.pkl')
182
 
183
+ def predict(img, text):
184
+ recommendations = get_recommendations(img, text, model, tokenizer, label_encoder)
185
+ return "\n".join([f"Model: {name}, Confidence: {conf:.2f}" for name, conf in recommendations])
186
 
187
  interface = gr.Interface(
188
  fn=predict,
189
+ inputs=[
190
+ gr.Image(type="pil", label="Upload Image"),
191
+ gr.Textbox(label="Enter Prompt")
192
+ ],
193
+ outputs=gr.Textbox(label="Recommended Models"),
194
+ title="Multimodal Model Recommendation System",
195
+ description="Upload an image and enter a prompt to get model recommendations"
196
  )
197
 
198
  return interface
199
 
200
  if __name__ == "__main__":
201
+ # Train model if not already trained
202
+ if not os.path.exists('multimodal_model.keras'): # Changed from 'multimodal_model'
203
  print("Training new model...")
204
  model, tokenizer, label_encoder = train_model()
205
  print("Training completed!")
206
  else:
207
  print("Loading existing model...")
208
 
209
+ # Launch Gradio interface
210
  interface = create_gradio_interface()
211
  interface.launch()