import torch from torchvision import transforms, datasets from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights from torchvision.transforms.functional import InterpolationMode from PIL import Image import gradio as gr import requests # Define el modelo y carga los pesos guardados model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT) model.classifier[1] = torch.nn.Linear(in_features=1280, out_features=101) #model.load_state_dict(torch.load('./Model_Food_ProyectoIA'), map_location=torch.device('cpu')) model.load_state_dict(torch.load('./Model_Food_ProyectoIA', map_location=torch.device('cpu'))) model.eval() # Poner el modelo en modo evaluación # Mueve el modelo a la GPU si está disponible device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model.to(device) # Define las transformaciones transform_preprocess = transforms.Compose([ transforms.Resize(256, interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Cargar el conjunto de datos Food-101 para obtener la lista de clases #image_path = '../../compartida/vision-project/' #food101_dataset = datasets.Food101(image_path, split='train') #classes = food101_dataset.classes # Función para cargar el archivo de texto desde una URL #def load_remote_dataset(url): # response = requests.get(url) # response.raise_for_status() # Asegúrate de que la solicitud fue exitosa # return response.text # Leer el archivo localmente with open("clases.txt", "r") as f: classes = f.read().strip().split("\n") # URL del archivo de texto en el Space de Hugging Face #file_url = "https://huggingface.co/spaces/Alan7/ProyectoComputerVision/blob/main/clases.txt" # Carga el archivo de texto desde la URL #file_content = load_remote_dataset(file_url) # Lee el contenido del archivo y divide por saltos de línea #classes = file_content.strip().split("\n") # Función para predecir la clase de una nueva imagen def predict_image(image): image = Image.fromarray(image).convert('RGB') # Convertir la imagen cargada a PIL image = transform_preprocess(image).unsqueeze(0) # Preprocesar y añadir dimensión de batch image = image.to(device) # Mover la imagen a la GPU si está disponible with torch.no_grad(): output = model(image) # Realizar la predicción prediction = torch.nn.functional.softmax(output[0], dim=0) # Aplicar softmax para obtener probabilidades confidences = {classes[i]: float(prediction[i]) for i in range(101)} # Crear diccionario de clases y probabilidades return confidences # Devolver las probabilidades de cada clase # Crear la interfaz de Gradio interface = gr.Interface( fn=predict_image, inputs=gr.Image(type="numpy"), outputs=gr.Label(num_top_classes=3), title="Food101 Classifier", description="Sube una imagen de comida y el modelo clasificará la imagen.", examples=["hamb.jpg","ptts.jpg","lechuga.jpg"] # Reemplaza con rutas de ejemplo ) # Iniciar la interfaz interface.launch(share=False)