wjbmattingly's picture
Create app.py
f8ba7b0 verified
raw
history blame
4.25 kB
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
import spaces
import subprocess
import json
from PIL import Image, ImageDraw
import os
import tempfile
# Dictionary of model names and their corresponding HuggingFace model IDs
MODEL_OPTIONS = {
"Microsoft Handwritten": "microsoft/trocr-base-handwritten",
"Medieval Base": "medieval-data/trocr-medieval-base",
"Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline",
"Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida",
"Medieval Humanistica": "medieval-data/trocr-medieval-humanistica",
"Medieval Textualis": "medieval-data/trocr-medieval-textualis",
"Medieval Cursiva": "medieval-data/trocr-medieval-cursiva",
"Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis",
"Medieval Praegothica": "medieval-data/trocr-medieval-praegothica",
"Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida",
"Medieval Print": "medieval-data/trocr-medieval-print"
}
# Global variables to store the current model and processor
current_model = None
current_processor = None
current_model_name = None
def load_model(model_name):
global current_model, current_processor, current_model_name
if model_name != current_model_name:
model_id = MODEL_OPTIONS[model_name]
current_processor = TrOCRProcessor.from_pretrained(model_id)
current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
current_model_name = model_name
# Move model to GPU
current_model = current_model.to('cuda')
return current_processor, current_model
@spaces.GPU
def process_image(image, model_name):
# Save the uploaded image to a temporary file
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_img:
image.save(temp_img, format="JPEG")
temp_img_path = temp_img.name
# Run Kraken for line detection
lines_json_path = "lines.json"
kraken_command = f"kraken -i {temp_img_path} {lines_json_path} binarize segment -bl"
subprocess.run(kraken_command, shell=True, check=True)
# Load the lines from the JSON file
with open(lines_json_path, 'r') as f:
lines_data = json.load(f)
processor, model = load_model(model_name)
# Process each line
transcriptions = []
for line in lines_data['lines']:
# Extract line coordinates
x1, y1 = line['baseline'][0]
x2, y2 = line['baseline'][-1]
# Crop the line from the original image
line_image = image.crop((x1, y1, x2, y2))
# Prepare image for TrOCR
pixel_values = processor(line_image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to('cuda')
# Generate (no beam search)
with torch.no_grad():
generated_ids = model.generate(pixel_values)
# Decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
transcriptions.append(generated_text)
# Clean up temporary files
os.unlink(temp_img_path)
os.unlink(lines_json_path)
# Create an image with bounding boxes
draw = ImageDraw.Draw(image)
for line in lines_data['lines']:
coords = line['baseline']
draw.line(coords, fill="red", width=2)
return image, "\n".join(transcriptions)
# Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Medieval Document Transcription")
gr.Markdown("Upload an image of a medieval document and select a model to transcribe it. The tool will detect lines and transcribe each line separately.")
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base")
with gr.Row():
output_image = gr.Image(type="pil", label="Detected Lines")
transcription_output = gr.Textbox(label="Transcription", lines=10)
submit_button = gr.Button("Transcribe")
submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=[output_image, transcription_output])
iface.launch()