t3_mejorado / app.py
Junior16's picture
Update app.py
1ae51b3 verified
from fastapi import FastAPI, File, UploadFile
import numpy as np
from PIL import Image
import io
import cv2
from datasets import load_dataset
app = FastAPI()
# Cargar el PlantVillage dataset predefinido
dataset = load_dataset("susnato/plant_disease_detection_processed")
# Aqu铆 puedes entrenar un modelo o usar uno preentrenado. Para simplificar, vamos a usar el dataset solo para mostrar ejemplos.
# Por ejemplo, podemos ver algunas im谩genes del dataset, pero en producci贸n deber铆as tener un modelo entrenado.
train_data = dataset["train"]
@app.post("/detect_disease/")
async def detect_disease(file: UploadFile = File(...)):
# Leer imagen cargada
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes))
img_np = np.array(image)
# Convertir la imagen a escala de grises
gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
edges = cv2.Canny(gray, 100, 200)
# Aqu铆 realizar铆as la predicci贸n usando tu modelo, en lugar de solo mostrar bordes.
# Simularemos un diagn贸stico simple usando el promedio de los bordes para ilustrar la idea.
disease_detected = "Enfermedad detectada" if np.mean(edges) > 50 else "Saludable"
# Visualizaci贸n de la primera imagen del dataset de ejemplo
example_image = train_data[0]['image'] # Imagen del dataset para ejemplo
example_label = train_data[0]['label'] # Etiqueta de la enfermedad de la imagen de ejemplo
return {
"diagnosis": disease_detected,
"example_image": example_image,
"example_label": example_label
}