|
import albumentations |
|
import cv2 |
|
import torch |
|
import timm |
|
import gradio as gr |
|
import numpy as np |
|
import os |
|
import random |
|
|
|
device = torch.device('cpu') |
|
|
|
labels = { |
|
0: 'bacterial_leaf_blight', |
|
1: 'bacterial_leaf_streak', |
|
2: 'bacterial_panicle_blight', |
|
3: 'blast', |
|
4: 'brown_spot', |
|
5: 'dead_heart', |
|
6: 'downy_mildew', |
|
7: 'hispa', |
|
8: 'normal', |
|
9: 'tungro' |
|
} |
|
|
|
def inference_fn(model, image=None): |
|
model.eval() |
|
image = image.to(device) |
|
with torch.no_grad(): |
|
output = model(image.unsqueeze(0)) |
|
out = output.sigmoid().detach().cpu().numpy().flatten() |
|
return out |
|
|
|
|
|
def predict(image=None) -> dict: |
|
mean = (0.485, 0.456, 0.406) |
|
std = (0.229, 0.224, 0.225) |
|
|
|
augmentations = albumentations.Compose( |
|
[ |
|
albumentations.Resize(256, 256), |
|
albumentations.HorizontalFlip(p=0.5), |
|
albumentations.VerticalFlip(p=0.5), |
|
albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True), |
|
] |
|
) |
|
|
|
augmented = augmentations(image=image) |
|
image = augmented["image"] |
|
image = np.transpose(image, (2, 0, 1)) |
|
image = torch.tensor(image, dtype=torch.float32) |
|
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10) |
|
model.load_state_dict(torch.load("paddy_model.pth", map_location=torch.device(device))) |
|
model.to(device) |
|
|
|
predicted = inference_fn(model, image) |
|
|
|
return {labels[i]: float(predicted[i]) for i in range(10)} |
|
|
|
|
|
gr.Interface(fn=predict, |
|
inputs=gr.inputs.Image(), |
|
outputs=gr.outputs.Label(num_top_classes=10), |
|
examples=["200005.jpg", "200006.jpg"], interpretation='default').launch() |