Obai33's picture
Update app.py
afa022d verified
# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
import zipfile
import os
import gradio as gr
from PIL import Image
CHARS = "~=" + " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,.'-!?:;\""
BLANK = 0
PAD = 1
CHARS_DICT = {c: i for i, c in enumerate(CHARS)}
TEXTLEN = 30
tokens_list = list(CHARS_DICT.keys())
silence_token = '|'
if silence_token not in tokens_list:
tokens_list.append(silence_token)
def fit_picture(img):
target_height = 32
target_width = 400
# Calculate resize dimensions
aspect_ratio = img.width / img.height
if aspect_ratio > (target_width / target_height):
resize_width = target_width
resize_height = int(target_width / aspect_ratio)
else:
resize_height = target_height
resize_width = int(target_height * aspect_ratio)
# Resize transformation
resize_transform = torchvision.transforms.Resize((resize_height, resize_width))
# Pad transformation
padding_height = (target_height - resize_height) if target_height > resize_height else 0
padding_width = (target_width - resize_width) if target_width > resize_width else 0
pad_transform = torchvision.transforms.Pad((0, 0, padding_width, padding_height), fill=0, padding_mode='constant')
transformss = torchvision.transforms.Compose([
torchvision.transforms.Grayscale(num_output_channels = 1),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(0.5,0.5),
resize_transform,
pad_transform
])
fin_img = transformss(img)
return fin_img
def load_model(filename):
data = torch.load(filename, map_location=torch.device('cpu'), weights_only=True)
recognizer.load_state_dict(data["recognizer"])
optimizer.load_state_dict(data["optimizer"])
def ctc_decode_sequence(seq):
"""Removes blanks and repetitions from the sequence."""
ret = []
prev = BLANK
for x in seq:
if prev != BLANK and prev != x:
ret.append(prev)
prev = x
if seq[-1] == 66:
ret.append(66)
return ret
def ctc_decode(codes):
"""Decode a batch of sequences."""
ret = []
for cs in codes.T:
ret.append(ctc_decode_sequence(cs))
return ret
def decode_text(codes):
chars = [CHARS[c] for c in codes]
return ''.join(chars)
class Residual(torch.nn.Module):
def __init__(self, in_channels, out_channels, stride, pdrop = 0.2):
super().__init__()
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3, stride, 1)
self.bn1 = torch.nn.BatchNorm2d(out_channels)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, 3, 1, 1)
self.bn2 = torch.nn.BatchNorm2d(out_channels)
if in_channels != out_channels or stride != 1:
self.skip = torch.nn.Conv2d(in_channels, out_channels, 1, stride, 0)
else:
self.skip = torch.nn.Identity()
self.dropout = torch.nn.Dropout2d(pdrop)
def forward(self, x):
y = torch.nn.functional.relu(self.bn1(self.conv1(x)))
y = torch.nn.functional.relu(self.bn2(self.conv2(y)) + self.skip(x))
y = self.dropout(y)
return y
class TextRecognizer(torch.nn.Module):
def __init__(self, labels):
super().__init__()
self.feature_extractor = torch.nn.Sequential(
Residual(1, 32, 1),
Residual(32, 32, 2),
Residual(32, 32, 1),
Residual(32, 64, 2),
Residual(64, 64, 1),
Residual(64, 128, (2,1)),
Residual(128, 128, 1),
Residual(128, 128, (2,1)),
Residual(128, 128, (2,1)),
)
self.recurrent = torch.nn.LSTM(128, 128, 1 ,bidirectional = True)
self.output = torch.nn.Linear(256, labels)
def forward(self, x):
x = self.feature_extractor(x)
x = x.squeeze(2)
x = x.permute(2,0,1)
x,_ = self.recurrent(x)
x = self.output(x)
return x
recognizer = TextRecognizer(len(CHARS))
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
LR = 1e-3
recognizer.to(DEVICE)
optimizer = torch.optim.Adam(recognizer.parameters(), lr=LR)
load_model('model.pt')
recognizer.eval()
def ctc_read(image):
imagefin = fit_picture(image)
image_tensor = imagefin.unsqueeze(0).to(DEVICE)
print(image_tensor.size())
with torch.no_grad():
scores = recognizer(image_tensor)
predictions = scores.argmax(2).cpu().numpy()
decoded_sequences = ctc_decode(predictions)
# Convert decoded sequences to text
for i in decoded_sequences:
decoded_text = decode_text(i)
return decoded_text
# Gradio Interface
iface = gr.Interface(
fn=ctc_read,
inputs=gr.Image(type="pil"), # PIL Image input
outputs="text", # Text output
title="Handwritten Text Recognition",
description="Upload an image, and the custome AI will extract the text."
)
iface.launch(share=True)