testapinavn / main.py
navpan2's picture
Create main.py
59651bd verified
import os
import numpy as np
import tensorflow as tf
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from io import BytesIO
from PIL import Image
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications import resnet50
from tensorflow.keras.applications.resnet50 import preprocess_input
import uvicorn
# Initialize FastAPI app
app = FastAPI()
# Model and class information
model_path = "model.keras"
class_indices = {0: 'glaucoma', 1: 'normal'}
# Load the model if it exists
if os.path.exists(model_path):
model = tf.keras.models.load_model(model_path)
print("Model loaded successfully.")
else:
print(f"Model file not found at {model_path}. Please upload the model.")
# Function to predict glaucoma in an image and return the class name
def predict_image(image_data):
try:
# Load the image from binary data
img = Image.open(BytesIO(image_data))
# Resize the image to the target size
img = img.resize((224, 224))
# Convert image to array format for the model
img_array = img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
# Make prediction
prediction = model.predict(img_array)
predicted_class = np.argmax(prediction[0])
class_name = class_indices[predicted_class] # Map to class name
return class_name
except Exception as e:
print("Prediction error:", e)
return "Error during prediction"
# Route for health check
@app.get("/health")
async def api_health_check():
return JSONResponse(content={"status": "Service is running"})
# Route for prediction using image via API
@app.post("/predict")
async def api_predict_image(file: UploadFile = File(...)):
try:
# Read the image file as binary data
image_data = await file.read()
# Call the prediction function with the image data
prediction = predict_image(image_data)
return JSONResponse(content={"prediction": prediction})
except Exception as e:
return JSONResponse(content={"error": str(e)})
# Run the FastAPI app
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)