wjbmattingly's picture
using api and gpu
462a9b2
raw
history blame
5.63 kB
import gradio as gr
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import subprocess
import json
import spaces
from PIL import Image, ImageDraw
import os
import tempfile
import numpy as np
import requests
# Dictionary of model names and their corresponding HuggingFace model IDs
MODEL_OPTIONS = {
"Microsoft Handwritten": "microsoft/trocr-base-handwritten",
"Medieval Base": "medieval-data/trocr-medieval-base",
"Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline",
"Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida",
"Medieval Humanistica": "medieval-data/trocr-medieval-humanistica",
"Medieval Textualis": "medieval-data/trocr-medieval-textualis",
"Medieval Cursiva": "medieval-data/trocr-medieval-cursiva",
"Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis",
"Medieval Praegothica": "medieval-data/trocr-medieval-praegothica",
"Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida",
"Medieval Print": "medieval-data/trocr-medieval-print"
}
def load_model(model_name):
model_id = MODEL_OPTIONS[model_name]
processor = TrOCRProcessor.from_pretrained(model_id)
model = VisionEncoderDecoderModel.from_pretrained(model_id)
# Move model to GPU if available, else use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
return processor, model
def detect_lines(image_path):
# API endpoint
url = "https://wjbmattingly-kraken-api.hf.space/detect_lines"
# Run Kraken for line detection
lines_json_path = "lines.json"
# Prepare the file for upload
files = {'file': ('ms.jpg', open(image_path, 'rb'), 'image/jpeg')}
# Specify the model to use
data = {'model_name': 'catmus-medieval.mlmodel'}
# Send the POST request
response = requests.post(url, files=files, data=data)
result = response.json()["result"]["lines"]
return result
def extract_line_images(image, lines):
line_images = []
for line in lines:
polygon = line['boundary']
# Calculate bounding box
x_coords, y_coords = zip(*polygon)
x1, y1, x2, y2 = int(min(x_coords)), int(min(y_coords)), int(max(x_coords)), int(max(y_coords))
# Crop the line from the original image
line_image = image.crop((x1, y1, x2, y2))
# Create a mask for the polygon
mask = Image.new('L', (x2-x1, y2-y1), 0)
adjusted_polygon = [(int(x-x1), int(y-y1)) for x, y in polygon]
ImageDraw.Draw(mask).polygon(adjusted_polygon, outline=255, fill=255)
# Convert images to numpy arrays
line_array = np.array(line_image)
mask_array = np.array(mask)
# Apply the mask
masked_line = np.where(mask_array[:,:,np.newaxis] == 255, line_array, 255)
# Convert back to PIL Image
masked_line_image = Image.fromarray(masked_line.astype('uint8'), 'RGB')
line_images.append(masked_line_image)
return line_images
def visualize_lines(image, lines):
output_image = image.copy()
draw = ImageDraw.Draw(output_image)
for line in lines:
polygon = [(int(x), int(y)) for x, y in line['boundary']]
draw.polygon(polygon, outline="red")
return output_image
@spaces.GPU
def transcribe_lines(line_images, model_name):
processor, model = load_model(model_name)
transcriptions = []
for line_image in line_images:
# Process the line image
pixel_values = processor(images=line_image, return_tensors="pt").pixel_values
# Generate (no beam search)
generated_ids = model.generate(pixel_values)
# Decode
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
transcriptions.append(generated_text)
return transcriptions
def process_document(image, model_name):
# Save the uploaded image temporarily
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
image.save(temp_file, format="JPEG")
temp_file_path = temp_file.name
# Step 1: Detect lines
lines = detect_lines(temp_file_path)
# Visualize detected lines
output_image = visualize_lines(image, lines)
# Step 2: Extract line images
line_images = extract_line_images(image, lines)
# Step 3: Transcribe lines
transcriptions = transcribe_lines(line_images, model_name)
# Clean up temporary file
os.unlink(temp_file_path)
return output_image, "\n".join(transcriptions)
# Gradio interface
def gradio_process_document(image, model_name):
output_image, transcriptions = process_document(image, model_name)
return output_image, transcriptions
with gr.Blocks() as iface:
gr.Markdown("# Document OCR and Transcription")
gr.Markdown("Upload an image and select a model to detect lines and transcribe the text.")
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image", height=300, width=300) # Adjusted size here
model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Medieval Base", label="Select Model")
submit_button = gr.Button("Process")
with gr.Row():
output_image = gr.Image(type="pil", label="Detected Lines")
output_text = gr.Textbox(label="Transcription")
submit_button.click(
fn=gradio_process_document,
inputs=[input_image, model_dropdown],
outputs=[output_image, output_text]
)
iface.launch(debug=True)