woods-today commited on
Commit
370c710
1 Parent(s): 5233da8

Working on it

Browse files
__pycache__/utils.cpython-311.pyc ADDED
Binary file (1.54 kB). View file
 
endpoints.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from routers import inference, training
4
  from huggingface_hub import login
5
  from config import settings
6
  import torch
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from routers import training
4
  from huggingface_hub import login
5
  from config import settings
6
  import torch
requirements-fastapi.txt CHANGED
@@ -12,4 +12,5 @@ diffusers==0.10.2
12
  torch
13
  scipy
14
  ftfy
15
- accelerate
 
 
12
  torch
13
  scipy
14
  ftfy
15
+ accelerate
16
+ uuid
routers/__pycache__/training.cpython-311.pyc ADDED
Binary file (4.24 kB). View file
 
routers/donut_evaluate.py DELETED
@@ -1,90 +0,0 @@
1
- from transformers import DonutProcessor, VisionEncoderDecoderModel
2
- import locale
3
-
4
- import re
5
- import json
6
- import torch
7
- from tqdm.auto import tqdm
8
- import numpy as np
9
- from donut import JSONParseEvaluator
10
- from datasets import load_dataset
11
- from functools import lru_cache
12
- import os
13
- import time
14
- from config import settings
15
-
16
- locale.getpreferredencoding = lambda: "UTF-8"
17
-
18
-
19
- @lru_cache(maxsize=1)
20
- def prepare_model():
21
- processor = DonutProcessor.from_pretrained(settings.processor)
22
- model = VisionEncoderDecoderModel.from_pretrained(settings.model)
23
-
24
- device = "cuda" if torch.cuda.is_available() else "cpu"
25
-
26
- model.eval()
27
- model.to(device)
28
-
29
- dataset = load_dataset(settings.dataset, split="test")
30
-
31
- return processor, model, device, dataset
32
-
33
-
34
- def run_evaluate_donut():
35
- worker_pid = os.getpid()
36
- print(f"Handling evaluation request with worker PID: {worker_pid}")
37
-
38
- start_time = time.time()
39
-
40
- output_list = []
41
- accs = []
42
-
43
- processor, model, device, dataset = prepare_model()
44
-
45
- for idx, sample in tqdm(enumerate(dataset), total=len(dataset)):
46
- # prepare encoder inputs
47
- pixel_values = processor(sample["image"].convert("RGB"), return_tensors="pt").pixel_values
48
- pixel_values = pixel_values.to(device)
49
- # prepare decoder inputs
50
- task_prompt = "<s_cord-v2>"
51
- decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
52
- decoder_input_ids = decoder_input_ids.to(device)
53
-
54
- # autoregressively generate sequence
55
- outputs = model.generate(
56
- pixel_values,
57
- decoder_input_ids=decoder_input_ids,
58
- max_length=model.decoder.config.max_position_embeddings,
59
- early_stopping=True,
60
- pad_token_id=processor.tokenizer.pad_token_id,
61
- eos_token_id=processor.tokenizer.eos_token_id,
62
- use_cache=True,
63
- num_beams=1,
64
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
65
- return_dict_in_generate=True,
66
- )
67
-
68
- # turn into JSON
69
- seq = processor.batch_decode(outputs.sequences)[0]
70
- seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
71
- seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
72
- seq = processor.token2json(seq)
73
-
74
- ground_truth = json.loads(sample["ground_truth"])
75
- ground_truth = ground_truth["gt_parse"]
76
- evaluator = JSONParseEvaluator()
77
- score = evaluator.cal_acc(seq, ground_truth)
78
-
79
- accs.append(score)
80
- output_list.append(seq)
81
-
82
- end_time = time.time()
83
- processing_time = end_time - start_time
84
-
85
- scores = {"accuracies": accs, "mean_accuracy": np.mean(accs)}
86
- print(scores, f"length : {len(accs)}")
87
- print("Mean accuracy:", np.mean(accs))
88
- print(f"Evaluation done, worker PID: {worker_pid}")
89
-
90
- return scores, np.mean(accs), processing_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
routers/donut_inference.py DELETED
@@ -1,60 +0,0 @@
1
- import re
2
- import time
3
- import torch
4
- from transformers import DonutProcessor, VisionEncoderDecoderModel
5
- from config import settings
6
- from functools import lru_cache
7
- import os
8
-
9
-
10
- @lru_cache(maxsize=1)
11
- def load_model():
12
- processor = DonutProcessor.from_pretrained(settings.processor)
13
- model = VisionEncoderDecoderModel.from_pretrained(settings.model)
14
-
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- model.to(device)
17
-
18
- return processor, model, device
19
-
20
-
21
- def process_document_donut(image):
22
- worker_pid = os.getpid()
23
- print(f"Handling inference request with worker PID: {worker_pid}")
24
-
25
- start_time = time.time()
26
-
27
- processor, model, device = load_model()
28
-
29
- # prepare encoder inputs
30
- pixel_values = processor(image, return_tensors="pt").pixel_values
31
-
32
- # prepare decoder inputs
33
- task_prompt = "<s_cord-v2>"
34
- decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
35
-
36
- # generate answer
37
- outputs = model.generate(
38
- pixel_values.to(device),
39
- decoder_input_ids=decoder_input_ids.to(device),
40
- max_length=model.decoder.config.max_position_embeddings,
41
- early_stopping=True,
42
- pad_token_id=processor.tokenizer.pad_token_id,
43
- eos_token_id=processor.tokenizer.eos_token_id,
44
- use_cache=True,
45
- num_beams=1,
46
- bad_words_ids=[[processor.tokenizer.unk_token_id]],
47
- return_dict_in_generate=True,
48
- )
49
-
50
- # postprocess
51
- sequence = processor.batch_decode(outputs.sequences)[0]
52
- sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
53
- sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
54
-
55
- end_time = time.time()
56
- processing_time = end_time - start_time
57
-
58
- print(f"Inference done, worker PID: {worker_pid}")
59
-
60
- return processor.token2json(sequence), processing_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
routers/donut_training.py DELETED
@@ -1,393 +0,0 @@
1
- # !pip install -q git+https://github.com/huggingface/transformers.git datasets sentencepiece
2
- # !pip install -q pytorch-lightning==1.9.5 wandb
3
-
4
- from config import settings
5
- from datasets import load_dataset
6
- from transformers import VisionEncoderDecoderConfig
7
- from transformers import DonutProcessor, VisionEncoderDecoderModel
8
-
9
- import json
10
- import random
11
- from typing import Any, List, Tuple
12
-
13
- import torch
14
- from torch.utils.data import Dataset
15
-
16
- from torch.utils.data import DataLoader
17
-
18
- import re
19
- from nltk import edit_distance
20
- import numpy as np
21
- import os
22
- import time
23
-
24
- import pytorch_lightning as pl
25
- from functools import lru_cache
26
-
27
- from pytorch_lightning.loggers import WandbLogger
28
- from pytorch_lightning.callbacks import Callback
29
- from config import settings
30
-
31
- added_tokens = []
32
-
33
- dataset_name = settings.dataset
34
- base_config_name = settings.base_config
35
- base_processor_name = settings.base_processor
36
- base_model_name = settings.base_model
37
- model_name = settings.model
38
-
39
- @lru_cache(maxsize=1)
40
- def prepare_job():
41
- print("Preparing job...")
42
-
43
- dataset = load_dataset(dataset_name)
44
-
45
- max_length = 768
46
- image_size = [1280, 960]
47
-
48
- # update image_size of the encoder
49
- # during pre-training, a larger image size was used
50
- config = VisionEncoderDecoderConfig.from_pretrained(base_config_name)
51
- config.encoder.image_size = image_size # (height, width)
52
- # update max_length of the decoder (for generation)
53
- config.decoder.max_length = max_length
54
- # TODO we should actually update max_position_embeddings and interpolate the pre-trained ones:
55
- # https://github.com/clovaai/donut/blob/0acc65a85d140852b8d9928565f0f6b2d98dc088/donut/model.py#L602
56
-
57
- processor = DonutProcessor.from_pretrained(base_processor_name)
58
- model = VisionEncoderDecoderModel.from_pretrained(base_model_name, config=config)
59
-
60
- return model, processor, dataset, config, image_size, max_length
61
-
62
-
63
- class DonutDataset(Dataset):
64
- """
65
- DonutDataset which is saved in huggingface datasets format. (see details in https://huggingface.co/docs/datasets)
66
- Each row, consists of image path(png/jpg/jpeg) and gt data (json/jsonl/txt),
67
- and it will be converted into input_tensor(vectorized image) and input_ids(tokenized string).
68
- Args:
69
- dataset_name_or_path: name of dataset (available at huggingface.co/datasets) or the path containing image files and metadata.jsonl
70
- max_length: the max number of tokens for the target sequences
71
- split: whether to load "train", "validation" or "test" split
72
- ignore_id: ignore_index for torch.nn.CrossEntropyLoss
73
- task_start_token: the special token to be fed to the decoder to conduct the target task
74
- prompt_end_token: the special token at the end of the sequences
75
- sort_json_key: whether or not to sort the JSON keys
76
- """
77
-
78
- def __init__(
79
- self,
80
- dataset_name_or_path: str,
81
- max_length: int,
82
- split: str = "train",
83
- ignore_id: int = -100,
84
- task_start_token: str = "<s>",
85
- prompt_end_token: str = None,
86
- sort_json_key: bool = True,
87
- ):
88
- super().__init__()
89
-
90
- model, processor, dataset, config, image_size, p1 = prepare_job()
91
-
92
- self.max_length = max_length
93
- self.split = split
94
- self.ignore_id = ignore_id
95
- self.task_start_token = task_start_token
96
- self.prompt_end_token = prompt_end_token if prompt_end_token else task_start_token
97
- self.sort_json_key = sort_json_key
98
-
99
- self.dataset = load_dataset(dataset_name_or_path, split=self.split)
100
- self.dataset_length = len(self.dataset)
101
-
102
- self.gt_token_sequences = []
103
- for sample in self.dataset:
104
- ground_truth = json.loads(sample["ground_truth"])
105
- if "gt_parses" in ground_truth: # when multiple ground truths are available, e.g., docvqa
106
- assert isinstance(ground_truth["gt_parses"], list)
107
- gt_jsons = ground_truth["gt_parses"]
108
- else:
109
- assert "gt_parse" in ground_truth and isinstance(ground_truth["gt_parse"], dict)
110
- gt_jsons = [ground_truth["gt_parse"]]
111
-
112
- self.gt_token_sequences.append(
113
- [
114
- self.json2token(
115
- gt_json,
116
- update_special_tokens_for_json_key=self.split == "train",
117
- sort_json_key=self.sort_json_key,
118
- )
119
- + processor.tokenizer.eos_token
120
- for gt_json in gt_jsons # load json from list of json
121
- ]
122
- )
123
-
124
- self.add_tokens([self.task_start_token, self.prompt_end_token])
125
- self.prompt_end_token_id = processor.tokenizer.convert_tokens_to_ids(self.prompt_end_token)
126
-
127
- def json2token(self, obj: Any, update_special_tokens_for_json_key: bool = True, sort_json_key: bool = True):
128
- """
129
- Convert an ordered JSON object into a token sequence
130
- """
131
- if type(obj) == dict:
132
- if len(obj) == 1 and "text_sequence" in obj:
133
- return obj["text_sequence"]
134
- else:
135
- output = ""
136
- if sort_json_key:
137
- keys = sorted(obj.keys(), reverse=True)
138
- else:
139
- keys = obj.keys()
140
- for k in keys:
141
- if update_special_tokens_for_json_key:
142
- self.add_tokens([fr"<s_{k}>", fr"</s_{k}>"])
143
- output += (
144
- fr"<s_{k}>"
145
- + self.json2token(obj[k], update_special_tokens_for_json_key, sort_json_key)
146
- + fr"</s_{k}>"
147
- )
148
- return output
149
- elif type(obj) == list:
150
- return r"<sep/>".join(
151
- [self.json2token(item, update_special_tokens_for_json_key, sort_json_key) for item in obj]
152
- )
153
- else:
154
- obj = str(obj)
155
- if f"<{obj}/>" in added_tokens:
156
- obj = f"<{obj}/>" # for categorical special tokens
157
- return obj
158
-
159
- def add_tokens(self, list_of_tokens: List[str]):
160
- """
161
- Add special tokens to tokenizer and resize the token embeddings of the decoder
162
- """
163
- model, processor, dataset, config, image_size, p1 = prepare_job()
164
-
165
- newly_added_num = processor.tokenizer.add_tokens(list_of_tokens)
166
- if newly_added_num > 0:
167
- model.decoder.resize_token_embeddings(len(processor.tokenizer))
168
- added_tokens.extend(list_of_tokens)
169
-
170
- def __len__(self) -> int:
171
- return self.dataset_length
172
-
173
- def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
174
- """
175
- Load image from image_path of given dataset_path and convert into input_tensor and labels
176
- Convert gt data into input_ids (tokenized string)
177
- Returns:
178
- input_tensor : preprocessed image
179
- input_ids : tokenized gt_data
180
- labels : masked labels (model doesn't need to predict prompt and pad token)
181
- """
182
-
183
- model, processor, dataset, config, image_size, p1 = prepare_job()
184
-
185
- sample = self.dataset[idx]
186
-
187
- # inputs
188
- pixel_values = processor(sample["image"], random_padding=self.split == "train",
189
- return_tensors="pt").pixel_values
190
- pixel_values = pixel_values.squeeze()
191
-
192
- # targets
193
- target_sequence = random.choice(self.gt_token_sequences[idx]) # can be more than one, e.g., DocVQA Task 1
194
- input_ids = processor.tokenizer(
195
- target_sequence,
196
- add_special_tokens=False,
197
- max_length=self.max_length,
198
- padding="max_length",
199
- truncation=True,
200
- return_tensors="pt",
201
- )["input_ids"].squeeze(0)
202
-
203
- labels = input_ids.clone()
204
- labels[labels == processor.tokenizer.pad_token_id] = self.ignore_id # model doesn't need to predict pad token
205
- # labels[: torch.nonzero(labels == self.prompt_end_token_id).sum() + 1] = self.ignore_id # model doesn't need to predict prompt (for VQA)
206
- return pixel_values, labels, target_sequence
207
-
208
-
209
- def build_data_loaders():
210
- print("Building data loaders...")
211
-
212
- model, processor, dataset, config, image_size, max_length = prepare_job()
213
-
214
- # we update some settings which differ from pretraining; namely the size of the images + no rotation required
215
- # source: https://github.com/clovaai/donut/blob/master/config/train_cord.yaml
216
- processor.feature_extractor.size = image_size[::-1] # should be (width, height)
217
- processor.feature_extractor.do_align_long_axis = False
218
-
219
- train_dataset = DonutDataset(dataset_name, max_length=max_length,
220
- split="train", task_start_token="<s_cord-v2>", prompt_end_token="<s_cord-v2>",
221
- sort_json_key=False, # cord dataset is preprocessed, so no need for this
222
- )
223
-
224
- val_dataset = DonutDataset(dataset_name, max_length=max_length,
225
- split="validation", task_start_token="<s_cord-v2>", prompt_end_token="<s_cord-v2>",
226
- sort_json_key=False, # cord dataset is preprocessed, so no need for this
227
- )
228
-
229
- model.config.pad_token_id = processor.tokenizer.pad_token_id
230
- model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s_cord-v2>'])[0]
231
-
232
- # feel free to increase the batch size if you have a lot of memory
233
- # I'm fine-tuning on Colab and given the large image size, batch size > 1 is not feasible
234
- # Set num_workers=4
235
- train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4)
236
- val_dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)
237
-
238
- return train_dataloader, val_dataloader, max_length
239
-
240
-
241
- class DonutModelPLModule(pl.LightningModule):
242
- def __init__(self, config, processor, model):
243
- super().__init__()
244
- self.config = config
245
- self.processor = processor
246
- self.model = model
247
-
248
- self.train_dataloader, self.val_dataloader, self.max_length = build_data_loaders()
249
-
250
- def training_step(self, batch, batch_idx):
251
- pixel_values, labels, _ = batch
252
-
253
- outputs = self.model(pixel_values, labels=labels)
254
- loss = outputs.loss
255
- self.log_dict({"train_loss": loss}, sync_dist=True)
256
- return loss
257
-
258
- def validation_step(self, batch, batch_idx, dataset_idx=0):
259
- pixel_values, labels, answers = batch
260
- batch_size = pixel_values.shape[0]
261
- # we feed the prompt to the model
262
- decoder_input_ids = torch.full((batch_size, 1), self.model.config.decoder_start_token_id, device=self.device)
263
-
264
- outputs = self.model.generate(pixel_values,
265
- decoder_input_ids=decoder_input_ids,
266
- max_length=self.max_length,
267
- early_stopping=True,
268
- pad_token_id=self.processor.tokenizer.pad_token_id,
269
- eos_token_id=self.processor.tokenizer.eos_token_id,
270
- use_cache=True,
271
- num_beams=1,
272
- bad_words_ids=[[self.processor.tokenizer.unk_token_id]],
273
- return_dict_in_generate=True, )
274
-
275
- predictions = []
276
- for seq in self.processor.tokenizer.batch_decode(outputs.sequences):
277
- seq = seq.replace(self.processor.tokenizer.eos_token, "").replace(self.processor.tokenizer.pad_token, "")
278
- seq = re.sub(r"<.*?>", "", seq, count=1).strip() # remove first task start token
279
- predictions.append(seq)
280
-
281
- scores = list()
282
- for pred, answer in zip(predictions, answers):
283
- pred = re.sub(r"(?:(?<=>) | (?=</s_))", "", pred)
284
- # NOT NEEDED ANYMORE
285
- # answer = re.sub(r"<.*?>", "", answer, count=1)
286
- answer = answer.replace(self.processor.tokenizer.eos_token, "")
287
- scores.append(edit_distance(pred, answer) / max(len(pred), len(answer)))
288
-
289
- if self.config.get("verbose", False) and len(scores) == 1:
290
- print(f"Prediction: {pred}")
291
- print(f" Answer: {answer}")
292
- print(f" Normed ED: {scores[0]}")
293
-
294
- return scores
295
-
296
- def validation_epoch_end(self, validation_step_outputs):
297
- # I set this to 1 manually
298
- # (previously set to len(self.config.dataset_name_or_paths))
299
- num_of_loaders = 1
300
- if num_of_loaders == 1:
301
- validation_step_outputs = [validation_step_outputs]
302
- assert len(validation_step_outputs) == num_of_loaders
303
- cnt = [0] * num_of_loaders
304
- total_metric = [0] * num_of_loaders
305
- val_metric = [0] * num_of_loaders
306
- for i, results in enumerate(validation_step_outputs):
307
- for scores in results:
308
- cnt[i] += len(scores)
309
- total_metric[i] += np.sum(scores)
310
- val_metric[i] = total_metric[i] / cnt[i]
311
- val_metric_name = f"val_metric_{i}th_dataset"
312
- self.log_dict({val_metric_name: val_metric[i]}, sync_dist=True)
313
- self.log_dict({"val_metric": np.sum(total_metric) / np.sum(cnt)}, sync_dist=True)
314
-
315
- def configure_optimizers(self):
316
- # TODO add scheduler
317
- optimizer = torch.optim.Adam(self.parameters(), lr=self.config.get("lr"))
318
-
319
- return optimizer
320
-
321
- def train_dataloader(self):
322
- return self.train_dataloader
323
-
324
- def val_dataloader(self):
325
- return self.val_dataloader
326
-
327
-
328
- class PushToHubCallback(Callback):
329
- def on_train_epoch_end(self, trainer, pl_module):
330
- print(f"Pushing model to the hub, epoch {trainer.current_epoch}")
331
- pl_module.model.push_to_hub(model_name,
332
- commit_message=f"Training in progress, epoch {trainer.current_epoch}")
333
-
334
- def on_train_end(self, trainer, pl_module):
335
- print(f"Pushing model to the hub after training")
336
- pl_module.processor.push_to_hub(model_name,
337
- commit_message=f"Training done")
338
- pl_module.model.push_to_hub(model_name,
339
- commit_message=f"Training done")
340
-
341
-
342
- def run_training_donut(max_epochs_param, val_check_interval_param, warmup_steps_param):
343
- worker_pid = os.getpid()
344
- print(f"Handling training request with worker PID: {worker_pid}")
345
-
346
- start_time = time.time()
347
-
348
- # Set epochs = 30
349
- # Set num_training_samples_per_epoch = training set size
350
- # Set val_check_interval = 0.4
351
- # Set warmup_steps: 425 / 8 = 54, 54 * 10 = 540, 540 * 0.15 = 81
352
- config_params = {"max_epochs": max_epochs_param,
353
- "val_check_interval": val_check_interval_param, # how many times we want to validate during an epoch
354
- "check_val_every_n_epoch": 1,
355
- "gradient_clip_val": 1.0,
356
- "num_training_samples_per_epoch": 425,
357
- "lr": 3e-5,
358
- "train_batch_sizes": [8],
359
- "val_batch_sizes": [1],
360
- # "seed":2022,
361
- "num_nodes": 1,
362
- "warmup_steps": warmup_steps_param, # 425 / 8 = 54, 54 * 10 = 540, 540 * 0.15 = 81
363
- "result_path": "./result",
364
- "verbose": False,
365
- }
366
-
367
- model, processor, dataset, config, image_size, p1 = prepare_job()
368
-
369
- model_module = DonutModelPLModule(config, processor, model)
370
-
371
- # wandb_logger = WandbLogger(project="sparrow", name="invoices-donut-v5")
372
-
373
- # trainer = pl.Trainer(
374
- # accelerator="gpu",
375
- # devices=1,
376
- # max_epochs=config_params.get("max_epochs"),
377
- # val_check_interval=config_params.get("val_check_interval"),
378
- # check_val_every_n_epoch=config_params.get("check_val_every_n_epoch"),
379
- # gradient_clip_val=config_params.get("gradient_clip_val"),
380
- # precision=16, # we'll use mixed precision
381
- # num_sanity_val_steps=0,
382
- # # logger=wandb_logger,
383
- # callbacks=[PushToHubCallback()],
384
- # )
385
-
386
- # trainer.fit(model_module)
387
-
388
- end_time = time.time()
389
- processing_time = end_time - start_time
390
-
391
- print(f"Training done, worker PID: {worker_pid}")
392
-
393
- return processing_time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
routers/inference.py DELETED
@@ -1,81 +0,0 @@
1
- from fastapi import APIRouter, File, UploadFile, Form
2
- from typing import Optional
3
- from PIL import Image
4
- import urllib.request
5
- from io import BytesIO
6
- from config import settings
7
- import utils
8
- import os
9
- import json
10
- from routers.donut_inference import process_document_donut
11
-
12
-
13
- router = APIRouter()
14
-
15
- def count_values(obj):
16
- if isinstance(obj, dict):
17
- count = 0
18
- for value in obj.values():
19
- count += count_values(value)
20
- return count
21
- elif isinstance(obj, list):
22
- count = 0
23
- for item in obj:
24
- count += count_values(item)
25
- return count
26
- else:
27
- return 1
28
-
29
-
30
- @router.post("/inference")
31
- async def run_inference(file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None),
32
- model_in_use: str = Form('donut'), sparrow_key: str = Form(None)):
33
-
34
- if sparrow_key != settings.sparrow_key:
35
- return {"error": "Invalid Sparrow key."}
36
-
37
- result = []
38
- if file:
39
- # Ensure the uploaded file is a JPG image
40
- if file.content_type not in ["image/jpeg", "image/jpg"]:
41
- return {"error": "Invalid file type. Only JPG images are allowed."}
42
-
43
- image = Image.open(BytesIO(await file.read()))
44
- processing_time = 0
45
- if model_in_use == 'donut':
46
- result, processing_time = process_document_donut(image)
47
- utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
48
- print(f"Processing time: {processing_time:.2f} seconds")
49
- elif image_url:
50
- # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
51
- with urllib.request.urlopen(image_url) as url:
52
- image = Image.open(BytesIO(url.read()))
53
-
54
- processing_time = 0
55
- if model_in_use == 'donut':
56
- result, processing_time = process_document_donut(image)
57
- # parse file name from url
58
- file_name = image_url.split("/")[-1]
59
- utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
60
- print(f"Processing time inference: {processing_time:.2f} seconds")
61
- else:
62
- result = {"info": "No input provided"}
63
-
64
- return result
65
-
66
-
67
- @router.get("/statistics")
68
- async def get_statistics():
69
- file_path = settings.inference_stats_file
70
-
71
- # Check if the file exists, and read its content
72
- if os.path.exists(file_path):
73
- with open(file_path, 'r') as file:
74
- try:
75
- content = json.load(file)
76
- except json.JSONDecodeError:
77
- content = []
78
- else:
79
- content = []
80
-
81
- return content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
routers/training.py CHANGED
@@ -2,8 +2,6 @@ from fastapi import APIRouter, Form, BackgroundTasks
2
  from config import settings
3
  import os
4
  import json
5
- from routers.donut_evaluate import run_evaluate_donut
6
- from routers.donut_training import run_training_donut
7
  import utils
8
  import torch
9
  import requests
@@ -11,6 +9,7 @@ from PIL import Image
11
  from io import BytesIO
12
  from pydantic import BaseModel
13
  import base64
 
14
 
15
  from diffusers import StableDiffusionImg2ImgPipeline
16
 
@@ -26,25 +25,26 @@ class ActionBody(BaseModel):
26
  prompt: str
27
  strength: float
28
  guidance_scale: float
 
 
29
 
30
  @router.post("/perform-action")
31
  async def performAction(actionBody: ActionBody):
32
 
33
  response = requests.get(actionBody.url)
34
  init_image = Image.open(BytesIO(response.content)).convert("RGB")
35
- init_image = init_image.resize((768, 512))
36
  images = pipe(prompt=actionBody.prompt, image=init_image, strength=actionBody.strength, guidance_scale=actionBody.guidance_scale).images
37
  print(images)
38
- print(images[0])
39
-
40
  buffered = BytesIO()
41
  images[0].save(buffered, format="JPEG")
42
  img_str = base64.b64encode(buffered.getvalue())
43
-
44
- # images[0].save("fantasy_landscape.png")
45
 
46
  return {
47
- "image": img_str
 
48
  }
49
 
50
 
@@ -67,5 +67,5 @@ async def hifunction():
67
  # images[0].save("fantasy_landscape.png")
68
 
69
  return {
70
- "image": img_str
71
  }
 
2
  from config import settings
3
  import os
4
  import json
 
 
5
  import utils
6
  import torch
7
  import requests
 
9
  from io import BytesIO
10
  from pydantic import BaseModel
11
  import base64
12
+ import uuid
13
 
14
  from diffusers import StableDiffusionImg2ImgPipeline
15
 
 
25
  prompt: str
26
  strength: float
27
  guidance_scale: float
28
+ resizeW: int
29
+ resizeH: int
30
 
31
  @router.post("/perform-action")
32
  async def performAction(actionBody: ActionBody):
33
 
34
  response = requests.get(actionBody.url)
35
  init_image = Image.open(BytesIO(response.content)).convert("RGB")
36
+ init_image = init_image.resize((actionBody.resizeW, actionBody.resizeH))
37
  images = pipe(prompt=actionBody.prompt, image=init_image, strength=actionBody.strength, guidance_scale=actionBody.guidance_scale).images
38
  print(images)
 
 
39
  buffered = BytesIO()
40
  images[0].save(buffered, format="JPEG")
41
  img_str = base64.b64encode(buffered.getvalue())
42
+ imgUUID = str(uuid.uuid4())
43
+ images[0].save(imageUUID+".png")
44
 
45
  return {
46
+ "imageName" : imageUUID+".png",
47
+ "image": "data:image/jpeg;base64,"+img_str
48
  }
49
 
50
 
 
67
  # images[0].save("fantasy_landscape.png")
68
 
69
  return {
70
+ "image": "data:image/jpeg;base64,"+img_str
71
  }