File size: 5,626 Bytes
f8ba7b0
 
0d4066e
f8ba7b0
 
462a9b2
f8ba7b0
 
 
546d56f
462a9b2
3fc0241
f8ba7b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d4066e
 
 
f8ba7b0
0d4066e
 
 
f8ba7b0
0d4066e
f8ba7b0
0d4066e
462a9b2
 
f8ba7b0
 
462a9b2
 
f8ba7b0
462a9b2
 
f8ba7b0
462a9b2
 
 
f8ba7b0
462a9b2
3fc0241
0d4066e
 
 
 
 
 
 
 
f8ba7b0
 
 
 
0d4066e
 
 
 
 
 
 
 
 
 
 
 
 
 
546d56f
0d4066e
 
 
 
 
 
 
 
 
 
 
 
462a9b2
0d4066e
 
 
 
 
 
 
f8ba7b0
 
0d4066e
f8ba7b0
 
 
 
 
0d4066e
f8ba7b0
0d4066e
 
 
 
 
f8ba7b0
0d4066e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8ba7b0
 
0d4066e
 
 
 
f8ba7b0
0d4066e
 
f8ba7b0
0d4066e
 
 
 
f8ba7b0
 
 
0d4066e
f8ba7b0
0d4066e
 
 
 
 
f8ba7b0
462a9b2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import gradio as gr
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import subprocess
import json
import spaces
from PIL import Image, ImageDraw
import os
import tempfile
import numpy as np
import requests

# Dictionary of model names and their corresponding HuggingFace model IDs
MODEL_OPTIONS = {
    "Microsoft Handwritten": "microsoft/trocr-base-handwritten",
    "Medieval Base": "medieval-data/trocr-medieval-base",
    "Medieval Latin Caroline": "medieval-data/trocr-medieval-latin-caroline",
    "Medieval Castilian Hybrida": "medieval-data/trocr-medieval-castilian-hybrida",
    "Medieval Humanistica": "medieval-data/trocr-medieval-humanistica",
    "Medieval Textualis": "medieval-data/trocr-medieval-textualis",
    "Medieval Cursiva": "medieval-data/trocr-medieval-cursiva",
    "Medieval Semitextualis": "medieval-data/trocr-medieval-semitextualis",
    "Medieval Praegothica": "medieval-data/trocr-medieval-praegothica",
    "Medieval Semihybrida": "medieval-data/trocr-medieval-semihybrida",
    "Medieval Print": "medieval-data/trocr-medieval-print"
}

def load_model(model_name):
    model_id = MODEL_OPTIONS[model_name]
    processor = TrOCRProcessor.from_pretrained(model_id)
    model = VisionEncoderDecoderModel.from_pretrained(model_id)
    
    # Move model to GPU if available, else use CPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    return processor, model

def detect_lines(image_path):
    # API endpoint
    url = "https://wjbmattingly-kraken-api.hf.space/detect_lines"
    # Run Kraken for line detection
    lines_json_path = "lines.json"
    # Prepare the file for upload
    files = {'file': ('ms.jpg', open(image_path, 'rb'), 'image/jpeg')}

    # Specify the model to use
    data = {'model_name': 'catmus-medieval.mlmodel'}

    # Send the POST request
    response = requests.post(url, files=files, data=data)
    result = response.json()["result"]["lines"]

    return result

def extract_line_images(image, lines):
    line_images = []
    for line in lines:
        polygon = line['boundary']
        
        # Calculate bounding box
        x_coords, y_coords = zip(*polygon)
        x1, y1, x2, y2 = int(min(x_coords)), int(min(y_coords)), int(max(x_coords)), int(max(y_coords))
        
        # Crop the line from the original image
        line_image = image.crop((x1, y1, x2, y2))

        # Create a mask for the polygon
        mask = Image.new('L', (x2-x1, y2-y1), 0)
        adjusted_polygon = [(int(x-x1), int(y-y1)) for x, y in polygon]
        ImageDraw.Draw(mask).polygon(adjusted_polygon, outline=255, fill=255)
        
        # Convert images to numpy arrays
        line_array = np.array(line_image)
        mask_array = np.array(mask)

        # Apply the mask
        masked_line = np.where(mask_array[:,:,np.newaxis] == 255, line_array, 255)

        # Convert back to PIL Image
        masked_line_image = Image.fromarray(masked_line.astype('uint8'), 'RGB')

        line_images.append(masked_line_image)

    return line_images

def visualize_lines(image, lines):
    output_image = image.copy()
    draw = ImageDraw.Draw(output_image)
    for line in lines:
        polygon = [(int(x), int(y)) for x, y in line['boundary']]
        draw.polygon(polygon, outline="red")
    return output_image

@spaces.GPU
def transcribe_lines(line_images, model_name):
    processor, model = load_model(model_name)
    
    transcriptions = []
    for line_image in line_images:
        # Process the line image
        pixel_values = processor(images=line_image, return_tensors="pt").pixel_values

        # Generate (no beam search)
        generated_ids = model.generate(pixel_values)

        # Decode
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        transcriptions.append(generated_text)

    return transcriptions

def process_document(image, model_name):
    # Save the uploaded image temporarily
    with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
        image.save(temp_file, format="JPEG")
        temp_file_path = temp_file.name

    # Step 1: Detect lines
    lines = detect_lines(temp_file_path)
    
    # Visualize detected lines
    output_image = visualize_lines(image, lines)
    
    # Step 2: Extract line images
    line_images = extract_line_images(image, lines)
    
    # Step 3: Transcribe lines
    transcriptions = transcribe_lines(line_images, model_name)
    
    # Clean up temporary file
    os.unlink(temp_file_path)
    
    return output_image, "\n".join(transcriptions)

# Gradio interface
def gradio_process_document(image, model_name):
    output_image, transcriptions = process_document(image, model_name)
    return output_image, transcriptions

with gr.Blocks() as iface:
    gr.Markdown("# Document OCR and Transcription")
    gr.Markdown("Upload an image and select a model to detect lines and transcribe the text.")
    
    with gr.Column():
        input_image = gr.Image(type="pil", label="Upload Image", height=300, width=300)  # Adjusted size here
        model_dropdown = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="Medieval Base", label="Select Model")
        submit_button = gr.Button("Process")
    
    with gr.Row():
        output_image = gr.Image(type="pil", label="Detected Lines")
        output_text = gr.Textbox(label="Transcription")
    
    submit_button.click(
        fn=gradio_process_document,
        inputs=[input_image, model_dropdown],
        outputs=[output_image, output_text]
    )

iface.launch(debug=True)