artificialguybr commited on
Commit
94a609f
·
verified ·
1 Parent(s): 5d72698

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -100
app.py CHANGED
@@ -1,9 +1,9 @@
1
  import gradio as gr
2
- import torch
3
  import logging
4
  import os
5
  import json
6
  from PIL import Image
 
7
  from surya.ocr import run_ocr
8
  from surya.detection import batch_text_detection
9
  from surya.layout import batch_layout_detection
@@ -11,16 +11,22 @@ from surya.ordering import batch_ordering
11
  from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
12
  from surya.model.recognition.model import load_model as load_rec_model
13
  from surya.model.recognition.processor import load_processor as load_rec_processor
14
- from surya.model.ordering.model import load_model as load_order_model
15
- from surya.model.ordering.processor import load_processor as load_order_processor
16
  from surya.settings import settings
 
 
17
 
18
  # Set up logging
19
- logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
20
  logger = logging.getLogger(__name__)
21
 
22
- # Load models and processors
23
- logger.info("Loading models and processors...")
 
 
 
 
 
 
24
  det_processor, det_model = load_det_processor(), load_det_model()
25
  rec_model, rec_processor = load_rec_model(), load_rec_processor()
26
  layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
@@ -28,105 +34,71 @@ layout_processor = load_det_processor(checkpoint=settings.LAYOUT_MODEL_CHECKPOIN
28
  order_model = load_order_model()
29
  order_processor = load_order_processor()
30
 
31
- # Compile the OCR model for better performance
32
- logger.info("Compiling OCR model...")
33
- os.environ['RECOGNITION_STATIC_CACHE'] = 'true'
34
  rec_model.decoder.model = torch.compile(rec_model.decoder.model)
35
 
36
- class SuryaJSONEncoder(json.JSONEncoder):
37
- def default(self, obj):
38
- if hasattr(obj, '__dict__'):
39
- return {key: self.default(value) for key, value in obj.__dict__.items()}
40
- elif isinstance(obj, (list, tuple)):
41
- return [self.default(item) for item in obj]
42
- elif isinstance(obj, Image.Image):
43
- return "PIL.Image.Image object"
44
- return super().default(obj)
45
 
46
- def process_image(image_path, langs):
47
- logger.info(f"Processing image: {image_path}")
48
- image = Image.open(image_path)
49
-
50
- results = {}
51
-
52
- try:
53
- # OCR
54
- logger.info("Performing OCR...")
55
- ocr_predictions = run_ocr([image], [langs.split(',')], det_model, det_processor, rec_model, rec_processor)
56
- results["ocr"] = ocr_predictions[0]
57
-
58
- # Text line detection
59
- logger.info("Detecting text lines...")
60
- line_predictions = batch_text_detection([image], det_model, det_processor)
61
- results["text_lines"] = line_predictions[0]
62
-
63
- # Layout analysis
64
- logger.info("Analyzing layout...")
65
- layout_predictions = batch_layout_detection([image], layout_model, layout_processor, line_predictions)
66
- results["layout"] = layout_predictions[0]
67
-
68
- # Reading order
69
- logger.info("Determining reading order...")
70
- logger.debug(f"Layout predictions: {layout_predictions}")
71
-
72
- if isinstance(layout_predictions[0], dict) and 'bboxes' in layout_predictions[0]:
73
- bboxes = [bbox['bbox'] for bbox in layout_predictions[0]['bboxes']]
74
- order_predictions = batch_ordering([image], [bboxes], order_model, order_processor)
75
- results["reading_order"] = order_predictions[0]
76
- else:
77
- logger.warning("Layout predictions do not have the expected structure. Skipping reading order detection.")
78
- results["reading_order"] = "Reading order detection skipped due to unexpected layout prediction structure."
79
-
80
- except Exception as e:
81
- logger.error(f"Error processing image: {str(e)}", exc_info=True)
82
- results["error"] = str(e)
83
-
84
- logger.info("Processing complete.")
85
- return json.dumps(results, indent=2, cls=SuryaJSONEncoder)
86
 
87
- def surya_ui(image, langs):
88
- if image is None:
89
- return "Please upload an image."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- try:
92
- result = process_image(image, langs)
93
- return result
94
- except Exception as e:
95
- logger.error(f"Error in UI processing: {str(e)}", exc_info=True)
96
- return f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
 
97
 
98
- # Create Gradio interface
99
- iface = gr.Interface(
100
- fn=surya_ui,
101
- inputs=[
102
- gr.Image(type="filepath", label="Upload Image"),
103
- gr.Textbox(label="Languages (comma-separated, e.g., 'en,fr')", value="en")
104
- ],
105
- outputs=gr.Textbox(label="Results"),
106
- title="Surya Document Analysis",
107
- description="Upload an image to perform OCR, text line detection, layout analysis, and reading order detection.",
108
- theme="huggingface",
109
- css="""
110
- .gradio-container {
111
- font-family: 'IBM Plex Sans', sans-serif;
112
- }
113
- .gr-button {
114
- color: white;
115
- border-radius: 8px;
116
- background: linear-gradient(45deg, #ff9a9e 0%, #fad0c4 99%, #fad0c4 100%);
117
- }
118
- .gr-button:hover {
119
- background: linear-gradient(45deg, #fad0c4 0%, #ff9a9e 99%, #ff9a9e 100%);
120
- }
121
- .gr-form {
122
- border-radius: 12px;
123
- background-color: #ffffff;
124
- box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
125
- }
126
- """
127
- )
128
 
129
- # Launch the interface
130
  if __name__ == "__main__":
131
- logger.info("Starting Gradio interface...")
132
- iface.launch()
 
1
  import gradio as gr
 
2
  import logging
3
  import os
4
  import json
5
  from PIL import Image
6
+ import torch
7
  from surya.ocr import run_ocr
8
  from surya.detection import batch_text_detection
9
  from surya.layout import batch_layout_detection
 
11
  from surya.model.detection.model import load_model as load_det_model, load_processor as load_det_processor
12
  from surya.model.recognition.model import load_model as load_rec_model
13
  from surya.model.recognition.processor import load_processor as load_rec_processor
 
 
14
  from surya.settings import settings
15
+ from surya.model.ordering.processor import load_processor as load_order_processor
16
+ from surya.model.ordering.model import load_model as load_order_model
17
 
18
  # Set up logging
19
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
20
  logger = logging.getLogger(__name__)
21
 
22
+ # Set environment variables for performance
23
+ os.environ["RECOGNITION_BATCH_SIZE"] = "512"
24
+ os.environ["DETECTOR_BATCH_SIZE"] = "36"
25
+ os.environ["ORDER_BATCH_SIZE"] = "32"
26
+ os.environ["RECOGNITION_STATIC_CACHE"] = "true"
27
+
28
+ # Load models
29
+ logger.info("Loading models...")
30
  det_processor, det_model = load_det_processor(), load_det_model()
31
  rec_model, rec_processor = load_rec_model(), load_rec_processor()
32
  layout_model = load_det_model(checkpoint=settings.LAYOUT_MODEL_CHECKPOINT)
 
34
  order_model = load_order_model()
35
  order_processor = load_order_processor()
36
 
37
+ # Compile recognition model
38
+ logger.info("Compiling recognition model...")
 
39
  rec_model.decoder.model = torch.compile(rec_model.decoder.model)
40
 
41
+ def ocr_workflow(image, langs):
42
+ logger.info(f"Starting OCR workflow with languages: {langs}")
43
+ image = Image.open(image.name)
44
+ predictions = run_ocr([image], [langs.split(',')], det_model, det_processor, rec_model, rec_processor)
45
+ logger.info("OCR workflow completed")
46
+ return json.dumps(predictions, indent=2)
 
 
 
47
 
48
+ def text_detection_workflow(image):
49
+ logger.info("Starting text detection workflow")
50
+ image = Image.open(image.name)
51
+ predictions = batch_text_detection([image], det_model, det_processor)
52
+ logger.info("Text detection workflow completed")
53
+ return json.dumps(predictions, indent=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
+ def layout_analysis_workflow(image):
56
+ logger.info("Starting layout analysis workflow")
57
+ image = Image.open(image.name)
58
+ line_predictions = batch_text_detection([image], det_model, det_processor)
59
+ layout_predictions = batch_layout_detection([image], layout_model, layout_processor, line_predictions)
60
+ logger.info("Layout analysis workflow completed")
61
+ return json.dumps(layout_predictions, indent=2)
62
+
63
+ def reading_order_workflow(image):
64
+ logger.info("Starting reading order workflow")
65
+ image = Image.open(image.name)
66
+ line_predictions = batch_text_detection([image], det_model, det_processor)
67
+ layout_predictions = batch_layout_detection([image], layout_model, layout_processor, line_predictions)
68
+ bboxes = [pred['bbox'] for pred in layout_predictions[0]['bboxes']]
69
+ order_predictions = batch_ordering([image], [bboxes], order_model, order_processor)
70
+ logger.info("Reading order workflow completed")
71
+ return json.dumps(order_predictions, indent=2)
72
+
73
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
74
+ gr.Markdown("# Surya Document Analysis")
75
 
76
+ with gr.Tab("OCR"):
77
+ gr.Markdown("## Optical Character Recognition")
78
+ with gr.Row():
79
+ ocr_input = gr.File(label="Upload Image or PDF")
80
+ ocr_langs = gr.Textbox(label="Languages (comma-separated)", value="en")
81
+ ocr_button = gr.Button("Run OCR")
82
+ ocr_output = gr.JSON(label="OCR Results")
83
+ ocr_button.click(ocr_workflow, inputs=[ocr_input, ocr_langs], outputs=ocr_output)
84
+
85
+ with gr.Tab("Text Detection"):
86
+ gr.Markdown("## Text Line Detection")
87
+ det_input = gr.File(label="Upload Image or PDF")
88
+ det_button = gr.Button("Run Text Detection")
89
+ det_output = gr.JSON(label="Text Detection Results")
90
+ det_button.click(text_detection_workflow, inputs=det_input, outputs=det_output)
91
 
92
+ with gr.Tab("Layout Analysis"):
93
+ gr.Markdown("## Layout Analysis and Reading Order")
94
+ layout_input = gr.File(label="Upload Image or PDF")
95
+ layout_button = gr.Button("Run Layout Analysis")
96
+ order_button = gr.Button("Determine Reading Order")
97
+ layout_output = gr.JSON(label="Layout Analysis Results")
98
+ order_output = gr.JSON(label="Reading Order Results")
99
+ layout_button.click(layout_analysis_workflow, inputs=layout_input, outputs=layout_output)
100
+ order_button.click(reading_order_workflow, inputs=layout_input, outputs=order_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
 
102
  if __name__ == "__main__":
103
+ logger.info("Starting Gradio app...")
104
+ demo.launch()