Spaces:
Sleeping
Sleeping
Commit
·
7218171
1
Parent(s):
e6a3d86
Update app.py
Browse files
app.py
CHANGED
@@ -29,6 +29,7 @@ id2entity = {v: k for k, v in entity2id.items()}
|
|
29 |
datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
|
30 |
num_ent_id = len(entity2id)
|
31 |
target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere
|
|
|
32 |
|
33 |
# Initialize your model here
|
34 |
model = DistMult(num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary
|
@@ -37,6 +38,7 @@ model.eval()
|
|
37 |
# Define your evaluation function
|
38 |
def evaluate(img):
|
39 |
transform_steps = transforms.Compose([
|
|
|
40 |
transforms.Resize((448, 448)),
|
41 |
transforms.ToTensor(),
|
42 |
transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)
|
@@ -64,4 +66,4 @@ species_model = gr.Interface(
|
|
64 |
description='Species Classification',
|
65 |
article='Species Classification'
|
66 |
)
|
67 |
-
species_model.launch(share=True)
|
|
|
29 |
datacsv = pd.read_csv('dataset_subtree.csv', low_memory=False)
|
30 |
num_ent_id = len(entity2id)
|
31 |
target_list = generate_target_list(datacsv, entity2id) # Assuming this function is defined elsewhere
|
32 |
+
overall_id_to_name = json.load(open('overall_id_to_name.json'))
|
33 |
|
34 |
# Initialize your model here
|
35 |
model = DistMult(num_ent_id, target_list, torch.device('cpu')) # Update arguments as necessary
|
|
|
38 |
# Define your evaluation function
|
39 |
def evaluate(img):
|
40 |
transform_steps = transforms.Compose([
|
41 |
+
transforms.ToPILImage(),
|
42 |
transforms.Resize((448, 448)),
|
43 |
transforms.ToTensor(),
|
44 |
transforms.Normalize(_DEFAULT_IMAGE_TENSOR_NORMALIZATION_MEAN, _DEFAULT_IMAGE_TENSOR_NORMALIZATION_STD)
|
|
|
66 |
description='Species Classification',
|
67 |
article='Species Classification'
|
68 |
)
|
69 |
+
species_model.launch(share=True)
|