Spaces:
Running
Running
import streamlit as st | |
import tensorflow as tf | |
import os | |
import torch | |
import cv2 | |
import numpy as np | |
import requests | |
import joblib | |
import sklearn | |
from PIL import Image | |
from sklearn.decomposition import PCA | |
from tensorflow.keras.models import load_model | |
from transformers import pipeline | |
token = os.environ['token'] | |
st.set_page_config( | |
page_title = 'Patacotrón', | |
layout = 'wide', | |
menu_items = { | |
"About" : 'Proyecto ideado para la investigación de "Clasificación de imágenes de una sola clase con algortimos de Inteligencia Artificial".', | |
"Report a Bug" : 'mailto:contact@patacotron.tech' | |
} | |
) | |
st.sidebar.write("contact@patacotron.tech") | |
cnn, vit, zero_shot, classic_ml = st.tabs(["CNN", "ViT", "Zero-Shot", "Machine Learning Clásico"]) | |
classic_ml_root = "/home/user/app/classicML" | |
def load_pca(): | |
return joblib.load(os.path.join(classic_ml_root, "pca_model.pkl")) | |
def _predict(_model_list, _img, sklearn = False): | |
y_gorrito = 0 | |
raw_img = cv2.cvtColor(_img, cv2.COLOR_BGR2RGB) | |
img = cv2.resize(_img, (IMAGE_WIDTH, IMAGE_HEIGHT)) | |
if sklearn: | |
fl_img =[img.flatten()] | |
data = pca.transform(fl_img) | |
for model in _model_list: | |
prediction = model.predict_proba(data) | |
y_gorrito += prediction[0][Categories.index("Patacon-True")] | |
else: | |
for model in _model_list: | |
y_gorrito += tf.cast(model(tf.expand_dims(img/255., 0)), dtype=tf.float32)*weight | |
return [y_gorrito / len(_model_list), raw_img] | |
#def _pca_predict(models, _img): | |
# y_gorrito = 0 | |
# raw_img = cv2.cvtColor(_img, cv2.COLOR_BGR2RGB) | |
# img = cv2.resize(_img, (IMAGE_WIDTH, IMAGE_HEIGHT)) | |
# fl_img =[img.flatten()] | |
# data = pca.transform(fl_img) | |
# for model in models: | |
# prediction = model.predict_proba(data) | |
# y_gorrito += prediction[0][Categories.index("Patacon-True")] | |
# return [y_gorrito / len(models), raw_img] | |
#def classic_ml_prediction(clfs, _img): | |
# y_gorrito = 0 | |
# raw_img = cv2.cvtColor(_img, cv2.COLOR_BGR2RGB) | |
# img = cv2.resize(_img, (IMAGE_WIDTH, IMAGE_HEIGHT)).flatten() | |
# data = pca.transform(img.reshape(1, -1)) | |
# for clf in clfs: | |
# y_gorrito += clf.predict(data) | |
# return [y_gorrito / len(clfs), raw_img] | |
def preprocess(file_uploader, module = 'cv2'): #makes the uploaded image readable | |
img = np.frombuffer(uploaded_file.read(), np.uint8) | |
if module == 'cv2': | |
img = cv2.imdecode(img, cv2.IMREAD_COLOR) | |
elif module == 'pil': | |
img = Image.open(file_uploader) | |
return img | |
def multiclass_prediction(classifier, important_class): #made for hf zero-shot pipeline results | |
score = (max([classifier[i]['score'] for i in range(len(classifier))])) | |
labels = [predict['label'] for predict in classifier if score == predict['score']] | |
for clase in classifier: | |
if clase['label'] == important_class: | |
class_score = clase['score'] | |
return (labels[0] if len(labels) == 1 else labels, score, class_score) | |
API_URL = "https://api-inference.huggingface.co/models" | |
headers = {"Authorization": f"Bearer {st.secrets['token']}"} | |
def query(data, models): #HF API | |
response = requests.post(API_URL + "/" + model_name, headers=headers, data=data) | |
if response.json()["error"] == "Internal Server Error": | |
return -1 | |
while "error" in response.json(): | |
response = requests.post(API_URL + "/" + model_name, headers=headers, data=data) | |
return response.json()[1]["score"] #.json | |
def load_clip(): | |
classifier = pipeline("zero-shot-image-classification", model = 'openai/clip-vit-large-patch14-336') | |
return classifier | |
with cnn: | |
col_a, col_b, = st.columns(2) | |
ultra_flag = None | |
with col_a: | |
st.title("Redes neuronales convolucionales") | |
st.caption("Los modelos no están en orden de eficacia, sino en orden de creación.") | |
current_dir = os.getcwd() | |
root_dir = os.path.dirname(current_dir) | |
# Join the path to the models folder | |
DIR = os.path.join(current_dir, "models") | |
models = os.listdir(DIR) | |
common_root = r"/home/user/app/models/ptctrn_v" | |
common_end = ".h5" | |
model_dict = dict() | |
for model in models: #preprocessing of strings so the name is readable in the multiselect bar | |
model_dir = os.path.join(DIR, model) | |
model_name = 'Patacotrón ' + model_dir.split(common_root)[-1].split(common_end)[0] | |
model_dict[model_name] = model_dir | |
weight_list = [] | |
# Create a dropdown menu to select the model | |
model_choice = st.multiselect("Seleccione uno o varios modelos de clasificación", model_dict.keys()) | |
threshold = st.slider('¿Cuál va a ser el límite donde se considere patacón? (el valor recomendado es de 75%-80%)', 0, 100, 50, key = 'threshold_convnet') | |
selected_models = [] | |
# Set the image dimensions | |
IMAGE_WIDTH = IMAGE_HEIGHT = 224 | |
executed = False | |
with col_b: | |
uploaded_file = st.file_uploader(key = 'conv_upload', label = 'Sube la imagen a clasificar',type= ['jpg','png', 'jpeg', 'jfif', 'webp', 'heic']) | |
if st.button(key = 'convnet_button', label ='¿Hay un patacón en la imagen?'): | |
if len(model_choice) < 1: | |
st.write('Debe elegir como mínimo un modelo.') | |
elif uploaded_file is not None: | |
img = preprocess(uploaded_file) | |
with st.spinner('Cargando predicción...'): | |
selected_models = [load_model(model_dict[model_name]) for model_name in model_choice if model_name not in selected_models] | |
final_weights = weight_list if len(weight_list) >= 1 else [1 for i in range(len(selected_models))] | |
y_gorrito, raw_img = _predict(selected_models, final_weights, img) | |
if round(float(y_gorrito*100)) >= threshold: | |
st.success("¡Patacón Detectado!") | |
else: | |
st.error("No se considera que haya un patacón en la imagen") | |
st.caption(f'La probabilidad de que la imagen tenga un patacón es del: {round(float(y_gorrito * 100), 2)}%') | |
st.image(raw_img) | |
else: | |
st.write('Revisa haber seleccionado los modelos y la imagen correctamente.') | |
with vit: | |
col_a, col_b = st.columns(2) | |
with col_a: | |
st.title('Visual Transformers') | |
st.caption('One class is all you need!') | |
model_dict = { | |
'google/vit-base-patch16-224-in21k' : 'frncscp/patacoptimus-prime', | |
'facebook/dinov2-base' : 'frncscp/dinotron', | |
'facebook/convnext-large-224' : 'frncscp/pataconxt', | |
'microsoft/focalnet-small' : 'frncscp/focalnet-small-patacon', | |
'microsoft/swin-tiny-patch4-window7-224' : 'frncscp/patacoswin' | |
} | |
model_choice = st.multiselect("Seleccione un modelo de clasificación", model_dict.keys(), key = 'ViT_multiselect') | |
uploaded_file = st.file_uploader(key = 'ViT_upload', label = 'Sube la imagen a clasificar',type= ['jpg','png', 'jpeg', 'jfif', 'webp', 'heic']) | |
flag = False | |
threshold = st.slider('¿Cuál va a ser el límite desde donde se considere patacón? (se recomienda por encima del 80%)', 0, 100, 80, key = 'threshold_vit') | |
with col_b: | |
if st.button(key = 'ViT_button', label ='¿Hay un patacón en la imagen?'): | |
if len(model_choice) < 1: | |
print('Recuerda seleccionar al menos un modelo de clasificación') | |
elif uploaded_file is not None: | |
with st.spinner('Cargando predicción...'): | |
classifiers = [pipeline("image-classification", model= model_dict[model_choice[i]], token = token) for i in range(len(model_choice))] | |
#classifier = pipeline("image-classification", model= model_dict[model_choice[0]]) | |
img = preprocess(uploaded_file, module = 'pil') | |
models = [model_dict[model] for model in model_choice] | |
#st.write(models) | |
def vit_ensemble(classifier_list, img): | |
y_gorrito = 0 | |
for classifier in classifier_list: | |
classifier = classifier(img) | |
for clase in classifier: | |
if clase['label'] == 'Patacon-True': | |
y_gorrito += clase["score"] | |
return y_gorrito / len(classifier_list) | |
#models = [model_dict[i] for i in range(len(model_choice))] | |
#st.write(type(models), models) | |
#st.write(model_choice) | |
#y_gorrito = 0 | |
#y_gorritoo = query(uploaded_file.read(), model_choice[0])#[1]["score"] | |
#i = -1 | |
#st.write("loop iniciado") | |
#for model in models: | |
# i+=1 | |
# st.write("y gorrito a cargar") | |
# a = query(uploaded_file.read(), model) | |
# if a == -1: | |
# st.write("Los servidores se encuentrar caídos, intente más tarde") | |
# st.write("query terminado") | |
# y_gorritoo += a | |
# st.write("y gorrito cargado") | |
#y_gorritoo /= i | |
#st.write(y_gorritoo) | |
#st.write("loop terminado") | |
#st.write("y gorrito calculado", len(model_choice)) | |
#classifier = classifier(img) | |
#for clase in classifier: | |
# if clase['label'] == 'Patacon-True': | |
# y_gorrito = clase["score"] | |
#y_gorrito = classifier[0]["score"] | |
y_gorrito = vit_ensemble(classifiers, img) | |
# | |
if round(float(y_gorrito * 100)) >= threshold: | |
st.success("¡Patacón Detectado!") | |
else: | |
st.error("No se considera que haya un patacón en la imagen") | |
st.caption(f'La probabilidad de que la imagen tenga un patacón es del: {round(float(y_gorrito * 100), 2)}%') | |
st.image(img) | |
else: | |
st.write("Asegúrate de haber subido correctamente la imagen.") | |
with zero_shot: | |
col_a, col_b = st.columns(2) | |
zsloaded = [] | |
with col_a: | |
st.title("Clasificación Zero-Shot") | |
st.caption("Usando Clip de OpenAI") | |
labels_for_classification = ["A yellow deep fried smashed plantain", | |
"A yellow corn dough", | |
"A stuffed fried dough", | |
"Fried food", | |
"Fruit", | |
"Anything"] | |
uploaded_file = st.file_uploader(key = 'ZS_upload', label = 'Sube la imagen a clasificar',type= ['jpg','png', 'jpeg', 'jfif', 'webp', 'heic']) | |
with col_b: | |
if st.button(key = 'ZS_button', label ='¿Hay un patacón en la imagen?'): | |
if uploaded_file is not None: | |
with st.spinner('Cargando el modelo (puede demorar hasta un minuto, pero después predice rápido)'): | |
classifier = load_clip() | |
with st.spinner('Cargando predicción...'): | |
img = preprocess(uploaded_file, module = 'pil') | |
zs_classifier = classifier(img, | |
candidate_labels = labels_for_classification) | |
label, _, y_gorrito = multiclass_prediction(zs_classifier, labels_for_classification[0]) | |
if label == "A yellow deep fried smashed plantain": | |
st.success("¡Patacón Detectado!") | |
else: | |
st.error("No se considera que haya un patacón en la imagen") | |
st.caption(f'La probabilidad de que la imagen tenga un patacón es del: {round(float(y_gorrito * 100), 2)}%') | |
st.image(img) | |
else: | |
st.write("Asegúrate de haber subido correctamente la imagen.") | |
with classic_ml: | |
pca = load_pca() | |
Categories=['Patacon-True','Patacon-False'] | |
col_a, col_b = st.columns(2) | |
with col_a: | |
st.title("Machine Learning Clásico") | |
st.caption("Usando análisis por componentes principales") | |
model_dict = { | |
'Máquina de vectores de soporte' : 'pca_svm.sav', | |
'K-Nearest Neighbors' : 'pca_knn.sav', | |
'Bosques Aleatorios' : 'pca_random_forest.sav', | |
} | |
for model_name, filename in model_dict.items(): | |
model_dict[model_name] = os.path.join(classic_ml_root, filename) | |
model_choice = st.multiselect("Seleccione un modelo de clasificación", model_dict.keys(), key = 'cML_multiselect') | |
uploaded_file = st.file_uploader(key = 'cML_upload', label = 'Sube la imagen a clasificar',type= ['jpg','png', 'jpeg', 'jfif', 'webp', 'heic']) | |
threshold = st.slider('¿Cuál va a ser el límite desde donde se considere patacón? (se recomienda por encima del 70%)', 0, 100, 70, key = 'threshold_cML') | |
with col_b: | |
if st.button(key = 'cML_button', label ='¿Hay un patacón en la imagen?'): | |
if len(model_choice) < 1: | |
print('Recuerda seleccionar al menos un modelo de clasificación') | |
elif uploaded_file is not None: | |
with st.spinner('Cargando predicción...'): | |
img = preprocess(uploaded_file) | |
selected_models = [joblib.load(model_dict[model_name]) for model_name in model_choice] | |
y_gorrito, raw_img = _predict(selected_models, img, sklearn = True) | |
if round(float(y_gorrito*100)) >= threshold: | |
st.success("¡Patacón Detectado!") | |
else: | |
st.error("No se considera que haya un patacón en la imagen") | |
st.caption(f'La probabilidad de que la imagen tenga un patacón es del: {round(float(y_gorrito * 100), 2)}%') | |
st.image(raw_img) | |
else: | |
st.write('Revisa haber seleccionado los modelos y la imagen correctamente.') |