gitlost-murali commited on
Commit
95431d3
·
1 Parent(s): 0aa610a

use askui-ml-helper library

Browse files
Files changed (3) hide show
  1. app.py +2 -91
  2. requirements.txt +2 -4
  3. utils.py +0 -144
app.py CHANGED
@@ -1,98 +1,9 @@
1
  import gradio as gr
2
- from PIL import Image, ImageDraw
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- import torch
7
- from transformers import Pix2StructProcessor, Pix2StructVisionModel
8
- from utils import download_default_font, render_header
9
-
10
- class Pix2StructForRegression(nn.Module):
11
- def __init__(self, sourcemodel_path, device):
12
- super(Pix2StructForRegression, self).__init__()
13
- self.model = Pix2StructVisionModel.from_pretrained(sourcemodel_path)
14
- self.regression_layer1 = nn.Linear(768, 1536)
15
- self.dropout1 = nn.Dropout(0.1)
16
- self.regression_layer2 = nn.Linear(1536, 768)
17
- self.dropout2 = nn.Dropout(0.1)
18
- self.regression_layer3 = nn.Linear(768, 2)
19
- self.device = device
20
-
21
- def forward(self, *args, **kwargs):
22
- outputs = self.model(*args, **kwargs)
23
- sequence_output = outputs.last_hidden_state
24
- first_token_output = sequence_output[:, 0, :]
25
-
26
- x = F.relu(self.regression_layer1(first_token_output))
27
- x = F.relu(self.regression_layer2(x))
28
- regression_output = torch.sigmoid(self.regression_layer3(x))
29
-
30
- return regression_output
31
-
32
- def load_state_dict_file(self, checkpoint_path, strict=True):
33
- state_dict = torch.load(checkpoint_path, map_location=self.device)
34
- self.load_state_dict(state_dict, strict=strict)
35
-
36
- class Inference:
37
- def __init__(self) -> None:
38
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
- self.model, self.processor = self.load_model_and_processor("google/matcha-base", "model/pta-text-v0.1.pt")
40
-
41
- def load_model_and_processor(self, model_name, checkpoint_path):
42
- model = Pix2StructForRegression(sourcemodel_path=model_name, device=self.device)
43
- model.load_state_dict_file(checkpoint_path=checkpoint_path)
44
- model.eval()
45
- model = model.to(self.device)
46
- processor = Pix2StructProcessor.from_pretrained(model_name, is_vqa=False)
47
- return model, processor
48
-
49
- def prepare_image(self, image, prompt, processor):
50
- image = image.resize((1920, 1080))
51
- download_default_font_path = download_default_font()
52
- rendered_image, _, render_variables = render_header(
53
- image=image,
54
- header=prompt,
55
- bbox={"xmin": 0, "ymin": 0, "xmax": 0, "ymax": 0},
56
- font_path=download_default_font_path,
57
- )
58
- encoding = processor(
59
- images=rendered_image,
60
- max_patches=2048,
61
- add_special_tokens=True,
62
- return_tensors="pt",
63
- )
64
- return encoding, render_variables
65
-
66
- def predict_coordinates(self, encoding, model, render_variables):
67
- with torch.no_grad():
68
- pred_regression_outs = model(flattened_patches=encoding["flattened_patches"], attention_mask=encoding["attention_mask"])
69
- new_height = render_variables["height"]
70
- new_header_height = render_variables["header_height"]
71
- new_total_height = render_variables["total_height"]
72
-
73
- pred_regression_outs[:, 1] = (
74
- (pred_regression_outs[:, 1] * new_total_height) - new_header_height
75
- ) / new_height
76
-
77
- pred_coordinates = pred_regression_outs.squeeze().tolist()
78
- return pred_coordinates
79
-
80
- def draw_circle_on_image(self, image, coordinates):
81
- x, y = coordinates[0] * image.width, coordinates[1] * image.height
82
- draw = ImageDraw.Draw(image)
83
- radius = 5
84
- draw.ellipse((x-radius, y-radius, x+radius, y+radius), fill="red")
85
- return image
86
-
87
- def process_image_and_draw_circle(self, image, prompt):
88
- encoding, render_variables = self.prepare_image(image, prompt, self.processor)
89
- pred_coordinates = self.predict_coordinates(encoding.to(self.device) , self.model, render_variables)
90
- result_image = self.draw_circle_on_image(image, pred_coordinates)
91
- return result_image
92
 
 
93
 
94
  def main():
95
- inference = Inference()
96
  # Gradio Interface
97
  iface = gr.Interface(
98
  fn=inference.process_image_and_draw_circle,
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
+ from askui_ml_helper.utils.pta_text import PtaTextInference
4
 
5
  def main():
6
+ inference = PtaTextInference("model/pta-text-v0.1.pt")
7
  # Gradio Interface
8
  iface = gr.Interface(
9
  fn=inference.process_image_and_draw_circle,
requirements.txt CHANGED
@@ -1,4 +1,2 @@
1
- torch
2
- transformers
3
- gradio
4
- Pillow
 
1
+ askui-ml-helper
2
+ gradio
 
 
utils.py DELETED
@@ -1,144 +0,0 @@
1
- import io
2
- import os
3
- import textwrap
4
- from typing import Dict, Optional, Tuple
5
-
6
- from huggingface_hub import hf_hub_download
7
- from PIL import Image, ImageDraw, ImageFont
8
-
9
- DEFAULT_FONT_PATH = "ybelkada/fonts"
10
-
11
-
12
- def download_default_font():
13
- font_path = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF")
14
- return font_path
15
-
16
-
17
- def render_text(
18
- text: str,
19
- text_size: int = 36,
20
- text_color: str = "black",
21
- background_color: str = "white",
22
- left_padding: int = 5,
23
- right_padding: int = 5,
24
- top_padding: int = 5,
25
- bottom_padding: int = 5,
26
- font_bytes: Optional[bytes] = None,
27
- font_path: Optional[str] = None,
28
- ) -> Image.Image:
29
- """
30
- Render text. This script is entirely adapted from the original script that can be found here:
31
- https://github.com/google-research/pix2struct/blob/main/pix2struct/preprocessing/preprocessing_utils.py
32
-
33
- Args:
34
- text (`str`, *optional*, defaults to ):
35
- Text to render.
36
- text_size (`int`, *optional*, defaults to 36):
37
- Size of the text.
38
- text_color (`str`, *optional*, defaults to `"black"`):
39
- Color of the text.
40
- background_color (`str`, *optional*, defaults to `"white"`):
41
- Color of the background.
42
- left_padding (`int`, *optional*, defaults to 5):
43
- Padding on the left.
44
- right_padding (`int`, *optional*, defaults to 5):
45
- Padding on the right.
46
- top_padding (`int`, *optional*, defaults to 5):
47
- Padding on the top.
48
- bottom_padding (`int`, *optional*, defaults to 5):
49
- Padding on the bottom.
50
- font_bytes (`bytes`, *optional*):
51
- Bytes of the font to use. If `None`, the default font will be used.
52
- font_path (`str`, *optional*):
53
- Path to the font to use. If `None`, the default font will be used.
54
- """
55
- wrapper = textwrap.TextWrapper(
56
- width=80
57
- ) # Add new lines so that each line is no more than 80 characters.
58
- lines = wrapper.wrap(text=text)
59
- wrapped_text = "\n".join(lines)
60
-
61
- if font_bytes is not None and font_path is None:
62
- font = io.BytesIO(font_bytes)
63
- elif font_path is not None:
64
- font = font_path
65
- else:
66
- font = hf_hub_download(DEFAULT_FONT_PATH, "Arial.TTF")
67
- raise ValueError(
68
- "Either font_bytes or font_path must be provided. "
69
- f"Using default font {font}."
70
- )
71
- font = ImageFont.truetype(font, encoding="UTF-8", size=text_size)
72
-
73
- # Use a temporary canvas to determine the width and height in pixels when
74
- # rendering the text.
75
- temp_draw = ImageDraw.Draw(Image.new("RGB", (1, 1), background_color))
76
- _, _, text_width, text_height = temp_draw.textbbox((0, 0), wrapped_text, font)
77
-
78
- # Create the actual image with a bit of padding around the text.
79
- image_width = text_width + left_padding + right_padding
80
- image_height = text_height + top_padding + bottom_padding
81
- image = Image.new("RGB", (image_width, image_height), background_color)
82
- draw = ImageDraw.Draw(image)
83
- draw.text(
84
- xy=(left_padding, top_padding), text=wrapped_text, fill=text_color, font=font
85
- )
86
- return image
87
-
88
-
89
- # Adapted from https://github.com/google-research/pix2struct/blob/0e1779af0f4db4b652c1d92b3bbd2550a7399123/pix2struct/preprocessing/preprocessing_utils.py#L87
90
- def render_header(
91
- image: Image.Image, header: str, bbox: Dict[str, float], font_path: str, **kwargs
92
- ) -> Tuple[Image.Image, Tuple[float, float, float, float]]:
93
- """
94
- Renders the input text as a header on the input image and updates the bounding box.
95
-
96
- Args:
97
- image (Image.Image):
98
- The image to render the header on.
99
- header (str):
100
- The header text.
101
- bbox (Dict[str,float]):
102
- The bounding box in relative position (0-1), format ("x_min": 0,
103
- "y_min": 0,
104
- "x_max": 0,
105
- "y_max": 0).
106
- input_data_format (Union[str, ChildProcessError], optional):
107
- The data format of the image.
108
-
109
- Returns:
110
- Tuple[Image.Image, Dict[str, float] ]:
111
- The image with the header rendered and the updated bounding box.
112
- """
113
- assert os.path.exists(font_path), f"Font path {font_path} does not exist."
114
- header_image = render_text(text=header, font_path=font_path, **kwargs)
115
- new_width = max(header_image.width, image.width)
116
-
117
- new_height = int(image.height * (new_width / image.width))
118
- new_header_height = int(header_image.height * (new_width / header_image.width))
119
-
120
- new_image = Image.new("RGB", (new_width, new_height + new_header_height), "white")
121
- new_image.paste(header_image.resize((new_width, new_header_height)), (0, 0))
122
- new_image.paste(image.resize((new_width, new_height)), (0, new_header_height))
123
-
124
- new_total_height = new_image.height
125
-
126
- new_bbox = {
127
- "xmin": bbox["xmin"],
128
- "ymin": ((bbox["ymin"] * new_height) + new_header_height)
129
- / new_total_height, # shift y_min down by the header's relative height
130
- "xmax": bbox["xmax"],
131
- "ymax": ((bbox["ymax"] * new_height) + new_header_height)
132
- / new_total_height, # shift y_min down by the header's relative height
133
- }
134
-
135
- return (
136
- new_image,
137
- new_bbox,
138
- {
139
- "width": new_width,
140
- "height": new_height,
141
- "header_height": new_header_height,
142
- "total_height": new_total_height,
143
- },
144
- )