ShAnSantosh's picture
Update app.py
8a41460
import albumentations
import cv2
import torch
import timm
import gradio as gr
import numpy as np
import os
import random
device = torch.device('cpu')
labels = {
0: 'bacterial_leaf_blight',
1: 'bacterial_leaf_streak',
2: 'bacterial_panicle_blight',
3: 'blast',
4: 'brown_spot',
5: 'dead_heart',
6: 'downy_mildew',
7: 'hispa',
8: 'normal',
9: 'tungro'
}
def inference_fn(model, image=None):
model.eval()
image = image.to(device)
with torch.no_grad():
output = model(image.unsqueeze(0))
out = output.sigmoid().detach().cpu().numpy().flatten()
return out
def predict(image=None) -> dict:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
augmentations = albumentations.Compose(
[
albumentations.Resize(256, 256),
albumentations.HorizontalFlip(p=0.5),
albumentations.VerticalFlip(p=0.5),
albumentations.Normalize(mean, std, max_pixel_value=255.0, always_apply=True),
]
)
augmented = augmentations(image=image)
image = augmented["image"]
image = np.transpose(image, (2, 0, 1))
image = torch.tensor(image, dtype=torch.float32)
model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=10)
model.load_state_dict(torch.load("paddy_model.pth", map_location=torch.device(device)))
model.to(device)
predicted = inference_fn(model, image)
return {labels[i]: float(predicted[i]) for i in range(10)}
gr.Interface(fn=predict,
inputs=gr.inputs.Image(),
outputs=gr.outputs.Label(num_top_classes=10),
examples=["200005.jpg", "200006.jpg"], interpretation='default').launch()