Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
import torch | |
import torch.nn.functional as F | |
import torchvision | |
import matplotlib.pyplot as plt | |
import zipfile | |
import os | |
import gradio as gr | |
from PIL import Image | |
CHARS = "~=" + " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,.'-!?:;\"" | |
BLANK = 0 | |
PAD = 1 | |
CHARS_DICT = {c: i for i, c in enumerate(CHARS)} | |
TEXTLEN = 30 | |
tokens_list = list(CHARS_DICT.keys()) | |
silence_token = '|' | |
if silence_token not in tokens_list: | |
tokens_list.append(silence_token) | |
def fit_picture(img): | |
target_height = 32 | |
target_width = 400 | |
# Calculate resize dimensions | |
aspect_ratio = img.width / img.height | |
if aspect_ratio > (target_width / target_height): | |
resize_width = target_width | |
resize_height = int(target_width / aspect_ratio) | |
else: | |
resize_height = target_height | |
resize_width = int(target_height * aspect_ratio) | |
# Resize transformation | |
resize_transform = torchvision.transforms.Resize((resize_height, resize_width)) | |
# Pad transformation | |
padding_height = (target_height - resize_height) if target_height > resize_height else 0 | |
padding_width = (target_width - resize_width) if target_width > resize_width else 0 | |
pad_transform = torchvision.transforms.Pad((0, 0, padding_width, padding_height), fill=0, padding_mode='constant') | |
transformss = torchvision.transforms.Compose([ | |
torchvision.transforms.Grayscale(num_output_channels = 1), | |
torchvision.transforms.ToTensor(), | |
torchvision.transforms.Normalize(0.5,0.5), | |
resize_transform, | |
pad_transform | |
]) | |
fin_img = transformss(img) | |
return fin_img | |
def load_model(filename): | |
data = torch.load(filename, map_location=torch.device('cpu'), weights_only=True) | |
recognizer.load_state_dict(data["recognizer"]) | |
optimizer.load_state_dict(data["optimizer"]) | |
def ctc_decode_sequence(seq): | |
"""Removes blanks and repetitions from the sequence.""" | |
ret = [] | |
prev = BLANK | |
for x in seq: | |
if prev != BLANK and prev != x: | |
ret.append(prev) | |
prev = x | |
if seq[-1] == 66: | |
ret.append(66) | |
return ret | |
def ctc_decode(codes): | |
"""Decode a batch of sequences.""" | |
ret = [] | |
for cs in codes.T: | |
ret.append(ctc_decode_sequence(cs)) | |
return ret | |
def decode_text(codes): | |
chars = [CHARS[c] for c in codes] | |
return ''.join(chars) | |
class Residual(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, stride, pdrop = 0.2): | |
super().__init__() | |
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3, stride, 1) | |
self.bn1 = torch.nn.BatchNorm2d(out_channels) | |
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, 3, 1, 1) | |
self.bn2 = torch.nn.BatchNorm2d(out_channels) | |
if in_channels != out_channels or stride != 1: | |
self.skip = torch.nn.Conv2d(in_channels, out_channels, 1, stride, 0) | |
else: | |
self.skip = torch.nn.Identity() | |
self.dropout = torch.nn.Dropout2d(pdrop) | |
def forward(self, x): | |
y = torch.nn.functional.relu(self.bn1(self.conv1(x))) | |
y = torch.nn.functional.relu(self.bn2(self.conv2(y)) + self.skip(x)) | |
y = self.dropout(y) | |
return y | |
class TextRecognizer(torch.nn.Module): | |
def __init__(self, labels): | |
super().__init__() | |
self.feature_extractor = torch.nn.Sequential( | |
Residual(1, 32, 1), | |
Residual(32, 32, 2), | |
Residual(32, 32, 1), | |
Residual(32, 64, 2), | |
Residual(64, 64, 1), | |
Residual(64, 128, (2,1)), | |
Residual(128, 128, 1), | |
Residual(128, 128, (2,1)), | |
Residual(128, 128, (2,1)), | |
) | |
self.recurrent = torch.nn.LSTM(128, 128, 1 ,bidirectional = True) | |
self.output = torch.nn.Linear(256, labels) | |
def forward(self, x): | |
x = self.feature_extractor(x) | |
x = x.squeeze(2) | |
x = x.permute(2,0,1) | |
x,_ = self.recurrent(x) | |
x = self.output(x) | |
return x | |
recognizer = TextRecognizer(len(CHARS)) | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print("Device:", DEVICE) | |
LR = 1e-3 | |
recognizer.to(DEVICE) | |
optimizer = torch.optim.Adam(recognizer.parameters(), lr=LR) | |
load_model('model.pt') | |
recognizer.eval() | |
def ctc_read(image): | |
imagefin = fit_picture(image) | |
image_tensor = imagefin.unsqueeze(0).to(DEVICE) | |
print(image_tensor.size()) | |
with torch.no_grad(): | |
scores = recognizer(image_tensor) | |
predictions = scores.argmax(2).cpu().numpy() | |
decoded_sequences = ctc_decode(predictions) | |
# Convert decoded sequences to text | |
for i in decoded_sequences: | |
decoded_text = decode_text(i) | |
return decoded_text | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=ctc_read, | |
inputs=gr.Image(type="pil"), # PIL Image input | |
outputs="text", # Text output | |
title="Handwritten Text Recognition", | |
description="Upload an image, and the custome AI will extract the text." | |
) | |
iface.launch(share=True) |