CNN_MLP_2 / app.py
bgaspra's picture
Update app.py
3aa52df verified
raw
history blame
6.89 kB
import os
import requests
from tqdm import tqdm
from datasets import load_dataset
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.layers import Dense, Input, Concatenate, Embedding, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.preprocessing import LabelEncoder
import joblib
from PIL import UnidentifiedImageError, Image
import gradio as gr
# Constants
MAX_TEXT_LENGTH = 200
EMBEDDING_DIM = 100
IMAGE_SIZE = 224
BATCH_SIZE = 32
def load_and_preprocess_data(subset_size=2700):
# Load dataset
dataset = load_dataset("thefcraft/civitai-stable-diffusion-337k")
dataset_subset = dataset['train'].shuffle(seed=42).select(range(subset_size))
# Filter out NSFW content
dataset_subset = dataset_subset.filter(lambda x: not x['nsfw'])
return dataset_subset
def process_text_data(dataset_subset):
# Combine prompt and negative prompt
text_data = [f"{sample['prompt']} {sample['negativePrompt']}" for sample in dataset_subset]
# Tokenize text
tokenizer = Tokenizer()
tokenizer.fit_on_texts(text_data)
sequences = tokenizer.texts_to_sequences(text_data)
text_data_padded = pad_sequences(sequences, maxlen=MAX_TEXT_LENGTH)
return text_data_padded, tokenizer
def process_image_data(dataset_subset):
image_dir = 'civitai_images'
os.makedirs(image_dir, exist_ok=True)
image_data = []
valid_indices = []
for idx, sample in enumerate(tqdm(dataset_subset)):
img_url = sample['url']
img_path = os.path.join(image_dir, os.path.basename(img_url))
try:
# Download and save image
response = requests.get(img_url)
response.raise_for_status()
if 'image' not in response.headers['Content-Type']:
continue
with open(img_path, 'wb') as f:
f.write(response.content)
# Load and preprocess image
img = image.load_img(img_path, target_size=(IMAGE_SIZE, IMAGE_SIZE))
img_array = image.img_to_array(img)
img_array = preprocess_input(img_array)
image_data.append(img_array)
valid_indices.append(idx)
except Exception as e:
print(f"Error processing image {img_url}: {e}")
continue
return np.array(image_data), valid_indices
def create_multimodal_model(num_words, num_classes):
# Image input branch (CNN)
image_input = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
cnn_base = ResNet50(weights='imagenet', include_top=False, pooling='avg')
cnn_features = cnn_base(image_input)
# Text input branch (MLP)
text_input = Input(shape=(MAX_TEXT_LENGTH,))
embedding_layer = Embedding(num_words, EMBEDDING_DIM)(text_input)
flatten_text = Flatten()(embedding_layer)
text_features = Dense(256, activation='relu')(flatten_text)
# Combine features
combined = Concatenate()([cnn_features, text_features])
# Fully connected layers
x = Dense(512, activation='relu')(combined)
x = Dense(256, activation='relu')(x)
output = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=[image_input, text_input], outputs=output)
return model
def train_model():
# Load and preprocess data
dataset_subset = load_and_preprocess_data()
# Process text data
text_data_padded, tokenizer = process_text_data(dataset_subset)
# Process image data
image_data, valid_indices = process_image_data(dataset_subset)
# Get valid text data and labels
text_data_padded = text_data_padded[valid_indices]
model_names = [dataset_subset[i]['Model'] for i in valid_indices]
# Encode labels
label_encoder = LabelEncoder()
encoded_labels = label_encoder.fit_transform(model_names)
# Create and compile model
model = create_multimodal_model(
num_words=len(tokenizer.word_index) + 1,
num_classes=len(label_encoder.classes_)
)
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy']
)
# Train model
history = model.fit(
[image_data, text_data_padded],
encoded_labels,
batch_size=BATCH_SIZE,
epochs=10,
validation_split=0.2
)
# Save models and encoders
model.save('multimodal_model')
joblib.dump(tokenizer, 'tokenizer.pkl')
joblib.dump(label_encoder, 'label_encoder.pkl')
return model, tokenizer, label_encoder
def get_recommendations(image_input, text_input, model, tokenizer, label_encoder, top_k=5):
# Preprocess image
img_array = image.img_to_array(image_input)
img_array = tf.image.resize(img_array, (IMAGE_SIZE, IMAGE_SIZE))
img_array = preprocess_input(img_array)
img_array = np.expand_dims(img_array, axis=0)
# Preprocess text
text_sequence = tokenizer.texts_to_sequences([text_input])
text_padded = pad_sequences(text_sequence, maxlen=MAX_TEXT_LENGTH)
# Get predictions
predictions = model.predict([img_array, text_padded])
top_indices = np.argsort(predictions[0])[-top_k:][::-1]
# Get recommended model names and confidence scores
recommendations = [
(label_encoder.inverse_transform([idx])[0], predictions[0][idx])
for idx in top_indices
]
return recommendations
# Gradio interface
def create_gradio_interface():
# Load saved models
model = tf.keras.models.load_model('multimodal_model')
tokenizer = joblib.load('tokenizer.pkl')
label_encoder = joblib.load('label_encoder.pkl')
def predict(img, text):
recommendations = get_recommendations(img, text, model, tokenizer, label_encoder)
return "\n".join([f"Model: {name}, Confidence: {conf:.2f}" for name, conf in recommendations])
interface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Textbox(label="Enter Prompt")
],
outputs=gr.Textbox(label="Recommended Models"),
title="Multimodal Model Recommendation System",
description="Upload an image and enter a prompt to get model recommendations"
)
return interface
if __name__ == "__main__":
# Train model if not already trained
if not os.path.exists('multimodal_model'):
model, tokenizer, label_encoder = train_model()
# Launch Gradio interface
interface = create_gradio_interface()
interface.launch()