|
import gradio as gr |
|
import torch |
|
from huggingface_hub import from_pretrained_fastai |
|
from pathlib import Path |
|
|
|
examples = ["examples/example_0.png", |
|
"examples/example_1.png", |
|
"examples/example_2.png", |
|
"examples/example_3.png", |
|
"examples/example_4.png"] |
|
|
|
repo_id = "hugginglearners/rice_image_classification" |
|
path = Path("./") |
|
|
|
def get_y(r): |
|
return r["label"] |
|
|
|
def get_x(r): |
|
return path/r["fname"] |
|
|
|
learner = from_pretrained_fastai(repo_id) |
|
|
|
def inference(image): |
|
label_predict,_,probs = learner.predict(image) |
|
labels_probs = {labels[i]: float(probs[i]) for i, _ in enumerate(labels)} |
|
return labels_probs |
|
|
|
gr.Interface( |
|
fn=inference, |
|
title="Rice image classification", |
|
description = "Predict which type of rice belong to Arborio, Basmati, Ipsala, Jasmine, Karacadag", |
|
inputs="image", |
|
examples=examples, |
|
output=gr.outputs.Label(num_top_classes=5, label='Prediction'), |
|
cache_examples=False, |
|
article = "Author: <a href=\"https://www.linkedin.com/in/vumichien/\">Vu Minh Chien</a>", |
|
).launch(debug=True, enable_queue=True) |