Update app.py
Browse files
app.py
CHANGED
@@ -16,12 +16,12 @@ from PIL import UnidentifiedImageError, Image
|
|
16 |
import gradio as gr
|
17 |
|
18 |
# Optimized Constants
|
19 |
-
MAX_TEXT_LENGTH = 100
|
20 |
-
EMBEDDING_DIM = 50
|
21 |
-
IMAGE_SIZE = 160
|
22 |
-
BATCH_SIZE = 64
|
23 |
|
24 |
-
def load_and_preprocess_data(subset_size=500):
|
25 |
# Load dataset
|
26 |
dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
|
27 |
dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
|
@@ -36,7 +36,7 @@ def process_text_data(dataset_subset):
|
|
36 |
text_data = [f"{sample['prompt']} {sample['negativePrompt']}" for sample in dataset_subset]
|
37 |
|
38 |
# Tokenize text
|
39 |
-
tokenizer = Tokenizer(num_words=10000)
|
40 |
tokenizer.fit_on_texts(text_data)
|
41 |
sequences = tokenizer.texts_to_sequences(text_data)
|
42 |
text_data_padded = pad_sequences(sequences, maxlen=MAX_TEXT_LENGTH)
|
@@ -56,7 +56,7 @@ def process_image_data(dataset_subset):
|
|
56 |
|
57 |
try:
|
58 |
# Download and save image
|
59 |
-
response = requests.get(img_url, timeout=5)
|
60 |
response.raise_for_status()
|
61 |
|
62 |
if 'image' not in response.headers['Content-Type']:
|
@@ -83,23 +83,23 @@ def create_multimodal_model(num_words, num_classes):
|
|
83 |
image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
|
84 |
cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
|
85 |
|
86 |
-
# Freeze most of the ResNet50 layers
|
87 |
for layer in cnn_base.layers[:-10]:
|
88 |
layer.trainable = False
|
89 |
|
90 |
cnn_features = cnn_base(image_input)
|
91 |
|
92 |
-
# Text input branch
|
93 |
text_input = Input(shape=(MAX_TEXT_LENGTH,))
|
94 |
embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
|
95 |
flatten_text = Flatten()(embedding_layer)
|
96 |
-
text_features = Dense(128, activation='relu')(flatten_text)
|
97 |
|
98 |
# Combine features
|
99 |
combined = Concatenate()([cnn_features, text_features])
|
100 |
|
101 |
# Simplified fully connected layers
|
102 |
-
x = Dense(256, activation='relu')(combined)
|
103 |
output = Dense(num_classes, activation='softmax')(x)
|
104 |
|
105 |
model = Model(inputs=[image_input, text_input], outputs=output)
|
@@ -125,7 +125,7 @@ def train_model():
|
|
125 |
|
126 |
# Create and compile model
|
127 |
model = create_multimodal_model(
|
128 |
-
num_words=10000,
|
129 |
num_classes=len(label_encoder.classes_)
|
130 |
)
|
131 |
|
@@ -135,17 +135,17 @@ def train_model():
|
|
135 |
metrics=['accuracy']
|
136 |
)
|
137 |
|
138 |
-
# Train model
|
139 |
history = model.fit(
|
140 |
[image_data, text_data_padded],
|
141 |
encoded_labels,
|
142 |
batch_size=BATCH_SIZE,
|
143 |
-
epochs=3,
|
144 |
validation_split=0.2
|
145 |
)
|
146 |
|
147 |
-
# Save models and encoders
|
148 |
-
model.save('multimodal_model')
|
149 |
joblib.dump(tokenizer, 'tokenizer.pkl')
|
150 |
joblib.dump(label_encoder, 'label_encoder.pkl')
|
151 |
|
@@ -174,10 +174,9 @@ def get_recommendations(image_input, text_input, model, tokenizer, label_encoder
|
|
174 |
|
175 |
return recommendations
|
176 |
|
177 |
-
# Gradio interface
|
178 |
def create_gradio_interface():
|
179 |
-
# Load saved models
|
180 |
-
model = tf.keras.models.load_model('multimodal_model')
|
181 |
tokenizer = joblib.load('tokenizer.pkl')
|
182 |
label_encoder = joblib.load('label_encoder.pkl')
|
183 |
|
@@ -200,8 +199,12 @@ def create_gradio_interface():
|
|
200 |
|
201 |
if __name__ == "__main__":
|
202 |
# Train model if not already trained
|
203 |
-
if not os.path.exists('multimodal_model'):
|
|
|
204 |
model, tokenizer, label_encoder = train_model()
|
|
|
|
|
|
|
205 |
|
206 |
# Launch Gradio interface
|
207 |
interface = create_gradio_interface()
|
|
|
16 |
import gradio as gr
|
17 |
|
18 |
# Optimized Constants
|
19 |
+
MAX_TEXT_LENGTH = 100
|
20 |
+
EMBEDDING_DIM = 50
|
21 |
+
IMAGE_SIZE = 160
|
22 |
+
BATCH_SIZE = 64
|
23 |
|
24 |
+
def load_and_preprocess_data(subset_size=500):
|
25 |
# Load dataset
|
26 |
dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
|
27 |
dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
|
|
|
36 |
text_data = [f"{sample['prompt']} {sample['negativePrompt']}" for sample in dataset_subset]
|
37 |
|
38 |
# Tokenize text
|
39 |
+
tokenizer = Tokenizer(num_words=10000)
|
40 |
tokenizer.fit_on_texts(text_data)
|
41 |
sequences = tokenizer.texts_to_sequences(text_data)
|
42 |
text_data_padded = pad_sequences(sequences, maxlen=MAX_TEXT_LENGTH)
|
|
|
56 |
|
57 |
try:
|
58 |
# Download and save image
|
59 |
+
response = requests.get(img_url, timeout=5)
|
60 |
response.raise_for_status()
|
61 |
|
62 |
if 'image' not in response.headers['Content-Type']:
|
|
|
83 |
image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
|
84 |
cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
|
85 |
|
86 |
+
# Freeze most of the ResNet50 layers
|
87 |
for layer in cnn_base.layers[:-10]:
|
88 |
layer.trainable = False
|
89 |
|
90 |
cnn_features = cnn_base(image_input)
|
91 |
|
92 |
+
# Text input branch
|
93 |
text_input = Input(shape=(MAX_TEXT_LENGTH,))
|
94 |
embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
|
95 |
flatten_text = Flatten()(embedding_layer)
|
96 |
+
text_features = Dense(128, activation='relu')(flatten_text)
|
97 |
|
98 |
# Combine features
|
99 |
combined = Concatenate()([cnn_features, text_features])
|
100 |
|
101 |
# Simplified fully connected layers
|
102 |
+
x = Dense(256, activation='relu')(combined)
|
103 |
output = Dense(num_classes, activation='softmax')(x)
|
104 |
|
105 |
model = Model(inputs=[image_input, text_input], outputs=output)
|
|
|
125 |
|
126 |
# Create and compile model
|
127 |
model = create_multimodal_model(
|
128 |
+
num_words=10000,
|
129 |
num_classes=len(label_encoder.classes_)
|
130 |
)
|
131 |
|
|
|
135 |
metrics=['accuracy']
|
136 |
)
|
137 |
|
138 |
+
# Train model
|
139 |
history = model.fit(
|
140 |
[image_data, text_data_padded],
|
141 |
encoded_labels,
|
142 |
batch_size=BATCH_SIZE,
|
143 |
+
epochs=3,
|
144 |
validation_split=0.2
|
145 |
)
|
146 |
|
147 |
+
# Save models and encoders with correct extensions
|
148 |
+
model.save('multimodal_model.keras') # Changed from 'multimodal_model'
|
149 |
joblib.dump(tokenizer, 'tokenizer.pkl')
|
150 |
joblib.dump(label_encoder, 'label_encoder.pkl')
|
151 |
|
|
|
174 |
|
175 |
return recommendations
|
176 |
|
|
|
177 |
def create_gradio_interface():
|
178 |
+
# Load saved models with correct path
|
179 |
+
model = tf.keras.models.load_model('multimodal_model.keras') # Changed from 'multimodal_model'
|
180 |
tokenizer = joblib.load('tokenizer.pkl')
|
181 |
label_encoder = joblib.load('label_encoder.pkl')
|
182 |
|
|
|
199 |
|
200 |
if __name__ == "__main__":
|
201 |
# Train model if not already trained
|
202 |
+
if not os.path.exists('multimodal_model.keras'): # Changed from 'multimodal_model'
|
203 |
+
print("Training new model...")
|
204 |
model, tokenizer, label_encoder = train_model()
|
205 |
+
print("Training completed!")
|
206 |
+
else:
|
207 |
+
print("Loading existing model...")
|
208 |
|
209 |
# Launch Gradio interface
|
210 |
interface = create_gradio_interface()
|