File size: 3,724 Bytes
217866d
 
 
 
 
 
 
 
 
80d68c7
217866d
 
 
 
 
 
80d68c7
217866d
80d68c7
 
217866d
 
 
 
80d68c7
 
217866d
80d68c7
217866d
 
 
 
 
80d68c7
 
 
202d636
217866d
202d636
 
 
 
 
 
 
 
 
80d68c7
217866d
 
 
 
 
 
202d636
 
217866d
 
 
 
b96cef5
217866d
 
 
 
57c7f8e
217866d
57c7f8e
217866d
 
 
57c7f8e
217866d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80d68c7
 
 
 
 
 
 
 
217866d
 
 
 
80d68c7
 
 
 
217866d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import numpy as np
import requests
import tensorflow as tf
from fastapi import FastAPI
from io import BytesIO
from PIL import Image
from pydantic import BaseModel
from cat_breeds_dict import CAT_BREEDS, CAT_DESCRIPTIONS
from scripts.crop_image import crop_image


class Url(BaseModel):
    link: str


MODEL = tf.keras.models.load_model('./models/cats_18_EfficientNetB0.h5')
GRADIO_PATH = '/'
INPUT_SHAPE = MODEL.layers[0].input_shape[1]
NUM_CLASSES = MODEL.layers[-1].output_shape[1]
app = FastAPI()


def predict(image, api_mode=False):
    image = crop_image(image, INPUT_SHAPE, INPUT_SHAPE)
    image = image.resize((INPUT_SHAPE, INPUT_SHAPE))
    image = np.asarray(image)
    image = image.reshape(1, INPUT_SHAPE, INPUT_SHAPE, 3)

    prediction = MODEL.predict(image)[0]
    predicted_breed = CAT_BREEDS[np.argmax(prediction)]
    breed_description = CAT_DESCRIPTIONS[predicted_breed]

    all_predictions = {
        CAT_BREEDS[i]: float(prediction[i]) for i in range(NUM_CLASSES)
        }

    if api_mode:
        breed_description = ' '.join(breed_description.replace('\n', '.')
                                                      .replace('#', '')
                                                      .split())

        return {
            'breed': predicted_breed,
            'description': breed_description,
            'predictions': all_predictions
            }
    return all_predictions, breed_description, gr.HTML.update(visible=True), gr.Markdown.update(visible=True)


@app.post('/predict_breed/')
def predict_api(url: Url):
    try:
        image = requests.get(url.link).content
    except Exception as e:
        return {'error': 'Invalid link', 'exception': str(e)}
    image = Image.open(BytesIO(image))
    return predict(image, api_mode=True)


with gr.Blocks(css='./static/style.css', title="Cat Classifier") as gradio_ui:

    gr.Markdown(
        """
        # Классификатор пород котов
        Разработано студентами Шершневым А.А, Онучиной М.К., Ивановым С.С, Шалаевой И.Г. и
        Ильиным С.С.
        Группы: РИМ-120906, РИМ-120907, РИМ-120908
        """,
        elem_id='md-text'
    )
    
    with gr.Row(elem_id='main-row') as row:

        with gr.Column(scale=2, elem_id='first-col') as col_1:
            user_image = gr.Image(
                label='Загрузите фотографию котика сюда',
                type='pil',
                elem_id='user-image'
            )
            predict_button = gr.Button(value='Определить породу')

        with gr.Column(scale=1, elem_id='second-col') as col_2:
            predicted_labels = gr.Label(
                num_top_classes=5,
                label='Результат определения породы',
                elem_id='predictions-text'
            )

    breed_description = gr.Markdown(elem_id='breed-description')
    banner_text = gr.Markdown(
        """
        # <center>Места, которые будут Вам интересны</center>
        """, visible=False
    )
    embedded_map = gr.HTML('''
    <iframe src="https://yandex.ru/map-widget/v1/?um=constructor%3A8cead4799165c7f6356c4f269f2847032ef2803cb46871dbfd6dd68c09834f4c&amp;source=constructor" width="100%" height="500" frameborder="0"></iframe>
    ''', visible=False, elem_id='embedded-map')

    predict_button.click(
        fn=predict,
        inputs=[user_image],
        outputs=[
            predicted_labels, breed_description,
            embedded_map, banner_text
            ]
    )

    app = gr.mount_gradio_app(app, gradio_ui, path=GRADIO_PATH)