Spaces:
Running
Running
import spaces | |
import os | |
import json | |
import gradio as gr | |
import pycountry | |
import torch | |
from datetime import datetime | |
from typing import Dict, Union | |
from gliner import GLiNER | |
_MODEL = {} | |
_CACHE_DIR = os.environ.get("CACHE_DIR", None) | |
THRESHOLD = 0.3 | |
LABELS = ["country", "year", "statistical indicator", "geographic region"] | |
QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024" | |
MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1"] | |
print(f"Cache directory: {_CACHE_DIR}") | |
def get_model(model_name: str = None): | |
start = datetime.now() | |
if model_name is None: | |
model_name = "urchade/gliner_base" | |
global _MODEL | |
if _MODEL.get(model_name) is None: | |
_MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR) | |
if torch.cuda.is_available() and not next(_MODEL[model_name].parameters()).device.type.startswith("cuda"): | |
_MODEL[model_name] = _MODEL[model_name].to("cuda") | |
print(f"{datetime.now()} :: get_model :: {datetime.now() - start}") | |
return _MODEL[model_name] | |
def get_country(country_name: str): | |
try: | |
return pycountry.countries.search_fuzzy(country_name) | |
except LookupError: | |
return None | |
def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False): | |
start = datetime.now() | |
model = get_model(model_name) | |
if isinstance(labels, str): | |
labels = [i.strip() for i in labels.split(",")] | |
entities = model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner) | |
print(f"{datetime.now()} :: predict_entities :: {datetime.now() - start}") | |
return entities | |
def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]: | |
entities = [] | |
_entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner) | |
for entity in _entities: | |
if entity["label"] == "country": | |
country = get_country(entity["text"]) | |
if country: | |
entity["normalized"] = [dict(c) for c in country] | |
entities.append(entity) | |
else: | |
entities.append(entity) | |
payload = {"query": query, "entities": entities} | |
print(f"{datetime.now()} :: parse_query :: {json.dumps(payload)}\n") | |
return payload | |
# Initialize model here. | |
print("Initializing models...") | |
for model_name in MODELS: | |
predict_entities(model_name, QUERY, LABELS, threshold=THRESHOLD) | |
with gr.Blocks(title="GLiNER-query-parser") as demo: | |
gr.Markdown( | |
""" | |
# GLiNER-based Query Parser (a zero-shot NER model) | |
This space demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the `pycountry` library. This GLiNER mode is licensed under the Apache 2.0 license. | |
## Links | |
* Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base | |
* All GLiNER models: https://huggingface.co/models?library=gliner | |
* Paper: https://arxiv.org/abs/2311.08526 | |
* Repository: https://github.com/urchade/GLiNER | |
""" | |
) | |
query = gr.Textbox( | |
value=QUERY, label="query", placeholder="Enter your query here" | |
) | |
with gr.Row() as row: | |
model_name = gr.Radio( | |
choices=MODELS, | |
value="urchade/gliner_base", | |
label="Model", | |
) | |
entities = gr.Textbox( | |
value=", ".join(LABELS), | |
label="entities", | |
placeholder="Enter the entities to detect here (comma separated)", | |
scale=2, | |
) | |
threshold = gr.Slider( | |
0, | |
1, | |
value=THRESHOLD, | |
step=0.01, | |
label="Threshold", | |
info="Lower threshold may extract more false-positive entities from the query.", | |
scale=1, | |
) | |
is_nested = gr.Checkbox( | |
value=False, | |
label="Nested NER", | |
info="Setting to True extracts nested entities", | |
scale=0, | |
) | |
output = gr.JSON(label="Extracted entities") | |
submit_btn = gr.Button("Submit") | |
# Submitting | |
query.submit( | |
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output | |
) | |
entities.submit( | |
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output | |
) | |
threshold.release( | |
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output | |
) | |
submit_btn.click( | |
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output | |
) | |
is_nested.change( | |
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output | |
) | |
model_name.change( | |
fn=parse_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output | |
) | |
demo.queue(default_concurrency_limit=5) | |
demo.launch(debug=True) | |
""" | |
from gradio_client import Client | |
client = Client("avsolatorio/query-parser") | |
result = client.predict( | |
query="gdp, m3, and child mortality of india and southeast asia 2024", | |
labels="country, year, statistical indicator, region", | |
threshold=0.3, | |
nested_ner=False, | |
api_name="/parse_query" | |
) | |
print(result) | |
""" | |