katanaml commited on
Commit
ce1dd07
·
1 Parent(s): 7f8dcfd

Sparrow ML

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+
2
+ .DS_Store
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.7-slim
2
+
3
+ WORKDIR /code
4
+
5
+ COPY requirements-fastapi.txt ./
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements-fastapi.txt
8
+
9
+ RUN useradd -m -u 1000 user
10
+
11
+ USER user
12
+
13
+ ENV HOME=/home/user \
14
+ PATH=/home/user/.local/bin:$PATH
15
+
16
+ WORKDIR $HOME/app
17
+
18
+ COPY --chown=user . $HOME/app/
19
+
20
+ CMD ["uvicorn", "endpoints:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Sparrow Ml
3
  emoji: 🌍
4
  colorFrom: purple
5
  colorTo: indigo
 
1
  ---
2
+ title: Sparrow ML
3
  emoji: 🌍
4
  colorFrom: purple
5
  colorTo: indigo
__init__.py ADDED
File without changes
config.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseSettings
2
+ import os
3
+
4
+
5
+ class Settings(BaseSettings):
6
+ huggingface_key: str = os.environ.get("huggingface_key")
7
+ sparrow_key: str = os.environ.get("sparrow_key")
8
+ processor: str = "katanaml-org/invoices-donut-model-v1"
9
+ model: str = "katanaml-org/invoices-donut-model-v1"
10
+ inference_stats_file: str = "data/donut_inference_stats.json"
11
+ training_stats_file: str = "data/donut_training_stats.json"
12
+
13
+
14
+ settings = Settings()
data/donut_inference_stats.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [[14.571558952331543, 21, "invoice_10.jpg", "katanaml-org/invoices-donut-model-v1", "2023-04-13 21:45:30"]]
data/donut_training_stats.json ADDED
@@ -0,0 +1 @@
 
 
1
+ [["2023-04-09 23:24:24", 0.1, 1260, "invoices-donut-model-v1"], ["2023-04-10 23:24:24", 0.2, 1360, "invoices-donut-model-v1"], ["2023-04-11 23:24:24", 0.85, 1750, "invoices-donut-model-v1"]]
endpoints.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from routers import inference, training
4
+
5
+ app = FastAPI(openapi_url="/api/v1/sparrow-ml/openapi.json", docs_url="/api/v1/sparrow-ml/docs")
6
+
7
+ app.add_middleware(
8
+ CORSMiddleware,
9
+ allow_origins=["*"],
10
+ allow_methods=["*"],
11
+ allow_headers=["*"],
12
+ allow_credentials=True,
13
+ )
14
+
15
+ app.include_router(inference.router, prefix="/api-inference/v1/sparrow-ml", tags=["Inference"])
16
+ app.include_router(training.router, prefix="/api-training/v1/sparrow-ml", tags=["Training"])
17
+
18
+
19
+ @app.get("/")
20
+ async def root():
21
+ return {"message": "Sparrow ML API"}
requirements-fastapi.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ datasets
3
+ sentencepiece
4
+ tensorboard
5
+ pytorch-lightning
6
+ Pillow
7
+ donut-python
8
+ fastapi==0.95.0
9
+ uvicorn[standard]
10
+ python-multipart
routers/__init__.py ADDED
File without changes
routers/donut_inference.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import time
3
+ import torch
4
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
5
+ from config import settings
6
+ from huggingface_hub import login
7
+
8
+
9
+ login(settings.huggingface_key)
10
+
11
+ processor = DonutProcessor.from_pretrained(settings.processor)
12
+ model = VisionEncoderDecoderModel.from_pretrained(settings.model)
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ model.to(device)
16
+
17
+ def process_document_donut(image):
18
+ start_time = time.time()
19
+
20
+ # prepare encoder inputs
21
+ pixel_values = processor(image, return_tensors="pt").pixel_values
22
+
23
+ # prepare decoder inputs
24
+ task_prompt = "<s_cord-v2>"
25
+ decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
26
+
27
+ # generate answer
28
+ outputs = model.generate(
29
+ pixel_values.to(device),
30
+ decoder_input_ids=decoder_input_ids.to(device),
31
+ max_length=model.decoder.config.max_position_embeddings,
32
+ early_stopping=True,
33
+ pad_token_id=processor.tokenizer.pad_token_id,
34
+ eos_token_id=processor.tokenizer.eos_token_id,
35
+ use_cache=True,
36
+ num_beams=1,
37
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
38
+ return_dict_in_generate=True,
39
+ )
40
+
41
+ # postprocess
42
+ sequence = processor.batch_decode(outputs.sequences)[0]
43
+ sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
44
+ sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
45
+
46
+ end_time = time.time()
47
+ processing_time = end_time - start_time
48
+
49
+ return processor.token2json(sequence), processing_time
routers/inference.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, File, UploadFile, Form
2
+ from typing import Optional
3
+ from PIL import Image
4
+ import urllib.request
5
+ from io import BytesIO
6
+ from config import settings
7
+ import utils
8
+ import os
9
+ import json
10
+ from routers.donut_inference import process_document_donut
11
+
12
+
13
+ router = APIRouter()
14
+
15
+ def count_values(obj):
16
+ if isinstance(obj, dict):
17
+ count = 0
18
+ for value in obj.values():
19
+ count += count_values(value)
20
+ return count
21
+ elif isinstance(obj, list):
22
+ count = 0
23
+ for item in obj:
24
+ count += count_values(item)
25
+ return count
26
+ else:
27
+ return 1
28
+
29
+
30
+ @router.post("/inference")
31
+ async def run_inference(file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None),
32
+ model_in_use: str = Form('donut'), sparrow_key: str = Form(None)):
33
+
34
+ if sparrow_key != settings.sparrow_key:
35
+ return {"error": "Invalid Sparrow key."}
36
+
37
+ result = []
38
+ if file:
39
+ # Ensure the uploaded file is a JPG image
40
+ if file.content_type not in ["image/jpeg", "image/jpg"]:
41
+ return {"error": "Invalid file type. Only JPG images are allowed."}
42
+
43
+ image = Image.open(BytesIO(await file.read()))
44
+ processing_time = 0
45
+ if model_in_use == 'donut':
46
+ result, processing_time = process_document_donut(image)
47
+ utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file.filename, settings.model])
48
+ print(f"Processing time: {processing_time:.2f} seconds")
49
+ elif image_url:
50
+ # test image url: https://raw.githubusercontent.com/katanaml/sparrow/main/sparrow-data/docs/input/invoices/processed/images/invoice_10.jpg
51
+ with urllib.request.urlopen(image_url) as url:
52
+ image = Image.open(BytesIO(url.read()))
53
+
54
+ processing_time = 0
55
+ if model_in_use == 'donut':
56
+ result, processing_time = process_document_donut(image)
57
+ # parse file name from url
58
+ file_name = image_url.split("/")[-1]
59
+ utils.log_stats(settings.inference_stats_file, [processing_time, count_values(result), file_name, settings.model])
60
+ print(f"Processing time: {processing_time:.2f} seconds")
61
+ else:
62
+ result = {"info": "No input provided"}
63
+
64
+ return result
65
+
66
+
67
+ @router.get("/statistics")
68
+ async def get_statistics():
69
+ file_path = settings.inference_stats_file
70
+
71
+ # Check if the file exists, and read its content
72
+ if os.path.exists(file_path):
73
+ with open(file_path, 'r') as file:
74
+ try:
75
+ content = json.load(file)
76
+ except json.JSONDecodeError:
77
+ content = []
78
+ else:
79
+ content = []
80
+
81
+ return content
routers/training.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter
2
+ from config import settings
3
+ import os
4
+ import json
5
+
6
+
7
+ router = APIRouter()
8
+
9
+
10
+ @router.get("/training")
11
+ async def run_training():
12
+ return {"message": "Sparrow ML training started"}
13
+
14
+
15
+ @router.get("/statistics")
16
+ async def get_statistics():
17
+ file_path = settings.training_stats_file
18
+
19
+ # Check if the file exists, and read its content
20
+ if os.path.exists(file_path):
21
+ with open(file_path, 'r') as file:
22
+ try:
23
+ content = json.load(file)
24
+ except json.JSONDecodeError:
25
+ content = []
26
+ else:
27
+ content = []
28
+
29
+ return content
utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from datetime import datetime
4
+
5
+
6
+ def log_stats(file_path, new_data):
7
+ # Check if the file exists, and read its content
8
+ if os.path.exists(file_path):
9
+ with open(file_path, 'r') as file:
10
+ try:
11
+ content = json.load(file)
12
+ except json.JSONDecodeError:
13
+ content = []
14
+ else:
15
+ content = []
16
+
17
+ # Get the current date and time
18
+ now = datetime.now()
19
+ # Format the date and time as a string
20
+ date_time_string = now.strftime("%Y-%m-%d %H:%M:%S")
21
+ new_data.append(date_time_string)
22
+
23
+ # Append the new data to the content
24
+ content.append(new_data)
25
+
26
+ # Write the updated content back to the file
27
+ with open(file_path, 'w') as file:
28
+ json.dump(content, file)
29
+