Spaces:
Sleeping
Sleeping
ZeroCommand
commited on
Commit
·
263c173
1
Parent(s):
09f3a52
change model_type to inference_type
Browse files- app.py +3 -3
- config.yaml +1 -1
- utils.py +8 -8
app.py
CHANGED
@@ -11,7 +11,7 @@ import json
|
|
11 |
from transformers.pipelines import TextClassificationPipeline
|
12 |
|
13 |
from text_classification import check_column_mapping_keys_validity, text_classification_fix_column_mapping
|
14 |
-
from utils import read_scanners, write_scanners,
|
15 |
|
16 |
HF_REPO_ID = 'HF_REPO_ID'
|
17 |
HF_SPACE_ID = 'SPACE_ID'
|
@@ -266,7 +266,7 @@ with gr.Blocks(theme=theme) as iface:
|
|
266 |
''')
|
267 |
with gr.Row():
|
268 |
run_local = gr.Checkbox(value=True, label="Run in this Space")
|
269 |
-
use_inference =
|
270 |
run_inference = gr.Checkbox(value=use_inference, label="Run with Inference API")
|
271 |
|
272 |
with gr.Row() as advanced_row:
|
@@ -347,7 +347,7 @@ with gr.Blocks(theme=theme) as iface:
|
|
347 |
outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe, feature_mapping_dataframe])
|
348 |
scanners.change(write_scanners, inputs=scanners)
|
349 |
run_inference.change(
|
350 |
-
|
351 |
inputs=[run_inference]
|
352 |
)
|
353 |
|
|
|
11 |
from transformers.pipelines import TextClassificationPipeline
|
12 |
|
13 |
from text_classification import check_column_mapping_keys_validity, text_classification_fix_column_mapping
|
14 |
+
from utils import read_scanners, write_scanners, read_inference_type, write_inference_type, convert_column_mapping_to_json
|
15 |
|
16 |
HF_REPO_ID = 'HF_REPO_ID'
|
17 |
HF_SPACE_ID = 'SPACE_ID'
|
|
|
266 |
''')
|
267 |
with gr.Row():
|
268 |
run_local = gr.Checkbox(value=True, label="Run in this Space")
|
269 |
+
use_inference = read_inference_type('./config.yaml')[0] == 'hf_inference_api'
|
270 |
run_inference = gr.Checkbox(value=use_inference, label="Run with Inference API")
|
271 |
|
272 |
with gr.Row() as advanced_row:
|
|
|
347 |
outputs=[run_btn, loading_row, preview_row, example_input, example_labels, id2label_mapping_dataframe, feature_mapping_dataframe])
|
348 |
scanners.change(write_scanners, inputs=scanners)
|
349 |
run_inference.change(
|
350 |
+
write_inference_type,
|
351 |
inputs=[run_inference]
|
352 |
)
|
353 |
|
config.yaml
CHANGED
@@ -6,5 +6,5 @@ detectors:
|
|
6 |
- underconfidence
|
7 |
- overconfidence
|
8 |
- spurious_correlation
|
9 |
-
|
10 |
- hf_inference_api
|
|
|
6 |
- underconfidence
|
7 |
- overconfidence
|
8 |
- spurious_correlation
|
9 |
+
inference_type:
|
10 |
- hf_inference_api
|
utils.py
CHANGED
@@ -26,23 +26,23 @@ def write_scanners(scanners):
|
|
26 |
yaml.dump(config, f, Dumper=Dumper)
|
27 |
|
28 |
# read model_type from yaml file
|
29 |
-
def
|
30 |
-
|
31 |
with open(path, "r") as f:
|
32 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
33 |
-
|
34 |
-
return
|
35 |
|
36 |
# write model_type to yaml file
|
37 |
-
def
|
38 |
with open(YAML_PATH, "r") as f:
|
39 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
40 |
if use_inference:
|
41 |
-
config["
|
42 |
else:
|
43 |
-
config["
|
44 |
with open(YAML_PATH, "w") as f:
|
45 |
-
# save
|
46 |
yaml.dump(config, f, Dumper=Dumper)
|
47 |
|
48 |
# convert column mapping dataframe to json
|
|
|
26 |
yaml.dump(config, f, Dumper=Dumper)
|
27 |
|
28 |
# read model_type from yaml file
|
29 |
+
def read_inference_type(path):
|
30 |
+
inference_type = ""
|
31 |
with open(path, "r") as f:
|
32 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
33 |
+
inference_type = config.get("inference_type", None)
|
34 |
+
return inference_type
|
35 |
|
36 |
# write model_type to yaml file
|
37 |
+
def write_inference_type(use_inference):
|
38 |
with open(YAML_PATH, "r") as f:
|
39 |
config = yaml.load(f, Loader=yaml.FullLoader)
|
40 |
if use_inference:
|
41 |
+
config["inference_type"] = ['hf_inference_api']
|
42 |
else:
|
43 |
+
config["inference_type"] = ['hf_pipeline']
|
44 |
with open(YAML_PATH, "w") as f:
|
45 |
+
# save inference_type to inference_type in yaml
|
46 |
yaml.dump(config, f, Dumper=Dumper)
|
47 |
|
48 |
# convert column mapping dataframe to json
|