File size: 2,375 Bytes
e155f7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import json
import os
from functools import lru_cache
from typing import Mapping

import gradio as gr
from huggingface_hub import HfFileSystem, hf_hub_download
from imgutils.data import ImageTyping, load_image
from natsort import natsorted

from onnx_ import _open_onnx_model
from preprocess import _img_encode

hfs = HfFileSystem()


@lru_cache()
def open_model_from_repo(repository, model):
    runtime = _open_onnx_model(hf_hub_download(repository, f'{model}/model.onnx'))
    with open(hf_hub_download(repository, f'{model}/meta.json'), 'r') as f:
        labels = json.load(f)['labels']

    return runtime, labels


class Classification:
    def __init__(self, title: str, repository: str, default_model=None, imgsize: int = 384):
        self.title = title
        self.repository = repository
        self.models = natsorted([
            os.path.dirname(os.path.relpath(file, self.repository))
            for file in hfs.glob(f'{self.repository}/*/model.onnx')
        ])
        self.default_model = default_model or self.models[0]
        self.imgsize = imgsize

    def _open_onnx_model(self, model):
        return open_model_from_repo(self.repository, model)

    def _gr_classification(self, image: ImageTyping, model_name: str, size=384) -> Mapping[str, float]:
        image = load_image(image, mode='RGB')
        input_ = _img_encode(image, size=(size, size))[None, ...]
        model, labels = self._open_onnx_model(model_name)
        output, = model.run(['output'], {'input': input_})

        values = dict(zip(labels, map(lambda x: x.item(), output[0])))
        return values

    def create_gr(self):
        with gr.Tab(self.title):
            with gr.Row():
                with gr.Column():
                    gr_input_image = gr.Image(type='pil', label='Original Image')
                    gr_model = gr.Dropdown(self.models, value=self.default_model, label='Model')
                    gr_infer_size = gr.Slider(224, 640, value=384, step=32, label='Infer Size')
                    gr_submit = gr.Button(value='Submit', variant='primary')

                with gr.Column():
                    gr_output = gr.Label(label='Classes')

                gr_submit.click(
                    self._gr_classification,
                    inputs=[gr_input_image, gr_model, gr_infer_size],
                    outputs=[gr_output],
                )