dnth commited on
Commit
aaeef06
1 Parent(s): a261053

Upload 2 files

Browse files
Files changed (2) hide show
  1. 10_gradio_app.py +53 -0
  2. requirements.txt +91 -0
10_gradio_app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from urllib.request import urlopen
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import onnxruntime as ort
7
+ import torch
8
+ from PIL import Image
9
+
10
+ from imagenet_classes import IMAGENET2012_CLASSES
11
+
12
+
13
+ def read_image(image: Image.Image):
14
+ image = image.convert("RGB")
15
+ img_numpy = np.array(image).astype(np.float32)
16
+ img_numpy = img_numpy.transpose(2, 0, 1)
17
+ img_numpy = np.expand_dims(img_numpy, axis=0)
18
+ return img_numpy
19
+
20
+
21
+ providers = ["CPUExecutionProvider"]
22
+
23
+ session = ort.InferenceSession("merged_model_compose.onnx", providers=providers)
24
+
25
+ input_name = session.get_inputs()[0].name
26
+ output_name = session.get_outputs()[0].name
27
+
28
+
29
+ def predict(img):
30
+ output = session.run([output_name], {input_name: read_image(img)})
31
+ output = torch.from_numpy(output[0])
32
+
33
+ top5_probabilities, top5_class_indices = torch.topk(output.softmax(dim=1), k=5)
34
+
35
+ im_classes = list(IMAGENET2012_CLASSES.values())
36
+ class_names = [im_classes[i] for i in top5_class_indices[0]]
37
+
38
+ results = {
39
+ name: float(prob) for name, prob in zip(class_names, top5_probabilities[0])
40
+ }
41
+ return results
42
+
43
+
44
+ iface = gr.Interface(
45
+ fn=predict,
46
+ inputs=gr.Image(type="pil"),
47
+ outputs=gr.Label(num_top_classes=5),
48
+ title="Image Classification with ONNX TensorRT",
49
+ description="Upload an image to classify it using the ONNX TensorRT model.",
50
+ )
51
+
52
+ if __name__ == "__main__":
53
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ annotated-types==0.7.0
3
+ anyio==4.6.0
4
+ certifi==2024.8.30
5
+ charset-normalizer==3.3.2
6
+ click==8.1.7
7
+ coloredlogs==15.0.1
8
+ contourpy==1.3.0
9
+ cupy-cuda12x==13.3.0
10
+ cycler==0.12.1
11
+ fastapi==0.115.0
12
+ fastrlock==0.8.2
13
+ ffmpy==0.4.0
14
+ filelock==3.16.1
15
+ flatbuffers==24.3.25
16
+ fonttools==4.54.1
17
+ fsspec==2024.9.0
18
+ gradio==4.44.1
19
+ gradio_client==1.3.0
20
+ h11==0.14.0
21
+ httpcore==1.0.6
22
+ httpx==0.27.2
23
+ huggingface-hub==0.25.1
24
+ humanfriendly==10.0
25
+ idna==3.10
26
+ importlib_resources==6.4.5
27
+ Jinja2==3.1.4
28
+ kiwisolver==1.4.7
29
+ markdown-it-py==3.0.0
30
+ MarkupSafe==2.1.5
31
+ matplotlib==3.9.2
32
+ mdurl==0.1.2
33
+ mpmath==1.3.0
34
+ networkx==3.3
35
+ numpy==2.1.1
36
+ nvidia-cublas-cu12==12.1.3.1
37
+ nvidia-cuda-cupti-cu12==12.1.105
38
+ nvidia-cuda-nvrtc-cu12==12.1.105
39
+ nvidia-cuda-runtime-cu12==12.1.105
40
+ nvidia-cudnn-cu12==9.1.0.70
41
+ nvidia-cufft-cu12==11.0.2.54
42
+ nvidia-curand-cu12==10.3.2.106
43
+ nvidia-cusolver-cu12==11.4.5.107
44
+ nvidia-cusparse-cu12==12.1.0.106
45
+ nvidia-nccl-cu12==2.20.5
46
+ nvidia-nvjitlink-cu12==12.6.68
47
+ nvidia-nvtx-cu12==12.1.105
48
+ onnx==1.16.2
49
+ onnxruntime-gpu==1.19.2
50
+ onnxsim==0.4.36
51
+ opencv-python==4.10.0.84
52
+ orjson==3.10.7
53
+ packaging==24.1
54
+ pandas==2.2.3
55
+ pillow==10.4.0
56
+ protobuf==5.28.2
57
+ pydantic==2.9.2
58
+ pydantic_core==2.23.4
59
+ pydub==0.25.1
60
+ Pygments==2.18.0
61
+ pyparsing==3.1.4
62
+ python-dateutil==2.9.0.post0
63
+ python-multipart==0.0.12
64
+ pytz==2024.2
65
+ PyYAML==6.0.2
66
+ requests==2.32.3
67
+ rich==13.8.1
68
+ ruff==0.6.9
69
+ safetensors==0.4.5
70
+ semantic-version==2.10.0
71
+ shellingham==1.5.4
72
+ six==1.16.0
73
+ sniffio==1.3.1
74
+ starlette==0.38.6
75
+ sympy==1.13.3
76
+ tensorrt==10.1.0
77
+ tensorrt-cu12==10.1.0
78
+ tensorrt-cu12-bindings==10.1.0
79
+ tensorrt-cu12-libs==10.1.0
80
+ timm==1.0.9
81
+ tomlkit==0.12.0
82
+ torch==2.4.1
83
+ torchvision==0.19.1
84
+ tqdm==4.66.5
85
+ triton==3.0.0
86
+ typer==0.12.5
87
+ typing_extensions==4.12.2
88
+ tzdata==2024.2
89
+ urllib3==2.2.3
90
+ uvicorn==0.31.0
91
+ websockets==12.0