azharaslam commited on
Commit
ce3c203
·
verified ·
1 Parent(s): 18d1755

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -67
app.py CHANGED
@@ -15,18 +15,21 @@ from PIL import Image
15
 
16
  from transformers.image_transforms import resize, to_channel_dimension_format
17
 
18
-
19
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
20
 
21
- DEVICE = torch.device("cuda")
 
22
  PROCESSOR = AutoProcessor.from_pretrained(
23
  "HuggingFaceM4/VLM_WebSight_finetuned",
24
  )
25
  MODEL = AutoModelForCausalLM.from_pretrained(
26
  "HuggingFaceM4/VLM_WebSight_finetuned",
27
  trust_remote_code=True,
28
- torch_dtype=torch.bfloat16,
29
  ).to(DEVICE)
 
 
30
  if MODEL.config.use_resampler:
31
  image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
32
  else:
@@ -36,12 +39,9 @@ else:
36
  BOS_TOKEN = PROCESSOR.tokenizer.bos_token
37
  BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
38
 
39
-
40
  ## Utils
41
 
42
  def convert_to_rgb(image):
43
- # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
44
- # for transparent images. The call to `alpha_composite` handles this case
45
  if image.mode == "RGB":
46
  return image
47
 
@@ -51,8 +51,6 @@ def convert_to_rgb(image):
51
  alpha_composite = alpha_composite.convert("RGB")
52
  return alpha_composite
53
 
54
- # The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip,
55
- # so this is a hack in order to redefine ONLY the transform method
56
  def custom_transform(x):
57
  x = convert_to_rgb(x)
58
  x = to_numpy_array(x)
@@ -69,13 +67,7 @@ def custom_transform(x):
69
 
70
  ## End of Utils
71
 
72
-
73
- IMAGE_GALLERY_PATHS = [
74
- f"example_images/{ex_image}"
75
- for ex_image in os.listdir(f"example_images")
76
- ]
77
-
78
-
79
  def install_playwright():
80
  try:
81
  subprocess.run(["playwright", "install"], check=True)
@@ -85,17 +77,15 @@ def install_playwright():
85
 
86
  install_playwright()
87
 
 
 
 
 
88
 
89
- def add_file_gallery(
90
- selected_state: gr.SelectData,
91
- gallery_list: List[str]
92
- ):
93
  return Image.open(gallery_list.root[selected_state.index].image.path)
94
 
95
-
96
- def render_webpage(
97
- html_css_code,
98
- ):
99
  with sync_playwright() as p:
100
  browser = p.chromium.launch(headless=True)
101
  context = browser.new_context(
@@ -115,11 +105,8 @@ def render_webpage(
115
 
116
  return Image.open(output_path_screenshot)
117
 
118
-
119
  @spaces.GPU(duration=180)
120
- def model_inference(
121
- image,
122
- ):
123
  if image is None:
124
  raise ValueError("`image` is None. It should be a PIL image.")
125
 
@@ -132,10 +119,7 @@ def model_inference(
132
  [image],
133
  transform=custom_transform
134
  )
135
- inputs = {
136
- k: v.to(DEVICE)
137
- for k, v in inputs.items()
138
- }
139
 
140
  streamer = TextIteratorStreamer(
141
  PROCESSOR.tokenizer,
@@ -147,16 +131,6 @@ def model_inference(
147
  max_length=4096,
148
  streamer=streamer,
149
  )
150
- # Regular generation version
151
- # generation_kwargs.pop("streamer")
152
- # generated_ids = MODEL.generate(**generation_kwargs)
153
- # generated_text = PROCESSOR.batch_decode(
154
- # generated_ids,
155
- # skip_special_tokens=True
156
- # )[0]
157
- # rendered_page = render_webpage(generated_text)
158
- # return generated_text, rendered_page
159
- # Token streaming version
160
  thread = Thread(
161
  target=MODEL.generate,
162
  kwargs=generation_kwargs,
@@ -172,20 +146,8 @@ def model_inference(
172
  generated_text += new_text
173
  yield generated_text, rendered_image
174
 
175
-
176
- generated_html = gr.Code(
177
- label="Extracted HTML",
178
- elem_id="generated_html",
179
- )
180
- rendered_html = gr.Image(
181
- label="Rendered HTML",
182
- show_download_button=False,
183
- show_share_button=False,
184
- )
185
- # rendered_html = gr.HTML(
186
- # label="Rendered HTML"
187
- # )
188
-
189
 
190
  css = """
191
  .gradio-container{max-width: 1000px!important}
@@ -193,7 +155,6 @@ h1{display: flex;align-items: center;justify-content: center;gap: .25em}
193
  *{transition: width 0.5s ease, flex-grow 0.5s ease}
194
  """
195
 
196
-
197
  with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo:
198
  gr.Markdown(
199
  "Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content."
@@ -208,15 +169,11 @@ with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as d
208
  )
209
  with gr.Group():
210
  with gr.Row():
211
- submit_btn = gr.Button(
212
- value="▶️ Submit", visible=True, min_width=120
213
- )
214
  clear_btn = gr.ClearButton(
215
  [imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120
216
  )
217
- regenerate_btn = gr.Button(
218
- value="🔄 Regenerate", visible=True, min_width=120
219
- )
220
  with gr.Column(scale=4):
221
  rendered_html.render()
222
 
@@ -235,11 +192,7 @@ with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as d
235
  )
236
 
237
  gr.on(
238
- triggers=[
239
- imagebox.upload,
240
- submit_btn.click,
241
- regenerate_btn.click,
242
- ],
243
  fn=model_inference,
244
  inputs=[imagebox],
245
  outputs=[generated_html, rendered_html],
 
15
 
16
  from transformers.image_transforms import resize, to_channel_dimension_format
17
 
18
+ # Install flash-attn without CUDA build isolation
19
  subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
20
 
21
+ # Set the device to GPU if available, otherwise use CPU
22
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  PROCESSOR = AutoProcessor.from_pretrained(
24
  "HuggingFaceM4/VLM_WebSight_finetuned",
25
  )
26
  MODEL = AutoModelForCausalLM.from_pretrained(
27
  "HuggingFaceM4/VLM_WebSight_finetuned",
28
  trust_remote_code=True,
29
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
30
  ).to(DEVICE)
31
+
32
+ # Determine image sequence length
33
  if MODEL.config.use_resampler:
34
  image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
35
  else:
 
39
  BOS_TOKEN = PROCESSOR.tokenizer.bos_token
40
  BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
41
 
 
42
  ## Utils
43
 
44
  def convert_to_rgb(image):
 
 
45
  if image.mode == "RGB":
46
  return image
47
 
 
51
  alpha_composite = alpha_composite.convert("RGB")
52
  return alpha_composite
53
 
 
 
54
  def custom_transform(x):
55
  x = convert_to_rgb(x)
56
  x = to_numpy_array(x)
 
67
 
68
  ## End of Utils
69
 
70
+ # Install Playwright
 
 
 
 
 
 
71
  def install_playwright():
72
  try:
73
  subprocess.run(["playwright", "install"], check=True)
 
77
 
78
  install_playwright()
79
 
80
+ IMAGE_GALLERY_PATHS = [
81
+ f"example_images/{ex_image}"
82
+ for ex_image in os.listdir(f"example_images")
83
+ ]
84
 
85
+ def add_file_gallery(selected_state: gr.SelectData, gallery_list: List[str]):
 
 
 
86
  return Image.open(gallery_list.root[selected_state.index].image.path)
87
 
88
+ def render_webpage(html_css_code):
 
 
 
89
  with sync_playwright() as p:
90
  browser = p.chromium.launch(headless=True)
91
  context = browser.new_context(
 
105
 
106
  return Image.open(output_path_screenshot)
107
 
 
108
  @spaces.GPU(duration=180)
109
+ def model_inference(image):
 
 
110
  if image is None:
111
  raise ValueError("`image` is None. It should be a PIL image.")
112
 
 
119
  [image],
120
  transform=custom_transform
121
  )
122
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
 
 
 
123
 
124
  streamer = TextIteratorStreamer(
125
  PROCESSOR.tokenizer,
 
131
  max_length=4096,
132
  streamer=streamer,
133
  )
 
 
 
 
 
 
 
 
 
 
134
  thread = Thread(
135
  target=MODEL.generate,
136
  kwargs=generation_kwargs,
 
146
  generated_text += new_text
147
  yield generated_text, rendered_image
148
 
149
+ generated_html = gr.Code(label="Extracted HTML", elem_id="generated_html")
150
+ rendered_html = gr.Image(label="Rendered HTML", show_download_button=False, show_share_button=False)
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  css = """
153
  .gradio-container{max-width: 1000px!important}
 
155
  *{transition: width 0.5s ease, flex-grow 0.5s ease}
156
  """
157
 
 
158
  with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo:
159
  gr.Markdown(
160
  "Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content."
 
169
  )
170
  with gr.Group():
171
  with gr.Row():
172
+ submit_btn = gr.Button(value="▶️ Submit", visible=True, min_width=120)
 
 
173
  clear_btn = gr.ClearButton(
174
  [imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120
175
  )
176
+ regenerate_btn = gr.Button(value="🔄 Regenerate", visible=True, min_width=120)
 
 
177
  with gr.Column(scale=4):
178
  rendered_html.render()
179
 
 
192
  )
193
 
194
  gr.on(
195
+ triggers=[imagebox.upload, submit_btn.click, regenerate_btn.click],
 
 
 
 
196
  fn=model_inference,
197
  inputs=[imagebox],
198
  outputs=[generated_html, rendered_html],