Spaces:
Runtime error
Runtime error
updated needed files
Browse files- labeling.py +0 -125
- main.py +0 -71
- ner_data_construction.py +0 -70
- predict.py +0 -14
labeling.py
DELETED
@@ -1,125 +0,0 @@
|
|
1 |
-
import jsonlines
|
2 |
-
import json
|
3 |
-
from tqdm import tqdm
|
4 |
-
import time
|
5 |
-
from openai import error as openai_error
|
6 |
-
import pandas as pd
|
7 |
-
import openai
|
8 |
-
import time
|
9 |
-
import tiktoken
|
10 |
-
import os
|
11 |
-
import glob
|
12 |
-
|
13 |
-
GPT_MODEL = 'gpt-3.5-turbo'
|
14 |
-
GPT_TOKEN_LIMIT = 1500
|
15 |
-
os.environ["OPENAI_API_KEY"] = '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'
|
16 |
-
# os.environ["OPENAI_API_KEY"] = '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'
|
17 |
-
openai.api_key = os.environ["OPENAI_API_KEY"]
|
18 |
-
|
19 |
-
LAST_INDEX_FILE_ADDR = 'last_index.txt'
|
20 |
-
TOKEN_COUNT_FILE_ADDR = 'tikitoken_count.txt'
|
21 |
-
|
22 |
-
def num_tokens(text: str, model: str = GPT_MODEL) -> int:
|
23 |
-
"""Return the number of tokens in a string."""
|
24 |
-
encoding = tiktoken.encoding_for_model(model)
|
25 |
-
return len(encoding.encode(text))
|
26 |
-
|
27 |
-
|
28 |
-
def extract_seen_ids():
|
29 |
-
seen_ids = set()
|
30 |
-
for tagged_data_addr in glob.iglob('./tagged_data*'):
|
31 |
-
seen_ids.update([json.loads(line)['id'] for line in open(tagged_data_addr)])
|
32 |
-
return seen_ids
|
33 |
-
|
34 |
-
|
35 |
-
def get_keyphrase_by_gpt(document) -> str:
|
36 |
-
global error_count
|
37 |
-
# prompt = 'extract main keywords from below document as sorted list (sort by importance). you should not use numbers for counting them. you should generate less than 10 keywords.'
|
38 |
-
# prompt = 'Output only valid JSON list. Please extract the main keywords from the following document. The keywords should be in a comma-separated list, sorted by their importance. Do not use numbers to count the keywords. Try to generate less than 10 keywords.'
|
39 |
-
prompt = 'there is a popular NLP task named KPE (keyphrase Extraction). please extract keyphrases from below article as a perfect Persian KPE model. '
|
40 |
-
role_prompt = 'return your answer using json list format'
|
41 |
-
message = prompt + '\n' + document
|
42 |
-
# message = prompt + '\n' + document
|
43 |
-
# message = document
|
44 |
-
messages = [
|
45 |
-
# {"role": "system", "content": "Output only valid JSON list"},
|
46 |
-
{"role": "system", "content": role_prompt},
|
47 |
-
{"role": "user", "content": message},
|
48 |
-
]
|
49 |
-
try:
|
50 |
-
response = openai.ChatCompletion.create(
|
51 |
-
model=GPT_MODEL,
|
52 |
-
messages=messages,
|
53 |
-
temperature=0
|
54 |
-
)
|
55 |
-
response_message = response["choices"][0]["message"]["content"]
|
56 |
-
error_count = 0
|
57 |
-
return response_message
|
58 |
-
except Exception as e:
|
59 |
-
if error_count > 3:
|
60 |
-
raise e
|
61 |
-
error_count += 1
|
62 |
-
time.sleep(20)
|
63 |
-
return []
|
64 |
-
|
65 |
-
#input_data = [json.load(line) for line in open('all_data.json').read().splitlines())
|
66 |
-
#input_data = open('all_data.json')
|
67 |
-
input_data = pd.read_csv('truncated_wiki_plus_shuffled_41203.csv')
|
68 |
-
#print('len input data : ', len(input_data))
|
69 |
-
try:
|
70 |
-
last_index = int(open(LAST_INDEX_FILE_ADDR).read())
|
71 |
-
print('load last index: ', last_index)
|
72 |
-
except:
|
73 |
-
print('error in loading last index')
|
74 |
-
last_index = 0
|
75 |
-
|
76 |
-
|
77 |
-
try:
|
78 |
-
token_count = int(open(TOKEN_COUNT_FILE_ADDR).read())
|
79 |
-
print('load token count: ', token_count)
|
80 |
-
except:
|
81 |
-
print('error in loading token_count')
|
82 |
-
token_count = 0
|
83 |
-
|
84 |
-
json_f_writer = jsonlines.open(f'tagged_data.jsonl_{str(last_index)}', mode='w')
|
85 |
-
seen_ids = extract_seen_ids()
|
86 |
-
for _, row_tup in enumerate(tqdm(input_data.iterrows(),total=len(input_data))):
|
87 |
-
index, row = row_tup
|
88 |
-
text = row['truncated_text_300']
|
89 |
-
id = row['id']
|
90 |
-
|
91 |
-
#filter by last index
|
92 |
-
if index < last_index:
|
93 |
-
print('skipping index: ', index)
|
94 |
-
continue
|
95 |
-
|
96 |
-
#filter by seen ids
|
97 |
-
if id in seen_ids:
|
98 |
-
print('repated id and skip')
|
99 |
-
continue
|
100 |
-
|
101 |
-
#filter by gpt max token
|
102 |
-
text_gpt_token_count = num_tokens(text, model=GPT_MODEL)
|
103 |
-
if text_gpt_token_count > GPT_TOKEN_LIMIT:
|
104 |
-
continue
|
105 |
-
|
106 |
-
token_count += text_gpt_token_count
|
107 |
-
keyphrases = get_keyphrase_by_gpt(text)
|
108 |
-
try:
|
109 |
-
keyphrases = json.loads(keyphrases)
|
110 |
-
if type(keyphrases) != list:
|
111 |
-
# if type(keyphrases) == str:
|
112 |
-
# keyphrases = keyphrases.split(',')
|
113 |
-
# else:
|
114 |
-
print(str(index), ': not a list!')
|
115 |
-
except:
|
116 |
-
print(str(index), ':invalid json!')
|
117 |
-
|
118 |
-
new_train_item = {'id': id, 'keyphrases':keyphrases}
|
119 |
-
json_f_writer.write(new_train_item)
|
120 |
-
last_index_f = open(LAST_INDEX_FILE_ADDR, 'w+')
|
121 |
-
last_index_f.write(str(index))
|
122 |
-
token_count_f = open(TOKEN_COUNT_FILE_ADDR, 'w+')
|
123 |
-
token_count_f.write(str(token_count))
|
124 |
-
|
125 |
-
print(token_count)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
DELETED
@@ -1,71 +0,0 @@
|
|
1 |
-
import uvicorn
|
2 |
-
import os
|
3 |
-
from typing import Union
|
4 |
-
from fastapi import FastAPI
|
5 |
-
from kpe import KPE
|
6 |
-
from fastapi.middleware.cors import CORSMiddleware
|
7 |
-
# from fastapi.middleware.trustedhost import TrustedHostMiddleware
|
8 |
-
from fastapi import APIRouter , Query
|
9 |
-
from sentence_transformers import SentenceTransformer
|
10 |
-
import utils
|
11 |
-
from ranker import get_sorted_keywords
|
12 |
-
from pydantic import BaseModel
|
13 |
-
|
14 |
-
|
15 |
-
app = FastAPI(
|
16 |
-
title="AHD Persian KPE",
|
17 |
-
# version=config.settings.VERSION,
|
18 |
-
description="Keyphrase Extraction",
|
19 |
-
openapi_url="/openapi.json",
|
20 |
-
docs_url="/",
|
21 |
-
)
|
22 |
-
|
23 |
-
TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
|
24 |
-
kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
|
25 |
-
ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu')
|
26 |
-
# Sets all CORS enabled origins
|
27 |
-
app.add_middleware(
|
28 |
-
CORSMiddleware,
|
29 |
-
allow_origins=["*"], #str(origin) for origin in config.settings.BACKEND_CORS_ORIGINS
|
30 |
-
allow_credentials=True,
|
31 |
-
allow_methods=["*"],
|
32 |
-
allow_headers=["*"],
|
33 |
-
)
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
class KpeParams(BaseModel):
|
39 |
-
text:str
|
40 |
-
count:int=10000
|
41 |
-
using_ner:bool=True
|
42 |
-
return_sorted:bool=False
|
43 |
-
|
44 |
-
|
45 |
-
router = APIRouter()
|
46 |
-
|
47 |
-
|
48 |
-
@router.get("/")
|
49 |
-
def home():
|
50 |
-
return "Welcome to AHD Keyphrase Extraction Service"
|
51 |
-
|
52 |
-
|
53 |
-
@router.post("/extract", description="extract keyphrase from persian documents")
|
54 |
-
async def extract(kpe_params: KpeParams):
|
55 |
-
global kpe
|
56 |
-
text = utils.normalize(kpe_params.text)
|
57 |
-
kps = kpe.extract(text, using_ner=kpe_params.using_ner)
|
58 |
-
if kpe_params.return_sorted:
|
59 |
-
kps = get_sorted_keywords(ranker_transformer, text, kps)
|
60 |
-
else:
|
61 |
-
kps = [(kp, 1) for kp in kps]
|
62 |
-
if len(kps) > kpe_params.count:
|
63 |
-
kps = kps[:kpe_params.count]
|
64 |
-
return kps
|
65 |
-
|
66 |
-
|
67 |
-
app.include_router(router)
|
68 |
-
|
69 |
-
|
70 |
-
if __name__ == "__main__":
|
71 |
-
uvicorn.run("main:app",host="0.0.0.0", port=7201)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ner_data_construction.py
DELETED
@@ -1,70 +0,0 @@
|
|
1 |
-
import pandas as pd
|
2 |
-
import json
|
3 |
-
import glob
|
4 |
-
|
5 |
-
|
6 |
-
def tag_document(keywords, tokens):
|
7 |
-
|
8 |
-
# Initialize the tags list with all O's
|
9 |
-
tags = ['O'] * len(tokens)
|
10 |
-
|
11 |
-
# Loop over the keywords and tag the document
|
12 |
-
for keyword in keywords:
|
13 |
-
# Split the keyword into words
|
14 |
-
keyword_words = keyword.split()
|
15 |
-
|
16 |
-
# Loop over the words in the document
|
17 |
-
for i in range(len(tokens)):
|
18 |
-
# If the current word matches the first word of the keyword
|
19 |
-
if tokens[i] == keyword_words[0]:
|
20 |
-
match = True
|
21 |
-
# Check if the rest of the words in the keyword match the following words in the document
|
22 |
-
for j in range(1, len(keyword_words)):
|
23 |
-
if i+j >= len(tokens) or tokens[i+j] != keyword_words[j]:
|
24 |
-
match = False
|
25 |
-
break
|
26 |
-
# If all the words in the keyword match the following words in the document, tag them as B-KEYWORD and I-KEYWORD
|
27 |
-
if match:
|
28 |
-
tags[i] = 'B-KEYWORD'
|
29 |
-
for j in range(1, len(keyword_words)):
|
30 |
-
tags[i+j] = 'I-KEYWORD'
|
31 |
-
|
32 |
-
return tags
|
33 |
-
|
34 |
-
|
35 |
-
def create_tner_dataset(all_tags, all_tokens, output_file_addr):
|
36 |
-
output_f = open(output_file_addr, 'a+')
|
37 |
-
for tags, tokens in zip(all_tags, all_tokens):
|
38 |
-
for tag, tok in zip(tags, tokens):
|
39 |
-
line = '\t'.join([tok, tag])
|
40 |
-
output_f.write(line)
|
41 |
-
output_f.write('\n')
|
42 |
-
output_f.write('\n')
|
43 |
-
|
44 |
-
|
45 |
-
if __name__ == '__main__':
|
46 |
-
|
47 |
-
data_df = pd.read_csv('truncated_wiki_plus_shuffled_41203.csv')
|
48 |
-
id2document = data_df.set_index('id')['truncated_text_300'].to_dict()
|
49 |
-
|
50 |
-
|
51 |
-
#tag documents!
|
52 |
-
print('------------------ tag documents --------------------')
|
53 |
-
all_tags = []
|
54 |
-
all_tokens = []
|
55 |
-
for tagged_data_addr in glob.iglob('./tagged_data*'):
|
56 |
-
for line in open(tagged_data_addr):
|
57 |
-
item = json.loads(line)
|
58 |
-
if type(item['keyphrases']) == list:
|
59 |
-
keywords = item['keyphrases']
|
60 |
-
document = id2document[item['id']]
|
61 |
-
tokens = document.split()
|
62 |
-
tags = tag_document(keywords, tokens)
|
63 |
-
assert len(tokens) == len(tags)
|
64 |
-
all_tags.append(tags)
|
65 |
-
all_tokens.append(tokens)
|
66 |
-
print(len(keywords), len(tags), len(document.split()), len([t for t in tags if t[0]== 'B']))
|
67 |
-
nerda_dataset = {'sentences':all_tokens, 'tags': all_tags}
|
68 |
-
with open('nerda_dataset.json', 'w+') as f:
|
69 |
-
json.dump(nerda_dataset, f)
|
70 |
-
# create_tner_dataset(all_tags, all_tokens, output_file_addr='./sample_train.conll')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
predict.py
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
from kpe import KPE
|
3 |
-
import sys
|
4 |
-
import os
|
5 |
-
|
6 |
-
|
7 |
-
if __name__ == '__main__':
|
8 |
-
TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model_10000.pt')
|
9 |
-
text_addr = sys.argv[1]
|
10 |
-
text = open(text_addr).read()
|
11 |
-
kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
|
12 |
-
s =time.time()
|
13 |
-
print(kpe.extract(text))
|
14 |
-
print(time.time() - s)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|