import gradio as gr import torch from transformers import TrOCRProcessor, VisionEncoderDecoderModel import subprocess import json from PIL import Image, ImageDraw import os import tempfile import numpy as np # Dictionary of model names and their corresponding HuggingFace model IDs MODEL_OPTIONS = { "Microsoft Handwritten": "microsoft/trocr-base-handwritten", "Medieval Base": "medieval-data/trocr-medieval-base", "Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline", "Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida", "Medieval Humanistica": "medieval-data/trocr-medieval-humanistica", "Medieval Textualis": "medieval-data/trocr-medieval-textualis", "Medieval Cursiva": "medieval-data/trocr-medieval-cursiva", "Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis", "Medieval Praegothica": "medieval-data/trocr-medieval-praegothica", "Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida", "Medieval Print": "medieval-data/trocr-medieval-print" } def load_model(model_name): model_id = MODEL_OPTIONS[model_name] processor = TrOCRProcessor.from_pretrained(model_id) model = VisionEncoderDecoderModel.from_pretrained(model_id) # Move model to GPU if available, else use CPU device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) return processor, model def detect_lines(image_path): # Run Kraken for line detection lines_json_path = "lines.json" kraken_command = f"kraken -i {image_path} {lines_json_path} segment -bl" subprocess.run(kraken_command, shell=True, check=True) # Load the lines from the JSON file with open(lines_json_path, 'r') as f: lines_data = json.load(f) # Clean up temporary file os.unlink(lines_json_path) return lines_data['lines'] def extract_line_images(image, lines): line_images = [] for line in lines: polygon = line['boundary'] # Calculate bounding box x_coords, y_coords = zip(*polygon) x1, y1, x2, y2 = int(min(x_coords)), int(min(y_coords)), int(max(x_coords)), int(max(y_coords)) # Crop the line from the original image line_image = image.crop((x1, y1, x2, y2)) # Create a mask for the polygon mask = Image.new('L', (x2-x1, y2-y1), 0) adjusted_polygon = [(int(x-x1), int(y-y1)) for x, y in polygon] ImageDraw.Draw(mask).polygon(adjusted_polygon, outline=255, fill=255) # Convert images to numpy arrays line_array = np.array(line_image) mask_array = np.array(mask) # Apply the mask masked_line = np.where(mask_array[:,:,np.newaxis] == 255, line_array, 255) # Convert back to PIL Image masked_line_image = Image.fromarray(masked_line.astype('uint8'), 'RGB') line_images.append(masked_line_image) return line_images def visualize_lines(image, lines): output_image = image.copy() draw = ImageDraw.Draw(output_image) for line in lines: polygon = [(int(x), int(y)) for x, y in line['boundary']] draw.polygon(polygon, outline="red") return output_image def transcribe_lines(line_images, model_name): processor, model = load_model(model_name) transcriptions = [] for line_image in line_images: # Process the line image pixel_values = processor(images=line_image, return_tensors="pt").pixel_values # Generate (no beam search) generated_ids = model.generate(pixel_values) # Decode generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcriptions.append(generated_text) return transcriptions def process_document(image, model_name): # Save the uploaded image temporarily with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: image.save(temp_file, format="JPEG") temp_file_path = temp_file.name # Step 1: Detect lines lines = detect_lines(temp_file_path) # Visualize detected lines output_image = visualize_lines(image, lines) # Step 2: Extract line images line_images = extract_line_images(image, lines) # Step 3: Transcribe lines transcriptions = transcribe_lines(line_images, model_name) # Clean up temporary file os.unlink(temp_file_path) return output_image, "\n".join(transcriptions) # Gradio interface def gradio_process_document(image, model_name): output_image, transcriptions = process_document(image, model_name) return output_image, transcriptions with gr.Blocks() as iface: gr.Markdown("# Document OCR and Transcription") gr.Markdown("Upload an image and select a model to detect lines and transcribe the text.") with gr.Column(): input_image = gr.Image(type="pil", label="Upload Image", height=300, width=300) # Adjusted size here model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Medieval Base", label="Select Model") submit_button = gr.Button("Process") with gr.Row(): output_image = gr.Image(type="pil", label="Detected Lines") output_text = gr.Textbox(label="Transcription") submit_button.click( fn=gradio_process_document, inputs=[input_image, model_dropdown], outputs=[output_image, output_text] ) iface.launch()