Spaces:
Running
Running
import os | |
import numpy as np | |
import tensorflow as tf | |
from fastapi import FastAPI, File, UploadFile | |
from fastapi.responses import JSONResponse | |
from io import BytesIO | |
from PIL import Image | |
from tensorflow.keras.preprocessing.image import img_to_array | |
from tensorflow.keras.applications import resnet50 | |
from tensorflow.keras.applications.resnet50 import preprocess_input | |
import uvicorn | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Model and class information | |
model_path = "model.keras" | |
class_indices = {0: 'glaucoma', 1: 'normal'} | |
# Load the model if it exists | |
if os.path.exists(model_path): | |
model = tf.keras.models.load_model(model_path) | |
print("Model loaded successfully.") | |
else: | |
print(f"Model file not found at {model_path}. Please upload the model.") | |
# Function to predict glaucoma in an image and return the class name | |
def predict_image(image_data): | |
try: | |
# Load the image from binary data | |
img = Image.open(BytesIO(image_data)) | |
# Resize the image to the target size | |
img = img.resize((224, 224)) | |
# Convert image to array format for the model | |
img_array = img_to_array(img) | |
img_array = np.expand_dims(img_array, axis=0) | |
img_array = preprocess_input(img_array) | |
# Make prediction | |
prediction = model.predict(img_array) | |
predicted_class = np.argmax(prediction[0]) | |
class_name = class_indices[predicted_class] # Map to class name | |
return class_name | |
except Exception as e: | |
print("Prediction error:", e) | |
return "Error during prediction" | |
# Route for health check | |
async def api_health_check(): | |
return JSONResponse(content={"status": "Service is running"}) | |
# Route for prediction using image via API | |
async def api_predict_image(file: UploadFile = File(...)): | |
try: | |
# Read the image file as binary data | |
image_data = await file.read() | |
# Call the prediction function with the image data | |
prediction = predict_image(image_data) | |
return JSONResponse(content={"prediction": prediction}) | |
except Exception as e: | |
return JSONResponse(content={"error": str(e)}) | |
# Run the FastAPI app | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |