File size: 4,235 Bytes
4da642e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import jsonlines
import json
from tqdm import tqdm
import time
from openai import error as openai_error
import pandas as pd
import openai
import time
import tiktoken
import os
import glob

GPT_MODEL = 'gpt-3.5-turbo'
GPT_TOKEN_LIMIT = 1500
os.environ["OPENAI_API_KEY"] = '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'
# os.environ["OPENAI_API_KEY"] = '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'
openai.api_key = os.environ["OPENAI_API_KEY"]

LAST_INDEX_FILE_ADDR = 'last_index.txt'
TOKEN_COUNT_FILE_ADDR = 'tikitoken_count.txt'

def num_tokens(text: str, model: str = GPT_MODEL) -> int:
    """Return the number of tokens in a string."""
    encoding = tiktoken.encoding_for_model(model)
    return len(encoding.encode(text))


def extract_seen_ids():
    seen_ids = set()
    for tagged_data_addr in glob.iglob('./tagged_data*'):
      seen_ids.update([json.loads(line)['id'] for line in open(tagged_data_addr)])
    return seen_ids


def get_keyphrase_by_gpt(document) -> str:
  global error_count
  # prompt = 'extract main keywords from below document as sorted list (sort by importance). you should not use numbers for counting them. you should generate less than 10 keywords.'
  # prompt = 'Output only valid JSON list. Please extract the main keywords from the following document. The keywords should be in a comma-separated list, sorted by their importance. Do not use numbers to count the keywords. Try to generate less than 10 keywords.'
  prompt = 'there is a popular NLP task named KPE (keyphrase Extraction). please extract keyphrases from below article as a perfect Persian KPE model. '
  role_prompt  = 'return your answer using json list format'
  message = prompt + '\n' + document
  # message = prompt + '\n' + document
  # message = document
  messages = [
      # {"role": "system", "content": "Output only valid JSON list"},
      {"role": "system", "content": role_prompt},
      {"role": "user", "content": message},
  ]
  try:
    response = openai.ChatCompletion.create(
      model=GPT_MODEL,
      messages=messages,
      temperature=0
    )
    response_message = response["choices"][0]["message"]["content"]
    error_count = 0
    return response_message
  except Exception as e:
    if error_count > 3:
        raise e
    error_count += 1
    time.sleep(20)
    return []

#input_data = [json.load(line) for line in open('all_data.json').read().splitlines())
#input_data = open('all_data.json')
input_data = pd.read_csv('truncated_wiki_plus_shuffled_41203.csv')
#print('len input data : ', len(input_data))
try:
   last_index = int(open(LAST_INDEX_FILE_ADDR).read())
   print('load last index: ', last_index)
except:
  print('error in loading last index')
  last_index = 0


try:
   token_count = int(open(TOKEN_COUNT_FILE_ADDR).read())
   print('load token count: ', token_count)
except:
  print('error in loading token_count')
  token_count = 0
   
json_f_writer = jsonlines.open(f'tagged_data.jsonl_{str(last_index)}', mode='w')
seen_ids = extract_seen_ids()
for _, row_tup in enumerate(tqdm(input_data.iterrows(),total=len(input_data))):
    index, row = row_tup
    text = row['truncated_text_300']
    id = row['id']
    
    #filter by last index
    if index < last_index:
       print('skipping index: ', index)
       continue
    
    #filter by seen ids
    if id in seen_ids:
       print('repated id and skip')
       continue
    
    #filter by gpt max token
    text_gpt_token_count = num_tokens(text, model=GPT_MODEL)
    if  text_gpt_token_count > GPT_TOKEN_LIMIT:
       continue
    
    token_count += text_gpt_token_count
    keyphrases = get_keyphrase_by_gpt(text)
    try:
        keyphrases = json.loads(keyphrases)
        if type(keyphrases) != list:
        # if type(keyphrases) == str:
        #     keyphrases = keyphrases.split(',')
        # else:
            print(str(index), ': not a list!')
    except:
        print(str(index), ':invalid json!')

    new_train_item = {'id': id, 'keyphrases':keyphrases}
    json_f_writer.write(new_train_item)
    last_index_f = open(LAST_INDEX_FILE_ADDR, 'w+')
    last_index_f.write(str(index))
    token_count_f = open(TOKEN_COUNT_FILE_ADDR, 'w+')
    token_count_f.write(str(token_count))

print(token_count)