Spaces:
Sleeping
Sleeping
File size: 3,724 Bytes
217866d 80d68c7 217866d 80d68c7 217866d 80d68c7 217866d 80d68c7 217866d 80d68c7 217866d 80d68c7 202d636 217866d 202d636 80d68c7 217866d 202d636 217866d b96cef5 217866d 57c7f8e 217866d 57c7f8e 217866d 57c7f8e 217866d 80d68c7 217866d 80d68c7 217866d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import gradio as gr
import numpy as np
import requests
import tensorflow as tf
from fastapi import FastAPI
from io import BytesIO
from PIL import Image
from pydantic import BaseModel
from cat_breeds_dict import CAT_BREEDS, CAT_DESCRIPTIONS
from scripts.crop_image import crop_image
class Url(BaseModel):
link: str
MODEL = tf.keras.models.load_model('./models/cats_18_EfficientNetB0.h5')
GRADIO_PATH = '/'
INPUT_SHAPE = MODEL.layers[0].input_shape[1]
NUM_CLASSES = MODEL.layers[-1].output_shape[1]
app = FastAPI()
def predict(image, api_mode=False):
image = crop_image(image, INPUT_SHAPE, INPUT_SHAPE)
image = image.resize((INPUT_SHAPE, INPUT_SHAPE))
image = np.asarray(image)
image = image.reshape(1, INPUT_SHAPE, INPUT_SHAPE, 3)
prediction = MODEL.predict(image)[0]
predicted_breed = CAT_BREEDS[np.argmax(prediction)]
breed_description = CAT_DESCRIPTIONS[predicted_breed]
all_predictions = {
CAT_BREEDS[i]: float(prediction[i]) for i in range(NUM_CLASSES)
}
if api_mode:
breed_description = ' '.join(breed_description.replace('\n', '.')
.replace('#', '')
.split())
return {
'breed': predicted_breed,
'description': breed_description,
'predictions': all_predictions
}
return all_predictions, breed_description, gr.HTML.update(visible=True), gr.Markdown.update(visible=True)
@app.post('/predict_breed/')
def predict_api(url: Url):
try:
image = requests.get(url.link).content
except Exception as e:
return {'error': 'Invalid link', 'exception': str(e)}
image = Image.open(BytesIO(image))
return predict(image, api_mode=True)
with gr.Blocks(css='./static/style.css', title="Cat Classifier") as gradio_ui:
gr.Markdown(
"""
# Классификатор пород котов
Разработано студентами Шершневым А.А, Онучиной М.К., Ивановым С.С, Шалаевой И.Г. и
Ильиным С.С.
Группы: РИМ-120906, РИМ-120907, РИМ-120908
""",
elem_id='md-text'
)
with gr.Row(elem_id='main-row') as row:
with gr.Column(scale=2, elem_id='first-col') as col_1:
user_image = gr.Image(
label='Загрузите фотографию котика сюда',
type='pil',
elem_id='user-image'
)
predict_button = gr.Button(value='Определить породу')
with gr.Column(scale=1, elem_id='second-col') as col_2:
predicted_labels = gr.Label(
num_top_classes=5,
label='Результат определения породы',
elem_id='predictions-text'
)
breed_description = gr.Markdown(elem_id='breed-description')
banner_text = gr.Markdown(
"""
# <center>Места, которые будут Вам интересны</center>
""", visible=False
)
embedded_map = gr.HTML('''
<iframe src="https://yandex.ru/map-widget/v1/?um=constructor%3A8cead4799165c7f6356c4f269f2847032ef2803cb46871dbfd6dd68c09834f4c&source=constructor" width="100%" height="500" frameborder="0"></iframe>
''', visible=False, elem_id='embedded-map')
predict_button.click(
fn=predict,
inputs=[user_image],
outputs=[
predicted_labels, breed_description,
embedded_map, banner_text
]
)
app = gr.mount_gradio_app(app, gradio_ui, path=GRADIO_PATH)
|