File size: 4,470 Bytes
9c909e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import pandas as pd
from PIL import Image
from torchmetrics.text import CharErrorRate

# Finetuned model
model_finetune_1 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_1")
model_finetune_2 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_15000")
model_finetune_3 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_30000")
model_finetune_4 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_45000")
model_finetune_5 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_60000")
model_finetune_6 = VisionEncoderDecoderModel.from_pretrained("hadrakey/alphapen_new_large_70000")

#Baseline
model_base = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")

# Checked label
df_path = "/mnt/data1/Datasets/AlphaPen/" + "testing_data.csv"
data = pd.read_csv(df_path)
data.dropna(inplace=True)
data.reset_index(inplace=True)
sample = data.iloc[:50,:]

root_dir = "/mnt/data1/Datasets/OCR/Alphapen/clean_data/"

inf_baseline = []
inf_finetune_1 = []
inf_finetune_2 = []
inf_finetune_3 = []
inf_finetune_4 = []
inf_finetune_5 = []
inf_finetune_6 = []

cer_fine_1 = []
cer_fine_2 = []
cer_fine_3 = []
cer_fine_4 = []
cer_fine_5 = []
cer_fine_6 = []
cer_base = []

cer_metric = CharErrorRate()

for idx in range(len(sample)):
    image = Image.open(root_dir + "final_cropped_rotated_" + sample.filename[idx]).convert("RGB")
    
    pixel_values = processor(image, return_tensors="pt").pixel_values
    generated_ids_base = model_base.generate(pixel_values)
    generated_ids_fine_1 = model_finetune_1.generate(pixel_values)
    generated_ids_fine_2= model_finetune_2.generate(pixel_values)
    generated_ids_fine_3 = model_finetune_3.generate(pixel_values)
    generated_ids_fine_4 = model_finetune_4.generate(pixel_values)
    generated_ids_fine_5 = model_finetune_5.generate(pixel_values)
    generated_ids_fine_6 = model_finetune_6.generate(pixel_values)

    generated_text_base = processor.batch_decode(generated_ids_base, skip_special_tokens=True)[0]
    generated_text_fine_1= processor.batch_decode(generated_ids_fine_1, skip_special_tokens=True)[0]
    generated_text_fine_2= processor.batch_decode(generated_ids_fine_2, skip_special_tokens=True)[0]
    generated_text_fine_3= processor.batch_decode(generated_ids_fine_3, skip_special_tokens=True)[0]
    generated_text_fine_4= processor.batch_decode(generated_ids_fine_4, skip_special_tokens=True)[0]
    generated_text_fine_5= processor.batch_decode(generated_ids_fine_5, skip_special_tokens=True)[0]
    generated_text_fine_6= processor.batch_decode(generated_ids_fine_6, skip_special_tokens=True)[0]

    cer_fine_1.append(cer_metric(generated_text_fine_1.lower(), sample.text[idx].lower()).detach().numpy())
    cer_fine_2.append(cer_metric(generated_text_fine_2.lower(), sample.text[idx].lower()).detach().numpy())
    cer_fine_3.append(cer_metric(generated_text_fine_3.lower(), sample.text[idx].lower()).detach().numpy())
    cer_fine_4.append(cer_metric(generated_text_fine_4.lower(), sample.text[idx].lower()).detach().numpy())
    cer_fine_5.append(cer_metric(generated_text_fine_5.lower(), sample.text[idx].lower()).detach().numpy())
    cer_fine_6.append(cer_metric(generated_text_fine_6.lower(), sample.text[idx].lower()).detach().numpy())
    cer_base.append(cer_metric(generated_text_base.lower(), sample.text[idx].lower()).detach().numpy())
    
    inf_baseline.append(generated_text_base)
    inf_finetune_1.append(generated_text_fine_1)
    inf_finetune_2.append(generated_text_fine_2)
    inf_finetune_3.append(generated_text_fine_3)
    inf_finetune_4.append(generated_text_fine_4)
    inf_finetune_5.append(generated_text_fine_5)
    inf_finetune_6.append(generated_text_fine_6)

sample["Baseline"]=inf_baseline
sample["Finetune_1"]=inf_finetune_1 
sample["Finetune_2"]=inf_finetune_2  
sample["Finetune_3"]=inf_finetune_3  
sample["Finetune_4"]=inf_finetune_4  
sample["Finetune_5"]=inf_finetune_5   
sample["Finetune_6"]=inf_finetune_6  

sample["cer_1"]=cer_fine_1 
sample["cer_2"]=cer_fine_2  
sample["cer_3"]=cer_fine_3  
sample["cer_4"]=cer_fine_4  
sample["cer_5"]=cer_fine_5   
sample["cer_6"]=cer_fine_6
sample["cer_base"]=cer_base

sample.to_csv("/mnt/data1/Datasets/AlphaPen/" + "inference_results.csv")