import gradio as gr import torch from transformers import TrOCRProcessor, VisionEncoderDecoderModel import subprocess import json import spaces from PIL import Image, ImageDraw import os import tempfile import numpy as np import requests # Dictionary of model names and their corresponding HuggingFace model IDs MODEL_OPTIONS = { "Microsoft Handwritten": "microsoft/trocr-base-handwritten", "Medieval Base": "medieval-data/trocr-medieval-base", "Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline", "Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida", "Medieval Humanistica": "medieval-data/trocr-medieval-humanistica", "Medieval Textualis": "medieval-data/trocr-medieval-textualis", "Medieval Cursiva": "medieval-data/trocr-medieval-cursiva", "Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis", "Medieval Praegothica": "medieval-data/trocr-medieval-praegothica", "Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida", "Medieval Print": "medieval-data/trocr-medieval-print" } def load_model(model_name): model_id = MODEL_OPTIONS[model_name] processor = TrOCRProcessor.from_pretrained(model_id) model = VisionEncoderDecoderModel.from_pretrained(model_id) # Move model to GPU if available, else use CPU device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = model.to(device) return processor, model def detect_lines(image_path): # API endpoint url = "https://wjbmattingly-kraken-api.hf.space/detect_lines" # Run Kraken for line detection lines_json_path = "lines.json" # Prepare the file for upload files = {'file': ('ms.jpg', open(image_path, 'rb'), 'image/jpeg')} # Specify the model to use data = {'model_name': 'catmus-medieval.mlmodel'} # Send the POST request response = requests.post(url, files=files, data=data) result = response.json()["result"]["lines"] return result def extract_line_images(image, lines): line_images = [] for line in lines: polygon = line['boundary'] # Calculate bounding box x_coords, y_coords = zip(*polygon) x1, y1, x2, y2 = int(min(x_coords)), int(min(y_coords)), int(max(x_coords)), int(max(y_coords)) # Crop the line from the original image line_image = image.crop((x1, y1, x2, y2)) # Create a mask for the polygon mask = Image.new('L', (x2-x1, y2-y1), 0) adjusted_polygon = [(int(x-x1), int(y-y1)) for x, y in polygon] ImageDraw.Draw(mask).polygon(adjusted_polygon, outline=255, fill=255) # Convert images to numpy arrays line_array = np.array(line_image) mask_array = np.array(mask) # Apply the mask masked_line = np.where(mask_array[:,:,np.newaxis] == 255, line_array, 255) # Convert back to PIL Image masked_line_image = Image.fromarray(masked_line.astype('uint8'), 'RGB') line_images.append(masked_line_image) return line_images def visualize_lines(image, lines): output_image = image.copy() draw = ImageDraw.Draw(output_image) for line in lines: polygon = [(int(x), int(y)) for x, y in line['boundary']] draw.polygon(polygon, outline="red") return output_image @spaces.GPU def transcribe_lines(line_images, model_name): processor, model = load_model(model_name) transcriptions = [] for line_image in line_images: # Process the line image pixel_values = processor(images=line_image, return_tensors="pt").pixel_values # Generate (no beam search) generated_ids = model.generate(pixel_values) # Decode generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] transcriptions.append(generated_text) return transcriptions def process_document(image, model_name): # Save the uploaded image temporarily with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: image.save(temp_file, format="JPEG") temp_file_path = temp_file.name # Step 1: Detect lines lines = detect_lines(temp_file_path) # Visualize detected lines output_image = visualize_lines(image, lines) # Step 2: Extract line images line_images = extract_line_images(image, lines) # Step 3: Transcribe lines transcriptions = transcribe_lines(line_images, model_name) # Clean up temporary file os.unlink(temp_file_path) return output_image, "\n".join(transcriptions) # Gradio interface def gradio_process_document(image, model_name): output_image, transcriptions = process_document(image, model_name) return output_image, transcriptions with gr.Blocks() as iface: gr.Markdown("# Document OCR and Transcription") gr.Markdown("Upload an image and select a model to detect lines and transcribe the text.") with gr.Column(): input_image = gr.Image(type="pil", label="Upload Image", height=300, width=300) # Adjusted size here model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Medieval Base", label="Select Model") submit_button = gr.Button("Process") with gr.Row(): output_image = gr.Image(type="pil", label="Detected Lines") output_text = gr.Textbox(label="Transcription") submit_button.click( fn=gradio_process_document, inputs=[input_image, model_dropdown], outputs=[output_image, output_text] ) iface.launch(debug=True)