Update app.py
Browse files
app.py
CHANGED
@@ -32,7 +32,6 @@ LABEL_MAPS = {
|
|
32 |
}
|
33 |
|
34 |
def tube_mask_generator(mask_ratio):
|
35 |
-
# mask_ratio=0.8
|
36 |
window_size = (
|
37 |
num_frames // 2,
|
38 |
input_size // patch_size[0],
|
@@ -50,11 +49,11 @@ def tube_mask_generator(mask_ratio):
|
|
50 |
|
51 |
|
52 |
def get_model(data_type):
|
53 |
-
# data_type = 'K400'
|
54 |
ft_model = keras.models.load_model(MODELS[data_type][0])
|
55 |
pt_model = keras.models.load_model(MODELS[data_type][1])
|
56 |
|
57 |
label_map = LABEL_MAPS.get(data_type)
|
|
|
58 |
label_map = {v: k for k, v in label_map.items()}
|
59 |
|
60 |
return ft_model, pt_model, label_map
|
|
|
32 |
}
|
33 |
|
34 |
def tube_mask_generator(mask_ratio):
|
|
|
35 |
window_size = (
|
36 |
num_frames // 2,
|
37 |
input_size // patch_size[0],
|
|
|
49 |
|
50 |
|
51 |
def get_model(data_type):
|
|
|
52 |
ft_model = keras.models.load_model(MODELS[data_type][0])
|
53 |
pt_model = keras.models.load_model(MODELS[data_type][1])
|
54 |
|
55 |
label_map = LABEL_MAPS.get(data_type)
|
56 |
+
label_map = K400_label_map
|
57 |
label_map = {v: k for k, v in label_map.items()}
|
58 |
|
59 |
return ft_model, pt_model, label_map
|