artificialguybr commited on
Commit
ee94965
·
verified ·
1 Parent(s): 8f9ee63

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -24
app.py CHANGED
@@ -1,35 +1,50 @@
1
  import gradio as gr
 
 
 
2
  from surya.detection import batch_inference
3
- from surya.model.segformer import load_model, load_processor
 
 
4
  from surya.postprocessing.heatmap import draw_polys_on_image
5
 
6
- model, processor = load_model(), load_processor()
 
 
7
 
8
- HEADER = """
9
- # Surya OCR Demo
10
- This demo will let you try surya, a multilingual OCR model. It supports text detection now, but will support text recognition in the future. This HF Space will be updated.
11
- Notes:
12
- - This works best on documents with printed text.
13
- - Model and code by Vik Paruchuri.
14
- Learn more [here](https://github.com/VikParuchuri/surya).
15
- """.strip()
16
 
17
- def text_detection(img):
18
- preds = batch_inference([img], model, processor)[0]
19
- img = draw_polys_on_image(preds["polygons"], img)
20
- return img, preds
21
 
 
 
 
 
22
 
23
  with gr.Blocks() as app:
24
- gr.Markdown(HEADER)
25
- with gr.Row():
26
- input_image = gr.Image(label="Input Image", type="pil")
27
- output_image = gr.Image(label="Output Image", type="pil", interactive=False)
28
- text_detection_btn = gr.Button("Run Text Detection")
29
-
30
- json_output = gr.JSON(label="JSON Output")
31
- text_detection_btn.click(fn=text_detection, inputs=input_image, outputs=[output_image, json_output], api_name="text_detection")
32
-
 
 
 
 
 
 
 
 
33
 
34
  if __name__ == "__main__":
35
- app.launch()
 
1
  import gradio as gr
2
+ import json
3
+ from PIL import Image
4
+ from surya.ocr import run_ocr
5
  from surya.detection import batch_inference
6
+ from surya.model.segformer import load_model as load_det_model, load_processor as load_det_processor
7
+ from surya.model.recognition.model import load_model as load_rec_model
8
+ from surya.model.recognition.processor import load_processor as load_rec_processor
9
  from surya.postprocessing.heatmap import draw_polys_on_image
10
 
11
+ # Load models and processors
12
+ det_model, det_processor = load_det_model(), load_det_processor()
13
+ rec_model, rec_processor = load_rec_model(), load_rec_processor()
14
 
15
+ # Load languages from JSON
16
+ with open("languages.json", "r") as file:
17
+ languages = json.load(file)
18
+ language_options = [(code, language) for code, language in languages.items()]
 
 
 
 
19
 
20
+ def ocr_function(img, langs):
21
+ predictions = run_ocr([img], langs.split(','), det_model, det_processor, rec_model, rec_processor)[0]
22
+ img_with_text = draw_polys_on_image(predictions["polys"], img)
23
+ return img_with_text, predictions
24
 
25
+ def text_line_detection_function(img):
26
+ preds = batch_inference([img], det_model, det_processor)[0]
27
+ img_with_lines = draw_polys_on_image(preds["polygons"], img)
28
+ return img_with_lines, preds
29
 
30
  with gr.Blocks() as app:
31
+ gr.Markdown("# Surya OCR and Text Line Detection Demo")
32
+ with gr.Tab("OCR"):
33
+ with gr.Row():
34
+ ocr_input_image = gr.Image(label="Input Image for OCR", type="pil")
35
+ ocr_language_selector = gr.Dropdown(label="Select Language(s) for OCR", choices=language_options, value="en", type="str")
36
+ ocr_output_image = gr.Image(label="OCR Output Image", type="pil", interactive=False)
37
+ ocr_json_output = gr.JSON(label="OCR JSON Output")
38
+ ocr_button = gr.Button("Run OCR")
39
+ ocr_button.click(fn=ocr_function, inputs=[ocr_input_image, ocr_language_selector], outputs=[ocr_output_image, ocr_json_output])
40
+
41
+ with gr.Tab("Text Line Detection"):
42
+ with gr.Row():
43
+ detection_input_image = gr.Image(label="Input Image for Detection", type="pil")
44
+ detection_output_image = gr.Image(label="Detection Output Image", type="pil", interactive=False)
45
+ detection_json_output = gr.JSON(label="Detection JSON Output")
46
+ detection_button = gr.Button("Run Text Line Detection")
47
+ detection_button.click(fn=text_line_detection_function, inputs=detection_input_image, outputs=[detection_output_image, detection_json_output])
48
 
49
  if __name__ == "__main__":
50
+ app.launch()