# -*- coding: utf-8 -*- import torch import torch.nn.functional as F import torchvision import matplotlib.pyplot as plt import zipfile import os import gradio as gr from PIL import Image CHARS = "~=" + " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789,.'-!?:;\"" BLANK = 0 PAD = 1 CHARS_DICT = {c: i for i, c in enumerate(CHARS)} TEXTLEN = 30 tokens_list = list(CHARS_DICT.keys()) silence_token = '|' if silence_token not in tokens_list: tokens_list.append(silence_token) def fit_picture(img): target_height = 32 target_width = 400 # Calculate resize dimensions aspect_ratio = img.width / img.height if aspect_ratio > (target_width / target_height): resize_width = target_width resize_height = int(target_width / aspect_ratio) else: resize_height = target_height resize_width = int(target_height * aspect_ratio) # Resize transformation resize_transform = torchvision.transforms.Resize((resize_height, resize_width)) # Pad transformation padding_height = (target_height - resize_height) if target_height > resize_height else 0 padding_width = (target_width - resize_width) if target_width > resize_width else 0 pad_transform = torchvision.transforms.Pad((0, 0, padding_width, padding_height), fill=0, padding_mode='constant') transformss = torchvision.transforms.Compose([ torchvision.transforms.Grayscale(num_output_channels = 1), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(0.5,0.5), resize_transform, pad_transform ]) fin_img = transformss(img) return fin_img def load_model(filename): data = torch.load(filename, map_location=torch.device('cpu'), weights_only=True) recognizer.load_state_dict(data["recognizer"]) optimizer.load_state_dict(data["optimizer"]) def ctc_decode_sequence(seq): """Removes blanks and repetitions from the sequence.""" ret = [] prev = BLANK for x in seq: if prev != BLANK and prev != x: ret.append(prev) prev = x if seq[-1] == 66: ret.append(66) return ret def ctc_decode(codes): """Decode a batch of sequences.""" ret = [] for cs in codes.T: ret.append(ctc_decode_sequence(cs)) return ret def decode_text(codes): chars = [CHARS[c] for c in codes] return ''.join(chars) class Residual(torch.nn.Module): def __init__(self, in_channels, out_channels, stride, pdrop = 0.2): super().__init__() self.conv1 = torch.nn.Conv2d(in_channels, out_channels, 3, stride, 1) self.bn1 = torch.nn.BatchNorm2d(out_channels) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, 3, 1, 1) self.bn2 = torch.nn.BatchNorm2d(out_channels) if in_channels != out_channels or stride != 1: self.skip = torch.nn.Conv2d(in_channels, out_channels, 1, stride, 0) else: self.skip = torch.nn.Identity() self.dropout = torch.nn.Dropout2d(pdrop) def forward(self, x): y = torch.nn.functional.relu(self.bn1(self.conv1(x))) y = torch.nn.functional.relu(self.bn2(self.conv2(y)) + self.skip(x)) y = self.dropout(y) return y class TextRecognizer(torch.nn.Module): def __init__(self, labels): super().__init__() self.feature_extractor = torch.nn.Sequential( Residual(1, 32, 1), Residual(32, 32, 2), Residual(32, 32, 1), Residual(32, 64, 2), Residual(64, 64, 1), Residual(64, 128, (2,1)), Residual(128, 128, 1), Residual(128, 128, (2,1)), Residual(128, 128, (2,1)), ) self.recurrent = torch.nn.LSTM(128, 128, 1 ,bidirectional = True) self.output = torch.nn.Linear(256, labels) def forward(self, x): x = self.feature_extractor(x) x = x.squeeze(2) x = x.permute(2,0,1) x,_ = self.recurrent(x) x = self.output(x) return x recognizer = TextRecognizer(len(CHARS)) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print("Device:", DEVICE) LR = 1e-3 recognizer.to(DEVICE) optimizer = torch.optim.Adam(recognizer.parameters(), lr=LR) load_model('model.pt') recognizer.eval() def ctc_read(image): imagefin = fit_picture(image) image_tensor = imagefin.unsqueeze(0).to(DEVICE) print(image_tensor.size()) with torch.no_grad(): scores = recognizer(image_tensor) predictions = scores.argmax(2).cpu().numpy() decoded_sequences = ctc_decode(predictions) # Convert decoded sequences to text for i in decoded_sequences: decoded_text = decode_text(i) return decoded_text # Gradio Interface iface = gr.Interface( fn=ctc_read, inputs=gr.Image(type="pil"), # PIL Image input outputs="text", # Text output title="Handwritten Text Recognition", description="Upload an image, and the custome AI will extract the text." ) iface.launch(share=True)