bgaspra commited on
Commit
9ec6191
·
verified ·
1 Parent(s): 647c1be

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -20
app.py CHANGED
@@ -16,12 +16,12 @@ from PIL import UnidentifiedImageError, Image
16
  import gradio as gr
17
 
18
  # Optimized Constants
19
- MAX_TEXT_LENGTH = 100 # Reduced from 200
20
- EMBEDDING_DIM = 50 # Reduced from 100
21
- IMAGE_SIZE = 160 # Reduced from 224
22
- BATCH_SIZE = 64 # Increased from 32
23
 
24
- def load_and_preprocess_data(subset_size=500): # Reduced from 2700
25
  # Load dataset
26
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
27
  dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
@@ -36,7 +36,7 @@ def process_text_data(dataset_subset):
36
  text_data = [f"{sample['prompt']} {sample['negativePrompt']}" for sample in dataset_subset]
37
 
38
  # Tokenize text
39
- tokenizer = Tokenizer(num_words=10000) # Added limit to vocabulary size
40
  tokenizer.fit_on_texts(text_data)
41
  sequences = tokenizer.texts_to_sequences(text_data)
42
  text_data_padded = pad_sequences(sequences, maxlen=MAX_TEXT_LENGTH)
@@ -56,7 +56,7 @@ def process_image_data(dataset_subset):
56
 
57
  try:
58
  # Download and save image
59
- response = requests.get(img_url, timeout=5) # Added timeout
60
  response.raise_for_status()
61
 
62
  if 'image' not in response.headers['Content-Type']:
@@ -83,23 +83,23 @@ def create_multimodal_model(num_words, num_classes):
83
  image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
84
  cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
85
 
86
- # Freeze most of the ResNet50 layers to reduce training time
87
  for layer in cnn_base.layers[:-10]:
88
  layer.trainable = False
89
 
90
  cnn_features = cnn_base(image_input)
91
 
92
- # Text input branch (Simplified MLP)
93
  text_input = Input(shape=(MAX_TEXT_LENGTH,))
94
  embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
95
  flatten_text = Flatten()(embedding_layer)
96
- text_features = Dense(128, activation='relu')(flatten_text) # Reduced from 256
97
 
98
  # Combine features
99
  combined = Concatenate()([cnn_features, text_features])
100
 
101
  # Simplified fully connected layers
102
- x = Dense(256, activation='relu')(combined) # Reduced from 512
103
  output = Dense(num_classes, activation='softmax')(x)
104
 
105
  model = Model(inputs=[image_input, text_input], outputs=output)
@@ -125,7 +125,7 @@ def train_model():
125
 
126
  # Create and compile model
127
  model = create_multimodal_model(
128
- num_words=10000, # Limited vocabulary size
129
  num_classes=len(label_encoder.classes_)
130
  )
131
 
@@ -135,17 +135,17 @@ def train_model():
135
  metrics=['accuracy']
136
  )
137
 
138
- # Train model with reduced epochs
139
  history = model.fit(
140
  [image_data, text_data_padded],
141
  encoded_labels,
142
  batch_size=BATCH_SIZE,
143
- epochs=3, # Reduced from 10
144
  validation_split=0.2
145
  )
146
 
147
- # Save models and encoders
148
- model.save('multimodal_model')
149
  joblib.dump(tokenizer, 'tokenizer.pkl')
150
  joblib.dump(label_encoder, 'label_encoder.pkl')
151
 
@@ -174,10 +174,9 @@ def get_recommendations(image_input, text_input, model, tokenizer, label_encoder
174
 
175
  return recommendations
176
 
177
- # Gradio interface
178
  def create_gradio_interface():
179
- # Load saved models
180
- model = tf.keras.models.load_model('multimodal_model')
181
  tokenizer = joblib.load('tokenizer.pkl')
182
  label_encoder = joblib.load('label_encoder.pkl')
183
 
@@ -200,8 +199,12 @@ def create_gradio_interface():
200
 
201
  if __name__ == "__main__":
202
  # Train model if not already trained
203
- if not os.path.exists('multimodal_model'):
 
204
  model, tokenizer, label_encoder = train_model()
 
 
 
205
 
206
  # Launch Gradio interface
207
  interface = create_gradio_interface()
 
16
  import gradio as gr
17
 
18
  # Optimized Constants
19
+ MAX_TEXT_LENGTH = 100
20
+ EMBEDDING_DIM = 50
21
+ IMAGE_SIZE = 160
22
+ BATCH_SIZE = 64
23
 
24
+ def load_and_preprocess_data(subset_size=500):
25
  # Load dataset
26
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
27
  dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
 
36
  text_data = [f"{sample['prompt']} {sample['negativePrompt']}" for sample in dataset_subset]
37
 
38
  # Tokenize text
39
+ tokenizer = Tokenizer(num_words=10000)
40
  tokenizer.fit_on_texts(text_data)
41
  sequences = tokenizer.texts_to_sequences(text_data)
42
  text_data_padded = pad_sequences(sequences, maxlen=MAX_TEXT_LENGTH)
 
56
 
57
  try:
58
  # Download and save image
59
+ response = requests.get(img_url, timeout=5)
60
  response.raise_for_status()
61
 
62
  if 'image' not in response.headers['Content-Type']:
 
83
  image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
84
  cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
85
 
86
+ # Freeze most of the ResNet50 layers
87
  for layer in cnn_base.layers[:-10]:
88
  layer.trainable = False
89
 
90
  cnn_features = cnn_base(image_input)
91
 
92
+ # Text input branch
93
  text_input = Input(shape=(MAX_TEXT_LENGTH,))
94
  embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
95
  flatten_text = Flatten()(embedding_layer)
96
+ text_features = Dense(128, activation='relu')(flatten_text)
97
 
98
  # Combine features
99
  combined = Concatenate()([cnn_features, text_features])
100
 
101
  # Simplified fully connected layers
102
+ x = Dense(256, activation='relu')(combined)
103
  output = Dense(num_classes, activation='softmax')(x)
104
 
105
  model = Model(inputs=[image_input, text_input], outputs=output)
 
125
 
126
  # Create and compile model
127
  model = create_multimodal_model(
128
+ num_words=10000,
129
  num_classes=len(label_encoder.classes_)
130
  )
131
 
 
135
  metrics=['accuracy']
136
  )
137
 
138
+ # Train model
139
  history = model.fit(
140
  [image_data, text_data_padded],
141
  encoded_labels,
142
  batch_size=BATCH_SIZE,
143
+ epochs=3,
144
  validation_split=0.2
145
  )
146
 
147
+ # Save models and encoders with correct extensions
148
+ model.save('multimodal_model.keras') # Changed from 'multimodal_model'
149
  joblib.dump(tokenizer, 'tokenizer.pkl')
150
  joblib.dump(label_encoder, 'label_encoder.pkl')
151
 
 
174
 
175
  return recommendations
176
 
 
177
  def create_gradio_interface():
178
+ # Load saved models with correct path
179
+ model = tf.keras.models.load_model('multimodal_model.keras') # Changed from 'multimodal_model'
180
  tokenizer = joblib.load('tokenizer.pkl')
181
  label_encoder = joblib.load('label_encoder.pkl')
182
 
 
199
 
200
  if __name__ == "__main__":
201
  # Train model if not already trained
202
+ if not os.path.exists('multimodal_model.keras'): # Changed from 'multimodal_model'
203
+ print("Training new model...")
204
  model, tokenizer, label_encoder = train_model()
205
+ print("Training completed!")
206
+ else:
207
+ print("Loading existing model...")
208
 
209
  # Launch Gradio interface
210
  interface = create_gradio_interface()