Sangjun2 commited on
Commit
fb56a77
Β·
verified Β·
1 Parent(s): ca427b1

new_new_new_vaiv_app.py

Browse files
Files changed (1) hide show
  1. app.py +182 -983
app.py CHANGED
@@ -20,6 +20,9 @@ import time
20
  import logging
21
  import subprocess
22
  import spaces
 
 
 
23
 
24
  # Git LFS pull λͺ…λ Ήμ–΄ μ‹€ν–‰
25
  result = subprocess.run(['git', 'lfs', 'pull'], capture_output=True, text=True)
@@ -36,55 +39,26 @@ logger = logging.getLogger()
36
  warnings.filterwarnings('ignore')
37
  MAX_PATCHES = 512
38
  # Load the models and processor
39
- #device = torch.device("cpu")
40
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
 
42
  # Paths to the models
43
- ko_deplot_model_path = './deplot_model_ver_kor_24.7.25_refinetuning_epoch3.bin'
44
- aihub_deplot_model_path='./deplot_k.pt'
45
- t5_model_path = './ke_t5.pt'
46
 
47
  # Load first model ko-deplot
48
-
49
  def load_model1():
50
  processor1 = Pix2StructProcessor.from_pretrained('nuua/ko-deplot')
51
  model1 = Pix2StructForConditionalGeneration.from_pretrained('nuua/ko-deplot')
52
  model1.load_state_dict(torch.load(ko_deplot_model_path, map_location="cpu"))
53
  model1.to(torch.device("cuda"))
54
- return processor1,model1
55
-
56
- processor1,model1=load_model1()
57
-
58
- # Load second model aihub-deplot
59
-
60
- def load_model2():
61
- processor2 = AutoProcessor.from_pretrained("ybelkada/pix2struct-base")
62
- model2 = Pix2StructForConditionalGeneration.from_pretrained("ybelkada/pix2struct-base")
63
- model2.load_state_dict(torch.load(aihub_deplot_model_path, map_location="cpu"))
64
- model2.to(torch.device("cuda"))
65
- return processor2,model2
66
-
67
- processor2,model2=load_model2()
68
 
 
69
 
70
- #Load third model unichart
71
-
72
- def load_model3():
73
- unichart_model_path = "./unichart4/chartqa-checkpoint-epoch=2-161952"
74
- model3 = VisionEncoderDecoderModel.from_pretrained(unichart_model_path)
75
- processor3 = DonutProcessor.from_pretrained(unichart_model_path)
76
- model3.to(torch.device("cuda"))
77
- return processor3,model3
78
-
79
- processor3,model3=load_model3()
80
-
81
- #ko-deplot μΆ”λ‘ ν•¨μˆ˜
82
  # Function to format output
83
  def format_output(prediction):
84
  return prediction.replace('<0x0A>', '\n')
85
 
86
- # First model prediction ko-deplot
87
- @spaces.GPU(enable_queue=True,duration=100)
88
  def predict_model1(image):
89
  images = [image]
90
  inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
@@ -98,1003 +72,228 @@ def predict_model1(image):
98
  formatted_output = format_output(outputs[0])
99
  return formatted_output
100
 
101
-
102
- def replace_unk(text):
103
- # 1. '제λͺ©:', 'μœ ν˜•:' κΈ€μž μ•žμ— μžˆλŠ” <unk>λŠ” \n둜 λ°”κΏˆ
104
- text = re.sub(r'<unk>(?=제λͺ©:|μœ ν˜•:)', '\n', text)
105
- # 2. 'μ„Έλ‘œ ' λ˜λŠ” 'κ°€λ‘œ '와 'λŒ€ν˜•' 사이에 μžˆλŠ” <unk>λ₯Ό ""둜 λ°”κΏˆ
106
- text = re.sub(r'(?<=μ„Έλ‘œ |κ°€λ‘œ )<unk>(?=λŒ€ν˜•)', '', text)
107
- # 3. μˆ«μžμ™€ ν…μŠ€νŠΈ 사이에 μžˆλŠ” <unk>λ₯Ό \n둜 λ°”κΏˆ
108
- text = re.sub(r'(\d)<unk>([^\d])', r'\1\n\2', text)
109
- # 4. %, 원, 건, λͺ… 뒀에 λ‚˜μ˜€λŠ” <unk>λ₯Ό \n둜 λ°”κΏˆ
110
- text = re.sub(r'(?<=[%원건λͺ…\)])<unk>', '\n', text)
111
- # 5. μˆ«μžμ™€ 숫자 사이에 μžˆλŠ” <unk>λ₯Ό \n둜 λ°”κΏˆ
112
- text = re.sub(r'(\d)<unk>(\d)', r'\1\n\2', text)
113
- # 6. 'ν˜•'μ΄λΌλŠ” κΈ€μžμ™€ ' |' 사이에 μžˆλŠ” <unk>λ₯Ό \n둜 λ°”κΏˆ
114
- text = re.sub(r'ν˜•<unk>(?= \|)', 'ν˜•\n', text)
115
- # 7. λ‚˜λ¨Έμ§€ <unk>λ₯Ό λͺ¨λ‘ ""둜 λ°”κΏˆ
116
- text = text.replace('<unk>', '')
117
- return text
118
-
119
-
120
- @spaces.GPU(enable_queue=True,duration=100)
121
- def predict_model3(image):
122
- image=image.convert("RGB")
123
- input_prompt = "<extract_data_table> <s_answer>"
124
- decoder_input_ids = processor3.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
125
- pixel_values = processor3(image, return_tensors="pt").pixel_values
126
- outputs = model3.generate(
127
- pixel_values.to(device),
128
- decoder_input_ids=decoder_input_ids.to(device),
129
- max_length=model3.decoder.config.max_position_embeddings,
130
- early_stopping=True,
131
- pad_token_id=processor3.tokenizer.pad_token_id,
132
- eos_token_id=processor3.tokenizer.eos_token_id,
133
- use_cache=True,
134
- num_beams=4,
135
- bad_words_ids=[[processor3.tokenizer.unk_token_id]],
136
- return_dict_in_generate=True,
137
- )
138
- sequence = processor3.batch_decode(outputs.sequences)[0]
139
- sequence = sequence.replace(processor3.tokenizer.eos_token, "").replace(processor3.tokenizer.pad_token, "")
140
- sequence = sequence.split("<s_answer>")[-1].strip()
141
-
142
- return sequence
143
- #function for converting aihub dataset labeling json file to ko-deplot data table
144
- def process_json_file(input_file):
145
- with open(input_file, 'r', encoding='utf-8') as file:
146
- data = json.load(file)
147
-
148
- # ν•„μš”ν•œ 데이터 μΆ”μΆœ
149
- chart_type = data['metadata']['chart_sub']
150
- title = data['annotations'][0]['title']
151
- x_axis = data['annotations'][0]['axis_label']['x_axis']
152
- y_axis = data['annotations'][0]['axis_label']['y_axis']
153
- legend = data['annotations'][0]['legend']
154
- data_labels = data['annotations'][0]['data_label']
155
- is_legend = data['annotations'][0]['is_legend']
156
-
157
- # μ›ν•˜λŠ” ν˜•μ‹μœΌλ‘œ λ³€ν™˜
158
- formatted_string = f"TITLE | {title} <0x0A> "
159
- if 'κ°€λ‘œ' in chart_type:
160
- if is_legend:
161
- # κ°€λ‘œ 차트 처리
162
- formatted_string += " | ".join(legend) + " <0x0A> "
163
- for i in range(len(y_axis)):
164
- row = [y_axis[i]]
165
- for j in range(len(legend)):
166
- if i < len(data_labels[j]):
167
- row.append(str(data_labels[j][i])) # 데이터 값을 λ¬Έμžμ—΄λ‘œ λ³€ν™˜
168
- else:
169
- row.append("") # 데이터가 μ—†λŠ” 경우 빈 λ¬Έμžμ—΄ μΆ”κ°€
170
- formatted_string += " | ".join(row) + " <0x0A> "
171
- else:
172
- # is_legendκ°€ False인 경우
173
- for i in range(len(y_axis)):
174
- row = [y_axis[i], str(data_labels[0][i])]
175
- formatted_string += " | ".join(row) + " <0x0A> "
176
- elif chart_type == "μ›ν˜•":
177
- # μ›ν˜• 차트 처리
178
- if legend:
179
- used_labels = legend
180
- else:
181
- used_labels = x_axis
182
-
183
- formatted_string += " | ".join(used_labels) + " <0x0A> "
184
- row = [data_labels[0][i] for i in range(len(used_labels))]
185
- formatted_string += " | ".join(row) + " <0x0A> "
186
- elif chart_type == "ν˜Όν•©ν˜•":
187
- # ν˜Όν•©ν˜• 차트 처리
188
- all_legends = [ann['legend'][0] for ann in data['annotations']]
189
- formatted_string += " | ".join(all_legends) + " <0x0A> "
190
-
191
- combined_data = []
192
- for i in range(len(x_axis)):
193
- row = [x_axis[i]]
194
- for ann in data['annotations']:
195
- if i < len(ann['data_label'][0]):
196
- row.append(str(ann['data_label'][0][i])) # 데이터 값을 λ¬Έμžμ—΄λ‘œ λ³€ν™˜
197
- else:
198
- row.append("") # 데이터가 μ—†λŠ” 경우 빈 λ¬Έμžμ—΄ μΆ”κ°€
199
- combined_data.append(" | ".join(row))
200
-
201
- formatted_string += " <0x0A> ".join(combined_data) + " <0x0A> "
202
- else:
203
- # 기타 차트 처리
204
- if is_legend:
205
- formatted_string += " | ".join(legend) + " <0x0A> "
206
- for i in range(len(x_axis)):
207
- row = [x_axis[i]]
208
- for j in range(len(legend)):
209
- if i < len(data_labels[j]):
210
- row.append(str(data_labels[j][i])) # 데이터 값을 λ¬Έμžμ—΄λ‘œ λ³€ν™˜
211
- else:
212
- row.append("") # 데이터가 μ—†λŠ” 경우 빈 λ¬Έμžμ—΄ μΆ”κ°€
213
- formatted_string += " | ".join(row) + " <0x0A> "
214
- else:
215
- for i in range(len(x_axis)):
216
- if i < len(data_labels[0]):
217
- formatted_string += f"{x_axis[i]} | {str(data_labels[0][i])} <0x0A> "
218
- else:
219
- formatted_string += f"{x_axis[i]} | <0x0A> " # 데이터가 μ—†λŠ” 경우 빈 λ¬Έμžμ—΄ μΆ”κ°€
220
-
221
- # λ§ˆμ§€λ§‰ "<0x0A> " 제거
222
- formatted_string = formatted_string[:-8]
223
- return format_output(formatted_string)
224
-
225
- def chart_data(data):
226
- datatable = []
227
- num = len(data)
228
- for n in range(num):
229
- title = data[n]['title'] if data[n]['is_title'] else ''
230
- legend = data[n]['legend'] if data[n]['is_legend'] else ''
231
- datalabel = data[n]['data_label'] if data[n]['is_datalabel'] else [0]
232
- unit = data[n]['unit'] if data[n]['is_unit'] else ''
233
- base = data[n]['base'] if data[n]['is_base'] else ''
234
- x_axis_title = data[n]['axis_title']['x_axis']
235
- y_axis_title = data[n]['axis_title']['y_axis']
236
- x_axis = data[n]['axis_label']['x_axis'] if data[n]['is_axis_label_x_axis'] else [0]
237
- y_axis = data[n]['axis_label']['y_axis'] if data[n]['is_axis_label_y_axis'] else [0]
238
-
239
- if len(legend) > 1:
240
- datalabel = np.array(datalabel).transpose().tolist()
241
-
242
- datatable.append([title, legend, datalabel, unit, base, x_axis_title, y_axis_title, x_axis, y_axis])
243
-
244
- return datatable
245
-
246
- def datatable(data, chart_type):
247
- data_table = ''
248
- num = len(data)
249
-
250
- if len(data) == 2:
251
- temp = []
252
- temp.append(f"λŒ€μƒ: {data[0][4]}")
253
- temp.append(f"제λͺ©: {data[0][0]}")
254
- temp.append(f"μœ ν˜•: {' '.join(chart_type[0:2])}")
255
- temp.append(f"{data[0][5]} | {data[0][1][0]}({data[0][3]}) | {data[1][1][0]}({data[1][3]})")
256
-
257
- x_axis = data[0][7]
258
- for idx, x in enumerate(x_axis):
259
- temp.append(f"{x} | {data[0][2][0][idx]} | {data[1][2][0][idx]}")
260
-
261
- data_table = '\n'.join(temp)
262
- else:
263
- for n in range(num):
264
- temp = []
265
-
266
- title, legend, datalabel, unit, base, x_axis_title, y_axis_title, x_axis, y_axis = data[n]
267
- legend = [element + f"({unit})" for element in legend]
268
-
269
- if len(legend) > 1:
270
- temp.append(f"λŒ€μƒ: {base}")
271
- temp.append(f"제λͺ©: {title}")
272
- temp.append(f"μœ ν˜•: {' '.join(chart_type[0:2])}")
273
- temp.append(f"{x_axis_title} | {' | '.join(legend)}")
274
-
275
- if chart_type[2] == "μ›ν˜•":
276
- datalabel = sum(datalabel, [])
277
- temp.append(f"{' | '.join([str(d) for d in datalabel])}")
278
- data_table = '\n'.join(temp)
279
- else:
280
- axis = y_axis if chart_type[2] == "κ°€λ‘œ λ§‰λŒ€ν˜•" else x_axis
281
- for idx, (x, d) in enumerate(zip(axis, datalabel)):
282
- temp_d = [str(e) for e in d]
283
- temp_d = " | ".join(temp_d)
284
- row = f"{x} | {temp_d}"
285
- temp.append(row)
286
- data_table = '\n'.join(temp)
287
- else:
288
- temp.append(f"λŒ€μƒ: {base}")
289
- temp.append(f"제λͺ©: {title}")
290
- temp.append(f"μœ ν˜•: {' '.join(chart_type[0:2])}")
291
- temp.append(f"{x_axis_title} | {unit}")
292
- axis = y_axis if chart_type[2] == "κ°€λ‘œ λ§‰λŒ€ν˜•" else x_axis
293
- datalabel = datalabel[0]
294
-
295
- for idx, x in enumerate(axis):
296
- row = f"{x} | {str(datalabel[idx])}"
297
- temp.append(row)
298
- data_table = '\n'.join(temp)
299
-
300
- return data_table
301
-
302
- #function for converting aihub dataset labeling json file to aihub-deplot data table
303
- def process_json_file2(input_file):
304
- with open(input_file, 'r', encoding='utf-8') as file:
305
- data = json.load(file)
306
- # ν•„μš”ν•œ 데이터 μΆ”μΆœ
307
- chart_multi = data['metadata']['chart_multi']
308
- chart_main = data['metadata']['chart_main']
309
- chart_sub = data['metadata']['chart_sub']
310
- chart_type = [chart_multi, chart_sub, chart_main]
311
- chart_annotations = data['annotations']
312
-
313
- charData = chart_data(chart_annotations)
314
- dataTable = datatable(charData, chart_type)
315
- return dataTable
316
-
317
- # RMS
318
- def _to_float(text): # λ‹¨μœ„ λ–Όκ³  숫자만..?
319
- try:
320
- if text.endswith("%"):
321
- # Convert percentages to floats.
322
- return float(text.rstrip("%")) / 100.0
323
- else:
324
- return float(text)
325
- except ValueError:
326
- return None
327
-
328
-
329
- def _get_relative_distance(
330
- target, prediction, theta = 1.0
331
- ):
332
- """Returns min(1, |target-prediction|/|target|)."""
333
- if not target:
334
- return int(not prediction)
335
- distance = min(abs((target - prediction) / target), 1)
336
- return distance if distance < theta else 1
337
-
338
- def anls_metric(target: str, prediction: str, theta: float = 0.5):
339
- edit_distance = editdistance.eval(target, prediction)
340
- normalize_ld = edit_distance / max(len(target), len(prediction))
341
- return 1 - normalize_ld if normalize_ld < theta else 0
342
-
343
- def _permute(values, indexes):
344
- return tuple(values[i] if i < len(values) else "" for i in indexes)
345
-
346
-
347
- @dataclasses.dataclass(frozen=True)
348
- class Table:
349
- """Helper class for the content of a markdown table."""
350
-
351
- base: Optional[str] = None
352
- title: Optional[str] = None
353
- chartType: Optional[str] = None
354
- headers: tuple[str, Ellipsis] = dataclasses.field(default_factory=tuple)
355
- rows: tuple[tuple[str, Ellipsis], Ellipsis] = dataclasses.field(default_factory=tuple)
356
-
357
- def permuted(self, indexes):
358
- """Builds a version of the table changing the column order."""
359
- return Table(
360
- base=self.base,
361
- title=self.title,
362
- chartType=self.chartType,
363
- headers=_permute(self.headers, indexes),
364
- rows=tuple(_permute(row, indexes) for row in self.rows),
365
- )
366
-
367
- def aligned(
368
- self, headers, text_theta = 0.5
369
- ):
370
- """Builds a column permutation with headers in the most correct order."""
371
- if len(headers) != len(self.headers):
372
- raise ValueError(f"Header length {headers} must match {self.headers}.")
373
- distance = []
374
- for h2 in self.headers:
375
- distance.append(
376
- [
377
- 1 - anls_metric(h1, h2, text_theta)
378
- for h1 in headers
379
- ]
380
- )
381
- cost_matrix = np.array(distance)
382
- row_ind, col_ind = optimize.linear_sum_assignment(cost_matrix)
383
- permutation = [idx for _, idx in sorted(zip(col_ind, row_ind))]
384
- score = (1 - cost_matrix)[permutation[1:], range(1, len(row_ind))].prod()
385
- return self.permuted(permutation), score
386
-
387
- def _parse_table(text, transposed = False): # ν‘œ 제λͺ©, μ—΄ 이름, ν–‰ μ°ΎκΈ°
388
- """Builds a table from a markdown representation."""
389
- lines = text.lower().splitlines()
390
- if not lines:
391
- return Table()
392
-
393
- if lines[0].startswith("λŒ€μƒ: "):
394
- base = lines[0][len("λŒ€μƒ: ") :].strip()
395
- offset = 1 #
396
- else:
397
- base = None
398
- offset = 0
399
- if lines[1].startswith("제λͺ©: "):
400
- title = lines[1][len("제λͺ©: ") :].strip()
401
- offset = 2 #
402
- else:
403
- title = None
404
- offset = 1
405
- if lines[2].startswith("μœ ν˜•: "):
406
- chartType = lines[2][len("μœ ν˜•: ") :].strip()
407
- offset = 3 #
408
- else:
409
- chartType = None
410
-
411
- if len(lines) < offset + 1:
412
- return Table(base=base, title=title, chartType=chartType)
413
-
414
- rows = []
415
- for line in lines[offset:]:
416
- rows.append(tuple(v.strip() for v in line.split(" | ")))
417
- if transposed:
418
- rows = [tuple(row) for row in itertools.zip_longest(*rows, fillvalue="")]
419
- return Table(base=base, title=title, chartType=chartType, headers=rows[0], rows=tuple(rows[1:]))
420
-
421
- def _get_table_datapoints(table):
422
- datapoints = {}
423
- if table.base is not None:
424
- datapoints["λŒ€μƒ"] = table.base
425
- if table.title is not None:
426
- datapoints["제λͺ©"] = table.title
427
- if table.chartType is not None:
428
- datapoints["μœ ν˜•"] = table.chartType
429
- if not table.rows or len(table.headers) <= 1:
430
- return datapoints
431
- for row in table.rows:
432
- for header, cell in zip(table.headers[1:], row[1:]):
433
- #print(f"{row[0]} {header} >> {cell}")
434
- datapoints[f"{row[0]} {header}"] = cell #
435
- return datapoints
436
-
437
- def _get_datapoint_metric( #
438
- target,
439
- prediction,
440
- text_theta=0.5,
441
- number_theta=0.1,
442
- ):
443
- """Computes a metric that scores how similar two datapoint pairs are."""
444
- key_metric = anls_metric(
445
- target[0], prediction[0], text_theta
446
- )
447
- pred_float = _to_float(prediction[1]) # μˆ«μžμΈμ§€ 확인
448
- target_float = _to_float(target[1])
449
- if pred_float is not None and target_float:
450
- return key_metric * (
451
- 1 - _get_relative_distance(target_float, pred_float, number_theta) # 숫자면 μƒλŒ€μ  거리값 계산
452
- )
453
- elif target[1] == prediction[1]:
454
- return key_metric
455
- else:
456
- return key_metric * anls_metric(
457
- target[1], prediction[1], text_theta
458
- )
459
-
460
- def _table_datapoints_precision_recall_f1( # 찐 계산
461
- target_table,
462
- prediction_table,
463
- text_theta = 0.5,
464
- number_theta = 0.1,
465
- ):
466
- """Calculates matching similarity between two tables as dicts."""
467
- target_datapoints = list(_get_table_datapoints(target_table).items())
468
- prediction_datapoints = list(_get_table_datapoints(prediction_table).items())
469
- if not target_datapoints and not prediction_datapoints:
470
- return 1, 1, 1
471
- if not target_datapoints:
472
- return 0, 1, 0
473
- if not prediction_datapoints:
474
- return 1, 0, 0
475
- distance = []
476
- for t, _ in target_datapoints:
477
- distance.append(
478
- [
479
- 1 - anls_metric(t, p, text_theta)
480
- for p, _ in prediction_datapoints
481
  ]
482
  )
483
- cost_matrix = np.array(distance)
484
- row_ind, col_ind = optimize.linear_sum_assignment(cost_matrix)
485
- score = 0
486
- for r, c in zip(row_ind, col_ind):
487
- score += _get_datapoint_metric(
488
- target_datapoints[r], prediction_datapoints[c], text_theta, number_theta
489
- )
490
- if score == 0:
491
- return 0, 0, 0
492
- precision = score / len(prediction_datapoints)
493
- recall = score / len(target_datapoints)
494
- return precision, recall, 2 * precision * recall / (precision + recall)
495
-
496
- def table_datapoints_precision_recall_per_point( # 각각 계산...
497
- targets,
498
- predictions,
499
- text_theta = 0.5,
500
- number_theta = 0.1,
501
- ):
502
- """Computes precisin recall and F1 metrics given two flattened tables.
503
- Parses each string into a dictionary of keys and values using row and column
504
- headers. Then we match keys between the two dicts as long as their relative
505
- levenshtein distance is below a threshold. Values are also compared with
506
- ANLS if strings or relative distance if they are numeric.
507
- Args:
508
- targets: list of list of strings.
509
- predictions: list of strings.
510
- text_theta: relative edit distance above this is set to the maximum of 1.
511
- number_theta: relative error rate above this is set to the maximum of 1.
512
- Returns:
513
- Dictionary with per-point precision, recall and F1
514
- """
515
- assert len(targets) == len(predictions)
516
- per_point_scores = {"precision": [], "recall": [], "f1": []}
517
- for pred, target in zip(predictions, targets):
518
- all_metrics = []
519
- for transposed in [True, False]:
520
- pred_table = _parse_table(pred, transposed=transposed)
521
- target_table = _parse_table(target, transposed=transposed)
522
-
523
- all_metrics.extend([_table_datapoints_precision_recall_f1(target_table, pred_table, text_theta, number_theta)])
524
 
525
- p, r, f = max(all_metrics, key=lambda x: x[-1])
526
- per_point_scores["precision"].append(p)
527
- per_point_scores["recall"].append(r)
528
- per_point_scores["f1"].append(f)
529
- return per_point_scores
530
 
531
- def table_datapoints_precision_recall( # deplot μ„±λŠ₯μ§€ν‘œ
532
- targets,
533
- predictions,
534
- text_theta = 0.5,
535
- number_theta = 0.1,
536
- ):
537
- """Aggregated version of table_datapoints_precision_recall_per_point().
538
- Same as table_datapoints_precision_recall_per_point() but returning aggregated
539
- scores instead of per-point scores.
540
- Args:
541
- targets: list of list of strings.
542
- predictions: list of strings.
543
- text_theta: relative edit distance above this is set to the maximum of 1.
544
- number_theta: relative error rate above this is set to the maximum of 1.
545
- Returns:
546
- Dictionary with aggregated precision, recall and F1
547
- """
548
- score_dict = table_datapoints_precision_recall_per_point(
549
- targets, predictions, text_theta, number_theta
550
- )
551
- return {
552
- "table_datapoints_precision": (
553
- sum(score_dict["precision"]) / len(targets)
554
- ),
555
- "table_datapoints_recall": (
556
- sum(score_dict["recall"]) / len(targets)
557
- ),
558
- "table_datapoints_f1": sum(score_dict["f1"]) / len(targets),
559
- }
560
-
561
- def evaluate_rms(generated_table,label_table):
562
- predictions=[generated_table]
563
- targets=[label_table]
564
- RMS = table_datapoints_precision_recall(targets, predictions)
565
- return RMS
566
-
567
- def ko_deplot_convert_to_dataframe(generated_table_str):
568
- lines = generated_table_str.strip().split(" \n")
569
- headers=[]
570
- data=[]
571
- for i in range(len(lines[1].split(" | "))):
572
- headers.append(f"{i}")
573
- for line in lines[1:len(lines)-1]:
574
- data.append(line.split("| "))
575
- df = pd.DataFrame(data, columns=headers)
576
- return df
577
-
578
- def ko_deplot_convert_to_dataframe2(label_table_str):
579
- lines = label_table_str.strip().split(" \n")
580
- headers=[]
581
  data=[]
582
- for i in range(len(lines[1].split(" | "))):
583
- headers.append(f"{i}")
584
- for line in lines[1:]:
585
- data.append(line.split("| "))
586
- df = pd.DataFrame(data, columns=headers)
587
- return df
588
-
589
- def aihub_deplot_convert_to_dataframe(table_str):
590
- lines = table_str.strip().split("\n")
591
- headers = []
592
- if(len(lines[3].split(" | "))>len(lines[4].split(" | "))):
593
- category=lines[3].split(" | ")
594
- del category[0]
595
- value=lines[4].split(" | ")
596
- df=pd.DataFrame({"λ²”λ‘€":category,"κ°’":value})
597
- return df
598
- else:
599
- for i in range(len(lines[3].split(" | "))):
600
- headers.append(f"{i}")
601
- data = [line.split(" | ") for line in lines[3:]]
602
- df = pd.DataFrame(data, columns=headers)
603
- return df
604
- def unichart_convert_to_dataframe(table_str):
605
- lines=table_str.split(" & ")
606
- headers=[]
607
- data=[]
608
- del lines[0]
609
- for i in range(len(lines[1].split(" | "))):
610
- headers.append(f"{i}")
611
- if lines[0]=="value":
612
- for line in lines[1:]:
613
- data.append(line.split(" | "))
614
- else:
615
- category=lines[0].split(" | ")
616
- category.insert(0," ")
617
- data.append(category)
618
- for line in lines[1:]:
619
- data.append(line.split(" | "))
620
- df=pd.DataFrame(data,columns=headers)
621
- return df
622
-
623
- class Highlighter:
624
- def __init__(self):
625
- self.row = 0
626
- self.col = 0
627
-
628
- def compare_and_highlight(self, pred_table_elem, target_table, pred_table_row, props=''):
629
- if self.row >= pred_table_row:
630
- self.col += 1
631
- self.row = 0
632
- if pred_table_elem != target_table.iloc[self.row, self.col]:
633
- self.row += 1
634
- return props
635
- else:
636
- self.row += 1
637
- return None
638
-
639
- # 1. 데이터 λ‘œλ“œ
640
- aihub_deplot_result_df = pd.read_csv('./aihub_deplot_result.csv')
641
- ko_deplot_result= './ko-deplot-base-pred-epoch3-refinetuning.json'
642
- unichart_result='./unichart_results.json'
643
-
644
- # 2. 체크해야 ν•˜λŠ” 이미지 파일 λ‘œλ“œ
645
- def load_image_checklist(file):
646
- with open(file, 'r') as f:
647
- #image_names = [f'"{line.strip()}"' for line in f]
648
- image_names = f.read().splitlines()
649
- return image_names
650
-
651
- # 3. ν˜„μž¬ 인덱슀λ₯Ό μΆ”μ ν•˜κΈ° μœ„ν•œ λ³€μˆ˜
652
- current_index = 0
653
- image_names = []
654
- def show_image(current_idx):
655
- image_name=image_names[current_idx]
656
- image_path = f"./top_20_percent_images/{image_name}.jpg"
657
- if not os.path.exists(image_path):
658
- image_path = f"./bottom_20_percent_images/{image_name}.jpg"
659
- return Image.open(image_path)
660
-
661
- # 4. λ²„νŠΌ 클릭 이벀트 ν•Έλ“€λŸ¬
662
- def non_real_time_check(file):
663
- highlighter1 = Highlighter()
664
- highlighter2 = Highlighter()
665
- highlighter3 = Highlighter()
666
- #global image_names, current_index
667
- #image_names = load_image_checklist(file)
668
- #current_index = 0
669
- #image=show_image(current_index)
670
- file_name =image_names[current_index].replace("Source","Label")
671
-
672
- json_path="./ko_deplot_labeling_data.json"
673
- with open(json_path, 'r', encoding='utf-8') as file:
674
- json_data = json.load(file)
675
- for key, value in json_data.items():
676
- if key == file_name:
677
- ko_deplot_labeling_str=value.get("txt").replace("<0x0A>","\n")
678
- ko_deplot_label_title=ko_deplot_labeling_str.split(" \n ")[0].replace("TITLE | ","제λͺ©:")
679
- break
680
-
681
- ko_deplot_rms_path="./ko_deplot_rms.txt"
682
- unichart_rms_path="./unichart_rms.txt"
683
-
684
- json_path="./unichart_labeling_data.json"
685
- with open(json_path, 'r', encoding='utf-8') as file:
686
- json_data = json.load(file)
687
- for entry in json_data:
688
- if entry["imgname"]==image_names[current_index]+".jpg":
689
- unichart_labeling_str=entry["label"]
690
- unichart_label_title=entry["label"].split(" & ")[0].split(" | ")[1]
691
-
692
- with open(ko_deplot_rms_path,'r',encoding='utf-8') as file:
693
- lines=file.readlines()
694
- flag=0
695
- for line in lines:
696
- parts=line.strip().split(", ")
697
- if(len(parts)==2 and parts[0]==image_names[current_index]):
698
- ko_deplot_rms=parts[1]
699
- flag=1
700
- break
701
- if(flag==0):
702
- ko_deplot_rms="none"
703
-
704
- with open(unichart_rms_path,'r',encoding='utf-8') as file:
705
- lines=file.readlines()
706
- flag=0
707
- for line in lines:
708
- parts=line.strip().split(": ")
709
- if(len(parts)==2 and parts[0]==image_names[current_index]+".jpg"):
710
- unichart_rms=parts[1]
711
- flag=1
712
- break
713
- if(flag==0):
714
- unichart_rms="none"
715
-
716
-
717
-
718
- ko_deplot_generated_title,ko_deplot_generated_table=ko_deplot_display_results(current_index)
719
- aihub_deplot_generated_table,aihub_deplot_label_table,aihub_deplot_generated_title,aihub_deplot_label_title=aihub_deplot_display_results(current_index)
720
- unichart_generated_table,unichart_generated_title=unichart_display_results(current_index)
721
- #ko_deplot_RMS=evaluate_rms(ko_deplot_generated_table,ko_deplot_labeling_str)
722
- aihub_deplot_RMS=evaluate_rms(aihub_deplot_generated_table,aihub_deplot_label_table)
723
-
724
-
725
- if flag == 1:
726
- value = [round(float(ko_deplot_rms), 1)]
727
  else:
728
- value = [0]
729
-
730
- ko_deplot_score_table = pd.DataFrame({
731
- 'category': ['f1'],
732
- 'value': value
733
- })
734
-
735
- value=[round(float(unichart_rms)/100,1)]
736
- unichart_score_table=pd.DataFrame({
737
- 'category':['f1'],
738
- 'value':value
739
- })
740
- aihub_deplot_score_table=pd.DataFrame({
741
- 'category': ['precision', 'recall', 'f1'],
742
- 'value': [
743
- round(aihub_deplot_RMS['table_datapoints_precision'],1),
744
- round(aihub_deplot_RMS['table_datapoints_recall'],1),
745
- round(aihub_deplot_RMS['table_datapoints_f1'],1)
746
- ]
747
- })
748
-
749
- #ko_deplot_generated_df=ko_deplot_convert_to_dataframe(ko_deplot_generated_table)
750
- #aihub_deplot_generated_df=aihub_deplot_convert_to_dataframe(aihub_deplot_generated_table)
751
- #unichart_generated_df=unichart_convert_to_dataframe(unichart_generated_table)
752
-
753
  try:
754
- ko_deplot_generated_df=ko_deplot_convert_to_dataframe(ko_deplot_generated_table)
755
- unichart_generated_df=unichart_convert_to_dataframe(unichart_generated_table)
 
 
 
 
 
 
 
 
756
  except Exception as e:
757
- return None,None,None,None,None,None,None,None,None,ko_deplot_generated_table,unichart_generated_table,1
758
- ko_deplot_labeling_df=ko_deplot_convert_to_dataframe2(ko_deplot_labeling_str)
759
- #aihub_deplot_labeling_df=aihub_deplot_convert_to_dataframe(aihub_deplot_label_table)
760
- unichart_labeling_df=unichart_convert_to_dataframe(unichart_labeling_str)
761
-
762
- ko_deplot_generated_df_row=ko_deplot_generated_df.shape[0]
763
- #aihub_deplot_generated_df_row=aihub_deplot_generated_df.shape[0]
764
- unichart_generated_df_row=unichart_generated_df.shape[0]
765
-
766
-
767
- styled_ko_deplot_table=ko_deplot_generated_df.style.applymap(highlighter1.compare_and_highlight,target_table=ko_deplot_labeling_df,pred_table_row=ko_deplot_generated_df_row,props='color:red')
768
-
769
-
770
- #styled_aihub_deplot_table=aihub_deplot_generated_df.style.applymap(highlighter2.compare_and_highlight,target_table=aihub_deplot_labeling_df,pred_table_row=aihub_deplot_generated_df_row,props='color:red')
771
-
772
-
773
- styled_unichart_table=unichart_generated_df.style.applymap(highlighter3.compare_and_highlight,target_table=unichart_labeling_df,pred_table_row=unichart_generated_df_row,props='color:red')
774
-
775
- #return ko_deplot_convert_to_dataframe(ko_deplot_generated_table), aihub_deplot_convert_to_dataframe(aihub_deplot_generated_table), aihub_deplot_convert_to_dataframe(label_table), ko_deplot_score_table, aihub_deplot_score_table
776
- return gr.DataFrame(styled_ko_deplot_table,label=ko_deplot_generated_title+"(VAIV_DePlot μΆ”λ‘  κ²°κ³Ό)"),None,gr.DataFrame(styled_unichart_table,label="제λͺ©:"+unichart_generated_title+"(VAIV_UniChart μΆ”λ‘  κ²°κ³Ό)"),gr.DataFrame(ko_deplot_labeling_df,label=ko_deplot_label_title+"(VAIV_DePlot μ •λ‹΅ ν…Œμ΄λΈ”)"),None,gr.DataFrame(unichart_labeling_df,label="제λͺ©:"+unichart_label_title+"(VAIV_UniChart μ •λ‹΅ ν…Œμ΄λΈ”)"),ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table,None,None,0
777
-
778
-
779
- def ko_deplot_display_results(index):
780
- filename=image_names[index]+".jpg"
781
- with open(ko_deplot_result, 'r', encoding='utf-8') as f:
782
- data = json.load(f)
783
- for entry in data:
784
- if entry['filename'].endswith(filename):
785
- #return entry['table']
786
- parts=entry['table'].split("\n",1)
787
- return parts[0].replace("TITLE | ","제λͺ©:"),entry['table']
788
-
789
- def aihub_deplot_display_results(index):
790
- if index < 0 or index >= len(image_names):
791
- return "Index out of range", None, None
792
- image_name = image_names[index]
793
- image_row = aihub_deplot_result_df[aihub_deplot_result_df['data_id'] == image_name]
794
- if not image_row.empty:
795
- generated_table = image_row['generated_table'].values[0]
796
- generated_title=generated_table.split("\n")[1]
797
- label_table = image_row['label_table'].values[0]
798
- label_title=label_table.split("\n")[1]
799
- return generated_table, label_table, generated_title, label_title
800
- else:
801
- return "No results found for the image", None, None
802
- def unichart_display_results(index):
803
- image_name=image_names[index]
804
- with open(unichart_result,'r',encoding='utf-8') as f:
805
- data=json.load(f)
806
- for entry in data:
807
- if entry['imgname']==image_name+".jpg":
808
- return entry['label'],entry['label'].split(" & ")[0].split(" | ")[1]
809
-
810
- def previous_image():
811
- global current_index
812
- if current_index>0:
813
- current_index-=1
814
- image=show_image(current_index)
815
- return image, image_names[current_index],gr.update(interactive=current_index>0), gr.update(interactive=current_index<len(image_names)-1)
816
-
817
- def next_image():
818
- global current_index
819
- if current_index<len(image_names)-1:
820
- current_index+=1
821
- image=show_image(current_index)
822
- return image, image_names[current_index],gr.update(interactive=current_index>0), gr.update(interactive=current_index<len(image_names)-1)
823
 
824
  def real_time_check(image_file):
825
- highlighter1 = Highlighter()
826
- highlighter2 = Highlighter()
827
- highlighter3=Highlighter()
828
  image = Image.open(image_file)
829
-
830
- result_model1 = predict_model1(image)
831
- parts=result_model1.split("\n")
832
  del parts[-1]
833
- result_model1="\n".join(parts)
834
- ko_deplot_generated_title=result_model1.split("\n")[0].split(" | ")[1]
835
- #ko_deplot_table=ko_deplot_convert_to_dataframe2(result_model1)
836
-
837
- result_model3=predict_model3(image)
838
- #unichart_table=unichart_convert_to_dataframe(result_model3)
839
- unichart_generated_title=result_model3.split(" & ")[0].split(" | ")[1]
840
-
841
  try:
842
- ko_deplot_table=ko_deplot_convert_to_dataframe2(result_model1)
843
- unichart_table=unichart_convert_to_dataframe(result_model3)
 
844
  except Exception as e:
845
- return None,None,None,None,None,None,None,None,None,result_model1,result_model3,1
846
-
847
- #aihub_labeling_data_json="./labeling_data/"+file_name+".json"
848
- if os.path.basename(image_file.name).startswith("C_Source"):
849
- image_base_name = os.path.basename(image_file.name).replace("Source","Label")
850
- file_name, _ = os.path.splitext(image_base_name)
851
- json_path="./ko_deplot_labeling_data.json"
852
- with open(json_path, 'r', encoding='utf-8') as file:
853
- json_data = json.load(file)
854
- for key, value in json_data.items():
855
- if key == file_name:
856
- ko_deplot_labeling_str=value.get("txt").replace("<0x0A>","\n")
857
- ko_deplot_label_title=ko_deplot_labeling_str.split(" \n ")[0].split(" | ")[1]
858
- break
859
-
860
- ko_deplot_label_table=ko_deplot_convert_to_dataframe2(ko_deplot_labeling_str)
861
-
862
- #aihub_deplot_labeling_str=process_json_file2(aihub_labeling_data_json)
863
- #aihub_deplot_label_title=aihub_deplot_labeling_str.split("\n")[1].split(":")[1]
864
-
865
- json_path="./unichart_labeling_data.json"
866
- with open(json_path, 'r', encoding='utf-8') as file:
867
- json_data = json.load(file)
868
- for entry in json_data:
869
- if entry["imgname"]==os.path.basename(image_file.name):
870
- unichart_labeling_str=entry["label"]
871
- unichart_label_title=entry["label"].split(" & ")[0].split(" | ")[1]
872
- unichart_label_table=unichart_convert_to_dataframe(unichart_labeling_str)
873
-
874
- ko_deplot_RMS=evaluate_rms(result_model1,ko_deplot_labeling_str)
875
- unichart_RMS=evaluate_rms(result_model3.replace("Characteristic","Title").replace("&","\n"),unichart_labeling_str.replace("Characteristic","Title").replace("&","\n"))
876
- ko_deplot_score_table=pd.DataFrame({
877
- 'category': ['precision', 'recall', 'f1'],
878
- 'value': [
879
- round(ko_deplot_RMS['table_datapoints_precision'],1),
880
- round(ko_deplot_RMS['table_datapoints_recall'],1),
881
- round(ko_deplot_RMS['table_datapoints_f1'],1)
882
- ]
883
- })
884
- unichart_score_table=pd.DataFrame({
885
- 'category': ['precision', 'recall', 'f1'],
886
- 'value': [
887
- round(unichart_RMS['table_datapoints_precision'],1),
888
- round(unichart_RMS['table_datapoints_recall'],1),
889
- round(unichart_RMS['table_datapoints_f1'],1)
890
- ]
891
- })
892
 
893
- ko_deplot_generated_df_row=ko_deplot_table.shape[0]
894
- unichart_generated_df_row=unichart_table.shape[0]
895
- styled_ko_deplot_table=ko_deplot_table.style.applymap(highlighter1.compare_and_highlight,target_table=ko_deplot_label_table,pred_table_row=ko_deplot_generated_df_row,props='color:red')
896
- styled_unichart_table=unichart_table.style.applymap(highlighter3.compare_and_highlight,target_table=unichart_label_table,pred_table_row=unichart_generated_df_row,props='color:red')
897
- return gr.DataFrame(styled_ko_deplot_table,label=ko_deplot_generated_title+"(VAIV_DePlot μΆ”λ‘  κ²°κ³Ό)") ,None,gr.DataFrame(styled_unichart_table,label=unichart_generated_title+"(VAIV_UniChart μΆ”λ‘  κ²°κ³Ό)"),gr.DataFrame(ko_deplot_label_table,label=ko_deplot_label_title+"(VAIV_DePlot μ •λ‹΅ ν…Œμ΄λΈ”)"),None,gr.DataFrame(unichart_label_table,label=unichart_label_title+"(VAIV_UniChart μ •λ‹΅ ν…Œμ΄λΈ”)"),ko_deplot_score_table,None,unichart_score_table,None,None,0
898
- else:
899
- return gr.DataFrame(ko_deplot_table,label=ko_deplot_generated_title+"(VAIV_DePlot μΆ”λ‘  κ²°κ³Ό)"),None,gr.DataFrame(unichart_table,label=unichart_generated_title+"(VAIV_UniChart μΆ”λ‘  κ²°κ³Ό)"),None,None,None,None,None,None,None,None,0
900
- def inference(mode,image_uploader,file_uploader):
901
- if(mode=="이미지 μ—…λ‘œλ“œ"):
902
- ko_deplot_table, aihub_deplot_table, unichart_table, ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table,ko_deplot_generated_txt,unichart_generated_txt,flag= real_time_check(image_uploader)
903
  if flag==1:
904
- return ko_deplot_table, aihub_deplot_table, unichart_table,ko_deplot_label_table, aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table,gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(unichart_generated_txt,visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False)
905
  else:
906
- return ko_deplot_table, aihub_deplot_table, unichart_table,ko_deplot_label_table, aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table,gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False)
907
  else:
908
- styled_ko_deplot_table,styled_aihub_deplot_table,styled_unichart_table,ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table,aihub_deplot_score_table, unichart_score_table,ko_deplot_generated_txt,unichart_generated_txt,flag=non_real_time_check(file_uploader)
909
  if flag==1:
910
- return styled_ko_deplot_table, styled_aihub_deplot_table, styled_unichart_table,ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table, aihub_deplot_score_table, unichart_score_table,gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(unichart_generated_txt,visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False)
911
  else:
912
- return styled_ko_deplot_table, styled_aihub_deplot_table, styled_unichart_table,ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table,ko_deplot_score_table, aihub_deplot_score_table, unichart_score_table,gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False)
913
- def interface_selector(selector):
914
- if selector == "이미지 μ—…λ‘œλ“œ":
915
- return gr.update(visible=True),gr.update(visible=False),gr.State("image_upload"),gr.update(visible=False),gr.update(visible=False),gr.File("./new_top_20_percent_images.txt"),"high score 차트"
916
- elif selector == "파일 μ—…λ‘œλ“œ":
917
- return gr.update(visible=False),gr.update(visible=True),gr.State("file_upload"), gr.update(visible=True),gr.update(visible=True),gr.File("./new_top_20_percent_images.txt"),"high score 차트"
918
-
919
- def file_selector(selector):
920
- if selector == "low score 차트":
921
- return gr.File("./new_bottom_20_percent_images.txt"),"전체"
922
- elif selector == "high score 차트":
923
- return gr.File("./new_top_20_percent_images.txt"),"전체"
924
- '''
925
- def update_results(model_type):
926
- if "ko_deplot" == model_type:
927
- return gr.update(visible=True),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False)
928
- elif "aihub_deplot" == model_type:
929
- return gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=False)
930
- elif "unichart"==model_type:
931
- return gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True),gr.update(visible=True),gr.update(visible=False),gr.update(visible=False),gr.update(visible=True)
932
- else:
933
- return gr.update(visible=True), gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True),gr.update(visible=True)
934
- '''
935
-
936
- def update_results(selected_models):
937
  # Create a visibility list initialized to False for all components
938
- visibility = [False] * 9
939
-
940
  # Update visibility based on the selected models
941
  if "VAIV_DePlot" in selected_models:
942
- visibility[0] = True # ko_deplot_generated_table
943
- visibility[1] = True # ko_deplot_score_table
944
- visibility[6] = True # ko_deplot_label_table
945
- '''
946
- if "aihub_deplot" in selected_models:
947
- visibility[2] = True # aihub_deplot_generated_table
948
- visibility[3] = True # aihub_deplot_score_table
949
- visibility[7] = True # aihub_deplot_label_table
950
- '''
951
- if "VAIV_UniChart" in selected_models:
952
- visibility[4] = True # unichart_generated_table
953
- visibility[5] = True # unichart_score_table
954
- visibility[8] = True # unichart_label_table
955
-
956
  if "all" in selected_models:
957
- visibility[0] = True # ko_deplot_generated_table
958
- visibility[1] = True # ko_deplot_score_table
959
- visibility[6] = True # ko_deplot_label_table
960
- visibility[4] = True # unichart_generated_table
961
- visibility[5] = True # unichart_score_table
962
- visibility[8] = True # unichart_label_table
963
-
 
964
  # Return gr.update for each component with the corresponding visibility status
965
  return tuple(gr.update(visible=v) for v in visibility)
966
 
 
 
 
 
 
967
 
968
  def display_image(image_file):
969
  image=Image.open(image_file)
970
  return image, os.path.basename(image_file)
971
 
972
- def display_image_in_file(image_checklist):
973
- global image_names, current_index
974
- image_names = load_image_checklist(image_checklist)
975
- image=show_image(current_index)
976
- return image,image_names[current_index]
977
-
978
- def update_file_based_on_chart_type(chart_type, all_file_path):
979
- with open(all_file_path, 'r', encoding='utf-8') as file:
980
- lines = file.readlines()
981
- filtered_lines=[]
982
- if chart_type == "전체":
983
- filtered_lines = lines
984
- elif chart_type == "일반 κ°€λ‘œ λ§‰λŒ€ν˜•":
985
- filtered_lines = [line for line in lines if "_horizontal bar_standard" in line]
986
- elif chart_type=="λˆ„μ  κ°€λ‘œ λ§‰λŒ€ν˜•":
987
- filtered_lines = [line for line in lines if "_horizontal bar_accumulation" in line]
988
- elif chart_type=="100% κΈ°μ€€ λˆ„μ  κ°€λ‘œ λ§‰λŒ€ν˜•":
989
- filtered_lines = [line for line in lines if "_horizontal bar_100per accumulation" in line]
990
- elif chart_type=="일반 μ„Έλ‘œ λ§‰λŒ€ν˜•":
991
- filtered_lines = [line for line in lines if "_vertical bar_standard" in line]
992
- elif chart_type=="λˆ„μ  μ„Έλ‘œ λ§‰λŒ€ν˜•":
993
- filtered_lines = [line for line in lines if "_vertical bar_accumulation" in line]
994
- elif chart_type=="100% κΈ°μ€€ λˆ„μ  μ„Έλ‘œ λ§‰λŒ€ν˜•":
995
- filtered_lines = [line for line in lines if "_vertical bar_100per accumulation" in line]
996
- elif chart_type=="μ„ ν˜•":
997
- filtered_lines = [line for line in lines if "_line_standard" in line]
998
- elif chart_type=="μ›ν˜•":
999
- filtered_lines = [line for line in lines if "_pie_standard" in line]
1000
- elif chart_type=="기타 λ°©μ‚¬ν˜•":
1001
- filtered_lines = [line for line in lines if "_etc_radial" in line]
1002
- elif chart_type=="기타 ν˜Όν•©ν˜•":
1003
- filtered_lines = [line for line in lines if "_etc_mix" in line]
1004
- # μƒˆλ‘œμš΄ νŒŒμΌμ— 기둝
1005
- new_file_path = "./filtered_chart_images.txt"
1006
- with open(new_file_path, 'w', encoding='utf-8') as file:
1007
- file.writelines(filtered_lines)
1008
-
1009
- return new_file_path
1010
 
1011
- def handle_chart_type_change(chart_type,all_file_path):
1012
- new_file_path = update_file_based_on_chart_type(chart_type, all_file_path)
1013
- global image_names, current_index
1014
- image_names = load_image_checklist(new_file_path)
1015
- current_index=0
1016
- image=show_image(current_index)
1017
- return image,image_names[current_index]
 
 
 
 
 
 
 
 
 
 
1018
 
1019
  css = """
1020
  .dataframe-class {
1021
- height: 300px; /* 높이λ₯Ό κ³ μ • */
1022
  overflow-y: auto !important; /* μŠ€ν¬λ‘€μ„ κ°€λŠ₯ν•˜κ²Œ */
 
1023
  }
1024
  """
1025
 
1026
  with gr.Blocks(css=css) as iface:
1027
- mode=gr.State("image_upload")
 
 
1028
  with gr.Row():
1029
  with gr.Column():
1030
- #mode_label=gr.Text("이미지 μ—…λ‘œλ“œκ°€ μ„ νƒλ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
1031
- upload_option = gr.Radio(choices=["이미지 μ—…λ‘œλ“œ", "파일 μ—…λ‘œλ“œ"], value="이미지 μ—…λ‘œλ“œ", label="μ—…λ‘œλ“œ μ˜΅μ…˜")
1032
- #with gr.Row():
1033
- #image_button = gr.Button("이미지 μ—…λ‘œλ“œ")
1034
- #file_button = gr.Button("파일 μ—…λ‘œλ“œ")
1035
-
1036
- # 이미지와 파일 μ—…λ‘œλ“œ μ»΄ν¬λ„ŒνŠΈ (μ΄ˆκΈ°μ—λŠ” μˆ¨κΉ€ μƒνƒœ)
1037
- # global image_uploader,file_uploader
1038
- image_uploader= gr.File(file_count="single",file_types=["image"],visible=True)
1039
- file_uploader= gr.File(file_count="single", file_types=[".txt"], visible=False)
1040
- file_upload_option=gr.Radio(choices=["low score 차트","high score 차트"],label="파일 μ—…λ‘œλ“œ μ˜΅μ…˜",visible=False)
1041
- chart_type = gr.Dropdown(["일반 κ°€λ‘œ λ§‰λŒ€ν˜•","λˆ„μ  κ°€λ‘œ λ§‰λŒ€ν˜•","100% κΈ°μ€€ λˆ„μ  κ°€λ‘œ λ§‰λŒ€ν˜•", "일반 μ„Έλ‘œ λ§‰λŒ€ν˜•","λˆ„μ  μ„Έλ‘œ λ§‰λŒ€ν˜•","100% κΈ°μ€€ λˆ„μ  μ„Έλ‘œ λ§‰λŒ€ν˜•","μ„ ν˜•", "μ›ν˜•", "기타 λ°©μ‚¬ν˜•", "기타 ν˜Όν•©ν˜•", "전체"], label="Chart Type", value="all")
1042
- model_type=gr.Dropdown(["VAIV_DePlot","VAIV_UniChart","all"],value="VAIV_DePlot",label="model",multiselect=True)
1043
- image_displayer=gr.Image(visible=True)
1044
  with gr.Row():
1045
- pre_button=gr.Button("이전",interactive="False")
1046
- next_button=gr.Button("λ‹€μŒ")
1047
- image_name=gr.Text("이미지 이름",visible=False)
1048
- #image_button.click(interface_selector, inputs=gr.State("이미지 μ—…λ‘œλ“œ"), outputs=[image_uploader,file_uploader,mode,mode_label,image_name])
1049
- #file_button.click(interface_selector, inputs=gr.State("파일 μ—…λ‘œλ“œ"), outputs=[image_uploader, file_uploader,mode,mode_label,image_name])
1050
- inference_button=gr.Button("μΆ”λ‘ ")
1051
- with gr.Column():
1052
- ko_deplot_generated_table=gr.DataFrame(visible=True,label="VAIV_DePlot μΆ”λ‘  κ²°κ³Ό",elem_classes="dataframe-class")
1053
- aihub_deplot_generated_table=gr.DataFrame(visible=False,label="aihub-deplot μΆ”λ‘  κ²°κ³Ό",elem_classes="dataframe-class")
1054
- unichart_generated_table=gr.DataFrame(visible=False,label="VAIV_UniChart μΆ”λ‘  κ²°κ³Ό",elem_classes="dataframe-class")
1055
- ko_deplot_generated_txt=gr.Text(visible=False,label="VAIV_DePlot μΆ”λ‘  κ²°κ³Ό")
1056
- unichart_generated_txt=gr.Text(visible=False,label="VAIV_UniChart μΆ”λ‘  κ²°κ³Ό")
1057
- with gr.Column():
1058
- ko_deplot_label_table=gr.DataFrame(visible=True,label="VAIV_DePlot μ •λ‹΅ν…Œμ΄λΈ”",elem_classes="dataframe-class")
1059
- aihub_deplot_label_table=gr.DataFrame(visible=False,label="aihub-deplot μ •λ‹΅ν…Œμ΄λΈ”",elem_classes="dataframe-class")
1060
- unichart_label_table=gr.DataFrame(visible=False,label="VAIV_UniChart μ •λ‹΅ν…Œμ΄λΈ”",elem_classes="dataframe-class")
1061
  with gr.Column():
1062
- ko_deplot_score_table=gr.DataFrame(visible=True,label="VAIV_DePlot 점수",elem_classes="dataframe-class")
1063
- aihub_deplot_score_table=gr.DataFrame(visible=False,label="aihub_deplot 점수",elem_classes="dataframe-class")
1064
- unichart_score_table=gr.DataFrame(visible=False,label="VAIV_UniChart 점수",elem_classes="dataframe-class")
1065
- model_type.change(
1066
- update_results,
1067
- inputs=[model_type],
1068
- outputs=[ko_deplot_generated_table,ko_deplot_score_table,aihub_deplot_generated_table,aihub_deplot_score_table,unichart_generated_table,unichart_score_table,ko_deplot_label_table,aihub_deplot_label_table,unichart_label_table]
1069
- )
1070
-
1071
- upload_option.change(
1072
- interface_selector,
1073
- inputs=[upload_option],
1074
- outputs=[image_uploader, file_uploader, mode, image_name,file_upload_option,file_uploader,file_upload_option]
1075
- )
1076
 
1077
- file_upload_option.change(
1078
- file_selector,
1079
- inputs=[file_upload_option],
1080
- outputs=[file_uploader,chart_type]
 
 
 
 
 
 
1081
  )
1082
 
1083
- chart_type.change(handle_chart_type_change, inputs=[chart_type,file_uploader],outputs=[image_displayer,image_name])
1084
  image_uploader.upload(display_image,inputs=[image_uploader],outputs=[image_displayer,image_name])
1085
- file_uploader.change(display_image_in_file,inputs=[file_uploader],outputs=[image_displayer,image_name])
1086
- pre_button.click(previous_image, outputs=[image_displayer,image_name,pre_button,next_button])
1087
- next_button.click(next_image, outputs=[image_displayer,image_name,pre_button,next_button])
1088
- inference_button.click(inference,inputs=[upload_option,image_uploader,file_uploader],outputs=[ko_deplot_generated_table, aihub_deplot_generated_table, unichart_generated_table, ko_deplot_label_table, aihub_deplot_label_table, unichart_label_table, ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table,ko_deplot_generated_txt,unichart_generated_txt,ko_deplot_generated_table, aihub_deplot_generated_table, unichart_generated_table, ko_deplot_label_table, aihub_deplot_label_table, unichart_label_table, ko_deplot_score_table, aihub_deplot_score_table,unichart_score_table])
1089
 
1090
- if __name__ == "__main__":
1091
- print("Launching Gradio interface...")
1092
- sys.stdout.flush() # stdout 버퍼λ₯Ό λΉ„μ›λ‹ˆλ‹€.
1093
- iface.launch(share=True)
1094
- #iface.launch(share=False,server_name="115.145.230.14",server_port=8080)
1095
- time.sleep(2) # Gradio URL이 좜λ ₯될 λ•ŒκΉŒμ§€ μž μ‹œ κΈ°λ‹€λ¦½λ‹ˆλ‹€.
1096
- sys.stdout.flush() # λ‹€μ‹œ stdout 버퍼λ₯Ό λΉ„μ›λ‹ˆλ‹€.
1097
- # Gradioκ°€ μ œκ³΅ν•˜λŠ” URLs을 νŒŒμΌμ— κΈ°λ‘ν•©λ‹ˆλ‹€.
1098
- with open("gradio_url.log", "w") as f:
1099
- print(iface.local_url, file=f)
1100
- print(iface.share_url, file=f)
 
20
  import logging
21
  import subprocess
22
  import spaces
23
+ import openai
24
+ import base64
25
+ from io import StringIO
26
 
27
  # Git LFS pull λͺ…λ Ήμ–΄ μ‹€ν–‰
28
  result = subprocess.run(['git', 'lfs', 'pull'], capture_output=True, text=True)
 
39
  warnings.filterwarnings('ignore')
40
  MAX_PATCHES = 512
41
  # Load the models and processor
 
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
 
44
  # Paths to the models
45
+ ko_deplot_model_path = './deplot_model_ver_24.11.21_korean_only(exclude NUUA)_epoch1.bin'
 
 
46
 
47
  # Load first model ko-deplot
 
48
  def load_model1():
49
  processor1 = Pix2StructProcessor.from_pretrained('nuua/ko-deplot')
50
  model1 = Pix2StructForConditionalGeneration.from_pretrained('nuua/ko-deplot')
51
  model1.load_state_dict(torch.load(ko_deplot_model_path, map_location="cpu"))
52
  model1.to(torch.device("cuda"))
53
+ return processor1, model1
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ processor1, model1 = load_model1()
56
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Function to format output
58
  def format_output(prediction):
59
  return prediction.replace('<0x0A>', '\n')
60
 
61
+ # First model prediction: ko-deplot
 
62
  def predict_model1(image):
63
  images = [image]
64
  inputs = processor1(images=images, text="What is the title of the chart", return_tensors="pt", padding=True)
 
72
  formatted_output = format_output(outputs[0])
73
  return formatted_output
74
 
75
+ # Set your OpenAI API key
76
+ openai.api_key = "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"
77
+
78
+ # Function to encode the image as base64
79
+ def encode_image(image_path):
80
+ with open(image_path, "rb") as image_file:
81
+ return base64.b64encode(image_file.read()).decode("utf-8")
82
+
83
+ # Second model prediction: gpt-4o-mini
84
+ def predict_model2(image):
85
+ # Encode the uploaded image to base64
86
+ image_data = encode_image(image)
87
+
88
+ # Prepare the request content
89
+ response = openai.ChatCompletion.create(
90
+ model="gpt-4o-mini",
91
+ messages=[
92
+ {
93
+ "role": "user",
94
+ "content": [
95
+ {
96
+ "type": "text",
97
+ "text": "please extract chart title and chart data manually and present them as a table. you should only provide title and table without adding any additional comments such as **Chart Title:** ."
98
+ },
99
+ {
100
+ "type": "image_url",
101
+ "image_url": {
102
+ "url": f"data:image/jpeg;base64,{image_data}"
103
+ }
104
+ }
105
+ ]
106
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  ]
108
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
+ # Return the table data from the response
111
+ return response.choices[0]["message"]["content"]
 
 
 
112
 
113
+ def ko_deplot_convert_to_dataframe(label_table_str): #function that converts text generated by ko-deplot to pandas dataframe
114
+ lines = label_table_str.strip().split("\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  data=[]
116
+ title= lines[0].split(" | ")[1]
117
+
118
+ if(len(lines[1].split("|")) == len(lines[2].split("|"))):
119
+ headers=lines[1].split(" | ")
120
+ for line in lines[2:]:
121
+ data.append(line.split(" | "))
122
+ df = pd.DataFrame(data, columns=headers)
123
+ return df, title
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  else:
125
+ legend_row=lines[1].split("|")
126
+ legend_row.insert(0," ")
127
+ for line in lines[2:]:
128
+ data.append(line.split(" | "))
129
+ df = pd.DataFrame(data, columns=legend_row)
130
+ return df, title
131
+
132
+ def gpt_convert_to_dataframe(table_text): #function that converts text generated by gpt to pandas dataframe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  try:
134
+ # Split the text into lines
135
+ lines = table_text.strip().split("\n")
136
+ title=lines[0]
137
+ lines.pop(1)
138
+ lines.pop(2)
139
+ # Process the remaining lines to create the DataFrame
140
+ data = [line.split("|")[1:-1] for line in lines[1:]] # Split by | and remove empty first/last items
141
+ dataframe = pd.DataFrame(data[1:], columns=[col.strip() for col in data[0]]) # Use the first row as headers
142
+
143
+ return dataframe, title
144
  except Exception as e:
145
+ return f"Error converting table to DataFrame: {e}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  def real_time_check(image_file):
 
 
 
148
  image = Image.open(image_file)
149
+ ko_deplot_generated_txt = predict_model1(image)
150
+ parts=ko_deplot_generated_txt.split("\n")
 
151
  del parts[-1]
152
+ ko_deplot_generated_txt="\n".join(parts)
153
+ gpt_generated_txt=predict_model2(image_file)
 
 
 
 
 
 
154
  try:
155
+ ko_deplot_generated_df, ko_deplot_generated_title=ko_deplot_convert_to_dataframe(ko_deplot_generated_txt)
156
+ gpt_generated_df, gpt_generated_title=gpt_convert_to_dataframe(gpt_generated_txt)
157
+ return gr.DataFrame(ko_deplot_generated_df, label= ko_deplot_generated_title), gr.DataFrame(gpt_generated_df, label= gpt_generated_title), None,None,0
158
  except Exception as e:
159
+ return None,None,ko_deplot_generated_txt,gpt_generated_txt,1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ flag = 0 #flag to check whether exception happens or not. if flag is 1, it means that exception(generated txt cannot be converted to pandas dataframe) happens.
162
+ def inference(image_uploader,mode_selector):
163
+ if(mode_selector=="파일 μ—…λ‘œλ“œ"):
164
+ ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_uploader)
 
 
 
 
 
 
165
  if flag==1:
166
+ return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True)
167
  else:
168
+ return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False)
169
  else:
170
+ ko_deplot_generated_df, gpt_generated_df,ko_deplot_generated_txt, gpt_generated_txt, flag= real_time_check(image_files[current_image_index])
171
  if flag==1:
172
+ return gr.update(visible=False), gr.update(visible=False), gr.Text(ko_deplot_generated_txt,visible=True),gr.Text(gpt_generated_txt,visible=True)
173
  else:
174
+ return ko_deplot_generated_df, gpt_generated_df, gr.update(visible=False),gr.update(visible=False)
175
+
176
+ def toggle_model(selected_models,flag):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  # Create a visibility list initialized to False for all components
178
+ visibility = [False] * 6
 
179
  # Update visibility based on the selected models
180
  if "VAIV_DePlot" in selected_models:
181
+ visibility[4]= True
182
+ if flag:
183
+ visibility[2]= True
184
+ else:
185
+ visibility[0]= True
186
+ if "gpt-4o-mini" in selected_models:
187
+ visibility[5]= True
188
+ if flag:
189
+ visibility[3]= True
190
+ else:
191
+ visibility[1]= True
 
 
 
192
  if "all" in selected_models:
193
+ visibility[4]=True
194
+ visibility[5]=True
195
+ if flag:
196
+ visibility[2]= True
197
+ visibility[3]= True
198
+ else:
199
+ visibility[0]= True
200
+ visibility[1]= True
201
  # Return gr.update for each component with the corresponding visibility status
202
  return tuple(gr.update(visible=v) for v in visibility)
203
 
204
+ def toggle_mode(mode):
205
+ if mode == "파일 μ—…λ‘œλ“œ":
206
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
207
+ else:
208
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
209
 
210
  def display_image(image_file):
211
  image=Image.open(image_file)
212
  return image, os.path.basename(image_file)
213
 
214
+ # Function to display the images in the folder sequentially
215
+ image_files = []
216
+ current_image_index = 0
217
+ image_files_cnt=0
218
+
219
+ def display_folder_images(image_file_path_list):
220
+ global image_files, current_image_index,image_files_cnt
221
+ image_files = image_file_path_list
222
+ image_files_cnt=len(image_files)
223
+ current_image_index = 0
224
+ if image_files:
225
+ return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=False), gr.update(interactive=True)
226
+ return None, "No images found"
227
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ def next_image():
230
+ global current_image_index
231
+ if image_files:
232
+ current_image_index = (current_image_index + 1)
233
+ prev_disabled = current_image_index == 0
234
+ next_disabled = current_image_index == (len(image_files) - 1)
235
+ return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled)
236
+ return None, "No images found"
237
+
238
+ def prev_image():
239
+ global current_image_index
240
+ if image_files:
241
+ current_image_index = (current_image_index - 1)
242
+ prev_disabled = current_image_index == 0
243
+ next_disabled = current_image_index == (len(image_files) - 1)
244
+ return Image.open(image_files[current_image_index]), os.path.basename(image_files[current_image_index]), gr.update(interactive=not prev_disabled), gr.update(interactive= not next_disabled)
245
+ return None, "No images found"
246
 
247
  css = """
248
  .dataframe-class {
 
249
  overflow-y: auto !important; /* μŠ€ν¬λ‘€μ„ κ°€λŠ₯ν•˜κ²Œ */
250
+ height: 250px
251
  }
252
  """
253
 
254
  with gr.Blocks(css=css) as iface:
255
+ with gr.Row():
256
+ gr.Markdown("<h1 style='text-align: center;'>SKKU-VAIV Automatic chart understanding evaluation tool</h1>")
257
+ gr.Markdown("<hr style='border: 1px solid #ddd;' />")
258
  with gr.Row():
259
  with gr.Column():
260
+ mode_selector = gr.Radio(["파일 μ—…λ‘œλ“œ", "폴더 μ—…λ‘œλ“œ"], label="Upload Mode", value="파일 μ—…λ‘œλ“œ")
261
+ image_uploader = gr.File(file_count="single", file_types=["image"], visible=True)
262
+ folder_uploader = gr.File(file_count="directory", file_types=["image"], visible=False, height=50)
263
+ model_type=gr.Dropdown(["VAIV_DePlot","gpt-4o-mini","all"],value="VAIV_DePlot",label="model",multiselect=True)
264
+ image_displayer = gr.Image(visible=True)
265
+ image_name = gr.Text("", visible=True)
 
 
 
 
 
 
 
 
266
  with gr.Row():
267
+ prev_button = gr.Button("이전", visible=False, interactive=False)
268
+ next_button = gr.Button("λ‹€μŒ", visible=False, interactive=False)
269
+ inference_button = gr.Button("μΆ”λ‘ ")
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  with gr.Column():
271
+ md1 = gr.Markdown("# VAIV_DePlot Inference Result")
272
+ ko_deplot_generated_df = gr.DataFrame(visible=True, elem_classes="dataframe-class")
273
+ ko_deplot_generated_txt = gr.Text(visible=False)
274
+ with gr.Column():
275
+ md2 = gr.Markdown("# gpt-4o-mini Inference Result", visible=False)
276
+ gpt_generated_df = gr.DataFrame(visible=False, elem_classes="dataframe-class")
277
+ gpt_generated_txt = gr.Text(visible=False)
278
+ #label_df = gr.DataFrame(visible=False, label="Ground Truth Table", elem_classes="dataframe-class",scale=1)
 
 
 
 
 
 
279
 
280
+ model_type.change(
281
+ toggle_model,
282
+ inputs=[model_type, gr.State(flag)],
283
+ outputs=[ko_deplot_generated_df,gpt_generated_df,ko_deplot_generated_txt,gpt_generated_txt,md1,md2]
284
+ )
285
+
286
+ mode_selector.change(
287
+ toggle_mode,
288
+ inputs=[mode_selector],
289
+ outputs=[image_uploader, folder_uploader, prev_button, next_button]
290
  )
291
 
 
292
  image_uploader.upload(display_image,inputs=[image_uploader],outputs=[image_displayer,image_name])
293
+ folder_uploader.upload(display_folder_images, inputs=[folder_uploader], outputs=[image_displayer, image_name, prev_button, next_button])
294
+ prev_button.click(prev_image, outputs=[image_displayer, image_name, prev_button, next_button])
295
+ next_button.click(next_image, outputs=[image_displayer, image_name, prev_button, next_button])
296
+ inference_button.click(inference,inputs=[image_uploader,mode_selector],outputs=[ko_deplot_generated_df, gpt_generated_df, ko_deplot_generated_txt, gpt_generated_txt])
297
 
298
+ if __name__ == "__main__":
299
+ iface.launch(share=True)