Spaces:
Sleeping
Sleeping
gitlost-murali
commited on
Commit
·
95431d3
1
Parent(s):
0aa610a
use askui-ml-helper library
Browse files- app.py +2 -91
- requirements.txt +2 -4
- utils.py +0 -144
app.py
CHANGED
@@ -1,98 +1,9 @@
|
|
1 |
import gradio as gr
|
2 |
-
from PIL import Image, ImageDraw
|
3 |
-
import torch
|
4 |
-
import torch.nn as nn
|
5 |
-
import torch.nn.functional as F
|
6 |
-
import torch
|
7 |
-
from transformers import Pix2StructProcessor, Pix2StructVisionModel
|
8 |
-
from utils import download_default_font, render_header
|
9 |
-
|
10 |
-
class Pix2StructForRegression(nn.Module):
|
11 |
-
def __init__(self, sourcemodel_path, device):
|
12 |
-
super(Pix2StructForRegression, self).__init__()
|
13 |
-
self.model = Pix2StructVisionModel.from_pretrained(sourcemodel_path)
|
14 |
-
self.regression_layer1 = nn.Linear(768, 1536)
|
15 |
-
self.dropout1 = nn.Dropout(0.1)
|
16 |
-
self.regression_layer2 = nn.Linear(1536, 768)
|
17 |
-
self.dropout2 = nn.Dropout(0.1)
|
18 |
-
self.regression_layer3 = nn.Linear(768, 2)
|
19 |
-
self.device = device
|
20 |
-
|
21 |
-
def forward(self, *args, **kwargs):
|
22 |
-
outputs = self.model(*args, **kwargs)
|
23 |
-
sequence_output = outputs.last_hidden_state
|
24 |
-
first_token_output = sequence_output[:, 0, :]
|
25 |
-
|
26 |
-
x = F.relu(self.regression_layer1(first_token_output))
|
27 |
-
x = F.relu(self.regression_layer2(x))
|
28 |
-
regression_output = torch.sigmoid(self.regression_layer3(x))
|
29 |
-
|
30 |
-
return regression_output
|
31 |
-
|
32 |
-
def load_state_dict_file(self, checkpoint_path, strict=True):
|
33 |
-
state_dict = torch.load(checkpoint_path, map_location=self.device)
|
34 |
-
self.load_state_dict(state_dict, strict=strict)
|
35 |
-
|
36 |
-
class Inference:
|
37 |
-
def __init__(self) -> None:
|
38 |
-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
39 |
-
self.model, self.processor = self.load_model_and_processor("google/matcha-base", "model/pta-text-v0.1.pt")
|
40 |
-
|
41 |
-
def load_model_and_processor(self, model_name, checkpoint_path):
|
42 |
-
model = Pix2StructForRegression(sourcemodel_path=model_name, device=self.device)
|
43 |
-
model.load_state_dict_file(checkpoint_path=checkpoint_path)
|
44 |
-
model.eval()
|
45 |
-
model = model.to(self.device)
|
46 |
-
processor = Pix2StructProcessor.from_pretrained(model_name, is_vqa=False)
|
47 |
-
return model, processor
|
48 |
-
|
49 |
-
def prepare_image(self, image, prompt, processor):
|
50 |
-
image = image.resize((1920, 1080))
|
51 |
-
download_default_font_path = download_default_font()
|
52 |
-
rendered_image, _, render_variables = render_header(
|
53 |
-
image=image,
|
54 |
-
header=prompt,
|
55 |
-
bbox={"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0},
|
56 |
-
font_path=download_default_font_path,
|
57 |
-
)
|
58 |
-
encoding = processor(
|
59 |
-
images=rendered_image,
|
60 |
-
max_patches=2048,
|
61 |
-
add_special_tokens=True,
|
62 |
-
return_tensors="pt",
|
63 |
-
)
|
64 |
-
return encoding, render_variables
|
65 |
-
|
66 |
-
def predict_coordinates(self, encoding, model, render_variables):
|
67 |
-
with torch.no_grad():
|
68 |
-
pred_regression_outs = model(flattened_patches=encoding["flattened_patches"], attention_mask=encoding["attention_mask"])
|
69 |
-
new_height = render_variables["height"]
|
70 |
-
new_header_height = render_variables["header_height"]
|
71 |
-
new_total_height = render_variables["total_height"]
|
72 |
-
|
73 |
-
pred_regression_outs[:, 1] = (
|
74 |
-
(pred_regression_outs[:, 1] * new_total_height) - new_header_height
|
75 |
-
) / new_height
|
76 |
-
|
77 |
-
pred_coordinates = pred_regression_outs.squeeze().tolist()
|
78 |
-
return pred_coordinates
|
79 |
-
|
80 |
-
def draw_circle_on_image(self, image, coordinates):
|
81 |
-
x, y = coordinates[0] * image.width, coordinates[1] * image.height
|
82 |
-
draw = ImageDraw.Draw(image)
|
83 |
-
radius = 5
|
84 |
-
draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="red")
|
85 |
-
return image
|
86 |
-
|
87 |
-
def process_image_and_draw_circle(self, image, prompt):
|
88 |
-
encoding, render_variables = self.prepare_image(image, prompt, self.processor)
|
89 |
-
pred_coordinates = self.predict_coordinates(encoding.to(self.device) , self.model, render_variables)
|
90 |
-
result_image = self.draw_circle_on_image(image, pred_coordinates)
|
91 |
-
return result_image
|
92 |
|
|
|
93 |
|
94 |
def main():
|
95 |
-
inference =
|
96 |
# Gradio Interface
|
97 |
iface = gr.Interface(
|
98 |
fn=inference.process_image_and_draw_circle,
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
from askui_ml_helper.utils.pta_text import PtaTextInference
|
4 |
|
5 |
def main():
|
6 |
+
inference = PtaTextInference("model/pta-text-v0.1.pt")
|
7 |
# Gradio Interface
|
8 |
iface = gr.Interface(
|
9 |
fn=inference.process_image_and_draw_circle,
|
requirements.txt
CHANGED
@@ -1,4 +1,2 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
gradio
|
4 |
-
Pillow
|
|
|
1 |
+
askui-ml-helper
|
2 |
+
gradio
|
|
|
|
utils.py
DELETED
@@ -1,144 +0,0 @@
|
|
1 |
-
import io
|
2 |
-
import os
|
3 |
-
import textwrap
|
4 |
-
from typing import Dict, Optional, Tuple
|
5 |
-
|
6 |
-
from huggingface_hub import hf_hub_download
|
7 |
-
from PIL import Image, ImageDraw, ImageFont
|
8 |
-
|
9 |
-
DEFAULT_FONT_PATH = "ybelkada/fonts"
|
10 |
-
|
11 |
-
|
12 |
-
def download_default_font():
|
13 |
-
font_path = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF")
|
14 |
-
return font_path
|
15 |
-
|
16 |
-
|
17 |
-
def render_text(
|
18 |
-
text: str,
|
19 |
-
text_size: int = 36,
|
20 |
-
text_color: str = "black",
|
21 |
-
background_color: str = "white",
|
22 |
-
left_padding: int = 5,
|
23 |
-
right_padding: int = 5,
|
24 |
-
top_padding: int = 5,
|
25 |
-
bottom_padding: int = 5,
|
26 |
-
font_bytes: Optional[bytes] = None,
|
27 |
-
font_path: Optional[str] = None,
|
28 |
-
) -> Image.Image:
|
29 |
-
"""
|
30 |
-
Render text. This script is entirely adapted from the original script that can be found here:
|
31 |
-
https://github.com/google-research/pix2struct/blob/main/pix2struct/preprocessing/preprocessing_utils.py
|
32 |
-
|
33 |
-
Args:
|
34 |
-
text (`str`, *optional*, defaults to ):
|
35 |
-
Text to render.
|
36 |
-
text_size (`int`, *optional*, defaults to 36):
|
37 |
-
Size of the text.
|
38 |
-
text_color (`str`, *optional*, defaults to `"black"`):
|
39 |
-
Color of the text.
|
40 |
-
background_color (`str`, *optional*, defaults to `"white"`):
|
41 |
-
Color of the background.
|
42 |
-
left_padding (`int`, *optional*, defaults to 5):
|
43 |
-
Padding on the left.
|
44 |
-
right_padding (`int`, *optional*, defaults to 5):
|
45 |
-
Padding on the right.
|
46 |
-
top_padding (`int`, *optional*, defaults to 5):
|
47 |
-
Padding on the top.
|
48 |
-
bottom_padding (`int`, *optional*, defaults to 5):
|
49 |
-
Padding on the bottom.
|
50 |
-
font_bytes (`bytes`, *optional*):
|
51 |
-
Bytes of the font to use. If `None`, the default font will be used.
|
52 |
-
font_path (`str`, *optional*):
|
53 |
-
Path to the font to use. If `None`, the default font will be used.
|
54 |
-
"""
|
55 |
-
wrapper = textwrap.TextWrapper(
|
56 |
-
width=80
|
57 |
-
) # Add new lines so that each line is no more than 80 characters.
|
58 |
-
lines = wrapper.wrap(text=text)
|
59 |
-
wrapped_text = "\n".join(lines)
|
60 |
-
|
61 |
-
if font_bytes is not None and font_path is None:
|
62 |
-
font = io.BytesIO(font_bytes)
|
63 |
-
elif font_path is not None:
|
64 |
-
font = font_path
|
65 |
-
else:
|
66 |
-
font = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF")
|
67 |
-
raise ValueError(
|
68 |
-
"Either font_bytes or font_path must be provided. "
|
69 |
-
f"Using default font {font}."
|
70 |
-
)
|
71 |
-
font = ImageFont.truetype(font, encoding="UTF-8", size=text_size)
|
72 |
-
|
73 |
-
# Use a temporary canvas to determine the width and height in pixels when
|
74 |
-
# rendering the text.
|
75 |
-
temp_draw = ImageDraw.Draw(Image.new("RGB", (1, 1), background_color))
|
76 |
-
_, _, text_width, text_height = temp_draw.textbbox((0, 0), wrapped_text, font)
|
77 |
-
|
78 |
-
# Create the actual image with a bit of padding around the text.
|
79 |
-
image_width = text_width + left_padding + right_padding
|
80 |
-
image_height = text_height + top_padding + bottom_padding
|
81 |
-
image = Image.new("RGB", (image_width, image_height), background_color)
|
82 |
-
draw = ImageDraw.Draw(image)
|
83 |
-
draw.text(
|
84 |
-
xy=(left_padding, top_padding), text=wrapped_text, fill=text_color, font=font
|
85 |
-
)
|
86 |
-
return image
|
87 |
-
|
88 |
-
|
89 |
-
# Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L87
|
90 |
-
def render_header(
|
91 |
-
image: Image.Image, header: str, bbox: Dict[str, float], font_path: str, **kwargs
|
92 |
-
) -> Tuple[Image.Image, Tuple[float, float, float, float]]:
|
93 |
-
"""
|
94 |
-
Renders the input text as a header on the input image and updates the bounding box.
|
95 |
-
|
96 |
-
Args:
|
97 |
-
image (Image.Image):
|
98 |
-
The image to render the header on.
|
99 |
-
header (str):
|
100 |
-
The header text.
|
101 |
-
bbox (Dict[str,float]):
|
102 |
-
The bounding box in relative position (0-1), format ("x_min": 0,
|
103 |
-
"y_min": 0,
|
104 |
-
"x_max": 0,
|
105 |
-
"y_max": 0).
|
106 |
-
input_data_format (Union[str, ChildProcessError], optional):
|
107 |
-
The data format of the image.
|
108 |
-
|
109 |
-
Returns:
|
110 |
-
Tuple[Image.Image, Dict[str, float] ]:
|
111 |
-
The image with the header rendered and the updated bounding box.
|
112 |
-
"""
|
113 |
-
assert os.path.exists(font_path), f"Font path {font_path} does not exist."
|
114 |
-
header_image = render_text(text=header, font_path=font_path, **kwargs)
|
115 |
-
new_width = max(header_image.width, image.width)
|
116 |
-
|
117 |
-
new_height = int(image.height * (new_width / image.width))
|
118 |
-
new_header_height = int(header_image.height * (new_width / header_image.width))
|
119 |
-
|
120 |
-
new_image = Image.new("RGB", (new_width, new_height + new_header_height), "white")
|
121 |
-
new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0))
|
122 |
-
new_image.paste(image.resize((new_width, new_height)), (0, new_header_height))
|
123 |
-
|
124 |
-
new_total_height = new_image.height
|
125 |
-
|
126 |
-
new_bbox = {
|
127 |
-
"xmin": bbox["xmin"],
|
128 |
-
"ymin": ((bbox["ymin"] * new_height) + new_header_height)
|
129 |
-
/ new_total_height, # shift y_min down by the header's relative height
|
130 |
-
"xmax": bbox["xmax"],
|
131 |
-
"ymax": ((bbox["ymax"] * new_height) + new_header_height)
|
132 |
-
/ new_total_height, # shift y_min down by the header's relative height
|
133 |
-
}
|
134 |
-
|
135 |
-
return (
|
136 |
-
new_image,
|
137 |
-
new_bbox,
|
138 |
-
{
|
139 |
-
"width": new_width,
|
140 |
-
"height": new_height,
|
141 |
-
"header_height": new_header_height,
|
142 |
-
"total_height": new_total_height,
|
143 |
-
},
|
144 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|