wjbmattingly commited on
Commit
dabac75
1 Parent(s): 1561fc5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -65
app.py CHANGED
@@ -1,11 +1,7 @@
1
  import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  import torch
4
- import subprocess
5
- import json
6
- from PIL import Image, ImageDraw
7
- import os
8
- import tempfile
9
 
10
  # Dictionary of model names and their corresponding HuggingFace model IDs
11
  MODEL_OPTIONS = {
@@ -36,78 +32,93 @@ def load_model(model_name):
36
  current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
37
  current_model_name = model_name
38
 
39
- # Move model to GPU if available
40
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
41
- current_model = current_model.to(device)
42
 
43
  return current_processor, current_model
44
 
 
45
  def process_image(image, model_name):
46
- # Save the uploaded image to a temporary file
47
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_img:
48
- image.save(temp_img, format="JPEG")
49
- temp_img_path = temp_img.name
50
-
51
- # Run Kraken for line detection
52
- lines_json_path = "lines.json"
53
- kraken_command = f"kraken -i {temp_img_path} {lines_json_path} binarize segment -bl"
54
- subprocess.run(kraken_command, shell=True, check=True)
55
-
56
- # Load the lines from the JSON file
57
- with open(lines_json_path, 'r') as f:
58
- lines_data = json.load(f)
59
-
60
  processor, model = load_model(model_name)
61
-
62
- # Process each line
63
- transcriptions = []
64
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65
- for line in lines_data['lines']:
66
- # Extract line coordinates
67
- x1, y1 = line['baseline'][0]
68
- x2, y2 = line['baseline'][-1]
69
-
70
- # Crop the line from the original image
71
- line_image = image.crop((x1, y1, x2, y2))
72
-
73
- # Prepare image for TrOCR
74
- pixel_values = processor(line_image, return_tensors="pt").pixel_values
75
- pixel_values = pixel_values.to(device)
76
-
77
- # Generate (no beam search)
78
- with torch.no_grad():
79
- generated_ids = model.generate(pixel_values)
80
-
81
- # Decode
82
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
83
- transcriptions.append(generated_text)
84
-
85
- # Clean up temporary files
86
- os.unlink(temp_img_path)
87
- os.unlink(lines_json_path)
88
-
89
- # Create an image with bounding boxes
90
- draw = ImageDraw.Draw(image)
91
- for line in lines_data['lines']:
92
- coords = line['baseline']
93
- draw.line(coords, fill="red", width=2)
94
-
95
- return image, "\n".join(transcriptions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # Gradio interface
98
- with gr.Blocks() as iface:
99
- gr.Markdown("# Medieval Document Transcription")
100
- gr.Markdown("Upload an image of a medieval document and select a model to transcribe it. The tool will detect lines and transcribe each line separately.")
101
 
102
  with gr.Row():
103
- input_image = gr.Image(type="pil", label="Input Image")
104
  model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base")
105
 
106
- with gr.Row():
107
- output_image = gr.Image(type="pil", label="Detected Lines")
108
- transcription_output = gr.Textbox(label="Transcription", lines=10)
109
 
110
  submit_button = gr.Button("Transcribe")
111
- submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=[output_image, transcription_output])
 
 
112
 
113
  iface.launch()
 
1
  import gradio as gr
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  import torch
4
+ import spaces
 
 
 
 
5
 
6
  # Dictionary of model names and their corresponding HuggingFace model IDs
7
  MODEL_OPTIONS = {
 
32
  current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
33
  current_model_name = model_name
34
 
35
+ # Move model to GPU
36
+ current_model = current_model.to('cuda')
 
37
 
38
  return current_processor, current_model
39
 
40
+ @spaces.GPU
41
  def process_image(image, model_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  processor, model = load_model(model_name)
43
+
44
+ # Prepare image
45
+ pixel_values = processor(image, return_tensors="pt").pixel_values
46
+
47
+ # Move input to GPU
48
+ pixel_values = pixel_values.to('cuda')
49
+
50
+ # Generate (no beam search)
51
+ with torch.no_grad():
52
+ generated_ids = model.generate(pixel_values)
53
+
54
+ # Decode
55
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
56
+ return generated_text
57
+
58
+ # Base URL for the images
59
+ base_url = "https://huggingface.co/medieval-data/trocr-medieval-base/resolve/main/images/"
60
+
61
+ # List of example images and their corresponding models
62
+ examples = [
63
+ [f"{base_url}caroline-1.png", "Medieval Latin Caroline"],
64
+ [f"{base_url}caroline-2.png", "Medieval Latin Caroline"],
65
+ [f"{base_url}cursiva-1.png", "Medieval Cursiva"],
66
+ [f"{base_url}cursiva-2.png", "Medieval Cursiva"],
67
+ [f"{base_url}cursiva-3.png", "Medieval Cursiva"],
68
+ [f"{base_url}humanistica-1.png", "Medieval Humanistica"],
69
+ [f"{base_url}humanistica-2.png", "Medieval Humanistica"],
70
+ [f"{base_url}humanistica-3.png", "Medieval Humanistica"],
71
+ [f"{base_url}hybrida-1.png", "Medieval Castilian Hybrida"],
72
+ [f"{base_url}hybrida-2.png", "Medieval Castilian Hybrida"],
73
+ [f"{base_url}hybrida-3.png", "Medieval Castilian Hybrida"],
74
+ [f"{base_url}praegothica-1.png", "Medieval Praegothica"],
75
+ [f"{base_url}praegothica-2.png", "Medieval Praegothica"],
76
+ [f"{base_url}praegothica-3.png", "Medieval Praegothica"],
77
+ [f"{base_url}print-1.png", "Medieval Print"],
78
+ [f"{base_url}print-2.png", "Medieval Print"],
79
+ [f"{base_url}print-3.png", "Medieval Print"],
80
+ [f"{base_url}semihybrida-1.png", "Medieval Semihybrida"],
81
+ [f"{base_url}semihybrida-2.png", "Medieval Semihybrida"],
82
+ [f"{base_url}semihybrida-3.png", "Medieval Semihybrida"],
83
+ [f"{base_url}semitextualis-1.png", "Medieval Semitextualis"],
84
+ [f"{base_url}semitextualis-2.png", "Medieval Semitextualis"],
85
+ [f"{base_url}semitextualis-3.png", "Medieval Semitextualis"],
86
+ [f"{base_url}textualis-1.png", "Medieval Textualis"],
87
+ [f"{base_url}textualis-2.png", "Medieval Textualis"],
88
+ [f"{base_url}textualis-3.png", "Medieval Textualis"],
89
+ ]
90
+
91
+ # Custom CSS to make the image wider
92
+ custom_css = """
93
+ #image_upload {
94
+ max-width: 100% !important;
95
+ width: 100% !important;
96
+ height: auto !important;
97
+ }
98
+ #image_upload > div:first-child {
99
+ width: 100% !important;
100
+ }
101
+ #image_upload img {
102
+ max-width: 100% !important;
103
+ width: 100% !important;
104
+ height: auto !important;
105
+ }
106
+ """
107
 
108
  # Gradio interface
109
+ with gr.Blocks(css=custom_css) as iface:
110
+ gr.Markdown("# Medieval TrOCR Model Switcher")
111
+ gr.Markdown("Upload an image of medieval text and select a model to transcribe it. Note: This tool is designed to work on a single line of text at a time for optimal results.")
112
 
113
  with gr.Row():
114
+ input_image = gr.Image(type="pil", label="Input Image", elem_id="image_upload")
115
  model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base")
116
 
117
+ transcription_output = gr.Textbox(label="Transcription")
 
 
118
 
119
  submit_button = gr.Button("Transcribe")
120
+ submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=transcription_output)
121
+
122
+ gr.Examples(examples, inputs=[input_image, model_dropdown], outputs=transcription_output)
123
 
124
  iface.launch()