import gradio as gr import numpy as np import requests import tensorflow as tf from fastapi import FastAPI from io import BytesIO from PIL import Image from pydantic import BaseModel from cat_breeds_dict import CAT_BREEDS, CAT_DESCRIPTIONS from scripts.crop_image import crop_image class Url(BaseModel): link: str MODEL = tf.keras.models.load_model('./models/cats_18_EfficientNetB0.h5') GRADIO_PATH = '/' INPUT_SHAPE = MODEL.layers[0].input_shape[1] NUM_CLASSES = MODEL.layers[-1].output_shape[1] app = FastAPI() def predict(image, api_mode=False): image = crop_image(image, INPUT_SHAPE, INPUT_SHAPE) image = image.resize((INPUT_SHAPE, INPUT_SHAPE)) image = np.asarray(image) image = image.reshape(1, INPUT_SHAPE, INPUT_SHAPE, 3) prediction = MODEL.predict(image)[0] predicted_breed = CAT_BREEDS[np.argmax(prediction)] breed_description = CAT_DESCRIPTIONS[predicted_breed] all_predictions = { CAT_BREEDS[i]: float(prediction[i]) for i in range(NUM_CLASSES) } if api_mode: breed_description = ' '.join(breed_description.replace('\n', '.') .replace('#', '') .split()) return { 'breed': predicted_breed, 'description': breed_description, 'predictions': all_predictions } return all_predictions, breed_description, gr.HTML.update(visible=True), gr.Markdown.update(visible=True) @app.post('/predict_breed/') def predict_api(url: Url): try: image = requests.get(url.link).content except Exception as e: return {'error': 'Invalid link', 'exception': str(e)} image = Image.open(BytesIO(image)) return predict(image, api_mode=True) with gr.Blocks(css='./static/style.css', title="Cat Classifier") as gradio_ui: gr.Markdown( """ # Классификатор пород котов Разработано студентами Шершневым А.А, Онучиной М.К., Ивановым С.С, Шалаевой И.Г. и Ильиным С.С. Группы: РИМ-120906, РИМ-120907, РИМ-120908 """, elem_id='md-text' ) with gr.Row(elem_id='main-row') as row: with gr.Column(scale=2, elem_id='first-col') as col_1: user_image = gr.Image( label='Загрузите фотографию котика сюда', type='pil', elem_id='user-image' ) predict_button = gr.Button(value='Определить породу') with gr.Column(scale=1, elem_id='second-col') as col_2: predicted_labels = gr.Label( num_top_classes=5, label='Результат определения породы', elem_id='predictions-text' ) breed_description = gr.Markdown(elem_id='breed-description') banner_text = gr.Markdown( """ #
Места, которые будут Вам интересны
""", visible=False ) embedded_map = gr.HTML(''' ''', visible=False, elem_id='embedded-map') predict_button.click( fn=predict, inputs=[user_image], outputs=[ predicted_labels, breed_description, embedded_map, banner_text ] ) app = gr.mount_gradio_app(app, gradio_ui, path=GRADIO_PATH)