Spaces:
Running
Running
avsolatorio
commited on
Commit
·
49fed2a
1
Parent(s):
2d67881
Initialize models
Browse filesSigned-off-by: Aivin V. Solatorio <avsolatorio@gmail.com>
app.py
CHANGED
@@ -11,8 +11,10 @@ from gliner import GLiNER
|
|
11 |
|
12 |
_MODEL = {}
|
13 |
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
|
|
|
14 |
LABELS = ["country", "year", "statistical indicator", "geographic region"]
|
15 |
QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
|
|
|
16 |
|
17 |
print(f"Cache directory: {_CACHE_DIR}")
|
18 |
|
@@ -36,6 +38,13 @@ def get_model(model_name: str = None):
|
|
36 |
return _MODEL[model_name]
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
def get_country(country_name: str):
|
40 |
try:
|
41 |
return pycountry.countries.search_fuzzy(country_name)
|
@@ -43,7 +52,7 @@ def get_country(country_name: str):
|
|
43 |
return None
|
44 |
|
45 |
|
46 |
-
@spaces.GPU(enable_queue=True, duration=
|
47 |
def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
|
48 |
start = datetime.now()
|
49 |
model = get_model(model_name)
|
@@ -99,7 +108,7 @@ with gr.Blocks(title="GLiNER-query-parser") as demo:
|
|
99 |
)
|
100 |
with gr.Row() as row:
|
101 |
model_name = gr.Radio(
|
102 |
-
choices=
|
103 |
value="urchade/gliner_base",
|
104 |
label="Model",
|
105 |
)
|
@@ -112,7 +121,7 @@ with gr.Blocks(title="GLiNER-query-parser") as demo:
|
|
112 |
threshold = gr.Slider(
|
113 |
0,
|
114 |
1,
|
115 |
-
value=
|
116 |
step=0.01,
|
117 |
label="Threshold",
|
118 |
info="Lower threshold may extract more false-positive entities from the query.",
|
|
|
11 |
|
12 |
_MODEL = {}
|
13 |
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
|
14 |
+
THRESHOLD = 0.3
|
15 |
LABELS = ["country", "year", "statistical indicator", "geographic region"]
|
16 |
QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
|
17 |
+
MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1"]
|
18 |
|
19 |
print(f"Cache directory: {_CACHE_DIR}")
|
20 |
|
|
|
38 |
return _MODEL[model_name]
|
39 |
|
40 |
|
41 |
+
# Initialize model here.
|
42 |
+
print("Initializing models...")
|
43 |
+
for model_name in MODELS:
|
44 |
+
model = get_model(model_name=model_name)
|
45 |
+
model.predict_entities(QUERY, LABELS, threshold=THRESHOLD)
|
46 |
+
|
47 |
+
|
48 |
def get_country(country_name: str):
|
49 |
try:
|
50 |
return pycountry.countries.search_fuzzy(country_name)
|
|
|
52 |
return None
|
53 |
|
54 |
|
55 |
+
@spaces.GPU(enable_queue=True, duration=5)
|
56 |
def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
|
57 |
start = datetime.now()
|
58 |
model = get_model(model_name)
|
|
|
108 |
)
|
109 |
with gr.Row() as row:
|
110 |
model_name = gr.Radio(
|
111 |
+
choices=MODELS,
|
112 |
value="urchade/gliner_base",
|
113 |
label="Model",
|
114 |
)
|
|
|
121 |
threshold = gr.Slider(
|
122 |
0,
|
123 |
1,
|
124 |
+
value=THRESHOLD,
|
125 |
step=0.01,
|
126 |
label="Threshold",
|
127 |
info="Lower threshold may extract more false-positive entities from the query.",
|