wjbmattingly commited on
Commit
546d56f
1 Parent(s): 0456d74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -6,7 +6,7 @@ import json
6
  from PIL import Image, ImageDraw
7
  import os
8
  import tempfile
9
-
10
  # Dictionary of model names and their corresponding HuggingFace model IDs
11
  MODEL_OPTIONS = {
12
  "Microsoft Handwritten": "microsoft/trocr-base-handwritten",
@@ -37,7 +37,7 @@ def load_model(model_name):
37
  current_model_name = model_name
38
 
39
  # Move model to GPU
40
- # current_model = current_model.to('cuda')
41
 
42
  return current_processor, current_model
43
 
@@ -69,9 +69,17 @@ def process_image(image, model_name):
69
  # Crop the line from the original image
70
  line_image = image.crop((x1, y1, x2, y2))
71
 
 
 
 
 
 
 
 
 
72
  # Prepare image for TrOCR
73
- pixel_values = processor(line_image, return_tensors="pt").pixel_values
74
- # pixel_values = pixel_values.to('cuda')
75
 
76
  # Generate (no beam search)
77
  with torch.no_grad():
 
6
  from PIL import Image, ImageDraw
7
  import os
8
  import tempfile
9
+ import numpy as np
10
  # Dictionary of model names and their corresponding HuggingFace model IDs
11
  MODEL_OPTIONS = {
12
  "Microsoft Handwritten": "microsoft/trocr-base-handwritten",
 
37
  current_model_name = model_name
38
 
39
  # Move model to GPU
40
+ current_model = current_model.to('cuda')
41
 
42
  return current_processor, current_model
43
 
 
69
  # Crop the line from the original image
70
  line_image = image.crop((x1, y1, x2, y2))
71
 
72
+ # Convert to grayscale if it's not already
73
+ if line_image.mode != 'L':
74
+ line_image = line_image.convert('L')
75
+
76
+ # Convert to numpy array and normalize
77
+ line_image_np = np.array(line_image).astype(np.float32) / 255.0
78
+ line_image_np = np.expand_dims(line_image_np, axis=0) # Add channel dimension
79
+
80
  # Prepare image for TrOCR
81
+ pixel_values = processor(images=line_image_np, return_tensors="pt").pixel_values
82
+ pixel_values = pixel_values.to('cuda')
83
 
84
  # Generate (no beam search)
85
  with torch.no_grad():