mrmft commited on
Commit
1a942ce
1 Parent(s): 4da642e

updated needed files

Browse files
Files changed (4) hide show
  1. labeling.py +0 -125
  2. main.py +0 -71
  3. ner_data_construction.py +0 -70
  4. predict.py +0 -14
labeling.py DELETED
@@ -1,125 +0,0 @@
1
- import jsonlines
2
- import json
3
- from tqdm import tqdm
4
- import time
5
- from openai import error as openai_error
6
- import pandas as pd
7
- import openai
8
- import time
9
- import tiktoken
10
- import os
11
- import glob
12
-
13
- GPT_MODEL = 'gpt-3.5-turbo'
14
- GPT_TOKEN_LIMIT = 1500
15
- os.environ["OPENAI_API_KEY"] = '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'
16
- # os.environ["OPENAI_API_KEY"] = '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'
17
- openai.api_key = os.environ["OPENAI_API_KEY"]
18
-
19
- LAST_INDEX_FILE_ADDR = 'last_index.txt'
20
- TOKEN_COUNT_FILE_ADDR = 'tikitoken_count.txt'
21
-
22
- def num_tokens(text: str, model: str = GPT_MODEL) -> int:
23
- """Return the number of tokens in a string."""
24
- encoding = tiktoken.encoding_for_model(model)
25
- return len(encoding.encode(text))
26
-
27
-
28
- def extract_seen_ids():
29
- seen_ids = set()
30
- for tagged_data_addr in glob.iglob('./tagged_data*'):
31
- seen_ids.update([json.loads(line)['id'] for line in open(tagged_data_addr)])
32
- return seen_ids
33
-
34
-
35
- def get_keyphrase_by_gpt(document) -> str:
36
- global error_count
37
- # prompt = 'extract main keywords from below document as sorted list (sort by importance). you should not use numbers for counting them. you should generate less than 10 keywords.'
38
- # prompt = 'Output only valid JSON list. Please extract the main keywords from the following document. The keywords should be in a comma-separated list, sorted by their importance. Do not use numbers to count the keywords. Try to generate less than 10 keywords.'
39
- prompt = 'there is a popular NLP task named KPE (keyphrase Extraction). please extract keyphrases from below article as a perfect Persian KPE model. '
40
- role_prompt = 'return your answer using json list format'
41
- message = prompt + '\n' + document
42
- # message = prompt + '\n' + document
43
- # message = document
44
- messages = [
45
- # {"role": "system", "content": "Output only valid JSON list"},
46
- {"role": "system", "content": role_prompt},
47
- {"role": "user", "content": message},
48
- ]
49
- try:
50
- response = openai.ChatCompletion.create(
51
- model=GPT_MODEL,
52
- messages=messages,
53
- temperature=0
54
- )
55
- response_message = response["choices"][0]["message"]["content"]
56
- error_count = 0
57
- return response_message
58
- except Exception as e:
59
- if error_count > 3:
60
- raise e
61
- error_count += 1
62
- time.sleep(20)
63
- return []
64
-
65
- #input_data = [json.load(line) for line in open('all_data.json').read().splitlines())
66
- #input_data = open('all_data.json')
67
- input_data = pd.read_csv('truncated_wiki_plus_shuffled_41203.csv')
68
- #print('len input data : ', len(input_data))
69
- try:
70
- last_index = int(open(LAST_INDEX_FILE_ADDR).read())
71
- print('load last index: ', last_index)
72
- except:
73
- print('error in loading last index')
74
- last_index = 0
75
-
76
-
77
- try:
78
- token_count = int(open(TOKEN_COUNT_FILE_ADDR).read())
79
- print('load token count: ', token_count)
80
- except:
81
- print('error in loading token_count')
82
- token_count = 0
83
-
84
- json_f_writer = jsonlines.open(f'tagged_data.jsonl_{str(last_index)}', mode='w')
85
- seen_ids = extract_seen_ids()
86
- for _, row_tup in enumerate(tqdm(input_data.iterrows(),total=len(input_data))):
87
- index, row = row_tup
88
- text = row['truncated_text_300']
89
- id = row['id']
90
-
91
- #filter by last index
92
- if index < last_index:
93
- print('skipping index: ', index)
94
- continue
95
-
96
- #filter by seen ids
97
- if id in seen_ids:
98
- print('repated id and skip')
99
- continue
100
-
101
- #filter by gpt max token
102
- text_gpt_token_count = num_tokens(text, model=GPT_MODEL)
103
- if text_gpt_token_count > GPT_TOKEN_LIMIT:
104
- continue
105
-
106
- token_count += text_gpt_token_count
107
- keyphrases = get_keyphrase_by_gpt(text)
108
- try:
109
- keyphrases = json.loads(keyphrases)
110
- if type(keyphrases) != list:
111
- # if type(keyphrases) == str:
112
- # keyphrases = keyphrases.split(',')
113
- # else:
114
- print(str(index), ': not a list!')
115
- except:
116
- print(str(index), ':invalid json!')
117
-
118
- new_train_item = {'id': id, 'keyphrases':keyphrases}
119
- json_f_writer.write(new_train_item)
120
- last_index_f = open(LAST_INDEX_FILE_ADDR, 'w+')
121
- last_index_f.write(str(index))
122
- token_count_f = open(TOKEN_COUNT_FILE_ADDR, 'w+')
123
- token_count_f.write(str(token_count))
124
-
125
- print(token_count)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py DELETED
@@ -1,71 +0,0 @@
1
- import uvicorn
2
- import os
3
- from typing import Union
4
- from fastapi import FastAPI
5
- from kpe import KPE
6
- from fastapi.middleware.cors import CORSMiddleware
7
- # from fastapi.middleware.trustedhost import TrustedHostMiddleware
8
- from fastapi import APIRouter , Query
9
- from sentence_transformers import SentenceTransformer
10
- import utils
11
- from ranker import get_sorted_keywords
12
- from pydantic import BaseModel
13
-
14
-
15
- app = FastAPI(
16
- title="AHD Persian KPE",
17
- # version=config.settings.VERSION,
18
- description="Keyphrase Extraction",
19
- openapi_url="/openapi.json",
20
- docs_url="/",
21
- )
22
-
23
- TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model', 'trained_model_10000.pt')
24
- kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
25
- ranker_transformer = SentenceTransformer('paraphrase-multilingual-mpnet-base-v2', device='cpu')
26
- # Sets all CORS enabled origins
27
- app.add_middleware(
28
- CORSMiddleware,
29
- allow_origins=["*"], #str(origin) for origin in config.settings.BACKEND_CORS_ORIGINS
30
- allow_credentials=True,
31
- allow_methods=["*"],
32
- allow_headers=["*"],
33
- )
34
-
35
-
36
-
37
-
38
- class KpeParams(BaseModel):
39
- text:str
40
- count:int=10000
41
- using_ner:bool=True
42
- return_sorted:bool=False
43
-
44
-
45
- router = APIRouter()
46
-
47
-
48
- @router.get("/")
49
- def home():
50
- return "Welcome to AHD Keyphrase Extraction Service"
51
-
52
-
53
- @router.post("/extract", description="extract keyphrase from persian documents")
54
- async def extract(kpe_params: KpeParams):
55
- global kpe
56
- text = utils.normalize(kpe_params.text)
57
- kps = kpe.extract(text, using_ner=kpe_params.using_ner)
58
- if kpe_params.return_sorted:
59
- kps = get_sorted_keywords(ranker_transformer, text, kps)
60
- else:
61
- kps = [(kp, 1) for kp in kps]
62
- if len(kps) > kpe_params.count:
63
- kps = kps[:kpe_params.count]
64
- return kps
65
-
66
-
67
- app.include_router(router)
68
-
69
-
70
- if __name__ == "__main__":
71
- uvicorn.run("main:app",host="0.0.0.0", port=7201)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ner_data_construction.py DELETED
@@ -1,70 +0,0 @@
1
- import pandas as pd
2
- import json
3
- import glob
4
-
5
-
6
- def tag_document(keywords, tokens):
7
-
8
- # Initialize the tags list with all O's
9
- tags = ['O'] * len(tokens)
10
-
11
- # Loop over the keywords and tag the document
12
- for keyword in keywords:
13
- # Split the keyword into words
14
- keyword_words = keyword.split()
15
-
16
- # Loop over the words in the document
17
- for i in range(len(tokens)):
18
- # If the current word matches the first word of the keyword
19
- if tokens[i] == keyword_words[0]:
20
- match = True
21
- # Check if the rest of the words in the keyword match the following words in the document
22
- for j in range(1, len(keyword_words)):
23
- if i+j >= len(tokens) or tokens[i+j] != keyword_words[j]:
24
- match = False
25
- break
26
- # If all the words in the keyword match the following words in the document, tag them as B-KEYWORD and I-KEYWORD
27
- if match:
28
- tags[i] = 'B-KEYWORD'
29
- for j in range(1, len(keyword_words)):
30
- tags[i+j] = 'I-KEYWORD'
31
-
32
- return tags
33
-
34
-
35
- def create_tner_dataset(all_tags, all_tokens, output_file_addr):
36
- output_f = open(output_file_addr, 'a+')
37
- for tags, tokens in zip(all_tags, all_tokens):
38
- for tag, tok in zip(tags, tokens):
39
- line = '\t'.join([tok, tag])
40
- output_f.write(line)
41
- output_f.write('\n')
42
- output_f.write('\n')
43
-
44
-
45
- if __name__ == '__main__':
46
-
47
- data_df = pd.read_csv('truncated_wiki_plus_shuffled_41203.csv')
48
- id2document = data_df.set_index('id')['truncated_text_300'].to_dict()
49
-
50
-
51
- #tag documents!
52
- print('------------------ tag documents --------------------')
53
- all_tags = []
54
- all_tokens = []
55
- for tagged_data_addr in glob.iglob('./tagged_data*'):
56
- for line in open(tagged_data_addr):
57
- item = json.loads(line)
58
- if type(item['keyphrases']) == list:
59
- keywords = item['keyphrases']
60
- document = id2document[item['id']]
61
- tokens = document.split()
62
- tags = tag_document(keywords, tokens)
63
- assert len(tokens) == len(tags)
64
- all_tags.append(tags)
65
- all_tokens.append(tokens)
66
- print(len(keywords), len(tags), len(document.split()), len([t for t in tags if t[0]== 'B']))
67
- nerda_dataset = {'sentences':all_tokens, 'tags': all_tags}
68
- with open('nerda_dataset.json', 'w+') as f:
69
- json.dump(nerda_dataset, f)
70
- # create_tner_dataset(all_tags, all_tokens, output_file_addr='./sample_train.conll')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
predict.py DELETED
@@ -1,14 +0,0 @@
1
- import time
2
- from kpe import KPE
3
- import sys
4
- import os
5
-
6
-
7
- if __name__ == '__main__':
8
- TRAINED_MODEL_ADDR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'trained_model_10000.pt')
9
- text_addr = sys.argv[1]
10
- text = open(text_addr).read()
11
- kpe = KPE(trained_kpe_model= TRAINED_MODEL_ADDR, flair_ner_model='flair/ner-english-ontonotes-large', device='cpu')
12
- s =time.time()
13
- print(kpe.extract(text))
14
- print(time.time() - s)