Spaces:
Runtime error
Runtime error
import jsonlines | |
import json | |
from tqdm import tqdm | |
import time | |
from openai import error as openai_error | |
import pandas as pd | |
import openai | |
import time | |
import tiktoken | |
import os | |
import glob | |
GPT_MODEL = 'gpt-3.5-turbo' | |
GPT_TOKEN_LIMIT = 1500 | |
os.environ["OPENAI_API_KEY"] = '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' | |
# os.environ["OPENAI_API_KEY"] = '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' | |
openai.api_key = os.environ["OPENAI_API_KEY"] | |
LAST_INDEX_FILE_ADDR = 'last_index.txt' | |
TOKEN_COUNT_FILE_ADDR = 'tikitoken_count.txt' | |
def num_tokens(text: str, model: str = GPT_MODEL) -> int: | |
"""Return the number of tokens in a string.""" | |
encoding = tiktoken.encoding_for_model(model) | |
return len(encoding.encode(text)) | |
def extract_seen_ids(): | |
seen_ids = set() | |
for tagged_data_addr in glob.iglob('./tagged_data*'): | |
seen_ids.update([json.loads(line)['id'] for line in open(tagged_data_addr)]) | |
return seen_ids | |
def get_keyphrase_by_gpt(document) -> str: | |
global error_count | |
# prompt = 'extract main keywords from below document as sorted list (sort by importance). you should not use numbers for counting them. you should generate less than 10 keywords.' | |
# prompt = 'Output only valid JSON list. Please extract the main keywords from the following document. The keywords should be in a comma-separated list, sorted by their importance. Do not use numbers to count the keywords. Try to generate less than 10 keywords.' | |
prompt = 'there is a popular NLP task named KPE (keyphrase Extraction). please extract keyphrases from below article as a perfect Persian KPE model. ' | |
role_prompt = 'return your answer using json list format' | |
message = prompt + '\n' + document | |
# message = prompt + '\n' + document | |
# message = document | |
messages = [ | |
# {"role": "system", "content": "Output only valid JSON list"}, | |
{"role": "system", "content": role_prompt}, | |
{"role": "user", "content": message}, | |
] | |
try: | |
response = openai.ChatCompletion.create( | |
model=GPT_MODEL, | |
messages=messages, | |
temperature=0 | |
) | |
response_message = response["choices"][0]["message"]["content"] | |
error_count = 0 | |
return response_message | |
except Exception as e: | |
if error_count > 3: | |
raise e | |
error_count += 1 | |
time.sleep(20) | |
return [] | |
#input_data = [json.load(line) for line in open('all_data.json').read().splitlines()) | |
#input_data = open('all_data.json') | |
input_data = pd.read_csv('truncated_wiki_plus_shuffled_41203.csv') | |
#print('len input data : ', len(input_data)) | |
try: | |
last_index = int(open(LAST_INDEX_FILE_ADDR).read()) | |
print('load last index: ', last_index) | |
except: | |
print('error in loading last index') | |
last_index = 0 | |
try: | |
token_count = int(open(TOKEN_COUNT_FILE_ADDR).read()) | |
print('load token count: ', token_count) | |
except: | |
print('error in loading token_count') | |
token_count = 0 | |
json_f_writer = jsonlines.open(f'tagged_data.jsonl_{str(last_index)}', mode='w') | |
seen_ids = extract_seen_ids() | |
for _, row_tup in enumerate(tqdm(input_data.iterrows(),total=len(input_data))): | |
index, row = row_tup | |
text = row['truncated_text_300'] | |
id = row['id'] | |
#filter by last index | |
if index < last_index: | |
print('skipping index: ', index) | |
continue | |
#filter by seen ids | |
if id in seen_ids: | |
print('repated id and skip') | |
continue | |
#filter by gpt max token | |
text_gpt_token_count = num_tokens(text, model=GPT_MODEL) | |
if text_gpt_token_count > GPT_TOKEN_LIMIT: | |
continue | |
token_count += text_gpt_token_count | |
keyphrases = get_keyphrase_by_gpt(text) | |
try: | |
keyphrases = json.loads(keyphrases) | |
if type(keyphrases) != list: | |
# if type(keyphrases) == str: | |
# keyphrases = keyphrases.split(',') | |
# else: | |
print(str(index), ': not a list!') | |
except: | |
print(str(index), ':invalid json!') | |
new_train_item = {'id': id, 'keyphrases':keyphrases} | |
json_f_writer.write(new_train_item) | |
last_index_f = open(LAST_INDEX_FILE_ADDR, 'w+') | |
last_index_f.write(str(index)) | |
token_count_f = open(TOKEN_COUNT_FILE_ADDR, 'w+') | |
token_count_f.write(str(token_count)) | |
print(token_count) |