navpan2 commited on
Commit
59651bd
·
verified ·
1 Parent(s): 1565713

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +69 -0
main.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from fastapi import FastAPI, File, UploadFile
5
+ from fastapi.responses import JSONResponse
6
+ from io import BytesIO
7
+ from PIL import Image
8
+ from tensorflow.keras.preprocessing.image import img_to_array
9
+ from tensorflow.keras.applications import resnet50
10
+ from tensorflow.keras.applications.resnet50 import preprocess_input
11
+ import uvicorn
12
+
13
+ # Initialize FastAPI app
14
+ app = FastAPI()
15
+
16
+ # Model and class information
17
+ model_path = "model.keras"
18
+ class_indices = {0: 'glaucoma', 1: 'normal'}
19
+
20
+ # Load the model if it exists
21
+ if os.path.exists(model_path):
22
+ model = tf.keras.models.load_model(model_path)
23
+ print("Model loaded successfully.")
24
+ else:
25
+ print(f"Model file not found at {model_path}. Please upload the model.")
26
+
27
+ # Function to predict glaucoma in an image and return the class name
28
+ def predict_image(image_data):
29
+ try:
30
+ # Load the image from binary data
31
+ img = Image.open(BytesIO(image_data))
32
+ # Resize the image to the target size
33
+ img = img.resize((224, 224))
34
+ # Convert image to array format for the model
35
+ img_array = img_to_array(img)
36
+ img_array = np.expand_dims(img_array, axis=0)
37
+ img_array = preprocess_input(img_array)
38
+
39
+ # Make prediction
40
+ prediction = model.predict(img_array)
41
+ predicted_class = np.argmax(prediction[0])
42
+ class_name = class_indices[predicted_class] # Map to class name
43
+ return class_name
44
+ except Exception as e:
45
+ print("Prediction error:", e)
46
+ return "Error during prediction"
47
+
48
+ # Route for health check
49
+ @app.get("/health")
50
+ async def api_health_check():
51
+ return JSONResponse(content={"status": "Service is running"})
52
+
53
+ # Route for prediction using image via API
54
+ @app.post("/predict")
55
+ async def api_predict_image(file: UploadFile = File(...)):
56
+ try:
57
+ # Read the image file as binary data
58
+ image_data = await file.read()
59
+
60
+ # Call the prediction function with the image data
61
+ prediction = predict_image(image_data)
62
+
63
+ return JSONResponse(content={"prediction": prediction})
64
+ except Exception as e:
65
+ return JSONResponse(content={"error": str(e)})
66
+
67
+ # Run the FastAPI app
68
+ if __name__ == "__main__":
69
+ uvicorn.run(app, host="0.0.0.0", port=7860)