bgaspra commited on
Commit
18088b4
·
verified ·
1 Parent(s): 9cb9435

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -19
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import os
2
  import requests
3
  from tqdm import tqdm
4
  from datasets import load_dataset
@@ -15,13 +15,13 @@ import joblib
15
  from PIL import UnidentifiedImageError, Image
16
  import gradio as gr
17
 
18
- # Constants
19
- MAX_TEXT_LENGTH = 200
20
- EMBEDDING_DIM = 100
21
- IMAGE_SIZE = 224
22
- BATCH_SIZE = 32
23
 
24
- def load_and_preprocess_data(subset_size=2700):
25
  # Load dataset
26
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
27
  dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
@@ -36,7 +36,7 @@ def process_text_data(dataset_subset):
36
  text_data = [f"{sample['prompt']} {sample['negativePrompt']}" for sample in dataset_subset]
37
 
38
  # Tokenize text
39
- tokenizer = Tokenizer()
40
  tokenizer.fit_on_texts(text_data)
41
  sequences = tokenizer.texts_to_sequences(text_data)
42
  text_data_padded = pad_sequences(sequences, maxlen=MAX_TEXT_LENGTH)
@@ -56,7 +56,7 @@ def process_image_data(dataset_subset):
56
 
57
  try:
58
  # Download and save image
59
- response = requests.get(img_url)
60
  response.raise_for_status()
61
 
62
  if 'image' not in response.headers['Content-Type']:
@@ -74,7 +74,6 @@ def process_image_data(dataset_subset):
74
  valid_indices.append(idx)
75
 
76
  except Exception as e:
77
- print(f"Error processing image {img_url}: {e}")
78
  continue
79
 
80
  return np.array(image_data), valid_indices
@@ -83,20 +82,24 @@ def create_multimodal_model(num_words, num_classes):
83
  # Image input branch (CNN)
84
  image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
85
  cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
 
 
 
 
 
86
  cnn_features = cnn_base(image_input)
87
 
88
- # Text input branch (MLP)
89
  text_input = Input(shape=(MAX_TEXT_LENGTH,))
90
  embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
91
  flatten_text = Flatten()(embedding_layer)
92
- text_features = Dense(256, activation='relu')(flatten_text)
93
 
94
  # Combine features
95
  combined = Concatenate()([cnn_features, text_features])
96
 
97
- # Fully connected layers
98
- x = Dense(512, activation='relu')(combined)
99
- x = Dense(256, activation='relu')(x)
100
  output = Dense(num_classes, activation='softmax')(x)
101
 
102
  model = Model(inputs=[image_input, text_input], outputs=output)
@@ -122,22 +125,22 @@ def train_model():
122
 
123
  # Create and compile model
124
  model = create_multimodal_model(
125
- num_words=len(tokenizer.word_index) + 1,
126
  num_classes=len(label_encoder.classes_)
127
  )
128
 
129
  model.compile(
130
- optimizer='adam',
131
  loss='sparse_categorical_crossentropy',
132
  metrics=['accuracy']
133
  )
134
 
135
- # Train model
136
  history = model.fit(
137
  [image_data, text_data_padded],
138
  encoded_labels,
139
  batch_size=BATCH_SIZE,
140
- epochs=10,
141
  validation_split=0.2
142
  )
143
 
 
1
+ mport os
2
  import requests
3
  from tqdm import tqdm
4
  from datasets import load_dataset
 
15
  from PIL import UnidentifiedImageError, Image
16
  import gradio as gr
17
 
18
+ # Optimized Constants
19
+ MAX_TEXT_LENGTH = 100 # Reduced from 200
20
+ EMBEDDING_DIM = 50 # Reduced from 100
21
+ IMAGE_SIZE = 160 # Reduced from 224
22
+ BATCH_SIZE = 64 # Increased from 32
23
 
24
+ def load_and_preprocess_data(subset_size=500): # Reduced from 2700
25
  # Load dataset
26
  dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
27
  dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
 
36
  text_data = [f"{sample['prompt']} {sample['negativePrompt']}" for sample in dataset_subset]
37
 
38
  # Tokenize text
39
+ tokenizer = Tokenizer(num_words=10000) # Added limit to vocabulary size
40
  tokenizer.fit_on_texts(text_data)
41
  sequences = tokenizer.texts_to_sequences(text_data)
42
  text_data_padded = pad_sequences(sequences, maxlen=MAX_TEXT_LENGTH)
 
56
 
57
  try:
58
  # Download and save image
59
+ response = requests.get(img_url, timeout=5) # Added timeout
60
  response.raise_for_status()
61
 
62
  if 'image' not in response.headers['Content-Type']:
 
74
  valid_indices.append(idx)
75
 
76
  except Exception as e:
 
77
  continue
78
 
79
  return np.array(image_data), valid_indices
 
82
  # Image input branch (CNN)
83
  image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
84
  cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
85
+
86
+ # Freeze most of the ResNet50 layers to reduce training time
87
+ for layer in cnn_base.layers[:-10]:
88
+ layer.trainable = False
89
+
90
  cnn_features = cnn_base(image_input)
91
 
92
+ # Text input branch (Simplified MLP)
93
  text_input = Input(shape=(MAX_TEXT_LENGTH,))
94
  embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
95
  flatten_text = Flatten()(embedding_layer)
96
+ text_features = Dense(128, activation='relu')(flatten_text) # Reduced from 256
97
 
98
  # Combine features
99
  combined = Concatenate()([cnn_features, text_features])
100
 
101
+ # Simplified fully connected layers
102
+ x = Dense(256, activation='relu')(combined) # Reduced from 512
 
103
  output = Dense(num_classes, activation='softmax')(x)
104
 
105
  model = Model(inputs=[image_input, text_input], outputs=output)
 
125
 
126
  # Create and compile model
127
  model = create_multimodal_model(
128
+ num_words=10000, # Limited vocabulary size
129
  num_classes=len(label_encoder.classes_)
130
  )
131
 
132
  model.compile(
133
+ optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
134
  loss='sparse_categorical_crossentropy',
135
  metrics=['accuracy']
136
  )
137
 
138
+ # Train model with reduced epochs
139
  history = model.fit(
140
  [image_data, text_data_padded],
141
  encoded_labels,
142
  batch_size=BATCH_SIZE,
143
+ epochs=3, # Reduced from 10
144
  validation_split=0.2
145
  )
146