wjbmattingly's picture
Update app.py
3fc0241 verified
raw
history blame
4.81 kB
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
import subprocess
import json
from PIL import Image, ImageDraw
import os
import tempfile
import numpy as np
# Dictionary of model names and their corresponding HuggingFace model IDs
MODEL_OPTIONS = {
"Microsoft Handwritten": "microsoft/trocr-base-handwritten",
"Medieval Base": "medieval-data/trocr-medieval-base",
"Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline",
"Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida",
"Medieval Humanistica": "medieval-data/trocr-medieval-humanistica",
"Medieval Textualis": "medieval-data/trocr-medieval-textualis",
"Medieval Cursiva": "medieval-data/trocr-medieval-cursiva",
"Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis",
"Medieval Praegothica": "medieval-data/trocr-medieval-praegothica",
"Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida",
"Medieval Print": "medieval-data/trocr-medieval-print"
}
# Global variables to store the current model and processor
current_model = None
current_processor = None
current_model_name = None
def load_model(model_name):
global current_model, current_processor, current_model_name
if model_name != current_model_name:
model_id = MODEL_OPTIONS[model_name]
current_processor = TrOCRProcessor.from_pretrained(model_id)
current_model = VisionEncoderDecoderModel.from_pretrained(model_id)
current_model_name = model_name
# Move model to GPU if available, else use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
current_model = current_model.to(device)
return current_processor, current_model
def process_image(image, model_name):
# Save the uploaded image to a temporary file
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_img:
image.save(temp_img, format="JPEG")
temp_img_path = temp_img.name
# Run Kraken for line detection
lines_json_path = "lines.json"
kraken_command = f"kraken -i {temp_img_path} {lines_json_path} binarize segment -bl"
subprocess.run(kraken_command, shell=True, check=True)
# Load the lines from the JSON file
with open(lines_json_path, 'r') as f:
lines_data = json.load(f)
processor, model = load_model(model_name)
# Determine device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Process each line
transcriptions = []
for line in lines_data['lines']:
# Extract line coordinates
x1, y1 = line['baseline'][0]
x2, y2 = line['baseline'][-1]
# Crop the line from the original image
line_image = image.crop((x1, y1, x2, y2))
# Convert to grayscale if it's not already
if line_image.mode != 'L':
line_image = line_image.convert('L')
# Convert to numpy array and normalize
line_image_np = np.array(line_image).astype(np.float32) / 255.0
line_image_np = np.expand_dims(line_image_np, axis=0) # Add channel dimension
# Prepare image for TrOCR
pixel_values = processor(images=line_image_np, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
# Generate (no beam search)
with torch.no_grad():
generated_ids = model.generate(pixel_values)
# Decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
transcriptions.append(generated_text)
# Clean up temporary files
os.unlink(temp_img_path)
os.unlink(lines_json_path)
# Create an image with bounding boxes
draw = ImageDraw.Draw(image)
for line in lines_data['lines']:
coords = line['baseline']
draw.line(coords, fill="red", width=2)
return image, "\n".join(transcriptions)
# Gradio interface
with gr.Blocks() as iface:
gr.Markdown("# Medieval Document Transcription")
gr.Markdown("Upload an image of a medieval document and select a model to transcribe it. The tool will detect lines and transcribe each line separately.")
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), label="Select Model", value="Medieval Base")
with gr.Row():
output_image = gr.Image(type="pil", label="Detected Lines")
transcription_output = gr.Textbox(label="Transcription", lines=10)
submit_button = gr.Button("Transcribe")
submit_button.click(fn=process_image, inputs=[input_image, model_dropdown], outputs=[output_image, transcription_output])
iface.launch(share=True)