wjbmattingly commited on
Commit
0d4066e
1 Parent(s): 228b5a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -58
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
- from transformers import TrOCRProcessor, VisionEncoderDecoderModel
3
  import torch
 
4
  import subprocess
5
  import json
6
  from PIL import Image, ImageDraw
@@ -23,97 +23,133 @@ MODEL_OPTIONS = {
23
  "Medieval Print": "medieval-data/trocr-medieval-print"
24
  }
25
 
26
- # Global variables to store the current model and processor
27
- current_model = None
28
- current_processor = None
29
- current_model_name = None
30
-
31
  def load_model(model_name):
32
- global current_model, current_processor, current_model_name
 
 
33
 
34
- if model_name != current_model_name:
35
- model_id = MODEL_OPTIONS[model_name]
36
- current_processor = TrOCRProcessor.from_pretrained(model_id)
37
- current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
38
- current_model_name = model_name
39
-
40
- # Move model to GPU if available, else use CPU
41
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
- current_model = current_model.to(device)
43
 
44
- return current_processor, current_model
45
-
46
- def process_image(image, model_name):
47
- # Save the uploaded image to a temporary file
48
- with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_img:
49
- image.save(temp_img, format="JPEG")
50
- temp_img_path = temp_img.name
51
 
 
52
  # Run Kraken for line detection
53
  lines_json_path = "lines.json"
54
- kraken_command = f"kraken -i {temp_img_path} {lines_json_path} binarize segment -bl"
55
  subprocess.run(kraken_command, shell=True, check=True)
56
 
57
  # Load the lines from the JSON file
58
  with open(lines_json_path, 'r') as f:
59
  lines_data = json.load(f)
60
 
61
- processor, model = load_model(model_name)
 
62
 
63
- # Determine device
64
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65
 
66
- # Process each line
67
- transcriptions = []
68
- for line in lines_data['lines']:
69
- # Extract line coordinates
70
- x1, y1 = line['baseline'][0]
71
- x2, y2 = line['baseline'][-1]
 
 
72
 
73
  # Crop the line from the original image
74
  line_image = image.crop((x1, y1, x2, y2))
75
 
76
- # Convert to RGB mode (3 channels)
77
- line_image = line_image.convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- # Prepare image for TrOCR
80
- pixel_values = processor(line_image, return_tensors="pt").pixel_values
81
- pixel_values = pixel_values.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  # Generate (no beam search)
84
- with torch.no_grad():
85
- generated_ids = model.generate(pixel_values)
86
 
87
  # Decode
88
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
89
  transcriptions.append(generated_text)
90
 
91
- # Clean up temporary files
92
- os.unlink(temp_img_path)
93
- os.unlink(lines_json_path)
94
 
95
- # Create an image with bounding boxes
96
- draw = ImageDraw.Draw(image)
97
- for line in lines_data['lines']:
98
- coords = line['baseline']
99
- draw.line(coords, fill="red", width=2)
100
 
101
- return image, "\n".join(transcriptions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  # Gradio interface
 
 
 
 
104
  with gr.Blocks() as iface:
105
- gr.Markdown("# Medieval Document Transcription")
106
- gr.Markdown("Upload an image of a medieval document and select a model to transcribe it. The tool will detect lines and transcribe each line separately.")
107
 
108
- with gr.Row():
109
- input_image = gr.Image(type="pil", label="Input Image")
110
- model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base")
 
111
 
112
  with gr.Row():
113
  output_image = gr.Image(type="pil", label="Detected Lines")
114
- transcription_output = gr.Textbox(label="Transcription", lines=10)
115
 
116
- submit_button = gr.Button("Transcribe")
117
- submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=[output_image, transcription_output])
 
 
 
118
 
119
- iface.launch(share=True)
 
1
  import gradio as gr
 
2
  import torch
3
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
  import subprocess
5
  import json
6
  from PIL import Image, ImageDraw
 
23
  "Medieval Print": "medieval-data/trocr-medieval-print"
24
  }
25
 
 
 
 
 
 
26
  def load_model(model_name):
27
+ model_id = MODEL_OPTIONS[model_name]
28
+ processor = TrOCRProcessor.from_pretrained(model_id)
29
+ model = VisionEncoderDecoderModel.from_pretrained(model_id)
30
 
31
+ # Move model to GPU if available, else use CPU
32
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
33
+ model = model.to(device)
 
 
 
 
 
 
34
 
35
+ return processor, model
 
 
 
 
 
 
36
 
37
+ def detect_lines(image_path):
38
  # Run Kraken for line detection
39
  lines_json_path = "lines.json"
40
+ kraken_command = f"kraken -i {image_path} {lines_json_path} segment -bl"
41
  subprocess.run(kraken_command, shell=True, check=True)
42
 
43
  # Load the lines from the JSON file
44
  with open(lines_json_path, 'r') as f:
45
  lines_data = json.load(f)
46
 
47
+ # Clean up temporary file
48
+ os.unlink(lines_json_path)
49
 
50
+ return lines_data['lines']
 
51
 
52
+ def extract_line_images(image, lines):
53
+ line_images = []
54
+ for line in lines:
55
+ polygon = line['boundary']
56
+
57
+ # Calculate bounding box
58
+ x_coords, y_coords = zip(*polygon)
59
+ x1, y1, x2, y2 = int(min(x_coords)), int(min(y_coords)), int(max(x_coords)), int(max(y_coords))
60
 
61
  # Crop the line from the original image
62
  line_image = image.crop((x1, y1, x2, y2))
63
 
64
+ # Create a mask for the polygon
65
+ mask = Image.new('L', (x2-x1, y2-y1), 0)
66
+ adjusted_polygon = [(int(x-x1), int(y-y1)) for x, y in polygon]
67
+ ImageDraw.Draw(mask).polygon(adjusted_polygon, outline=255, fill=255)
68
+
69
+ # Convert images to numpy arrays
70
+ line_array = np.array(line_image)
71
+ mask_array = np.array(mask)
72
+
73
+ # Apply the mask
74
+ masked_line = np.where(mask_array[:,:,np.newaxis] == 255, line_array, 255)
75
+
76
+ # Convert back to PIL Image
77
+ masked_line_image = Image.fromarray(masked_line.astype('uint8'), 'RGB')
78
 
79
+ line_images.append(masked_line_image)
80
+
81
+ return line_images
82
+
83
+ def visualize_lines(image, lines):
84
+ output_image = image.copy()
85
+ draw = ImageDraw.Draw(output_image)
86
+ for line in lines:
87
+ polygon = [(int(x), int(y)) for x, y in line['boundary']]
88
+ draw.polygon(polygon, outline="red")
89
+ return output_image
90
+
91
+ def transcribe_lines(line_images, model_name):
92
+ processor, model = load_model(model_name)
93
+
94
+ transcriptions = []
95
+ for line_image in line_images:
96
+ # Process the line image
97
+ pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
98
 
99
  # Generate (no beam search)
100
+ generated_ids = model.generate(pixel_values)
 
101
 
102
  # Decode
103
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
104
  transcriptions.append(generated_text)
105
 
106
+ return transcriptions
 
 
107
 
108
+ def process_document(image, model_name):
109
+ # Save the uploaded image temporarily
110
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
111
+ image.save(temp_file, format="JPEG")
112
+ temp_file_path = temp_file.name
113
 
114
+ # Step 1: Detect lines
115
+ lines = detect_lines(temp_file_path)
116
+
117
+ # Visualize detected lines
118
+ output_image = visualize_lines(image, lines)
119
+
120
+ # Step 2: Extract line images
121
+ line_images = extract_line_images(image, lines)
122
+
123
+ # Step 3: Transcribe lines
124
+ transcriptions = transcribe_lines(line_images, model_name)
125
+
126
+ # Clean up temporary file
127
+ os.unlink(temp_file_path)
128
+
129
+ return output_image, "\n".join(transcriptions)
130
 
131
  # Gradio interface
132
+ def gradio_process_document(image, model_name):
133
+ output_image, transcriptions = process_document(image, model_name)
134
+ return output_image, transcriptions
135
+
136
  with gr.Blocks() as iface:
137
+ gr.Markdown("# Document OCR and Transcription")
138
+ gr.Markdown("Upload an image and select a model to detect lines and transcribe the text.")
139
 
140
+ with gr.Column():
141
+ input_image = gr.Image(type="pil", label="Upload Image", height=300, width=300) # Adjusted size here
142
+ model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Medieval Base", label="Select Model")
143
+ submit_button = gr.Button("Process")
144
 
145
  with gr.Row():
146
  output_image = gr.Image(type="pil", label="Detected Lines")
147
+ output_text = gr.Textbox(label="Transcription")
148
 
149
+ submit_button.click(
150
+ fn=gradio_process_document,
151
+ inputs=[input_image, model_dropdown],
152
+ outputs=[output_image, output_text]
153
+ )
154
 
155
+ iface.launch()