query-parser / app.py
avsolatorio's picture
Use the predict_entities method
0864cfe
raw
history blame
5.75 kB
import spaces
import os
import json
import gradio as gr
import pycountry
import torch
from datetime import datetime
from typing import Dict, Union
from gliner import GLiNER
_MODEL = {}
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
THRESHOLD = 0.3
LABELS = ["country", "year", "statistical indicator", "geographic region"]
QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1"]
print(f"Cache directory: {_CACHE_DIR}")
def get_model(model_name: str = None):
start = datetime.now()
if model_name is None:
model_name = "urchade/gliner_base"
global _MODEL
if _MODEL.get(model_name) is None:
_MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)
if torch.cuda.is_available() and not next(_MODEL[model_name].parameters()).device.type.startswith("cuda"):
_MODEL[model_name] = _MODEL[model_name].to("cuda")
print(f"{datetime.now()} :: get_model :: {datetime.now() - start}")
return _MODEL[model_name]
def get_country(country_name: str):
try:
return pycountry.countries.search_fuzzy(country_name)
except LookupError:
return None
@spaces.GPU(enable_queue=True, duration=5)
def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
start = datetime.now()
model = get_model(model_name)
if isinstance(labels, str):
labels = [i.strip() for i in labels.split(",")]
entities = model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner)
print(f"{datetime.now()} :: predict_entities :: {datetime.now() - start}")
return entities
def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
entities = []
_entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)
for entity in _entities:
if entity["label"] == "country":
country = get_country(entity["text"])
if country:
entity["normalized"] = [dict(c) for c in country]
entities.append(entity)
else:
entities.append(entity)
payload = {"query": query, "entities": entities}
print(f"{datetime.now()} :: parse_query :: {json.dumps(payload)}\n")
return payload
# Initialize model here.
print("Initializing models...")
for model_name in MODELS:
predict_entities(model_name, QUERY, LABELS, threshold=THRESHOLD)
with gr.Blocks(title="GLiNER-query-parser") as demo:
gr.Markdown(
"""
# GLiNER-based Query Parser (a zero-shot NER model)
This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license.
## Links
* Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base
* All GLiNER models: https://huggingface.co/models?library=gliner
* Paper: https://arxiv.org/abs/2311.08526
* Repository: https://github.com/urchade/GLiNER
"""
)
query = gr.Textbox(
value=QUERY, label="query", placeholder="Enter your query here"
)
with gr.Row() as row:
model_name = gr.Radio(
choices=MODELS,
value="urchade/gliner_base",
label="Model",
)
entities = gr.Textbox(
value=", ".join(LABELS),
label="entities",
placeholder="Enter the entities to detect here (comma separated)",
scale=2,
)
threshold = gr.Slider(
0,
1,
value=THRESHOLD,
step=0.01,
label="Threshold",
info="Lower threshold may extract more false-positive entities from the query.",
scale=1,
)
is_nested = gr.Checkbox(
value=False,
label="Nested NER",
info="Setting to True extracts nested entities",
scale=0,
)
output = gr.JSON(label="Extracted entities")
submit_btn = gr.Button("Submit")
# Submitting
query.submit(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
entities.submit(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
threshold.release(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
submit_btn.click(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
is_nested.change(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
model_name.change(
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
)
demo.queue(default_concurrency_limit=5)
demo.launch(debug=True)
"""
from gradio_client import Client
client = Client("avsolatorio/query-parser")
result = client.predict(
query="gdp, m3, and child mortality of india and southeast asia 2024",
labels="country, year, statistical indicator, region",
threshold=0.3,
nested_ner=False,
api_name="/parse_query"
)
print(result)
"""