vardaan123 commited on
Commit
7218171
·
1 Parent(s): e6a3d86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -29,6 +29,7 @@ id2entity = {v: k for k, v in entity2id.items()}
29
  datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
30
  num_ent_id = len(entity2id)
31
  target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere
 
32
 
33
  # Initialize your model here
34
  model = DistMult(num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary
@@ -37,6 +38,7 @@ model.eval()
37
  # Define your evaluation function
38
  def evaluate(img):
39
  transform_steps = transforms.Compose([
 
40
  transforms.Resize((448, 448)),
41
  transforms.ToTensor(),
42
  transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)
@@ -64,4 +66,4 @@ species_model = gr.Interface(
64
  description='Species Classification',
65
  article='Species Classification'
66
  )
67
- species_model.launch(share=True)
 
29
  datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
30
  num_ent_id = len(entity2id)
31
  target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere
32
+ overall_id_to_name = json.load(open('overall_id_to_name.json'))
33
 
34
  # Initialize your model here
35
  model = DistMult(num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary
 
38
  # Define your evaluation function
39
  def evaluate(img):
40
  transform_steps = transforms.Compose([
41
+ transforms.ToPILImage(),
42
  transforms.Resize((448, 448)),
43
  transforms.ToTensor(),
44
  transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)
 
66
  description='Species Classification',
67
  article='Species Classification'
68
  )
69
+ species_model.launch(share=True)