Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,626 Bytes
f8ba7b0 0d4066e f8ba7b0 462a9b2 f8ba7b0 546d56f 462a9b2 3fc0241 f8ba7b0 0d4066e f8ba7b0 0d4066e f8ba7b0 0d4066e f8ba7b0 0d4066e 462a9b2 f8ba7b0 462a9b2 f8ba7b0 462a9b2 f8ba7b0 462a9b2 f8ba7b0 462a9b2 3fc0241 0d4066e f8ba7b0 0d4066e 546d56f 0d4066e 462a9b2 0d4066e f8ba7b0 0d4066e f8ba7b0 0d4066e f8ba7b0 0d4066e f8ba7b0 0d4066e f8ba7b0 0d4066e f8ba7b0 0d4066e f8ba7b0 0d4066e f8ba7b0 0d4066e f8ba7b0 0d4066e f8ba7b0 462a9b2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import gradio as gr
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import subprocess
import json
import spaces
from PIL import Image, ImageDraw
import os
import tempfile
import numpy as np
import requests
# Dictionary of model names and their corresponding HuggingFace model IDs
MODEL_OPTIONS = {
"Microsoft Handwritten": "microsoft/trocr-base-handwritten",
"Medieval Base": "medieval-data/trocr-medieval-base",
"Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline",
"Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida",
"Medieval Humanistica": "medieval-data/trocr-medieval-humanistica",
"Medieval Textualis": "medieval-data/trocr-medieval-textualis",
"Medieval Cursiva": "medieval-data/trocr-medieval-cursiva",
"Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis",
"Medieval Praegothica": "medieval-data/trocr-medieval-praegothica",
"Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida",
"Medieval Print": "medieval-data/trocr-medieval-print"
}
def load_model(model_name):
model_id = MODEL_OPTIONS[model_name]
processor = TrOCRProcessor.from_pretrained(model_id)
model = VisionEncoderDecoderModel.from_pretrained(model_id)
# Move model to GPU if available, else use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
return processor, model
def detect_lines(image_path):
# API endpoint
url = "https://wjbmattingly-kraken-api.hf.space/detect_lines"
# Run Kraken for line detection
lines_json_path = "lines.json"
# Prepare the file for upload
files = {'file': ('ms.jpg', open(image_path, 'rb'), 'image/jpeg')}
# Specify the model to use
data = {'model_name': 'catmus-medieval.mlmodel'}
# Send the POST request
response = requests.post(url, files=files, data=data)
result = response.json()["result"]["lines"]
return result
def extract_line_images(image, lines):
line_images = []
for line in lines:
polygon = line['boundary']
# Calculate bounding box
x_coords, y_coords = zip(*polygon)
x1, y1, x2, y2 = int(min(x_coords)), int(min(y_coords)), int(max(x_coords)), int(max(y_coords))
# Crop the line from the original image
line_image = image.crop((x1, y1, x2, y2))
# Create a mask for the polygon
mask = Image.new('L', (x2-x1, y2-y1), 0)
adjusted_polygon = [(int(x-x1), int(y-y1)) for x, y in polygon]
ImageDraw.Draw(mask).polygon(adjusted_polygon, outline=255, fill=255)
# Convert images to numpy arrays
line_array = np.array(line_image)
mask_array = np.array(mask)
# Apply the mask
masked_line = np.where(mask_array[:,:,np.newaxis] == 255, line_array, 255)
# Convert back to PIL Image
masked_line_image = Image.fromarray(masked_line.astype('uint8'), 'RGB')
line_images.append(masked_line_image)
return line_images
def visualize_lines(image, lines):
output_image = image.copy()
draw = ImageDraw.Draw(output_image)
for line in lines:
polygon = [(int(x), int(y)) for x, y in line['boundary']]
draw.polygon(polygon, outline="red")
return output_image
@spaces.GPU
def transcribe_lines(line_images, model_name):
processor, model = load_model(model_name)
transcriptions = []
for line_image in line_images:
# Process the line image
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
# Generate (no beam search)
generated_ids = model.generate(pixel_values)
# Decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
transcriptions.append(generated_text)
return transcriptions
def process_document(image, model_name):
# Save the uploaded image temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
image.save(temp_file, format="JPEG")
temp_file_path = temp_file.name
# Step 1: Detect lines
lines = detect_lines(temp_file_path)
# Visualize detected lines
output_image = visualize_lines(image, lines)
# Step 2: Extract line images
line_images = extract_line_images(image, lines)
# Step 3: Transcribe lines
transcriptions = transcribe_lines(line_images, model_name)
# Clean up temporary file
os.unlink(temp_file_path)
return output_image, "\n".join(transcriptions)
# Gradio interface
def gradio_process_document(image, model_name):
output_image, transcriptions = process_document(image, model_name)
return output_image, transcriptions
with gr.Blocks() as iface:
gr.Markdown("# Document OCR and Transcription")
gr.Markdown("Upload an image and select a model to detect lines and transcribe the text.")
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image", height=300, width=300) # Adjusted size here
model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Medieval Base", label="Select Model")
submit_button = gr.Button("Process")
with gr.Row():
output_image = gr.Image(type="pil", label="Detected Lines")
output_text = gr.Textbox(label="Transcription")
submit_button.click(
fn=gradio_process_document,
inputs=[input_image, model_dropdown],
outputs=[output_image, output_text]
)
iface.launch(debug=True) |