ketanmore commited on
Commit
0838c13
·
verified ·
1 Parent(s): 21f72c4

Delete layout-fine-tune.ipynb

Browse files
Files changed (1) hide show
  1. layout-fine-tune.ipynb +0 -187
layout-fine-tune.ipynb DELETED
@@ -1,187 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "markdown",
5
- "metadata": {},
6
- "source": [
7
- "# Loading Packages"
8
- ]
9
- },
10
- {
11
- "cell_type": "code",
12
- "execution_count": null,
13
- "metadata": {},
14
- "outputs": [],
15
- "source": [
16
- "import os\n",
17
- "import torch\n",
18
- "import torch.nn as nn\n",
19
- "import torch.optim as optim\n",
20
- "from torch.utils.data import DataLoader\n",
21
- "# from transformers import SegformerConfig\n",
22
- "# from surya.model.detection.segformer import SegformerForRegressionMask\n",
23
- "from surya.input.processing import prepare_image_detection\n",
24
- "from surya.model.detection.segformer import load_processor , load_model\n",
25
- "from datasets import load_dataset\n",
26
- "from tqdm import tqdm\n",
27
- "from torch.utils.tensorboard import SummaryWriter\n",
28
- "import torch.nn.functional as F\n",
29
- "import numpy as np \n",
30
- "from surya.layout import parallel_get_regions"
31
- ]
32
- },
33
- {
34
- "cell_type": "markdown",
35
- "metadata": {},
36
- "source": [
37
- "# Initializing The Dataset And Model"
38
- ]
39
- },
40
- {
41
- "cell_type": "code",
42
- "execution_count": null,
43
- "metadata": {},
44
- "outputs": [],
45
- "source": [
46
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
47
- "dataset = load_dataset(\"vikp/publaynet_bench\", split=\"train[:100]\") # You can choose you own dataset\n",
48
- "model = load_model(\"vikp/surya_layout2\") "
49
- ]
50
- },
51
- {
52
- "cell_type": "markdown",
53
- "metadata": {},
54
- "source": [
55
- "# Helper Functions, Loss Function And Optimizer"
56
- ]
57
- },
58
- {
59
- "cell_type": "code",
60
- "execution_count": null,
61
- "metadata": {},
62
- "outputs": [],
63
- "source": [
64
- "\n",
65
- "optimizer = optim.Adam(model.parameters(), lr=0.00001)\n",
66
- "log_dir = \"logs\"\n",
67
- "checkpoint_dir = \"checkpoints\"\n",
68
- "os.makedirs(log_dir, exist_ok=True)\n",
69
- "os.makedirs(checkpoint_dir, exist_ok=True)\n",
70
- "writer = SummaryWriter(log_dir=log_dir)\n",
71
- "\n",
72
- "def logits_to_bboxes(logits,image) : # This function is useful for converting the logits(mask) into bounding boxes.(The model does not provide bounding boxes.)\n",
73
- " correct_shape = (300, 300) \n",
74
- " logits_temp = F.interpolate(logits, size=correct_shape, mode='bilinear', align_corners=False)\n",
75
- " logits_temp = logits_temp.cpu().detach().numpy().astype(np.float32)\n",
76
- "\n",
77
- " heatmap_count = logits_temp.shape[1]\n",
78
- " heatmaps = [logits_temp[i][k] for i in range(logits_temp.shape[0]) for k in range(heatmap_count)]\n",
79
- " regions = parallel_get_regions(heatmaps=heatmaps, orig_size=image.size, id2label=model.config.id2label)\n",
80
- "\n",
81
- " final_bboxes = []\n",
82
- " for i in regions.bboxes :\n",
83
- " final_bboxes.append(i.bbox)\n",
84
- " return final_bboxes\n",
85
- "\n",
86
- "\n",
87
- "def loss_function(): # This model does not have inbuild loss function, So we have to define it according to our dataset and the Requirements.\n",
88
- " pass"
89
- ]
90
- },
91
- {
92
- "cell_type": "markdown",
93
- "metadata": {},
94
- "source": [
95
- "# Fine-Tuning Process"
96
- ]
97
- },
98
- {
99
- "cell_type": "code",
100
- "execution_count": null,
101
- "metadata": {},
102
- "outputs": [],
103
- "source": [
104
- "num_epochs = 5\n",
105
- "for epoch in range(num_epochs):\n",
106
- " model.train()\n",
107
- " running_loss = 0.0\n",
108
- " avg_loss = 0.0\n",
109
- "\n",
110
- " for idx, item in enumerate(tqdm(dataset, desc=f\"Epoch {epoch + 1}/{num_epochs}\")):\n",
111
- "\n",
112
- " images = [prepare_image_detection(img=item['image'], processor=load_processor())]\n",
113
- " images = torch.stack(images, dim=0).to(model.dtype).to(model.device)\n",
114
- " \n",
115
- " optimizer.zero_grad()\n",
116
- " outputs = model(pixel_values=images)\n",
117
- "\n",
118
- " predicted_boxes = logits_to_bboxes(outputs.logits, item['image'])\n",
119
- " target_boxes = item['bboxes']\n",
120
- "\n",
121
- " loss = loss_function(predicted_boxes,target_boxes)\n",
122
- "\n",
123
- " loss.backward()\n",
124
- " optimizer.step()\n",
125
- " running_loss += loss.item()\n",
126
- "\n",
127
- " avg_loss = 0.9 * avg_loss + 0.1 * loss.item() if idx > 0 else loss.item()\n",
128
- "\n",
129
- " avg_loss = running_loss / len(dataset)\n",
130
- " writer.add_scalar('Training Loss', avg_loss, epoch + 1)\n",
131
- " print(f\"Average Loss for Epoch {epoch + 1}: {avg_loss:.4f}\")\n",
132
- "\n",
133
- " torch.save(model.state_dict(), os.path.join(checkpoint_dir, f\"model_epoch_{epoch + 1}.pth\"))"
134
- ]
135
- },
136
- {
137
- "cell_type": "markdown",
138
- "metadata": {},
139
- "source": [
140
- "# Loading The Checkpoint "
141
- ]
142
- },
143
- {
144
- "cell_type": "code",
145
- "execution_count": null,
146
- "metadata": {},
147
- "outputs": [],
148
- "source": [
149
- "checkpoint_path = 'checkpoints/model_epoch_350.pth' \n",
150
- "state_dict = torch.load(checkpoint_path,weights_only=True)\n",
151
- "\n",
152
- "model.load_state_dict(state_dict)"
153
- ]
154
- },
155
- {
156
- "cell_type": "code",
157
- "execution_count": null,
158
- "metadata": {},
159
- "outputs": [],
160
- "source": [
161
- "model.to('cpu')\n",
162
- "model.save_pretrained(\"fine-tuned-surya-model-layout\")"
163
- ]
164
- }
165
- ],
166
- "metadata": {
167
- "kernelspec": {
168
- "display_name": "Python 3",
169
- "language": "python",
170
- "name": "python3"
171
- },
172
- "language_info": {
173
- "codemirror_mode": {
174
- "name": "ipython",
175
- "version": 3
176
- },
177
- "file_extension": ".py",
178
- "mimetype": "text/x-python",
179
- "name": "python",
180
- "nbconvert_exporter": "python",
181
- "pygments_lexer": "ipython3",
182
- "version": "3.10.14"
183
- }
184
- },
185
- "nbformat": 4,
186
- "nbformat_minor": 2
187
- }