fadliaulawi commited on
Commit
63bec36
1 Parent(s): 8503206

Remove unnecessary function

Browse files
Files changed (1) hide show
  1. process.py +4 -166
process.py CHANGED
@@ -1,4 +1,3 @@
1
- from datetime import datetime
2
  from dotenv import load_dotenv
3
  from img2table.document import Image
4
  from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
@@ -9,7 +8,7 @@ from langchain.prompts import PromptTemplate
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langchain_openai import ChatOpenAI
11
  from pdf2image import convert_from_path
12
- from prompt import prompt_entity_gsd_chunk, prompt_entity_gsd_combine, prompt_entity_summ_chunk, prompt_entity_summ_combine, prompt_entities_chunk, prompt_entities_combine, prompt_entity_one_chunk, prompt_table, prompt_validation
13
  from table_detector import detection_transform, device, model, ocr, outputs_to_objects
14
 
15
  import io
@@ -17,8 +16,6 @@ import json
17
  import os
18
  import pandas as pd
19
  import re
20
- import requests
21
- import time
22
  import torch
23
 
24
  load_dotenv()
@@ -31,7 +28,7 @@ prompts = {
31
 
32
  class Process():
33
 
34
- def __init__(self, llm, llm_val):
35
 
36
  if llm.startswith('gpt'):
37
  self.llm = ChatOpenAI(temperature=0, model_name=llm)
@@ -40,13 +37,6 @@ class Process():
40
  else:
41
  self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")
42
 
43
- if llm_val.startswith('gpt'):
44
- self.llm_val = ChatOpenAI(temperature=0, model_name=llm_val)
45
- elif llm_val.startswith('gemini'):
46
- self.llm_val = ChatGoogleGenerativeAI(temperature=0, model=llm_val)
47
- else:
48
- self.llm_val = ChatOpenAI(temperature=0, model_name=llm_val, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")
49
-
50
  def get_entity(self, data):
51
 
52
  chunks, types = data
@@ -97,9 +87,7 @@ class Process():
97
 
98
  def get_table(self, path):
99
 
100
- start_time = datetime.now()
101
  images = convert_from_path(path)
102
- print('PDF to Image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
103
  tables = []
104
 
105
  # Loop pages
@@ -124,7 +112,6 @@ class Process():
124
  print(detected_tables[idx])
125
  tables.append(cropped_table)
126
 
127
- print('Detect table from image', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
128
  genes = []
129
  snps = []
130
  diseases = []
@@ -147,29 +134,10 @@ class Process():
147
  for extracted_table in extracted_tables[1:]:
148
  df_table = pd.concat([df_table, extracted_table.df]).reset_index(drop=True)
149
 
150
- df_table.loc[0] = df_table.loc[0].fillna('')
151
-
152
- # Identify multiple rows (in dataframe) as one row (in image)
153
- rows = []
154
- indexes = []
155
- for i in df_table.index:
156
- if not df_table.loc[i].isna().any():
157
- if len(indexes) > 0:
158
- rows.append(indexes)
159
- indexes = []
160
- indexes.append(i)
161
- rows.append(indexes)
162
-
163
- df_table_cleaned = pd.DataFrame(columns=df_table.columns)
164
- for row in rows:
165
- row_str = df_table.loc[row[0]]
166
- for idx in row[1:]:
167
- row_str += ' ' + df_table.loc[idx].fillna('')
168
- row_str = row_str.str.strip()
169
- df_table_cleaned.loc[len(df_table_cleaned)] = row_str
170
 
171
  # Ask LLM with JSON data
172
- json_table = df_table_cleaned.to_json(orient='records')
173
  str_json_table = json.dumps(json.loads(json_table), indent=2)
174
 
175
  result = self.llm.invoke(prompt_table.format(str_json_table)).content
@@ -191,135 +159,5 @@ class Process():
191
  snps.append(snp)
192
  diseases.append(res_disease)
193
 
194
- print('OCR table to extract', round((datetime.now().timestamp() - start_time.timestamp()) / 60, 2), "minutes")
195
  print(genes, snps, diseases)
196
-
197
  return genes, snps, diseases
198
-
199
- def validate(self, df):
200
-
201
- df = df.fillna('')
202
- df['Genes'] = df['Genes'].str.replace(' ', '').str.upper()
203
- df['SNPs'] = df['SNPs'].str.lower()
204
-
205
- # Check if there is two gene names
206
- sym = [',', '/', '|']
207
- for i in df.index:
208
- gene = df.loc[i, 'Genes']
209
- for s in sym:
210
- if s in gene:
211
- genes = gene.split(s)
212
- df.loc[i + 0.5] = df.loc[i]
213
- df = df.sort_index().reset_index(drop=True)
214
- df.loc[i, 'Genes'], df.loc[i + 1, 'Genes'] = genes[0], s.join(genes[1:])
215
- break
216
-
217
- # Check if there is SNPs without 'rs'
218
- for i in df.index:
219
- safe = True
220
- snp = df.loc[i, 'SNPs']
221
- snp = snp.replace('l', '1')
222
- if re.fullmatch('rs(\d)+|', snp):
223
- pass
224
- elif re.fullmatch('ts(\d)+', snp):
225
- snp = 'r' + snp[1:]
226
- elif re.fullmatch('s(\d)+', snp):
227
- snp = 'r' + snp
228
- elif re.fullmatch('(\d)+', snp):
229
- snp = 'rs' + snp
230
- else:
231
- safe = False
232
- df = df.drop(i)
233
-
234
- if safe:
235
- df.loc[i, 'SNPs'] = snp
236
-
237
- df.reset_index(drop=True, inplace=True)
238
- df_clean = df.copy()
239
-
240
- # # Validate genes and SNPs with APIs
241
- def permutate(word):
242
-
243
- if len(word) == 0:
244
- return ['']
245
-
246
- change = []
247
- res = permutate(word[1:])
248
-
249
- if word[0] in mistakes:
250
- change = [mistakes[word[0]] + r for r in res]
251
-
252
- return [word[0] + r for r in res] + change
253
-
254
- def call(url):
255
-
256
- while True:
257
- try:
258
- res = requests.get(url)
259
- time.sleep(1)
260
- break
261
- except Exception as e:
262
- print(e)
263
-
264
- return res
265
-
266
- mistakes = {'I': '1', 'O': '0'} # Common mistakes need to be maintained
267
- dbsnp = {}
268
-
269
- for i in df.index:
270
- snp = df.loc[i, 'SNPs']
271
- gene = df.loc[i, 'Genes']
272
-
273
- if snp not in dbsnp:
274
- res = call(f'https://www.ebi.ac.uk/gwas/rest/api/singleNucleotidePolymorphisms/{snp}/')
275
- try:
276
- res = res.json()
277
- dbsnp[snp] = [r['gene']['geneName'] for r in res['genomicContexts']]
278
- except:
279
- dbsnp[snp] = []
280
-
281
- res = call(f'https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=snp&retmode=json&id={snp[2:]}').json()['result'][snp[2:]]
282
- if 'error' not in res:
283
- dbsnp[snp].extend([r['name'] for r in res['genes']])
284
-
285
- dbsnp[snp] = list(set(dbsnp[snp]))
286
-
287
- if gene not in dbsnp[snp]:
288
- for other in permutate(gene):
289
- if other in dbsnp[snp]:
290
- df.loc[i, 'Genes'] = other
291
- print(f'{gene} corrected to {other}')
292
- break
293
- else:
294
- df = df.drop(i)
295
-
296
- # df.reset_index(drop=True, inplace=True)
297
- df_no_llm = df.copy()
298
-
299
- # Validate genes and diseases with LLM (for each 50 rows)
300
- idx = 0
301
- results = []
302
-
303
- while True:
304
- json_table = df[['Genes', 'SNPs', 'Diseases']][idx:idx+50].to_json(orient='records')
305
- str_json_table = json.dumps(json.loads(json_table), indent=2)
306
-
307
- result = self.llm_val.invoke(input=prompt_validation.format(str_json_table)).content
308
- print('val', idx)
309
- print(result)
310
-
311
- result = result[result.find('['):result.rfind(']')+1]
312
- try:
313
- result = eval(result)
314
- except SyntaxError:
315
- result = []
316
-
317
- results.extend(result)
318
- idx += 50
319
- if idx not in df.index:
320
- break
321
-
322
- df = pd.DataFrame(results)
323
- df = df.merge(df_no_llm.head(1).drop(['Genes', 'SNPs', 'Diseases'], axis=1), 'cross')
324
-
325
- return df, df_no_llm, df_clean
 
 
1
  from dotenv import load_dotenv
2
  from img2table.document import Image
3
  from langchain.chains.combine_documents.map_reduce import MapReduceDocumentsChain
 
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  from langchain_openai import ChatOpenAI
10
  from pdf2image import convert_from_path
11
+ from prompt import *
12
  from table_detector import detection_transform, device, model, ocr, outputs_to_objects
13
 
14
  import io
 
16
  import os
17
  import pandas as pd
18
  import re
 
 
19
  import torch
20
 
21
  load_dotenv()
 
28
 
29
  class Process():
30
 
31
+ def __init__(self, llm):
32
 
33
  if llm.startswith('gpt'):
34
  self.llm = ChatOpenAI(temperature=0, model_name=llm)
 
37
  else:
38
  self.llm = ChatOpenAI(temperature=0, model_name=llm, api_key=os.environ['PERPLEXITY_API_KEY'], base_url="https://api.perplexity.ai")
39
 
 
 
 
 
 
 
 
40
  def get_entity(self, data):
41
 
42
  chunks, types = data
 
87
 
88
  def get_table(self, path):
89
 
 
90
  images = convert_from_path(path)
 
91
  tables = []
92
 
93
  # Loop pages
 
112
  print(detected_tables[idx])
113
  tables.append(cropped_table)
114
 
 
115
  genes = []
116
  snps = []
117
  diseases = []
 
134
  for extracted_table in extracted_tables[1:]:
135
  df_table = pd.concat([df_table, extracted_table.df]).reset_index(drop=True)
136
 
137
+ df_table = df_table.fillna('')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
  # Ask LLM with JSON data
140
+ json_table = df_table.to_json(orient='records')
141
  str_json_table = json.dumps(json.loads(json_table), indent=2)
142
 
143
  result = self.llm.invoke(prompt_table.format(str_json_table)).content
 
159
  snps.append(snp)
160
  diseases.append(res_disease)
161
 
 
162
  print(genes, snps, diseases)
 
163
  return genes, snps, diseases