Update app.py
Browse files
app.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
|
2 |
import requests
|
3 |
from tqdm import tqdm
|
4 |
from datasets import load_dataset
|
@@ -15,13 +15,13 @@ import joblib
|
|
15 |
from PIL import UnidentifiedImageError, Image
|
16 |
import gradio as gr
|
17 |
|
18 |
-
# Constants
|
19 |
-
MAX_TEXT_LENGTH = 200
|
20 |
-
EMBEDDING_DIM = 100
|
21 |
-
IMAGE_SIZE = 224
|
22 |
-
BATCH_SIZE = 32
|
23 |
|
24 |
-
def load_and_preprocess_data(subset_size=
|
25 |
# Load dataset
|
26 |
dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
|
27 |
dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
|
@@ -36,7 +36,7 @@ def process_text_data(dataset_subset):
|
|
36 |
text_data = [f"{sample['prompt']} {sample['negativePrompt']}" for sample in dataset_subset]
|
37 |
|
38 |
# Tokenize text
|
39 |
-
tokenizer = Tokenizer()
|
40 |
tokenizer.fit_on_texts(text_data)
|
41 |
sequences = tokenizer.texts_to_sequences(text_data)
|
42 |
text_data_padded = pad_sequences(sequences, maxlen=MAX_TEXT_LENGTH)
|
@@ -56,7 +56,7 @@ def process_image_data(dataset_subset):
|
|
56 |
|
57 |
try:
|
58 |
# Download and save image
|
59 |
-
response = requests.get(img_url)
|
60 |
response.raise_for_status()
|
61 |
|
62 |
if 'image' not in response.headers['Content-Type']:
|
@@ -74,7 +74,6 @@ def process_image_data(dataset_subset):
|
|
74 |
valid_indices.append(idx)
|
75 |
|
76 |
except Exception as e:
|
77 |
-
print(f"Error processing image {img_url}: {e}")
|
78 |
continue
|
79 |
|
80 |
return np.array(image_data), valid_indices
|
@@ -83,20 +82,24 @@ def create_multimodal_model(num_words, num_classes):
|
|
83 |
# Image input branch (CNN)
|
84 |
image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
|
85 |
cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
|
|
|
|
|
|
|
|
|
|
|
86 |
cnn_features = cnn_base(image_input)
|
87 |
|
88 |
-
# Text input branch (MLP)
|
89 |
text_input = Input(shape=(MAX_TEXT_LENGTH,))
|
90 |
embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
|
91 |
flatten_text = Flatten()(embedding_layer)
|
92 |
-
text_features = Dense(
|
93 |
|
94 |
# Combine features
|
95 |
combined = Concatenate()([cnn_features, text_features])
|
96 |
|
97 |
-
#
|
98 |
-
x = Dense(
|
99 |
-
x = Dense(256, activation='relu')(x)
|
100 |
output = Dense(num_classes, activation='softmax')(x)
|
101 |
|
102 |
model = Model(inputs=[image_input, text_input], outputs=output)
|
@@ -122,22 +125,22 @@ def train_model():
|
|
122 |
|
123 |
# Create and compile model
|
124 |
model = create_multimodal_model(
|
125 |
-
num_words=
|
126 |
num_classes=len(label_encoder.classes_)
|
127 |
)
|
128 |
|
129 |
model.compile(
|
130 |
-
optimizer=
|
131 |
loss='sparse_categorical_crossentropy',
|
132 |
metrics=['accuracy']
|
133 |
)
|
134 |
|
135 |
-
# Train model
|
136 |
history = model.fit(
|
137 |
[image_data, text_data_padded],
|
138 |
encoded_labels,
|
139 |
batch_size=BATCH_SIZE,
|
140 |
-
epochs=10
|
141 |
validation_split=0.2
|
142 |
)
|
143 |
|
|
|
1 |
+
mport os
|
2 |
import requests
|
3 |
from tqdm import tqdm
|
4 |
from datasets import load_dataset
|
|
|
15 |
from PIL import UnidentifiedImageError, Image
|
16 |
import gradio as gr
|
17 |
|
18 |
+
# Optimized Constants
|
19 |
+
MAX_TEXT_LENGTH = 100 # Reduced from 200
|
20 |
+
EMBEDDING_DIM = 50 # Reduced from 100
|
21 |
+
IMAGE_SIZE = 160 # Reduced from 224
|
22 |
+
BATCH_SIZE = 64 # Increased from 32
|
23 |
|
24 |
+
def load_and_preprocess_data(subset_size=500): # Reduced from 2700
|
25 |
# Load dataset
|
26 |
dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
|
27 |
dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
|
|
|
36 |
text_data = [f"{sample['prompt']} {sample['negativePrompt']}" for sample in dataset_subset]
|
37 |
|
38 |
# Tokenize text
|
39 |
+
tokenizer = Tokenizer(num_words=10000) # Added limit to vocabulary size
|
40 |
tokenizer.fit_on_texts(text_data)
|
41 |
sequences = tokenizer.texts_to_sequences(text_data)
|
42 |
text_data_padded = pad_sequences(sequences, maxlen=MAX_TEXT_LENGTH)
|
|
|
56 |
|
57 |
try:
|
58 |
# Download and save image
|
59 |
+
response = requests.get(img_url, timeout=5) # Added timeout
|
60 |
response.raise_for_status()
|
61 |
|
62 |
if 'image' not in response.headers['Content-Type']:
|
|
|
74 |
valid_indices.append(idx)
|
75 |
|
76 |
except Exception as e:
|
|
|
77 |
continue
|
78 |
|
79 |
return np.array(image_data), valid_indices
|
|
|
82 |
# Image input branch (CNN)
|
83 |
image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
|
84 |
cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
|
85 |
+
|
86 |
+
# Freeze most of the ResNet50 layers to reduce training time
|
87 |
+
for layer in cnn_base.layers[:-10]:
|
88 |
+
layer.trainable = False
|
89 |
+
|
90 |
cnn_features = cnn_base(image_input)
|
91 |
|
92 |
+
# Text input branch (Simplified MLP)
|
93 |
text_input = Input(shape=(MAX_TEXT_LENGTH,))
|
94 |
embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
|
95 |
flatten_text = Flatten()(embedding_layer)
|
96 |
+
text_features = Dense(128, activation='relu')(flatten_text) # Reduced from 256
|
97 |
|
98 |
# Combine features
|
99 |
combined = Concatenate()([cnn_features, text_features])
|
100 |
|
101 |
+
# Simplified fully connected layers
|
102 |
+
x = Dense(256, activation='relu')(combined) # Reduced from 512
|
|
|
103 |
output = Dense(num_classes, activation='softmax')(x)
|
104 |
|
105 |
model = Model(inputs=[image_input, text_input], outputs=output)
|
|
|
125 |
|
126 |
# Create and compile model
|
127 |
model = create_multimodal_model(
|
128 |
+
num_words=10000, # Limited vocabulary size
|
129 |
num_classes=len(label_encoder.classes_)
|
130 |
)
|
131 |
|
132 |
model.compile(
|
133 |
+
optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
|
134 |
loss='sparse_categorical_crossentropy',
|
135 |
metrics=['accuracy']
|
136 |
)
|
137 |
|
138 |
+
# Train model with reduced epochs
|
139 |
history = model.fit(
|
140 |
[image_data, text_data_padded],
|
141 |
encoded_labels,
|
142 |
batch_size=BATCH_SIZE,
|
143 |
+
epochs=3, # Reduced from 10
|
144 |
validation_split=0.2
|
145 |
)
|
146 |
|