wjbmattingly commited on
Commit
2c92317
·
verified ·
1 Parent(s): 7898fe8

Upload finetune.py

Browse files
Files changed (1) hide show
  1. finetune.py +136 -0
finetune.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from datasets import load_dataset
3
+ from sklearn.model_selection import train_test_split
4
+ import torch
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel, Seq2SeqTrainingArguments, Seq2SeqTrainer
7
+ from PIL import Image
8
+ import io
9
+ import numpy as np
10
+
11
+ device = 'mps:0'
12
+ # Load the dataset and filter for Latin entries
13
+ dataset = load_dataset("CATMuS/medieval", split='train')
14
+ # latin_dataset = dataset.filter(lambda example: example['language'] == 'Latin')
15
+ latin_dataset = dataset.filter(lambda example: example['language'] == 'Latin' and example['script_type'] == 'Caroline')
16
+
17
+ print(latin_dataset)
18
+ # Convert to pandas DataFrame for easier manipulation
19
+ df = pd.DataFrame(latin_dataset)
20
+
21
+ # Split the data into training and testing sets
22
+ train_df, test_df = train_test_split(df, test_size=0.2)
23
+ train_df.reset_index(drop=True, inplace=True)
24
+ test_df.reset_index(drop=True, inplace=True)
25
+
26
+ # Define the dataset class
27
+ class HandwrittenTextDataset(Dataset):
28
+ def __init__(self, df, processor, max_target_length=128):
29
+ self.df = df
30
+ self.processor = processor
31
+ self.max_target_length = max_target_length
32
+
33
+ def __len__(self):
34
+ return len(self.df)
35
+
36
+ def __getitem__(self, idx):
37
+ image_data = self.df['im'][idx]
38
+ text = self.df['text'][idx]
39
+
40
+ # Convert array to PIL image
41
+ image = Image.fromarray(np.array(image_data)).convert("RGB")
42
+
43
+ # Prepare image (i.e., resize + normalize)
44
+ pixel_values = self.processor(images=image, return_tensors="pt").pixel_values
45
+
46
+ # Add labels (input_ids) by encoding the text
47
+ labels = self.processor.tokenizer(text,
48
+ padding="max_length",
49
+ max_length=self.max_target_length,
50
+ truncation=True).input_ids
51
+ # Important: make sure that PAD tokens are ignored by the loss function
52
+ labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]
53
+
54
+ encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
55
+ return encoding
56
+ # Instantiate processor and dataset
57
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
58
+ train_dataset = HandwrittenTextDataset(df=train_df, processor=processor)
59
+ eval_dataset = HandwrittenTextDataset(df=test_df, processor=processor)
60
+
61
+ # Create corresponding dataloaders
62
+ train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
63
+ eval_dataloader = DataLoader(eval_dataset, batch_size=4)
64
+
65
+ # Load the model
66
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
67
+
68
+ # Set special tokens used for creating the decoder_input_ids from the labels
69
+ model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
70
+ model.config.pad_token_id = processor.tokenizer.pad_token_id
71
+
72
+ # Make sure vocab size is set correctly
73
+ model.config.vocab_size = model.config.decoder.vocab_size
74
+
75
+ # Set beam search parameters
76
+ model.config.eos_token_id = processor.tokenizer.sep_token_id
77
+ model.config.max_length = 64
78
+ model.config.early_stopping = True
79
+ model.config.no_repeat_ngram_size = 3
80
+ model.config.length_penalty = 2.0
81
+ model.config.num_beams = 4
82
+
83
+ # Training arguments
84
+ training_args = Seq2SeqTrainingArguments(
85
+ output_dir="./results",
86
+ per_device_train_batch_size=4,
87
+ num_train_epochs=10,
88
+ logging_steps=1000,
89
+ save_steps=1000,
90
+ evaluation_strategy="steps",
91
+ save_total_limit=2,
92
+ predict_with_generate=True,
93
+ fp16=False, # Set to True if using a compatible GPU
94
+ )
95
+
96
+ # Trainer
97
+ trainer = Seq2SeqTrainer(
98
+ model=model,
99
+ args=training_args,
100
+ train_dataset=train_dataset,
101
+ eval_dataset=eval_dataset,
102
+ )
103
+
104
+ # Train the model
105
+ trainer.train()
106
+
107
+ # After training, save both the model and the processor
108
+ model.save_pretrained("./finetuned_model")
109
+ processor.save_pretrained("./finetuned_model")
110
+
111
+ from datasets import load_metric
112
+
113
+ cer_metric = load_metric("cer")
114
+
115
+ def compute_cer(pred_ids, label_ids):
116
+ pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
117
+ label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
118
+ label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
119
+
120
+ cer = cer_metric.compute(predictions=pred_str, references=label_str)
121
+
122
+ return cer
123
+
124
+ # Evaluation
125
+ model.eval()
126
+ valid_cer = 0.0
127
+ with torch.no_grad():
128
+ for batch in eval_dataloader:
129
+ # Run batch generation
130
+ outputs = model.generate(batch["pixel_values"].to(device))
131
+ # Compute metrics
132
+ cer = compute_cer(pred_ids=outputs, label_ids=batch["labels"])
133
+ valid_cer += cer
134
+
135
+ print("Validation CER:", valid_cer / len(eval_dataloader))
136
+