Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
import torch | |
import subprocess | |
import json | |
from PIL import Image, ImageDraw | |
import os | |
import tempfile | |
import numpy as np | |
# Dictionary of model names and their corresponding HuggingFace model IDs | |
MODEL_OPTIONS = { | |
"Microsoft Handwritten": "microsoft/trocr-base-handwritten", | |
"Medieval Base": "medieval-data/trocr-medieval-base", | |
"Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline", | |
"Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida", | |
"Medieval Humanistica": "medieval-data/trocr-medieval-humanistica", | |
"Medieval Textualis": "medieval-data/trocr-medieval-textualis", | |
"Medieval Cursiva": "medieval-data/trocr-medieval-cursiva", | |
"Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis", | |
"Medieval Praegothica": "medieval-data/trocr-medieval-praegothica", | |
"Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida", | |
"Medieval Print": "medieval-data/trocr-medieval-print" | |
} | |
# Global variables to store the current model and processor | |
current_model = None | |
current_processor = None | |
current_model_name = None | |
def load_model(model_name): | |
global current_model, current_processor, current_model_name | |
if model_name != current_model_name: | |
model_id = MODEL_OPTIONS[model_name] | |
current_processor = TrOCRProcessor.from_pretrained(model_id) | |
current_model = VisionEncoderDecoderModel.from_pretrained(model_id) | |
current_model_name = model_name | |
# Move model to GPU if available, else use CPU | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
current_model = current_model.to(device) | |
return current_processor, current_model | |
def process_image(image, model_name): | |
# Save the uploaded image to a temporary file | |
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_img: | |
image.save(temp_img, format="JPEG") | |
temp_img_path = temp_img.name | |
# Run Kraken for line detection | |
lines_json_path = "lines.json" | |
kraken_command = f"kraken -i {temp_img_path} {lines_json_path} binarize segment -bl" | |
subprocess.run(kraken_command, shell=True, check=True) | |
# Load the lines from the JSON file | |
with open(lines_json_path, 'r') as f: | |
lines_data = json.load(f) | |
processor, model = load_model(model_name) | |
# Determine device | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Process each line | |
transcriptions = [] | |
for line in lines_data['lines']: | |
# Extract line coordinates | |
x1, y1 = line['baseline'][0] | |
x2, y2 = line['baseline'][-1] | |
# Crop the line from the original image | |
line_image = image.crop((x1, y1, x2, y2)) | |
# Convert to grayscale if it's not already | |
if line_image.mode != 'L': | |
line_image = line_image.convert('L') | |
# Convert to numpy array and normalize | |
line_image_np = np.array(line_image).astype(np.float32) / 255.0 | |
line_image_np = np.expand_dims(line_image_np, axis=0) # Add channel dimension | |
# Prepare image for TrOCR | |
pixel_values = processor(images=line_image_np, return_tensors="pt").pixel_values | |
pixel_values = pixel_values.to(device) | |
# Generate (no beam search) | |
with torch.no_grad(): | |
generated_ids = model.generate(pixel_values) | |
# Decode | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
transcriptions.append(generated_text) | |
# Clean up temporary files | |
os.unlink(temp_img_path) | |
os.unlink(lines_json_path) | |
# Create an image with bounding boxes | |
draw = ImageDraw.Draw(image) | |
for line in lines_data['lines']: | |
coords = line['baseline'] | |
draw.line(coords, fill="red", width=2) | |
return image, "\n".join(transcriptions) | |
# Gradio interface | |
with gr.Blocks() as iface: | |
gr.Markdown("# Medieval Document Transcription") | |
gr.Markdown("Upload an image of a medieval document and select a model to transcribe it. The tool will detect lines and transcribe each line separately.") | |
with gr.Row(): | |
input_image = gr.Image(type="pil", label="Input Image") | |
model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base") | |
with gr.Row(): | |
output_image = gr.Image(type="pil", label="Detected Lines") | |
transcription_output = gr.Textbox(label="Transcription", lines=10) | |
submit_button = gr.Button("Transcribe") | |
submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=[output_image, transcription_output]) | |
iface.launch(share=True) |