import gradio as gr from transformers import TrOCRProcessor, VisionEncoderDecoderModel import torch import subprocess import json from PIL import Image, ImageDraw import os import tempfile import numpy as np # Dictionary of model names and their corresponding HuggingFace model IDs MODEL_OPTIONS = { "Microsoft Handwritten": "microsoft/trocr-base-handwritten", "Medieval Base": "medieval-data/trocr-medieval-base", "Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline", "Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida", "Medieval Humanistica": "medieval-data/trocr-medieval-humanistica", "Medieval Textualis": "medieval-data/trocr-medieval-textualis", "Medieval Cursiva": "medieval-data/trocr-medieval-cursiva", "Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis", "Medieval Praegothica": "medieval-data/trocr-medieval-praegothica", "Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida", "Medieval Print": "medieval-data/trocr-medieval-print" } # Global variables to store the current model and processor current_model = None current_processor = None current_model_name = None def load_model(model_name): global current_model, current_processor, current_model_name if model_name != current_model_name: model_id = MODEL_OPTIONS[model_name] current_processor = TrOCRProcessor.from_pretrained(model_id) current_model = VisionEncoderDecoderModel.from_pretrained(model_id) current_model_name = model_name # Move model to GPU if available, else use CPU device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') current_model = current_model.to(device) return current_processor, current_model def process_image(image, model_name): # Save the uploaded image to a temporary file with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_img: image.save(temp_img, format="JPEG") temp_img_path = temp_img.name # Run Kraken for line detection lines_json_path = "lines.json" kraken_command = f"kraken -i {temp_img_path} {lines_json_path} binarize segment -bl" subprocess.run(kraken_command, shell=True, check=True) # Load the lines from the JSON file with open(lines_json_path, 'r') as f: lines_data = json.load(f) processor, model = load_model(model_name) # Determine device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Process each line transcriptions = [] for line in lines_data['lines']: # Extract line coordinates x1, y1 = line['baseline'][0] x2, y2 = line['baseline'][-1] # Crop the line from the original image line_image = image.crop((x1, y1, x2, y2)) # Convert to grayscale if it's not already if line_image.mode != 'L': line_image = line_image.convert('L') # Convert to numpy array and normalize line_image_np = np.array(line_image).astype(np.float32) / 255.0 line_image_np = np.expand_dims(line_image_np, axis=0) # Add channel dimension # Prepare image for TrOCR pixel_values = processor(images=line_image_np, return_tensors="pt").pixel_values pixel_values = pixel_values.to(device) # Generate (no beam search) with torch.no_grad(): generated_ids = model.generate(pixel_values) # Decode generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcriptions.append(generated_text) # Clean up temporary files os.unlink(temp_img_path) os.unlink(lines_json_path) # Create an image with bounding boxes draw = ImageDraw.Draw(image) for line in lines_data['lines']: coords = line['baseline'] draw.line(coords, fill="red", width=2) return image, "\n".join(transcriptions) # Gradio interface with gr.Blocks() as iface: gr.Markdown("# Medieval Document Transcription") gr.Markdown("Upload an image of a medieval document and select a model to transcribe it. The tool will detect lines and transcribe each line separately.") with gr.Row(): input_image = gr.Image(type="pil", label="Input Image") model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base") with gr.Row(): output_image = gr.Image(type="pil", label="Detected Lines") transcription_output = gr.Textbox(label="Transcription", lines=10) submit_button = gr.Button("Transcribe") submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=[output_image, transcription_output]) iface.launch(share=True)