HashamUllah commited on
Commit
874d2f6
·
verified ·
1 Parent(s): 25926f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -20
app.py CHANGED
@@ -1,13 +1,19 @@
1
- import gradio as gr
 
2
  import numpy as np
3
  import cv2
4
  import pickle
5
  from tensorflow.keras.models import load_model
6
  from tensorflow.keras.preprocessing.image import img_to_array
7
 
 
 
 
8
  # Load the model and the label binarizer
9
  model = load_model('cnn_model.h5')
 
10
  label_binarizer = pickle.load(open('label_transform.pkl', 'rb'))
 
11
 
12
  # Function to convert images to array
13
  def convert_image_to_array(image_dir):
@@ -22,14 +28,41 @@ def convert_image_to_array(image_dir):
22
  print(f"Error : {e}")
23
  return None
24
 
25
- def predict(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  try:
27
- # Convert the image to an array
28
- _, image_data = cv2.imencode('.jpg', image)
29
- image_array = convert_image_to_array(image_data.tobytes())
 
30
 
31
  if image_array.size == 0:
32
- return "Invalid image"
33
 
34
  # Normalize the image
35
  image_array = np.array(image_array, dtype=np.float16) / 255.0
@@ -41,19 +74,10 @@ def predict(image):
41
  prediction = model.predict(image_array)
42
  predicted_class = label_binarizer.inverse_transform(prediction)[0]
43
 
44
- return predicted_class
45
  except Exception as e:
46
- return str(e)
47
-
48
- # Create a Gradio interface
49
- iface = gr.Interface(
50
- fn=predict,
51
- inputs=gr.inputs.Image(type="numpy", label="Upload an image"),
52
- outputs="text",
53
- title="Image Classification",
54
- description="Upload an image to classify it."
55
- )
56
-
57
- # Launch the interface
58
  if __name__ == "__main__":
59
- iface.launch()
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.responses import JSONResponse
3
  import numpy as np
4
  import cv2
5
  import pickle
6
  from tensorflow.keras.models import load_model
7
  from tensorflow.keras.preprocessing.image import img_to_array
8
 
9
+ app = FastAPI()
10
+
11
+ print("app run")
12
  # Load the model and the label binarizer
13
  model = load_model('cnn_model.h5')
14
+ print("model loaded")
15
  label_binarizer = pickle.load(open('label_transform.pkl', 'rb'))
16
+ print("labels loaded")
17
 
18
  # Function to convert images to array
19
  def convert_image_to_array(image_dir):
 
28
  print(f"Error : {e}")
29
  return None
30
 
31
+ @app.post("/predict")
32
+ async def predict(file: UploadFile = File(...)):
33
+ try:
34
+ # Read the file and convert it to an array
35
+ image_data = await file.read()
36
+ image_array = convert_image_to_array(image_data)
37
+
38
+ if image_array.size == 0:
39
+ return JSONResponse(content={"error": "Invalid image"}, status_code=400)
40
+
41
+ # Normalize the image
42
+ image_array = np.array(image_array, dtype=np.float16) / 255.0
43
+
44
+ # Ensure the image_array has the correct shape (1, 256, 256, 3)
45
+ image_array = np.expand_dims(image_array, axis=0)
46
+
47
+ # Make a prediction
48
+ prediction = model.predict(image_array)
49
+ predicted_class = label_binarizer.inverse_transform(prediction)[0]
50
+
51
+ return {"prediction": predicted_class}
52
+ except Exception as e:
53
+ return JSONResponse(content={"error": str(e)}, status_code=500)
54
+
55
+ # Add a test GET endpoint to manually trigger the prediction
56
+ @app.get("/test-predict")
57
+ def test_predict():
58
  try:
59
+ image_path = 'crop_image2.jpg'
60
+ image = cv2.imread(image_path)
61
+ image_array = cv2.resize(image, (256, 256))
62
+ image_array = img_to_array(image_array)
63
 
64
  if image_array.size == 0:
65
+ return JSONResponse(content={"error": "Invalid image"}, status_code=400)
66
 
67
  # Normalize the image
68
  image_array = np.array(image_array, dtype=np.float16) / 255.0
 
74
  prediction = model.predict(image_array)
75
  predicted_class = label_binarizer.inverse_transform(prediction)[0]
76
 
77
+ return {"prediction": predicted_class}
78
  except Exception as e:
79
+ return JSONResponse(content={"error": str(e)}, status_code=500)
80
+
 
 
 
 
 
 
 
 
 
 
81
  if __name__ == "__main__":
82
+ import uvicorn
83
+ uvicorn.run(app, host="127.0.0.1", port=8000)