Spaces:
Configuration error
Configuration error
Upload 24 files
Browse files- .README.md.swp +0 -0
- .idea/.gitignore +8 -0
- README.md +159 -0
- requirements.txt +21 -0
- setup.py +37 -0
- sparrow_parse/__init__.py +1 -0
- sparrow_parse/__main__.py +6 -0
- sparrow_parse/extractors/__init__.py +0 -0
- sparrow_parse/extractors/vllm_extractor.py +229 -0
- sparrow_parse/helpers/__init__.py +0 -0
- sparrow_parse/helpers/image_optimizer.py +59 -0
- sparrow_parse/helpers/pdf_optimizer.py +79 -0
- sparrow_parse/images/graph.png +0 -0
- sparrow_parse/processors/__init__.py +0 -0
- sparrow_parse/processors/table_structure_processor.py +275 -0
- sparrow_parse/text_extraction.py +30 -0
- sparrow_parse/vllm/__init__.py +0 -0
- sparrow_parse/vllm/huggingface_inference.py +60 -0
- sparrow_parse/vllm/inference_base.py +30 -0
- sparrow_parse/vllm/inference_factory.py +25 -0
- sparrow_parse/vllm/infra/qwen2_vl_7b/app.py +155 -0
- sparrow_parse/vllm/infra/qwen2_vl_7b/requirements.txt +9 -0
- sparrow_parse/vllm/local_gpu_inference.py +16 -0
- sparrow_parse/vllm/mlx_inference.py +140 -0
.README.md.swp
ADDED
Binary file (1.02 kB). View file
|
|
.idea/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Editor-based HTTP Client requests
|
5 |
+
/httpRequests/
|
6 |
+
# Datasource local storage ignored files
|
7 |
+
/dataSources/
|
8 |
+
/dataSources.local.xml
|
README.md
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Sparrow Parse
|
2 |
+
|
3 |
+
## Description
|
4 |
+
|
5 |
+
This module implements Sparrow Parse [library](https://pypi.org/project/sparrow-parse/) library with helpful methods for data pre-processing, parsing and extracting information. Library relies on Visual LLM functionality, Table Transformers and is part of Sparrow. Check main [README](https://github.com/katanaml/sparrow)
|
6 |
+
|
7 |
+
## Install
|
8 |
+
|
9 |
+
```
|
10 |
+
pip install sparrow-parse
|
11 |
+
```
|
12 |
+
|
13 |
+
## Parsing and extraction
|
14 |
+
|
15 |
+
### Sparrow Parse VL (vision-language model) extractor with local MLX or Hugging Face Cloud GPU infra
|
16 |
+
|
17 |
+
```
|
18 |
+
# run locally: python -m sparrow_parse.extractors.vllm_extractor
|
19 |
+
|
20 |
+
from sparrow_parse.vllm.inference_factory import InferenceFactory
|
21 |
+
from sparrow_parse.extractors.vllm_extractor import VLLMExtractor
|
22 |
+
|
23 |
+
extractor = VLLMExtractor()
|
24 |
+
|
25 |
+
config = {
|
26 |
+
"method": "mlx", # Could be 'huggingface', 'mlx' or 'local_gpu'
|
27 |
+
"model_name": "mlx-community/Qwen2-VL-72B-Instruct-4bit",
|
28 |
+
}
|
29 |
+
|
30 |
+
# Use the factory to get the correct instance
|
31 |
+
factory = InferenceFactory(config)
|
32 |
+
model_inference_instance = factory.get_inference_instance()
|
33 |
+
|
34 |
+
input_data = [
|
35 |
+
{
|
36 |
+
"file_path": "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/bonds_table.jpg",
|
37 |
+
"text_input": "retrieve all data. return response in JSON format"
|
38 |
+
}
|
39 |
+
]
|
40 |
+
|
41 |
+
# Now you can run inference without knowing which implementation is used
|
42 |
+
results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, tables_only=False,
|
43 |
+
generic_query=False,
|
44 |
+
crop_size=80,
|
45 |
+
debug_dir=None,
|
46 |
+
debug=True,
|
47 |
+
mode=None)
|
48 |
+
|
49 |
+
for i, result in enumerate(results_array):
|
50 |
+
print(f"Result for page {i + 1}:", result)
|
51 |
+
print(f"Number of pages: {num_pages}")
|
52 |
+
```
|
53 |
+
|
54 |
+
Use `tables_only=True` if you want to extract only tables.
|
55 |
+
|
56 |
+
Use `crop_size=N` (where `N` is an integer) to crop N pixels from all borders of the input images. This can be helpful for removing unwanted borders or frame artifacts from scanned documents.
|
57 |
+
|
58 |
+
Use `mode="static"` if you want to simulate LLM call, without executing LLM backend.
|
59 |
+
|
60 |
+
Method `run_inference` will return results and number of pages processed.
|
61 |
+
|
62 |
+
To run with Hugging Face backend use these config values:
|
63 |
+
|
64 |
+
```
|
65 |
+
config = {
|
66 |
+
"method": "huggingface", # Could be 'huggingface' or 'local_gpu'
|
67 |
+
"hf_space": "katanaml/sparrow-qwen2-vl-7b",
|
68 |
+
"hf_token": os.getenv('HF_TOKEN'),
|
69 |
+
}
|
70 |
+
```
|
71 |
+
|
72 |
+
Note: GPU backend `katanaml/sparrow-qwen2-vl-7b` is private, to be able to run below command, you need to create your own backend on Hugging Face space using [code](https://github.com/katanaml/sparrow/tree/main/sparrow-data/parse/sparrow_parse/vllm/infra/qwen2_vl_7b) from Sparrow Parse.
|
73 |
+
|
74 |
+
## PDF pre-processing
|
75 |
+
|
76 |
+
```
|
77 |
+
from sparrow_parse.extractor.pdf_optimizer import PDFOptimizer
|
78 |
+
|
79 |
+
pdf_optimizer = PDFOptimizer()
|
80 |
+
|
81 |
+
num_pages, output_files, temp_dir = pdf_optimizer.split_pdf_to_pages(file_path,
|
82 |
+
debug_dir,
|
83 |
+
convert_to_images)
|
84 |
+
|
85 |
+
```
|
86 |
+
|
87 |
+
Example:
|
88 |
+
|
89 |
+
*file_path* - `/data/invoice_1.pdf`
|
90 |
+
|
91 |
+
*debug_dir* - set to not `None`, for debug purposes only
|
92 |
+
|
93 |
+
*convert_to_images* - default `False`, to split into PDF files
|
94 |
+
|
95 |
+
## Image cropping
|
96 |
+
|
97 |
+
```
|
98 |
+
from sparrow_parse.helpers.image_optimizer import ImageOptimizer
|
99 |
+
|
100 |
+
image_optimizer = ImageOptimizer()
|
101 |
+
|
102 |
+
cropped_file_path = image_optimizer.crop_image_borders(file_path, temp_dir, debug_dir, crop_size)
|
103 |
+
```
|
104 |
+
|
105 |
+
Example:
|
106 |
+
|
107 |
+
*file_path* - `/data/invoice_1.jpg`
|
108 |
+
|
109 |
+
*temp_dir* - directory to store cropped files
|
110 |
+
|
111 |
+
*debug_dir* - set to not `None`, for debug purposes only
|
112 |
+
|
113 |
+
*crop_size* - Number of pixels to crop from each border
|
114 |
+
|
115 |
+
## Library build
|
116 |
+
|
117 |
+
Create Python virtual environment
|
118 |
+
|
119 |
+
```
|
120 |
+
python -m venv .env_sparrow_parse
|
121 |
+
```
|
122 |
+
|
123 |
+
Install Python libraries
|
124 |
+
|
125 |
+
```
|
126 |
+
pip install -r requirements.txt
|
127 |
+
```
|
128 |
+
|
129 |
+
Build package
|
130 |
+
|
131 |
+
```
|
132 |
+
pip install setuptools wheel
|
133 |
+
python setup.py sdist bdist_wheel
|
134 |
+
```
|
135 |
+
|
136 |
+
Upload to PyPI
|
137 |
+
|
138 |
+
```
|
139 |
+
pip install twine
|
140 |
+
twine upload dist/*
|
141 |
+
```
|
142 |
+
|
143 |
+
## Commercial usage
|
144 |
+
|
145 |
+
Sparrow is available under the GPL 3.0 license, promoting freedom to use, modify, and distribute the software while ensuring any modifications remain open source under the same license. This aligns with our commitment to supporting the open-source community and fostering collaboration.
|
146 |
+
|
147 |
+
Additionally, we recognize the diverse needs of organizations, including small to medium-sized enterprises (SMEs). Therefore, Sparrow is also offered for free commercial use to organizations with gross revenue below $5 million USD in the past 12 months, enabling them to leverage Sparrow without the financial burden often associated with high-quality software solutions.
|
148 |
+
|
149 |
+
For businesses that exceed this revenue threshold or require usage terms not accommodated by the GPL 3.0 license—such as integrating Sparrow into proprietary software without the obligation to disclose source code modifications—we offer dual licensing options. Dual licensing allows Sparrow to be used under a separate proprietary license, offering greater flexibility for commercial applications and proprietary integrations. This model supports both the project's sustainability and the business's needs for confidentiality and customization.
|
150 |
+
|
151 |
+
If your organization is seeking to utilize Sparrow under a proprietary license, or if you are interested in custom workflows, consulting services, or dedicated support and maintenance options, please contact us at abaranovskis@redsamuraiconsulting.com. We're here to provide tailored solutions that meet your unique requirements, ensuring you can maximize the benefits of Sparrow for your projects and workflows.
|
152 |
+
|
153 |
+
## Author
|
154 |
+
|
155 |
+
[Katana ML](https://katanaml.io), [Andrej Baranovskij](https://github.com/abaranovskis-redsamurai)
|
156 |
+
|
157 |
+
## License
|
158 |
+
|
159 |
+
Licensed under the GPL 3.0. Copyright 2020-2025 Katana ML, Andrej Baranovskij. [Copy of the license](https://github.com/katanaml/sparrow/blob/main/LICENSE).
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
rich
|
2 |
+
# mlx-vlm==0.1.12 works with transformers from source only
|
3 |
+
# git+https://github.com/huggingface/transformers.git
|
4 |
+
# transformers==4.48.2
|
5 |
+
torchvision==0.21.0
|
6 |
+
torch==2.6.0
|
7 |
+
sentence-transformers==3.3.1
|
8 |
+
numpy==2.1.3
|
9 |
+
pypdf==5.2.0
|
10 |
+
gradio_client
|
11 |
+
pdf2image
|
12 |
+
# mlx==0.22.0; sys_platform == "darwin" and platform_machine == "arm64"
|
13 |
+
mlx>=0.22.0; sys_platform == "darwin" and platform_machine == "arm64"
|
14 |
+
mlx-vlm==0.1.12; sys_platform == "darwin" and platform_machine == "arm64"
|
15 |
+
|
16 |
+
|
17 |
+
# Force reinstall:
|
18 |
+
# pip install --force-reinstall -r requirements.txt
|
19 |
+
|
20 |
+
# For pdf2image, additional step is required:
|
21 |
+
# brew install poppler
|
setup.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
|
3 |
+
with open("README.md", "r", encoding="utf-8") as fh:
|
4 |
+
long_description = fh.read()
|
5 |
+
|
6 |
+
with open("requirements.txt", "r", encoding="utf-8") as fh:
|
7 |
+
requirements = fh.read().splitlines()
|
8 |
+
|
9 |
+
setup(
|
10 |
+
name="sparrow-parse",
|
11 |
+
version="0.5.4",
|
12 |
+
author="Andrej Baranovskij",
|
13 |
+
author_email="andrejus.baranovskis@gmail.com",
|
14 |
+
description="Sparrow Parse is a Python package (part of Sparrow) for parsing and extracting information from documents.",
|
15 |
+
long_description=long_description,
|
16 |
+
long_description_content_type="text/markdown",
|
17 |
+
url="https://github.com/katanaml/sparrow/tree/main/sparrow-data/parse",
|
18 |
+
project_urls={
|
19 |
+
"Homepage": "https://github.com/katanaml/sparrow/tree/main/sparrow-data/parse",
|
20 |
+
"Repository": "https://github.com/katanaml/sparrow",
|
21 |
+
},
|
22 |
+
classifiers=[
|
23 |
+
"Operating System :: OS Independent",
|
24 |
+
"License :: OSI Approved :: GNU General Public License v3 (GPLv3)",
|
25 |
+
"Topic :: Software Development",
|
26 |
+
"Programming Language :: Python :: 3.10",
|
27 |
+
],
|
28 |
+
entry_points={
|
29 |
+
'console_scripts': [
|
30 |
+
'sparrow-parse=sparrow_parse:main',
|
31 |
+
],
|
32 |
+
},
|
33 |
+
keywords="llm, vllm, ocr, vision",
|
34 |
+
packages=find_packages(),
|
35 |
+
python_requires='>=3.10',
|
36 |
+
install_requires=requirements,
|
37 |
+
)
|
sparrow_parse/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__version__ = '0.5.4'
|
sparrow_parse/__main__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def main():
|
2 |
+
print('Sparrow Parse is a Python package for parsing and extracting information from documents.')
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
main()
|
sparrow_parse/extractors/__init__.py
ADDED
File without changes
|
sparrow_parse/extractors/vllm_extractor.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from sparrow_parse.vllm.inference_factory import InferenceFactory
|
3 |
+
from sparrow_parse.helpers.pdf_optimizer import PDFOptimizer
|
4 |
+
from sparrow_parse.helpers.image_optimizer import ImageOptimizer
|
5 |
+
from sparrow_parse.processors.table_structure_processor import TableDetector
|
6 |
+
from rich import print
|
7 |
+
import os
|
8 |
+
import tempfile
|
9 |
+
import shutil
|
10 |
+
|
11 |
+
|
12 |
+
class VLLMExtractor(object):
|
13 |
+
def __init__(self):
|
14 |
+
pass
|
15 |
+
|
16 |
+
def run_inference(self, model_inference_instance, input_data, tables_only=False,
|
17 |
+
generic_query=False, crop_size=None, debug_dir=None, debug=False, mode=None):
|
18 |
+
"""
|
19 |
+
Main entry point for processing input data using a model inference instance.
|
20 |
+
Handles generic queries, PDFs, and table extraction.
|
21 |
+
"""
|
22 |
+
if generic_query:
|
23 |
+
input_data[0]["text_input"] = "retrieve document data. return response in JSON format"
|
24 |
+
|
25 |
+
if debug:
|
26 |
+
print("Input data:", input_data)
|
27 |
+
|
28 |
+
file_path = input_data[0]["file_path"]
|
29 |
+
if self.is_pdf(file_path):
|
30 |
+
return self._process_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode)
|
31 |
+
|
32 |
+
return self._process_non_pdf(model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir)
|
33 |
+
|
34 |
+
|
35 |
+
def _process_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir, mode):
|
36 |
+
"""
|
37 |
+
Handles processing and inference for PDF files, including page splitting and optional table extraction.
|
38 |
+
"""
|
39 |
+
pdf_optimizer = PDFOptimizer()
|
40 |
+
num_pages, output_files, temp_dir = pdf_optimizer.split_pdf_to_pages(input_data[0]["file_path"],
|
41 |
+
debug_dir, convert_to_images=True)
|
42 |
+
|
43 |
+
results = self._process_pages(model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir)
|
44 |
+
|
45 |
+
# Clean up temporary directory
|
46 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
47 |
+
return results, num_pages
|
48 |
+
|
49 |
+
|
50 |
+
def _process_non_pdf(self, model_inference_instance, input_data, tables_only, crop_size, debug, debug_dir):
|
51 |
+
"""
|
52 |
+
Handles processing and inference for non-PDF files, with optional table extraction.
|
53 |
+
"""
|
54 |
+
file_path = input_data[0]["file_path"]
|
55 |
+
|
56 |
+
if tables_only:
|
57 |
+
return self._extract_tables(model_inference_instance, file_path, input_data, debug, debug_dir), 1
|
58 |
+
else:
|
59 |
+
temp_dir = tempfile.mkdtemp()
|
60 |
+
|
61 |
+
if crop_size:
|
62 |
+
if debug:
|
63 |
+
print(f"Cropping image borders by {crop_size} pixels.")
|
64 |
+
image_optimizer = ImageOptimizer()
|
65 |
+
cropped_file_path = image_optimizer.crop_image_borders(file_path, temp_dir, debug_dir, crop_size)
|
66 |
+
input_data[0]["file_path"] = cropped_file_path
|
67 |
+
|
68 |
+
file_path = input_data[0]["file_path"]
|
69 |
+
input_data[0]["file_path"] = [file_path]
|
70 |
+
results = model_inference_instance.inference(input_data)
|
71 |
+
|
72 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
73 |
+
|
74 |
+
return results, 1
|
75 |
+
|
76 |
+
def _process_pages(self, model_inference_instance, output_files, input_data, tables_only, crop_size, debug, debug_dir):
|
77 |
+
"""
|
78 |
+
Processes individual pages (PDF split) and handles table extraction or inference.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
model_inference_instance: The model inference object.
|
82 |
+
output_files: List of file paths for the split PDF pages.
|
83 |
+
input_data: Input data for inference.
|
84 |
+
tables_only: Whether to only process tables.
|
85 |
+
crop_size: Size for cropping image borders.
|
86 |
+
debug: Debug flag for logging.
|
87 |
+
debug_dir: Directory for saving debug information.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
List of results from the processing or inference.
|
91 |
+
"""
|
92 |
+
results_array = []
|
93 |
+
|
94 |
+
if tables_only:
|
95 |
+
if debug:
|
96 |
+
print(f"Processing {len(output_files)} pages for table extraction.")
|
97 |
+
# Process each page individually for table extraction
|
98 |
+
for i, file_path in enumerate(output_files):
|
99 |
+
tables_result = self._extract_tables(
|
100 |
+
model_inference_instance, file_path, input_data, debug, debug_dir, page_index=i
|
101 |
+
)
|
102 |
+
# Since _extract_tables returns a list with one JSON string, unpack it
|
103 |
+
results_array.extend(tables_result) # Unpack the single JSON string
|
104 |
+
else:
|
105 |
+
if debug:
|
106 |
+
print(f"Processing {len(output_files)} pages for inference at once.")
|
107 |
+
|
108 |
+
temp_dir = tempfile.mkdtemp()
|
109 |
+
cropped_files = []
|
110 |
+
|
111 |
+
if crop_size:
|
112 |
+
if debug:
|
113 |
+
print(f"Cropping image borders by {crop_size} pixels from {len(output_files)} images.")
|
114 |
+
|
115 |
+
image_optimizer = ImageOptimizer()
|
116 |
+
|
117 |
+
# Process each file in the output_files array
|
118 |
+
for file_path in output_files:
|
119 |
+
cropped_file_path = image_optimizer.crop_image_borders(
|
120 |
+
file_path,
|
121 |
+
temp_dir,
|
122 |
+
debug_dir,
|
123 |
+
crop_size
|
124 |
+
)
|
125 |
+
cropped_files.append(cropped_file_path)
|
126 |
+
|
127 |
+
# Use the cropped files for inference
|
128 |
+
input_data[0]["file_path"] = cropped_files
|
129 |
+
else:
|
130 |
+
# If no cropping needed, use original files directly
|
131 |
+
input_data[0]["file_path"] = output_files
|
132 |
+
|
133 |
+
# Process all files at once
|
134 |
+
results = model_inference_instance.inference(input_data)
|
135 |
+
results_array.extend(results)
|
136 |
+
|
137 |
+
# Clean up temporary directory
|
138 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
139 |
+
|
140 |
+
return results_array
|
141 |
+
|
142 |
+
|
143 |
+
def _extract_tables(self, model_inference_instance, file_path, input_data, debug, debug_dir, page_index=None):
|
144 |
+
"""
|
145 |
+
Detects and processes tables from an input file.
|
146 |
+
"""
|
147 |
+
table_detector = TableDetector()
|
148 |
+
cropped_tables = table_detector.detect_tables(file_path, local=False, debug_dir=debug_dir, debug=debug)
|
149 |
+
results_array = []
|
150 |
+
temp_dir = tempfile.mkdtemp()
|
151 |
+
|
152 |
+
for i, table in enumerate(cropped_tables):
|
153 |
+
table_index = f"page_{page_index + 1}_table_{i + 1}" if page_index is not None else f"table_{i + 1}"
|
154 |
+
print(f"Processing {table_index} for document {file_path}")
|
155 |
+
|
156 |
+
output_filename = os.path.join(temp_dir, f"{table_index}.jpg")
|
157 |
+
table.save(output_filename, "JPEG")
|
158 |
+
|
159 |
+
input_data[0]["file_path"] = [output_filename]
|
160 |
+
result = self._run_model_inference(model_inference_instance, input_data)
|
161 |
+
results_array.append(result)
|
162 |
+
|
163 |
+
shutil.rmtree(temp_dir, ignore_errors=True)
|
164 |
+
|
165 |
+
# Merge results_array elements into a single JSON structure
|
166 |
+
merged_results = {"page_tables": results_array}
|
167 |
+
|
168 |
+
# Format the merged results as a JSON string with indentation
|
169 |
+
formatted_results = json.dumps(merged_results, indent=4)
|
170 |
+
|
171 |
+
# Return the formatted JSON string wrapped in a list
|
172 |
+
return [formatted_results]
|
173 |
+
|
174 |
+
|
175 |
+
@staticmethod
|
176 |
+
def _run_model_inference(model_inference_instance, input_data):
|
177 |
+
"""
|
178 |
+
Runs model inference and handles JSON decoding.
|
179 |
+
"""
|
180 |
+
result = model_inference_instance.inference(input_data)[0]
|
181 |
+
try:
|
182 |
+
return json.loads(result) if isinstance(result, str) else result
|
183 |
+
except json.JSONDecodeError:
|
184 |
+
return {"message": "Invalid JSON format in LLM output", "valid": "false"}
|
185 |
+
|
186 |
+
|
187 |
+
@staticmethod
|
188 |
+
def is_pdf(file_path):
|
189 |
+
"""Checks if a file is a PDF based on its extension."""
|
190 |
+
return file_path.lower().endswith('.pdf')
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == "__main__":
|
194 |
+
# run locally: python -m sparrow_parse.extractors.vllm_extractor
|
195 |
+
|
196 |
+
extractor = VLLMExtractor()
|
197 |
+
|
198 |
+
# # export HF_TOKEN="hf_"
|
199 |
+
# config = {
|
200 |
+
# "method": "mlx", # Could be 'huggingface', 'mlx' or 'local_gpu'
|
201 |
+
# "model_name": "mlx-community/Qwen2.5-VL-7B-Instruct-8bit",
|
202 |
+
# # "hf_space": "katanaml/sparrow-qwen2-vl-7b",
|
203 |
+
# # "hf_token": os.getenv('HF_TOKEN'),
|
204 |
+
# # Additional fields for local GPU inference
|
205 |
+
# # "device": "cuda", "model_path": "model.pth"
|
206 |
+
# }
|
207 |
+
#
|
208 |
+
# # Use the factory to get the correct instance
|
209 |
+
# factory = InferenceFactory(config)
|
210 |
+
# model_inference_instance = factory.get_inference_instance()
|
211 |
+
#
|
212 |
+
# input_data = [
|
213 |
+
# {
|
214 |
+
# "file_path": "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/bonds_table.png",
|
215 |
+
# "text_input": "retrieve document data. return response in JSON format"
|
216 |
+
# }
|
217 |
+
# ]
|
218 |
+
#
|
219 |
+
# # Now you can run inference without knowing which implementation is used
|
220 |
+
# results_array, num_pages = extractor.run_inference(model_inference_instance, input_data, tables_only=False,
|
221 |
+
# generic_query=False,
|
222 |
+
# crop_size=0,
|
223 |
+
# debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/",
|
224 |
+
# debug=True,
|
225 |
+
# mode=None)
|
226 |
+
#
|
227 |
+
# for i, result in enumerate(results_array):
|
228 |
+
# print(f"Result for page {i + 1}:", result)
|
229 |
+
# print(f"Number of pages: {num_pages}")
|
sparrow_parse/helpers/__init__.py
ADDED
File without changes
|
sparrow_parse/helpers/image_optimizer.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
class ImageOptimizer(object):
|
6 |
+
def __init__(self):
|
7 |
+
pass
|
8 |
+
|
9 |
+
def crop_image_borders(self, file_path, temp_dir, debug_dir=None, crop_size=60):
|
10 |
+
"""
|
11 |
+
Crops all four borders of an image by the specified size.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
file_path (str): Path to the input image
|
15 |
+
temp_dir (str): Temporary directory to store the cropped image
|
16 |
+
debug_dir (str, optional): Directory to save a debug copy of the cropped image
|
17 |
+
crop_size (int): Number of pixels to crop from each border
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
str: Path to the cropped image in temp_dir
|
21 |
+
"""
|
22 |
+
try:
|
23 |
+
# Open the image
|
24 |
+
with Image.open(file_path) as img:
|
25 |
+
# Get image dimensions
|
26 |
+
width, height = img.size
|
27 |
+
|
28 |
+
# Calculate the crop box
|
29 |
+
left = crop_size
|
30 |
+
top = crop_size
|
31 |
+
right = width - crop_size
|
32 |
+
bottom = height - crop_size
|
33 |
+
|
34 |
+
# Ensure we're not trying to crop more than the image size
|
35 |
+
if right <= left or bottom <= top:
|
36 |
+
raise ValueError("Crop size is too large for the image dimensions")
|
37 |
+
|
38 |
+
# Perform the crop
|
39 |
+
cropped_img = img.crop((left, top, right, bottom))
|
40 |
+
|
41 |
+
# Get original filename without path
|
42 |
+
filename = os.path.basename(file_path)
|
43 |
+
name, ext = os.path.splitext(filename)
|
44 |
+
|
45 |
+
# Save cropped image in temp_dir
|
46 |
+
output_path = os.path.join(temp_dir, f"{name}_cropped{ext}")
|
47 |
+
cropped_img.save(output_path)
|
48 |
+
|
49 |
+
# If debug_dir is provided, save a debug copy
|
50 |
+
if debug_dir:
|
51 |
+
os.makedirs(debug_dir, exist_ok=True)
|
52 |
+
debug_path = os.path.join(debug_dir, f"{name}_cropped_debug{ext}")
|
53 |
+
cropped_img.save(debug_path)
|
54 |
+
print(f"Debug cropped image saved to: {debug_path}")
|
55 |
+
|
56 |
+
return output_path
|
57 |
+
|
58 |
+
except Exception as e:
|
59 |
+
raise Exception(f"Error processing image: {str(e)}")
|
sparrow_parse/helpers/pdf_optimizer.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pypdf
|
2 |
+
from pdf2image import convert_from_path
|
3 |
+
import os
|
4 |
+
import tempfile
|
5 |
+
import shutil
|
6 |
+
|
7 |
+
|
8 |
+
class PDFOptimizer(object):
|
9 |
+
def __init__(self):
|
10 |
+
pass
|
11 |
+
|
12 |
+
def split_pdf_to_pages(self, file_path, debug_dir=None, convert_to_images=False):
|
13 |
+
# Create a temporary directory
|
14 |
+
temp_dir = tempfile.mkdtemp()
|
15 |
+
output_files = []
|
16 |
+
|
17 |
+
if not convert_to_images:
|
18 |
+
# Open the PDF file
|
19 |
+
with open(file_path, 'rb') as pdf_file:
|
20 |
+
reader = pypdf.PdfReader(pdf_file)
|
21 |
+
number_of_pages = len(reader.pages)
|
22 |
+
|
23 |
+
# Split the PDF into separate files per page
|
24 |
+
for page_num in range(number_of_pages):
|
25 |
+
writer = pypdf.PdfWriter()
|
26 |
+
writer.add_page(reader.pages[page_num])
|
27 |
+
|
28 |
+
output_filename = os.path.join(temp_dir, f'page_{page_num + 1}.pdf')
|
29 |
+
with open(output_filename, 'wb') as output_file:
|
30 |
+
writer.write(output_file)
|
31 |
+
output_files.append(output_filename)
|
32 |
+
|
33 |
+
if debug_dir:
|
34 |
+
# Save each page to the debug folder
|
35 |
+
debug_output_filename = os.path.join(debug_dir, f'page_{page_num + 1}.pdf')
|
36 |
+
with open(debug_output_filename, 'wb') as output_file:
|
37 |
+
writer.write(output_file)
|
38 |
+
|
39 |
+
# Return the number of pages, the list of file paths, and the temporary directory
|
40 |
+
return number_of_pages, output_files, temp_dir
|
41 |
+
else:
|
42 |
+
# Convert the PDF to images
|
43 |
+
images = convert_from_path(file_path, dpi=300)
|
44 |
+
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
45 |
+
|
46 |
+
# Save the images to the temporary directory
|
47 |
+
for i, image in enumerate(images):
|
48 |
+
output_filename = os.path.join(temp_dir, f'{base_name}_page_{i + 1}.jpg')
|
49 |
+
image.save(output_filename, 'JPEG')
|
50 |
+
output_files.append(output_filename)
|
51 |
+
|
52 |
+
if debug_dir:
|
53 |
+
# Save each image to the debug folder
|
54 |
+
os.makedirs(debug_dir, exist_ok=True)
|
55 |
+
debug_output_filename = os.path.join(debug_dir, f'{base_name}_page_{i + 1}_debug.jpg')
|
56 |
+
image.save(debug_output_filename, 'JPEG')
|
57 |
+
print(f"Debug image saved to: {debug_output_filename}")
|
58 |
+
|
59 |
+
# Return the number of pages, the list of file paths, and the temporary directory
|
60 |
+
return len(images), output_files, temp_dir
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == "__main__":
|
64 |
+
pdf_optimizer = PDFOptimizer()
|
65 |
+
|
66 |
+
# debug_dir = "/Users/andrejb/infra/shared/katana-git/sparrow/sparrow-ml/llm/data/"
|
67 |
+
# # Ensure the output directory exists
|
68 |
+
# os.makedirs(output_directory, exist_ok=True)
|
69 |
+
#
|
70 |
+
# # Split the optimized PDF into separate pages
|
71 |
+
# num_pages, output_files, temp_dir = pdf_optimizer.split_pdf_to_pages("/Users/andrejb/infra/shared/katana-git/sparrow/sparrow-ml/llm/data/oracle_10k_2014_q1_small.pdf",
|
72 |
+
# debug_dir,
|
73 |
+
# True)
|
74 |
+
#
|
75 |
+
# print(f"Number of pages: {num_pages}")
|
76 |
+
# print(f"Output files: {output_files}")
|
77 |
+
# print(f"Temporary directory: {temp_dir}")
|
78 |
+
#
|
79 |
+
# shutil.rmtree(temp_dir, ignore_errors=True)
|
sparrow_parse/images/graph.png
ADDED
![]() |
sparrow_parse/processors/__init__.py
ADDED
File without changes
|
sparrow_parse/processors/table_structure_processor.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rich.progress import Progress, SpinnerColumn, TextColumn
|
2 |
+
from rich import print
|
3 |
+
from transformers import AutoModelForObjectDetection
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision import transforms
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
class TableDetector(object):
|
11 |
+
_model = None # Static variable to hold the table detection model
|
12 |
+
_device = None # Static variable to hold the device information
|
13 |
+
|
14 |
+
def __init__(self):
|
15 |
+
pass
|
16 |
+
|
17 |
+
class MaxResize(object):
|
18 |
+
def __init__(self, max_size=800):
|
19 |
+
self.max_size = max_size
|
20 |
+
|
21 |
+
def __call__(self, image):
|
22 |
+
width, height = image.size
|
23 |
+
current_max_size = max(width, height)
|
24 |
+
scale = self.max_size / current_max_size
|
25 |
+
resized_image = image.resize((int(round(scale * width)), int(round(scale * height))))
|
26 |
+
|
27 |
+
return resized_image
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def _initialize_model(cls, invoke_pipeline_step, local):
|
31 |
+
"""
|
32 |
+
Static method to initialize the table detection model if not already initialized.
|
33 |
+
"""
|
34 |
+
if cls._model is None:
|
35 |
+
# Use invoke_pipeline_step to load the model
|
36 |
+
cls._model, cls._device = invoke_pipeline_step(
|
37 |
+
lambda: cls.load_table_detection_model(),
|
38 |
+
"Loading table detection model...",
|
39 |
+
local
|
40 |
+
)
|
41 |
+
print("Table detection model initialized.")
|
42 |
+
|
43 |
+
|
44 |
+
def detect_tables(self, file_path, local=True, debug_dir=None, debug=False):
|
45 |
+
# Ensure the model is initialized using invoke_pipeline_step
|
46 |
+
self._initialize_model(self.invoke_pipeline_step, local)
|
47 |
+
|
48 |
+
# Use the static model and device
|
49 |
+
model, device = self._model, self._device
|
50 |
+
|
51 |
+
outputs, image = self.invoke_pipeline_step(
|
52 |
+
lambda: self.prepare_image(file_path, model, device),
|
53 |
+
"Preparing image for table detection...",
|
54 |
+
local
|
55 |
+
)
|
56 |
+
|
57 |
+
objects = self.invoke_pipeline_step(
|
58 |
+
lambda: self.identify_tables(model, outputs, image),
|
59 |
+
"Identifying tables in the image...",
|
60 |
+
local
|
61 |
+
)
|
62 |
+
|
63 |
+
cropped_tables = self.invoke_pipeline_step(
|
64 |
+
lambda: self.crop_tables(file_path, image, objects, debug, debug_dir),
|
65 |
+
"Cropping tables from the image...",
|
66 |
+
local
|
67 |
+
)
|
68 |
+
|
69 |
+
return cropped_tables
|
70 |
+
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def load_table_detection_model():
|
74 |
+
model = AutoModelForObjectDetection.from_pretrained("microsoft/table-transformer-detection", revision="no_timm")
|
75 |
+
|
76 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
77 |
+
model.to(device)
|
78 |
+
|
79 |
+
return model, device
|
80 |
+
|
81 |
+
|
82 |
+
def prepare_image(self, file_path, model, device):
|
83 |
+
image = Image.open(file_path).convert("RGB")
|
84 |
+
|
85 |
+
detection_transform = transforms.Compose([
|
86 |
+
self.MaxResize(800),
|
87 |
+
transforms.ToTensor(),
|
88 |
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
89 |
+
])
|
90 |
+
|
91 |
+
pixel_values = detection_transform(image).unsqueeze(0)
|
92 |
+
pixel_values = pixel_values.to(device)
|
93 |
+
|
94 |
+
with torch.no_grad():
|
95 |
+
outputs = model(pixel_values)
|
96 |
+
|
97 |
+
return outputs, image
|
98 |
+
|
99 |
+
def identify_tables(self, model, outputs, image):
|
100 |
+
id2label = model.config.id2label
|
101 |
+
id2label[len(model.config.id2label)] = "no object"
|
102 |
+
|
103 |
+
objects = self.outputs_to_objects(outputs, image.size, id2label)
|
104 |
+
return objects
|
105 |
+
|
106 |
+
|
107 |
+
def crop_tables(self, file_path, image, objects, debug, debug_dir):
|
108 |
+
tokens = []
|
109 |
+
detection_class_thresholds = {
|
110 |
+
"table": 0.5,
|
111 |
+
"table rotated": 0.5,
|
112 |
+
"no object": 10
|
113 |
+
}
|
114 |
+
crop_padding = 30
|
115 |
+
|
116 |
+
tables_crops = self.objects_to_crops(image, tokens, objects, detection_class_thresholds, padding=crop_padding)
|
117 |
+
|
118 |
+
cropped_tables = []
|
119 |
+
|
120 |
+
if len(tables_crops) == 0:
|
121 |
+
if debug:
|
122 |
+
print("No tables detected in: ", file_path)
|
123 |
+
|
124 |
+
return None
|
125 |
+
elif len(tables_crops) > 1:
|
126 |
+
for i, table_crop in enumerate(tables_crops):
|
127 |
+
if debug:
|
128 |
+
print("Table detected in:", file_path, "-", i + 1)
|
129 |
+
|
130 |
+
cropped_table = table_crop['image'].convert("RGB")
|
131 |
+
cropped_tables.append(cropped_table)
|
132 |
+
|
133 |
+
if debug_dir:
|
134 |
+
file_name_table = self.append_filename(file_path, debug_dir, f"table_cropped_{i + 1}")
|
135 |
+
cropped_table.save(file_name_table)
|
136 |
+
else:
|
137 |
+
if debug:
|
138 |
+
print("Table detected in: ", file_path)
|
139 |
+
|
140 |
+
cropped_table = tables_crops[0]['image'].convert("RGB")
|
141 |
+
cropped_tables.append(cropped_table)
|
142 |
+
|
143 |
+
if debug_dir:
|
144 |
+
file_name_table = self.append_filename(file_path, debug_dir, "table_cropped")
|
145 |
+
cropped_table.save(file_name_table)
|
146 |
+
|
147 |
+
return cropped_tables
|
148 |
+
|
149 |
+
# for output bounding box post-processing
|
150 |
+
@staticmethod
|
151 |
+
def box_cxcywh_to_xyxy(x):
|
152 |
+
x_c, y_c, w, h = x.unbind(-1)
|
153 |
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
154 |
+
return torch.stack(b, dim=1)
|
155 |
+
|
156 |
+
def rescale_bboxes(self, out_bbox, size):
|
157 |
+
img_w, img_h = size
|
158 |
+
b = self.box_cxcywh_to_xyxy(out_bbox)
|
159 |
+
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
|
160 |
+
return b
|
161 |
+
|
162 |
+
def outputs_to_objects(self, outputs, img_size, id2label):
|
163 |
+
m = outputs.logits.softmax(-1).max(-1)
|
164 |
+
pred_labels = list(m.indices.detach().cpu().numpy())[0]
|
165 |
+
pred_scores = list(m.values.detach().cpu().numpy())[0]
|
166 |
+
pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
|
167 |
+
pred_bboxes = [elem.tolist() for elem in self.rescale_bboxes(pred_bboxes, img_size)]
|
168 |
+
|
169 |
+
objects = []
|
170 |
+
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
|
171 |
+
class_label = id2label[int(label)]
|
172 |
+
if not class_label == 'no object':
|
173 |
+
objects.append({'label': class_label, 'score': float(score),
|
174 |
+
'bbox': [float(elem) for elem in bbox]})
|
175 |
+
|
176 |
+
return objects
|
177 |
+
|
178 |
+
def objects_to_crops(self, img, tokens, objects, class_thresholds, padding=10):
|
179 |
+
"""
|
180 |
+
Process the bounding boxes produced by the table detection model into
|
181 |
+
cropped table images and cropped tokens.
|
182 |
+
"""
|
183 |
+
|
184 |
+
table_crops = []
|
185 |
+
for obj in objects:
|
186 |
+
if obj['score'] < class_thresholds[obj['label']]:
|
187 |
+
continue
|
188 |
+
|
189 |
+
cropped_table = {}
|
190 |
+
|
191 |
+
bbox = obj['bbox']
|
192 |
+
bbox = [bbox[0] - padding, bbox[1] - padding, bbox[2] + padding, bbox[3] + padding]
|
193 |
+
|
194 |
+
cropped_img = img.crop(bbox)
|
195 |
+
|
196 |
+
table_tokens = [token for token in tokens if self.iob(token['bbox'], bbox) >= 0.5]
|
197 |
+
for token in table_tokens:
|
198 |
+
token['bbox'] = [token['bbox'][0] - bbox[0],
|
199 |
+
token['bbox'][1] - bbox[1],
|
200 |
+
token['bbox'][2] - bbox[0],
|
201 |
+
token['bbox'][3] - bbox[1]]
|
202 |
+
|
203 |
+
# If table is predicted to be rotated, rotate cropped image and tokens/words:
|
204 |
+
if obj['label'] == 'table rotated':
|
205 |
+
cropped_img = cropped_img.rotate(270, expand=True)
|
206 |
+
for token in table_tokens:
|
207 |
+
bbox = token['bbox']
|
208 |
+
bbox = [cropped_img.size[0] - bbox[3] - 1,
|
209 |
+
bbox[0],
|
210 |
+
cropped_img.size[0] - bbox[1] - 1,
|
211 |
+
bbox[2]]
|
212 |
+
token['bbox'] = bbox
|
213 |
+
|
214 |
+
cropped_table['image'] = cropped_img
|
215 |
+
cropped_table['tokens'] = table_tokens
|
216 |
+
|
217 |
+
table_crops.append(cropped_table)
|
218 |
+
|
219 |
+
return table_crops
|
220 |
+
|
221 |
+
|
222 |
+
@staticmethod
|
223 |
+
def append_filename(file_path, debug_dir, word):
|
224 |
+
directory, filename = os.path.split(file_path)
|
225 |
+
name, ext = os.path.splitext(filename)
|
226 |
+
new_filename = f"{name}_{word}{ext}"
|
227 |
+
return os.path.join(debug_dir, new_filename)
|
228 |
+
|
229 |
+
@staticmethod
|
230 |
+
def iob(boxA, boxB):
|
231 |
+
# Determine the coordinates of the intersection rectangle
|
232 |
+
xA = max(boxA[0], boxB[0])
|
233 |
+
yA = max(boxA[1], boxB[1])
|
234 |
+
xB = min(boxA[2], boxB[2])
|
235 |
+
yB = min(boxA[3], boxB[3])
|
236 |
+
|
237 |
+
# Compute the area of intersection rectangle
|
238 |
+
interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
|
239 |
+
|
240 |
+
# Compute the area of both the prediction and ground-truth rectangles
|
241 |
+
boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
|
242 |
+
boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
|
243 |
+
|
244 |
+
# Compute the intersection over box (IoB)
|
245 |
+
iob = interArea / float(boxAArea)
|
246 |
+
|
247 |
+
return iob
|
248 |
+
|
249 |
+
|
250 |
+
@staticmethod
|
251 |
+
def invoke_pipeline_step(task_call, task_description, local):
|
252 |
+
if local:
|
253 |
+
with Progress(
|
254 |
+
SpinnerColumn(),
|
255 |
+
TextColumn("[progress.description]{task.description}"),
|
256 |
+
transient=False,
|
257 |
+
) as progress:
|
258 |
+
progress.add_task(description=task_description, total=None)
|
259 |
+
ret = task_call()
|
260 |
+
else:
|
261 |
+
print(task_description)
|
262 |
+
ret = task_call()
|
263 |
+
|
264 |
+
return ret
|
265 |
+
|
266 |
+
|
267 |
+
if __name__ == "__main__":
|
268 |
+
table_detector = TableDetector()
|
269 |
+
|
270 |
+
# file_path = "/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/bonds_table.png"
|
271 |
+
# cropped_tables = table_detector.detect_tables(file_path, local=True, debug_dir="/Users/andrejb/Work/katana-git/sparrow/sparrow-ml/llm/data/", debug=True)
|
272 |
+
|
273 |
+
# for i, cropped_table in enumerate(cropped_tables):
|
274 |
+
# file_name_table = table_detector.append_filename(file_path, "cropped_" + str(i))
|
275 |
+
# cropped_table.save(file_name_table)
|
sparrow_parse/text_extraction.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mlx_vlm import load, apply_chat_template, generate
|
2 |
+
from mlx_vlm.utils import load_image
|
3 |
+
|
4 |
+
# For test purposes, we will use a sample image
|
5 |
+
|
6 |
+
# Load model and processor
|
7 |
+
qwen_vl_model, qwen_vl_processor = load("mlx-community/Qwen2.5-VL-7B-Instruct-8bit")
|
8 |
+
qwen_vl_config = qwen_vl_model.config
|
9 |
+
|
10 |
+
image = load_image("images/graph.png")
|
11 |
+
|
12 |
+
messages = [
|
13 |
+
{"role": "system", "content": "You are an expert at extracting text from images. Format your response in json."},
|
14 |
+
{"role": "user", "content": "Extract the names, labels and y coordinates from the image."}
|
15 |
+
]
|
16 |
+
|
17 |
+
# Apply chat template
|
18 |
+
prompt = apply_chat_template(qwen_vl_processor, qwen_vl_config, messages)
|
19 |
+
|
20 |
+
# Generate text
|
21 |
+
qwen_vl_output = generate(
|
22 |
+
qwen_vl_model,
|
23 |
+
qwen_vl_processor,
|
24 |
+
prompt,
|
25 |
+
image,
|
26 |
+
max_tokens=1000,
|
27 |
+
temperature=0.7,
|
28 |
+
)
|
29 |
+
|
30 |
+
print(qwen_vl_output)
|
sparrow_parse/vllm/__init__.py
ADDED
File without changes
|
sparrow_parse/vllm/huggingface_inference.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from gradio_client import Client, handle_file
|
2 |
+
from sparrow_parse.vllm.inference_base import ModelInference
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import ast
|
6 |
+
|
7 |
+
|
8 |
+
class HuggingFaceInference(ModelInference):
|
9 |
+
def __init__(self, hf_space, hf_token):
|
10 |
+
self.hf_space = hf_space
|
11 |
+
self.hf_token = hf_token
|
12 |
+
|
13 |
+
|
14 |
+
def process_response(self, output_text):
|
15 |
+
json_string = output_text
|
16 |
+
|
17 |
+
json_string = json_string.strip("[]'")
|
18 |
+
json_string = json_string.replace("```json\n", "").replace("\n```", "")
|
19 |
+
json_string = json_string.replace("'", "")
|
20 |
+
|
21 |
+
try:
|
22 |
+
formatted_json = json.loads(json_string)
|
23 |
+
return json.dumps(formatted_json, indent=2)
|
24 |
+
except json.JSONDecodeError as e:
|
25 |
+
print("Failed to parse JSON:", e)
|
26 |
+
return output_text
|
27 |
+
|
28 |
+
|
29 |
+
def inference(self, input_data, mode=None):
|
30 |
+
if mode == "static":
|
31 |
+
simple_json = self.get_simple_json()
|
32 |
+
return [simple_json]
|
33 |
+
|
34 |
+
client = Client(self.hf_space, hf_token=self.hf_token)
|
35 |
+
|
36 |
+
# Extract and prepare the absolute paths for all file paths in input_data
|
37 |
+
file_paths = [
|
38 |
+
os.path.abspath(file_path)
|
39 |
+
for data in input_data
|
40 |
+
for file_path in data["file_path"]
|
41 |
+
]
|
42 |
+
|
43 |
+
# Validate file existence and prepare files for the Gradio client
|
44 |
+
image_files = [handle_file(path) for path in file_paths if os.path.exists(path)]
|
45 |
+
|
46 |
+
results = client.predict(
|
47 |
+
input_imgs=image_files,
|
48 |
+
text_input=input_data[0]["text_input"], # Single shared text input for all images
|
49 |
+
api_name="/run_inference" # Specify the Gradio API endpoint
|
50 |
+
)
|
51 |
+
|
52 |
+
# Convert the string into a Python list
|
53 |
+
parsed_results = ast.literal_eval(results)
|
54 |
+
|
55 |
+
results_array = []
|
56 |
+
for page_output in parsed_results:
|
57 |
+
page_result = self.process_response(page_output)
|
58 |
+
results_array.append(page_result)
|
59 |
+
|
60 |
+
return results_array
|
sparrow_parse/vllm/inference_base.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
import json
|
3 |
+
|
4 |
+
|
5 |
+
class ModelInference(ABC):
|
6 |
+
@abstractmethod
|
7 |
+
def inference(self, input_data, mode=None):
|
8 |
+
"""This method should be implemented by subclasses."""
|
9 |
+
pass
|
10 |
+
|
11 |
+
def get_simple_json(self):
|
12 |
+
# Define a simple data structure
|
13 |
+
data = {
|
14 |
+
"table": [
|
15 |
+
{
|
16 |
+
"description": "Revenues",
|
17 |
+
"latest_amount": 12453,
|
18 |
+
"previous_amount": 11445
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"description": "Operating expenses",
|
22 |
+
"latest_amount": 9157,
|
23 |
+
"previous_amount": 8822
|
24 |
+
}
|
25 |
+
]
|
26 |
+
}
|
27 |
+
|
28 |
+
# Convert the dictionary to a JSON string
|
29 |
+
json_data = json.dumps(data, indent=4)
|
30 |
+
return json_data
|
sparrow_parse/vllm/inference_factory.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sparrow_parse.vllm.huggingface_inference import HuggingFaceInference
|
2 |
+
from sparrow_parse.vllm.local_gpu_inference import LocalGPUInference
|
3 |
+
from sparrow_parse.vllm.mlx_inference import MLXInference
|
4 |
+
|
5 |
+
|
6 |
+
class InferenceFactory:
|
7 |
+
def __init__(self, config):
|
8 |
+
self.config = config
|
9 |
+
|
10 |
+
def get_inference_instance(self):
|
11 |
+
if self.config["method"] == "huggingface":
|
12 |
+
return HuggingFaceInference(hf_space=self.config["hf_space"], hf_token=self.config["hf_token"])
|
13 |
+
elif self.config["method"] == "local_gpu":
|
14 |
+
model = self._load_local_model() # Replace with actual model loading logic
|
15 |
+
return LocalGPUInference(model=model, device=self.config.get("device", "cuda"))
|
16 |
+
elif self.config["method"] == "mlx":
|
17 |
+
return MLXInference(model_name=self.config["model_name"])
|
18 |
+
else:
|
19 |
+
raise ValueError(f"Unknown method: {self.config['method']}")
|
20 |
+
|
21 |
+
def _load_local_model(self):
|
22 |
+
# Example: Load a PyTorch model (replace with actual loading code)
|
23 |
+
# model = torch.load('model.pth')
|
24 |
+
# return model
|
25 |
+
raise NotImplementedError("Model loading logic not implemented")
|
sparrow_parse/vllm/infra/qwen2_vl_7b/app.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
|
4 |
+
from qwen_vl_utils import process_vision_info
|
5 |
+
from PIL import Image
|
6 |
+
from datetime import datetime
|
7 |
+
import os
|
8 |
+
|
9 |
+
# subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
10 |
+
|
11 |
+
DESCRIPTION = "[Sparrow Qwen2-VL-7B Backend](https://github.com/katanaml/sparrow)"
|
12 |
+
|
13 |
+
|
14 |
+
def array_to_image_path(image_filepath, max_width=1250, max_height=1750):
|
15 |
+
if image_filepath is None:
|
16 |
+
raise ValueError("No image provided. Please upload an image before submitting.")
|
17 |
+
|
18 |
+
# Open the uploaded image using its filepath
|
19 |
+
img = Image.open(image_filepath)
|
20 |
+
|
21 |
+
# Extract the file extension from the uploaded file
|
22 |
+
input_image_extension = image_filepath.split('.')[-1].lower() # Extract extension from filepath
|
23 |
+
|
24 |
+
# Set file extension based on the original file, otherwise default to PNG
|
25 |
+
if input_image_extension in ['jpg', 'jpeg', 'png']:
|
26 |
+
file_extension = input_image_extension
|
27 |
+
else:
|
28 |
+
file_extension = 'png' # Default to PNG if extension is unavailable or invalid
|
29 |
+
|
30 |
+
# Get the current dimensions of the image
|
31 |
+
width, height = img.size
|
32 |
+
|
33 |
+
# Initialize new dimensions to current size
|
34 |
+
new_width, new_height = width, height
|
35 |
+
|
36 |
+
# Check if the image exceeds the maximum dimensions
|
37 |
+
if width > max_width or height > max_height:
|
38 |
+
# Calculate the new size, maintaining the aspect ratio
|
39 |
+
aspect_ratio = width / height
|
40 |
+
|
41 |
+
if width > max_width:
|
42 |
+
new_width = max_width
|
43 |
+
new_height = int(new_width / aspect_ratio)
|
44 |
+
|
45 |
+
if new_height > max_height:
|
46 |
+
new_height = max_height
|
47 |
+
new_width = int(new_height * aspect_ratio)
|
48 |
+
|
49 |
+
# Generate a unique filename using timestamp
|
50 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
51 |
+
filename = f"image_{timestamp}.{file_extension}"
|
52 |
+
|
53 |
+
# Save the image
|
54 |
+
img.save(filename)
|
55 |
+
|
56 |
+
# Get the full path of the saved image
|
57 |
+
full_path = os.path.abspath(filename)
|
58 |
+
|
59 |
+
return full_path, new_width, new_height
|
60 |
+
|
61 |
+
|
62 |
+
# Initialize the model and processor globally to optimize performance
|
63 |
+
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
64 |
+
"Qwen/Qwen2-VL-7B-Instruct",
|
65 |
+
torch_dtype="auto",
|
66 |
+
device_map="auto"
|
67 |
+
)
|
68 |
+
|
69 |
+
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")
|
70 |
+
|
71 |
+
|
72 |
+
@spaces.GPU
|
73 |
+
def run_inference(input_imgs, text_input):
|
74 |
+
results = []
|
75 |
+
|
76 |
+
for image in input_imgs:
|
77 |
+
# Convert each image to the required format
|
78 |
+
image_path, width, height = array_to_image_path(image)
|
79 |
+
|
80 |
+
try:
|
81 |
+
# Prepare messages for each image
|
82 |
+
messages = [
|
83 |
+
{
|
84 |
+
"role": "user",
|
85 |
+
"content": [
|
86 |
+
{
|
87 |
+
"type": "image",
|
88 |
+
"image": image_path,
|
89 |
+
"resized_height": height,
|
90 |
+
"resized_width": width
|
91 |
+
},
|
92 |
+
{
|
93 |
+
"type": "text",
|
94 |
+
"text": text_input
|
95 |
+
}
|
96 |
+
]
|
97 |
+
}
|
98 |
+
]
|
99 |
+
|
100 |
+
# Prepare inputs for the model
|
101 |
+
text = processor.apply_chat_template(
|
102 |
+
messages, tokenize=False, add_generation_prompt=True
|
103 |
+
)
|
104 |
+
|
105 |
+
image_inputs, video_inputs = process_vision_info(messages)
|
106 |
+
inputs = processor(
|
107 |
+
text=[text],
|
108 |
+
images=image_inputs,
|
109 |
+
videos=video_inputs,
|
110 |
+
padding=True,
|
111 |
+
return_tensors="pt",
|
112 |
+
)
|
113 |
+
inputs = inputs.to("cuda")
|
114 |
+
|
115 |
+
# Generate inference output
|
116 |
+
generated_ids = model.generate(**inputs, max_new_tokens=4096)
|
117 |
+
generated_ids_trimmed = [
|
118 |
+
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
119 |
+
]
|
120 |
+
raw_output = processor.batch_decode(
|
121 |
+
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
122 |
+
)
|
123 |
+
|
124 |
+
results.append(raw_output[0])
|
125 |
+
print("Processed: " + image)
|
126 |
+
finally:
|
127 |
+
# Clean up the temporary image file
|
128 |
+
os.remove(image_path)
|
129 |
+
|
130 |
+
return results
|
131 |
+
|
132 |
+
|
133 |
+
css = """
|
134 |
+
#output {
|
135 |
+
height: 500px;
|
136 |
+
overflow: auto;
|
137 |
+
border: 1px solid #ccc;
|
138 |
+
}
|
139 |
+
"""
|
140 |
+
|
141 |
+
with gr.Blocks(css=css) as demo:
|
142 |
+
gr.Markdown(DESCRIPTION)
|
143 |
+
with gr.Tab(label="Qwen2-VL-7B Input"):
|
144 |
+
with gr.Row():
|
145 |
+
with gr.Column():
|
146 |
+
input_imgs = gr.Files(file_types=["image"], label="Upload Document Images")
|
147 |
+
text_input = gr.Textbox(label="Query")
|
148 |
+
submit_btn = gr.Button(value="Submit", variant="primary")
|
149 |
+
with gr.Column():
|
150 |
+
output_text = gr.Textbox(label="Response")
|
151 |
+
|
152 |
+
submit_btn.click(run_inference, [input_imgs, text_input], [output_text])
|
153 |
+
|
154 |
+
demo.queue(api_open=True)
|
155 |
+
demo.launch(debug=True)
|
sparrow_parse/vllm/infra/qwen2_vl_7b/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.24.4
|
2 |
+
Pillow==10.3.0
|
3 |
+
Requests==2.31.0
|
4 |
+
torch
|
5 |
+
torchvision
|
6 |
+
git+https://github.com/huggingface/transformers.git
|
7 |
+
accelerate
|
8 |
+
qwen-vl-utils
|
9 |
+
gradio==4.44.1
|
sparrow_parse/vllm/local_gpu_inference.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from sparrow_parse.vllm.inference_base import ModelInference
|
3 |
+
|
4 |
+
|
5 |
+
class LocalGPUInference(ModelInference):
|
6 |
+
def __init__(self, model, device='cuda'):
|
7 |
+
self.model = model
|
8 |
+
self.device = device
|
9 |
+
self.model.to(self.device)
|
10 |
+
|
11 |
+
def inference(self, input_data, mode=None):
|
12 |
+
self.model.eval() # Set the model to evaluation mode
|
13 |
+
with torch.no_grad(): # No need to calculate gradients
|
14 |
+
input_tensor = torch.tensor(input_data).to(self.device)
|
15 |
+
output = self.model(input_tensor)
|
16 |
+
return output.cpu().numpy() # Convert the output back to NumPy if necessary
|
sparrow_parse/vllm/mlx_inference.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mlx_vlm import load, generate
|
2 |
+
from mlx_vlm.prompt_utils import apply_chat_template
|
3 |
+
from mlx_vlm.utils import load_image
|
4 |
+
from sparrow_parse.vllm.inference_base import ModelInference
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
from rich import print
|
8 |
+
|
9 |
+
|
10 |
+
class MLXInference(ModelInference):
|
11 |
+
"""
|
12 |
+
A class for performing inference using the MLX model.
|
13 |
+
Handles image preprocessing, response formatting, and model interaction.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, model_name):
|
17 |
+
"""
|
18 |
+
Initialize the inference class with the given model name.
|
19 |
+
|
20 |
+
:param model_name: Name of the model to load.
|
21 |
+
"""
|
22 |
+
self.model_name = model_name
|
23 |
+
print(f"MLXInference initialized for model: {model_name}")
|
24 |
+
|
25 |
+
|
26 |
+
@staticmethod
|
27 |
+
def _load_model_and_processor(model_name):
|
28 |
+
"""
|
29 |
+
Load the model and processor for inference.
|
30 |
+
|
31 |
+
:param model_name: Name of the model to load.
|
32 |
+
:return: Tuple containing the loaded model and processor.
|
33 |
+
"""
|
34 |
+
model, processor = load(model_name)
|
35 |
+
print(f"Loaded model: {model_name}")
|
36 |
+
return model, processor
|
37 |
+
|
38 |
+
|
39 |
+
def process_response(self, output_text):
|
40 |
+
"""
|
41 |
+
Process and clean the model's raw output to format as JSON.
|
42 |
+
|
43 |
+
:param output_text: Raw output text from the model.
|
44 |
+
:return: A formatted JSON string or the original text in case of errors.
|
45 |
+
"""
|
46 |
+
try:
|
47 |
+
cleaned_text = (
|
48 |
+
output_text.strip("[]'")
|
49 |
+
.replace("```json\n", "")
|
50 |
+
.replace("\n```", "")
|
51 |
+
.replace("'", "")
|
52 |
+
)
|
53 |
+
formatted_json = json.loads(cleaned_text)
|
54 |
+
return json.dumps(formatted_json, indent=2)
|
55 |
+
except json.JSONDecodeError as e:
|
56 |
+
print(f"Failed to parse JSON in MLX inference backend: {e}")
|
57 |
+
return output_text
|
58 |
+
|
59 |
+
|
60 |
+
def load_image_data(self, image_filepath, max_width=1250, max_height=1750):
|
61 |
+
"""
|
62 |
+
Load and resize image while maintaining its aspect ratio.
|
63 |
+
|
64 |
+
:param image_filepath: Path to the image file.
|
65 |
+
:param max_width: Maximum allowed width of the image.
|
66 |
+
:param max_height: Maximum allowed height of the image.
|
67 |
+
:return: Tuple containing the image object and its new dimensions.
|
68 |
+
"""
|
69 |
+
image = load_image(image_filepath) # Assuming load_image is defined elsewhere
|
70 |
+
width, height = image.size
|
71 |
+
|
72 |
+
# Calculate new dimensions while maintaining the aspect ratio
|
73 |
+
if width > max_width or height > max_height:
|
74 |
+
aspect_ratio = width / height
|
75 |
+
new_width = min(max_width, int(max_height * aspect_ratio))
|
76 |
+
new_height = min(max_height, int(max_width / aspect_ratio))
|
77 |
+
return image, new_width, new_height
|
78 |
+
|
79 |
+
return image, width, height
|
80 |
+
|
81 |
+
|
82 |
+
def inference(self, input_data, mode=None):
|
83 |
+
"""
|
84 |
+
Perform inference on input data using the specified model.
|
85 |
+
|
86 |
+
:param input_data: A list of dictionaries containing image file paths and text inputs.
|
87 |
+
:param mode: Optional mode for inference ("static" for simple JSON output).
|
88 |
+
:return: List of processed model responses.
|
89 |
+
"""
|
90 |
+
if mode == "static":
|
91 |
+
return [self.get_simple_json()]
|
92 |
+
|
93 |
+
# Load the model and processor
|
94 |
+
model, processor = self._load_model_and_processor(self.model_name)
|
95 |
+
config = model.config
|
96 |
+
|
97 |
+
# Prepare absolute file paths
|
98 |
+
file_paths = self._extract_file_paths(input_data)
|
99 |
+
|
100 |
+
results = []
|
101 |
+
for file_path in file_paths:
|
102 |
+
image, width, height = self.load_image_data(file_path)
|
103 |
+
|
104 |
+
# Prepare messages for the chat model
|
105 |
+
messages = [
|
106 |
+
{"role": "system", "content": "You are an expert at extracting structured text from image documents."},
|
107 |
+
{"role": "user", "content": input_data[0]["text_input"]},
|
108 |
+
]
|
109 |
+
|
110 |
+
# Generate and process response
|
111 |
+
prompt = apply_chat_template(processor, config, messages) # Assuming defined
|
112 |
+
response = generate(
|
113 |
+
model,
|
114 |
+
processor,
|
115 |
+
prompt,
|
116 |
+
image,
|
117 |
+
resize_shape=(width, height),
|
118 |
+
max_tokens=4000,
|
119 |
+
temperature=0.0,
|
120 |
+
verbose=False
|
121 |
+
)
|
122 |
+
results.append(self.process_response(response))
|
123 |
+
|
124 |
+
print("Inference completed successfully for: ", file_path)
|
125 |
+
|
126 |
+
return results
|
127 |
+
|
128 |
+
@staticmethod
|
129 |
+
def _extract_file_paths(input_data):
|
130 |
+
"""
|
131 |
+
Extract and resolve absolute file paths from input data.
|
132 |
+
|
133 |
+
:param input_data: List of dictionaries containing image file paths.
|
134 |
+
:return: List of absolute file paths.
|
135 |
+
"""
|
136 |
+
return [
|
137 |
+
os.path.abspath(file_path)
|
138 |
+
for data in input_data
|
139 |
+
for file_path in data.get("file_path", [])
|
140 |
+
]
|