Spaces:
Sleeping
Sleeping
import os | |
import pandas as pd | |
import torch | |
import gradio as gr | |
from model import DistMult | |
from PIL import Image | |
from torchvision import transforms | |
import json | |
from tqdm import tqdm | |
# Default image tensor normalization | |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN = [0.485, 0.456, 0.406] | |
_DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD = [0.229, 0.224, 0.225] | |
def generate_target_list(data, entity2id): | |
sub = data.loc[(data["datatype_h"] == "image") & (data["datatype_t"] == "id"), ['t']] | |
sub = list(sub['t']) | |
categories = [] | |
for item in tqdm(sub): | |
if entity2id[str(int(float(item)))] not in categories: | |
categories.append(entity2id[str(int(float(item)))]) | |
# print('categories = {}'.format(categories)) | |
# print("No. of target categories = {}".format(len(categories))) | |
return torch.tensor(categories, dtype=torch.long).unsqueeze(-1) | |
# Load necessary data and initialize the model | |
entity2id = json.load(open('entity2id_subtree.json', 'r')) | |
id2entity = {v: k for k, v in entity2id.items()} | |
datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False) | |
num_ent_id = len(entity2id) | |
target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere | |
# Initialize your model here | |
model = DistMult(args, num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary | |
model.eval() | |
# Define your evaluation function | |
def evaluate(img): | |
transform_steps = transforms.Compose([ | |
transforms.Resize((448, 448)), | |
transforms.ToTensor(), | |
transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD) | |
]) | |
h = transform_steps(img) | |
r = torch.tensor([3]) | |
# Assuming `move_to` is a function to move tensors to the desired device | |
h = h.unsqueeze(0) | |
r = r.unsqueeze(0) | |
outputs = model.forward_ce(h, r, triple_type=('image', 'id')) | |
y_pred = outputs.argmax(-1).cpu() | |
pred_label = target_list[y_pred].item() | |
species_label = overall_id_to_name[str(id2entity[pred_label])] | |
return species_label | |
# Gradio interface | |
species_model = gr.Interface( | |
evaluate, | |
gr.inputs.Image(shape=(200, 200)), | |
outputs="label", | |
title='Species Classification', | |
description='Species Classification', | |
article='Species Classification' | |
) | |
species_model.launch() | |