Clement Vachet commited on
Commit
51588cf
·
0 Parent(s):

Initial commit

Browse files
Files changed (3) hide show
  1. .gitignore +165 -0
  2. app.py +116 -0
  3. utils.py +86 -0
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment file
2
+ config_api.env
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
113
+ .pdm.toml
114
+ .pdm-python
115
+ .pdm-build/
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ .idea/
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import os
4
+ import requests
5
+ import json
6
+ import utils
7
+
8
+ from dotenv import load_dotenv, find_dotenv
9
+
10
+ # List of ML models
11
+ list_models = ["facebook/detr-resnet-50", "facebook/detr-resnet-101", "hustvl/yolos-tiny", "hustvl/yolos-small"]
12
+ list_models_simple = [os.path.basename(model) for model in list_models]
13
+
14
+ # ECS APIs
15
+ AWS_DETR_URL = None
16
+ AWS_YOLOS_URL = None
17
+
18
+
19
+ # Initialize API URLs from env file or global settings
20
+ def initialize_api_endpoints():
21
+
22
+ env_path = find_dotenv('config_api.env')
23
+ if env_path:
24
+ load_dotenv(dotenv_path=env_path)
25
+ print("config_api.env file loaded successfully.")
26
+ else:
27
+ print("config_api.env file not found.")
28
+
29
+ # Use of AWS ECS endpoint or local container by default
30
+ global AWS_DETR_URL, AWS_YOLOS_URL
31
+ AWS_DETR_URL = os.getenv("AWS_DETR_URL", default="http://0.0.0.0:8000")
32
+ AWS_YOLOS_URL = os.getenv("AWS_YOLOS_URL", default="http://0.0.0.0:8001")
33
+
34
+
35
+ # Retrieve correct endpoint based on model_type
36
+ def retrieve_api_endpoint(model_type):
37
+ if "detr" in model_type:
38
+ API_URL = AWS_DETR_URL
39
+ else:
40
+ API_URL = AWS_YOLOS_URL
41
+
42
+ return API_URL
43
+
44
+
45
+ #@spaces.GPU
46
+ def detect(image_path, model_id, threshold):
47
+ print("\n Object detection...")
48
+ print("\t ML model:", list_models[model_id])
49
+
50
+ with open(image_path, 'rb') as image_file:
51
+ image_bytes = image_file.read()
52
+
53
+ API_URL = retrieve_api_endpoint(list_models_simple[model_id])
54
+
55
+ # API Call for object prediction with model type as query parameter
56
+ API_Endpoint = API_URL + "/api/v1/detect" + "?model=" + list_models_simple[model_id]
57
+ print("\t API_Endpoint: ", API_Endpoint)
58
+
59
+ response = requests.post(API_Endpoint, files={"image": image_bytes})
60
+ if response.status_code == 200:
61
+ # Process the response
62
+ response_string = response.json()
63
+ response_dict = json.loads(response_string)
64
+ print('\t API response', response_string)
65
+ else:
66
+ response_dict = {"Error": response.status_code}
67
+ gr.Error(f"\t API Error: {response.status_code}")
68
+
69
+ # Generate gradio output components: image and json
70
+ output_json, output_pil_img = utils.generate_gradio_outputs(image_path, response_dict, threshold)
71
+
72
+ return output_json, output_pil_img
73
+
74
+
75
+ def demo():
76
+ initialize_api_endpoints()
77
+ with gr.Blocks(theme="base") as demo:
78
+ gr.Markdown("# Object detection task - use of ECS endpoints")
79
+ gr.Markdown(
80
+ """
81
+ This web application uses transformer models to detect objects on images.
82
+ Machine learning models were trained on the COCO dataset.
83
+ You can load an image and see the predictions for the objects detected.
84
+
85
+ Note: This web application uses AWS ECS endpoints as a back-end APIs to run these ML models.
86
+ """
87
+ )
88
+
89
+ with gr.Row():
90
+ with gr.Column():
91
+ model_id = gr.Radio(list_models, \
92
+ label="Detection models", value=list_models[0], type="index", info="Choose your detection model")
93
+ with gr.Column():
94
+ threshold = gr.Slider(0, 1.0, value=0.9, label='Detection threshold', info="Choose your detection threshold")
95
+
96
+ with gr.Row():
97
+ input_image = gr.Image(label="Input image", type="filepath")
98
+ output_image = gr.Image(label="Output image", type="pil")
99
+ output_json = gr.JSON(label="JSON output", min_height=240, max_height=300)
100
+
101
+ with gr.Row():
102
+ submit_btn = gr.Button("Submit")
103
+ clear_button = gr.ClearButton()
104
+
105
+ gr.Examples(['samples/savanna.jpg', 'samples/boats.jpg'], inputs=input_image)
106
+
107
+ submit_btn.click(fn=detect, inputs=[input_image, model_id, threshold], outputs=[output_json, output_image])
108
+ clear_button.click(lambda: [None, None, None], \
109
+ inputs=None, \
110
+ outputs=[input_image, output_image, output_json], \
111
+ queue=False)
112
+
113
+ demo.queue().launch(debug=True)
114
+
115
+ if __name__ == "__main__":
116
+ demo()
utils.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import matplotlib.pyplot as plt
3
+ import io
4
+
5
+
6
+ # COCO classes
7
+ CLASSES = [
8
+ 'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
9
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
10
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
11
+ 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
12
+ 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
13
+ 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
14
+ 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
15
+ 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
16
+ 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
17
+ 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
18
+ 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
19
+ 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
20
+ 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
21
+ 'toothbrush'
22
+ ]
23
+ COLORS = [
24
+ [0.000, 0.447, 0.741],
25
+ [0.850, 0.325, 0.098],
26
+ [0.929, 0.694, 0.125],
27
+ [0.494, 0.184, 0.556],
28
+ [0.466, 0.674, 0.188],
29
+ [0.301, 0.745, 0.933],
30
+ ]
31
+
32
+
33
+ # Update JSON dictionary with rounded values and classes
34
+ def generate_output_json(json_dict):
35
+ json_dict['scores'] = [round(score, 3) for score in json_dict['scores']]
36
+ json_dict['boxes'] = [[round(coord, 3) for coord in box] for box in json_dict['boxes']]
37
+ json_dict['labels'] = [CLASSES[label] for label in json_dict['labels']]
38
+ return json_dict
39
+
40
+
41
+ # Generate matplotlib figure from prediction scores and boxes
42
+ def generate_output_figure(image_path, results, threshold):
43
+ pil_img = Image.open(image_path)
44
+
45
+ plt.figure(figsize=(16, 10))
46
+ plt.imshow(pil_img)
47
+ ax = plt.gca()
48
+ colors = COLORS * 100
49
+
50
+ print("\t Detailed information...")
51
+ for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
52
+ #box = [round(i, 2) for i in box]
53
+ print(
54
+ f"\t\t Detected {label} with confidence "
55
+ f"{score} at location {box}"
56
+ )
57
+
58
+ if score > threshold:
59
+ c = COLORS[hash(label) % len(COLORS)]
60
+ ax.add_patch(
61
+ plt.Rectangle((box[0], box[1]), box[2] - box[0], box[3] - box[1], fill=False, color=c, linewidth=3)
62
+ )
63
+ text = f"{label}: {score:0.2f}"
64
+ ax.text(box[0], box[1], text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
65
+ plt.axis("off")
66
+
67
+ return plt.gcf()
68
+
69
+
70
+ # Generate PIL image from matplotlib figure
71
+ def generate_output_image(output_figure):
72
+ # Convert matplotlib figure to PIL image
73
+ #output_figure = plt.gcf()
74
+ buf = io.BytesIO()
75
+ output_figure.savefig(buf, bbox_inches="tight")
76
+ buf.seek(0)
77
+ output_pil_img = Image.open(buf)
78
+
79
+ return output_pil_img
80
+
81
+
82
+ def generate_gradio_outputs(image_path, response_dict, threshold):
83
+ output_json = generate_output_json(response_dict)
84
+ output_figure = generate_output_figure(image_path, output_json, threshold)
85
+ output_pil_img = generate_output_image(output_figure)
86
+ return output_json, output_pil_img