import jsonlines import json from tqdm import tqdm import time from openai import error as openai_error import pandas as pd import openai import time import tiktoken import os import glob GPT_MODEL = 'gpt-3.5-turbo' GPT_TOKEN_LIMIT = 1500 os.environ["OPENAI_API_KEY"] = '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' # os.environ["OPENAI_API_KEY"] = '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' openai.api_key = os.environ["OPENAI_API_KEY"] LAST_INDEX_FILE_ADDR = 'last_index.txt' TOKEN_COUNT_FILE_ADDR = 'tikitoken_count.txt' def num_tokens(text: str, model: str = GPT_MODEL) -> int: """Return the number of tokens in a string.""" encoding = tiktoken.encoding_for_model(model) return len(encoding.encode(text)) def extract_seen_ids(): seen_ids = set() for tagged_data_addr in glob.iglob('./tagged_data*'): seen_ids.update([json.loads(line)['id'] for line in open(tagged_data_addr)]) return seen_ids def get_keyphrase_by_gpt(document) -> str: global error_count # prompt = 'extract main keywords from below document as sorted list (sort by importance). you should not use numbers for counting them. you should generate less than 10 keywords.' # prompt = 'Output only valid JSON list. Please extract the main keywords from the following document. The keywords should be in a comma-separated list, sorted by their importance. Do not use numbers to count the keywords. Try to generate less than 10 keywords.' prompt = 'there is a popular NLP task named KPE (keyphrase Extraction). please extract keyphrases from below article as a perfect Persian KPE model. ' role_prompt = 'return your answer using json list format' message = prompt + '\n' + document # message = prompt + '\n' + document # message = document messages = [ # {"role": "system", "content": "Output only valid JSON list"}, {"role": "system", "content": role_prompt}, {"role": "user", "content": message}, ] try: response = openai.ChatCompletion.create( model=GPT_MODEL, messages=messages, temperature=0 ) response_message = response["choices"][0]["message"]["content"] error_count = 0 return response_message except Exception as e: if error_count > 3: raise e error_count += 1 time.sleep(20) return [] #input_data = [json.load(line) for line in open('all_data.json').read().splitlines()) #input_data = open('all_data.json') input_data = pd.read_csv('truncated_wiki_plus_shuffled_41203.csv') #print('len input data : ', len(input_data)) try: last_index = int(open(LAST_INDEX_FILE_ADDR).read()) print('load last index: ', last_index) except: print('error in loading last index') last_index = 0 try: token_count = int(open(TOKEN_COUNT_FILE_ADDR).read()) print('load token count: ', token_count) except: print('error in loading token_count') token_count = 0 json_f_writer = jsonlines.open(f'tagged_data.jsonl_{str(last_index)}', mode='w') seen_ids = extract_seen_ids() for _, row_tup in enumerate(tqdm(input_data.iterrows(),total=len(input_data))): index, row = row_tup text = row['truncated_text_300'] id = row['id'] #filter by last index if index < last_index: print('skipping index: ', index) continue #filter by seen ids if id in seen_ids: print('repated id and skip') continue #filter by gpt max token text_gpt_token_count = num_tokens(text, model=GPT_MODEL) if text_gpt_token_count > GPT_TOKEN_LIMIT: continue token_count += text_gpt_token_count keyphrases = get_keyphrase_by_gpt(text) try: keyphrases = json.loads(keyphrases) if type(keyphrases) != list: # if type(keyphrases) == str: # keyphrases = keyphrases.split(',') # else: print(str(index), ': not a list!') except: print(str(index), ':invalid json!') new_train_item = {'id': id, 'keyphrases':keyphrases} json_f_writer.write(new_train_item) last_index_f = open(LAST_INDEX_FILE_ADDR, 'w+') last_index_f.write(str(index)) token_count_f = open(TOKEN_COUNT_FILE_ADDR, 'w+') token_count_f.write(str(token_count)) print(token_count)