mrmft's picture
adding project source
4da642e
raw history blame
No virus
4.24 kB
import jsonlines
import json
from tqdm import tqdm
import time
from openai import error as openai_error
import pandas as pd
import openai
import time
import tiktoken
import os
import glob
GPT_MODEL = 'gpt-3.5-turbo'
GPT_TOKEN_LIMIT = 1500
os.environ["OPENAI_API_KEY"] = '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'
# os.environ["OPENAI_API_KEY"] = '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'
openai.api_key = os.environ["OPENAI_API_KEY"]
LAST_INDEX_FILE_ADDR = 'last_index.txt'
TOKEN_COUNT_FILE_ADDR = 'tikitoken_count.txt'
def num_tokens(text: str, model: str = GPT_MODEL) -> int:
"""Return the number of tokens in a string."""
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(text))
def extract_seen_ids():
seen_ids = set()
for tagged_data_addr in glob.iglob('./tagged_data*'):
seen_ids.update([json.loads(line)['id'] for line in open(tagged_data_addr)])
return seen_ids
def get_keyphrase_by_gpt(document) -> str:
global error_count
# prompt = 'extract main keywords from below document as sorted list (sort by importance). you should not use numbers for counting them. you should generate less than 10 keywords.'
# prompt = 'Output only valid JSON list. Please extract the main keywords from the following document. The keywords should be in a comma-separated list, sorted by their importance. Do not use numbers to count the keywords. Try to generate less than 10 keywords.'
prompt = 'there is a popular NLP task named KPE (keyphrase Extraction). please extract keyphrases from below article as a perfect Persian KPE model. '
role_prompt = 'return your answer using json list format'
message = prompt + '\n' + document
# message = prompt + '\n' + document
# message = document
messages = [
# {"role": "system", "content": "Output only valid JSON list"},
{"role": "system", "content": role_prompt},
{"role": "user", "content": message},
]
try:
response = openai.ChatCompletion.create(
model=GPT_MODEL,
messages=messages,
temperature=0
)
response_message = response["choices"][0]["message"]["content"]
error_count = 0
return response_message
except Exception as e:
if error_count > 3:
raise e
error_count += 1
time.sleep(20)
return []
#input_data = [json.load(line) for line in open('all_data.json').read().splitlines())
#input_data = open('all_data.json')
input_data = pd.read_csv('truncated_wiki_plus_shuffled_41203.csv')
#print('len input data : ', len(input_data))
try:
last_index = int(open(LAST_INDEX_FILE_ADDR).read())
print('load last index: ', last_index)
except:
print('error in loading last index')
last_index = 0
try:
token_count = int(open(TOKEN_COUNT_FILE_ADDR).read())
print('load token count: ', token_count)
except:
print('error in loading token_count')
token_count = 0
json_f_writer = jsonlines.open(f'tagged_data.jsonl_{str(last_index)}', mode='w')
seen_ids = extract_seen_ids()
for _, row_tup in enumerate(tqdm(input_data.iterrows(),total=len(input_data))):
index, row = row_tup
text = row['truncated_text_300']
id = row['id']
#filter by last index
if index < last_index:
print('skipping index: ', index)
continue
#filter by seen ids
if id in seen_ids:
print('repated id and skip')
continue
#filter by gpt max token
text_gpt_token_count = num_tokens(text, model=GPT_MODEL)
if text_gpt_token_count > GPT_TOKEN_LIMIT:
continue
token_count += text_gpt_token_count
keyphrases = get_keyphrase_by_gpt(text)
try:
keyphrases = json.loads(keyphrases)
if type(keyphrases) != list:
# if type(keyphrases) == str:
# keyphrases = keyphrases.split(',')
# else:
print(str(index), ': not a list!')
except:
print(str(index), ':invalid json!')
new_train_item = {'id': id, 'keyphrases':keyphrases}
json_f_writer.write(new_train_item)
last_index_f = open(LAST_INDEX_FILE_ADDR, 'w+')
last_index_f.write(str(index))
token_count_f = open(TOKEN_COUNT_FILE_ADDR, 'w+')
token_count_f.write(str(token_count))
print(token_count)