Spaces:
Build error
Build error
Upload 2 files
Browse files
app.py
CHANGED
@@ -14,7 +14,8 @@ from utils import (
|
|
14 |
format_query,
|
15 |
get_flan_alpaca_xl_model,
|
16 |
generate_alpaca_ner_prompt,
|
17 |
-
|
|
|
18 |
format_entities_flan_alpaca,
|
19 |
generate_flant5_prompt_instruct_chunk_context,
|
20 |
generate_flant5_prompt_instruct_chunk_context_single,
|
@@ -56,9 +57,7 @@ col1, col2 = st.columns([3, 3], gap="medium")
|
|
56 |
with st.sidebar:
|
57 |
ner_choice = st.selectbox("Select NER Model", ["Alpaca", "Spacy"])
|
58 |
|
59 |
-
if ner_choice == "
|
60 |
-
ner_model, ner_tokenizer = get_flan_alpaca_xl_model()
|
61 |
-
else:
|
62 |
ner_model = get_spacy_model()
|
63 |
|
64 |
with col1:
|
@@ -70,7 +69,7 @@ with col1:
|
|
70 |
|
71 |
if ner_choice == "Alpaca":
|
72 |
ner_prompt = generate_alpaca_ner_prompt(query_text)
|
73 |
-
entity_text =
|
74 |
company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(entity_text)
|
75 |
else:
|
76 |
company_ent, quarter_ent, year_ent = extract_entities(query_text, ner_model)
|
|
|
14 |
format_query,
|
15 |
get_flan_alpaca_xl_model,
|
16 |
generate_alpaca_ner_prompt,
|
17 |
+
generate_entities_flan_alpaca_checkpoint,
|
18 |
+
generate_entities_flan_alpaca_inference_api,
|
19 |
format_entities_flan_alpaca,
|
20 |
generate_flant5_prompt_instruct_chunk_context,
|
21 |
generate_flant5_prompt_instruct_chunk_context_single,
|
|
|
57 |
with st.sidebar:
|
58 |
ner_choice = st.selectbox("Select NER Model", ["Alpaca", "Spacy"])
|
59 |
|
60 |
+
if ner_choice == "Spacy":
|
|
|
|
|
61 |
ner_model = get_spacy_model()
|
62 |
|
63 |
with col1:
|
|
|
69 |
|
70 |
if ner_choice == "Alpaca":
|
71 |
ner_prompt = generate_alpaca_ner_prompt(query_text)
|
72 |
+
entity_text = generate_entities_flan_alpaca_inference_api(ner_prompt)
|
73 |
company_ent, quarter_ent, year_ent = format_entities_flan_alpaca(entity_text)
|
74 |
else:
|
75 |
company_ent, quarter_ent, year_ent = extract_entities(query_text, ner_model)
|
utils.py
CHANGED
@@ -1,4 +1,6 @@
|
|
1 |
import re
|
|
|
|
|
2 |
|
3 |
import openai
|
4 |
import pandas as pd
|
@@ -513,8 +515,20 @@ Company - Cisco, Quarter - none, Year - none
|
|
513 |
### Response:"""
|
514 |
return prompt
|
515 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
516 |
|
517 |
-
def
|
518 |
model_inputs = tokenizer(prompt, return_tensors="pt")
|
519 |
input_ids = inputs["input_ids"]
|
520 |
generation_output = model.generate(
|
|
|
1 |
import re
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
|
5 |
import openai
|
6 |
import pandas as pd
|
|
|
515 |
### Response:"""
|
516 |
return prompt
|
517 |
|
518 |
+
def generate_entities_flan_alpaca_inference_api(prompt):
|
519 |
+
API_URL = "https://api-inference.huggingface.co/models/declare-lab/flan-alpaca-xl"
|
520 |
+
payload = {
|
521 |
+
"inputs": prompt,
|
522 |
+
"parameters": {"do_sample": True, "temperature":0.1, "max_length":80},
|
523 |
+
"options": {"use_cache": True, "wait_for_model": True}
|
524 |
+
}
|
525 |
+
data = json.dumps(payload)
|
526 |
+
response = requests.request("POST", API_URL, data=data)
|
527 |
+
output = json.loads(response.content.decode("utf-8"))[0]["generated_text"]
|
528 |
+
return output
|
529 |
+
|
530 |
|
531 |
+
def generate_entities_flan_alpaca_checkpoint(model, tokenizer, prompt):
|
532 |
model_inputs = tokenizer(prompt, return_tensors="pt")
|
533 |
input_ids = inputs["input_ids"]
|
534 |
generation_output = model.generate(
|