Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
import subprocess | |
import json | |
import spaces | |
from PIL import Image, ImageDraw | |
import os | |
import tempfile | |
import numpy as np | |
import requests | |
# Dictionary of model names and their corresponding HuggingFace model IDs | |
MODEL_OPTIONS = { | |
"Microsoft Handwritten": "microsoft/trocr-base-handwritten", | |
"Medieval Base": "medieval-data/trocr-medieval-base", | |
"Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline", | |
"Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida", | |
"Medieval Humanistica": "medieval-data/trocr-medieval-humanistica", | |
"Medieval Textualis": "medieval-data/trocr-medieval-textualis", | |
"Medieval Cursiva": "medieval-data/trocr-medieval-cursiva", | |
"Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis", | |
"Medieval Praegothica": "medieval-data/trocr-medieval-praegothica", | |
"Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida", | |
"Medieval Print": "medieval-data/trocr-medieval-print" | |
} | |
def load_model(model_name): | |
model_id = MODEL_OPTIONS[model_name] | |
processor = TrOCRProcessor.from_pretrained(model_id) | |
model = VisionEncoderDecoderModel.from_pretrained(model_id) | |
# Move model to GPU if available, else use CPU | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = model.to(device) | |
return processor, model | |
def detect_lines(image_path): | |
# API endpoint | |
url = "https://wjbmattingly-kraken-api.hf.space/detect_lines" | |
# Run Kraken for line detection | |
lines_json_path = "lines.json" | |
# Prepare the file for upload | |
files = {'file': ('ms.jpg', open(image_path, 'rb'), 'image/jpeg')} | |
# Specify the model to use | |
data = {'model_name': 'catmus-medieval.mlmodel'} | |
# Send the POST request | |
response = requests.post(url, files=files, data=data) | |
result = response.json()["result"]["lines"] | |
return result | |
def extract_line_images(image, lines): | |
line_images = [] | |
for line in lines: | |
polygon = line['boundary'] | |
# Calculate bounding box | |
x_coords, y_coords = zip(*polygon) | |
x1, y1, x2, y2 = int(min(x_coords)), int(min(y_coords)), int(max(x_coords)), int(max(y_coords)) | |
# Crop the line from the original image | |
line_image = image.crop((x1, y1, x2, y2)) | |
# Create a mask for the polygon | |
mask = Image.new('L', (x2-x1, y2-y1), 0) | |
adjusted_polygon = [(int(x-x1), int(y-y1)) for x, y in polygon] | |
ImageDraw.Draw(mask).polygon(adjusted_polygon, outline=255, fill=255) | |
# Convert images to numpy arrays | |
line_array = np.array(line_image) | |
mask_array = np.array(mask) | |
# Apply the mask | |
masked_line = np.where(mask_array[:,:,np.newaxis] == 255, line_array, 255) | |
# Convert back to PIL Image | |
masked_line_image = Image.fromarray(masked_line.astype('uint8'), 'RGB') | |
line_images.append(masked_line_image) | |
return line_images | |
def visualize_lines(image, lines): | |
output_image = image.copy() | |
draw = ImageDraw.Draw(output_image) | |
for line in lines: | |
polygon = [(int(x), int(y)) for x, y in line['boundary']] | |
draw.polygon(polygon, outline="red") | |
return output_image | |
def transcribe_lines(line_images, model_name): | |
processor, model = load_model(model_name) | |
transcriptions = [] | |
for line_image in line_images: | |
# Process the line image | |
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values | |
# Generate (no beam search) | |
generated_ids = model.generate(pixel_values) | |
# Decode | |
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
transcriptions.append(generated_text) | |
return transcriptions | |
def process_document(image, model_name): | |
# Save the uploaded image temporarily | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file: | |
image.save(temp_file, format="JPEG") | |
temp_file_path = temp_file.name | |
# Step 1: Detect lines | |
lines = detect_lines(temp_file_path) | |
# Visualize detected lines | |
output_image = visualize_lines(image, lines) | |
# Step 2: Extract line images | |
line_images = extract_line_images(image, lines) | |
# Step 3: Transcribe lines | |
transcriptions = transcribe_lines(line_images, model_name) | |
# Clean up temporary file | |
os.unlink(temp_file_path) | |
return output_image, "\n".join(transcriptions) | |
# Gradio interface | |
def gradio_process_document(image, model_name): | |
output_image, transcriptions = process_document(image, model_name) | |
return output_image, transcriptions | |
with gr.Blocks() as iface: | |
gr.Markdown("# Document OCR and Transcription") | |
gr.Markdown("Upload an image and select a model to detect lines and transcribe the text.") | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Upload Image", height=300, width=300) # Adjusted size here | |
model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Medieval Base", label="Select Model") | |
submit_button = gr.Button("Process") | |
with gr.Row(): | |
output_image = gr.Image(type="pil", label="Detected Lines") | |
output_text = gr.Textbox(label="Transcription") | |
submit_button.click( | |
fn=gradio_process_document, | |
inputs=[input_image, model_dropdown], | |
outputs=[output_image, output_text] | |
) | |
iface.launch(debug=True) |