darknoon commited on
Commit
7569b80
·
1 Parent(s): 9817019

Add match image objective

Browse files
Files changed (1) hide show
  1. app.py +140 -56
app.py CHANGED
@@ -2,12 +2,13 @@ import gradio as gr
2
  from playwright.async_api import async_playwright, Page
3
  from PIL import Image
4
  from io import BytesIO
5
- from anthropic import Anthropic, TextEvent
6
  from dotenv import load_dotenv
7
  import os
8
  from typing import Literal
9
  import time
10
  from base64 import b64encode
 
11
 
12
  load_dotenv()
13
  # check for ANTHROPIC_API_KEY
@@ -53,23 +54,25 @@ Assume that the content is being inserted into a template like this:
53
  {apply_tailwind("your html here")}
54
  """
55
 
56
- improve_prompt = """
57
- Given the current draft of the webpage you generated for me as HTML and the screenshot of it rendered, improve the HTML to look nicer.
58
- """
59
 
 
 
 
 
60
 
61
- def stream_initial(prompt):
 
 
 
62
  with anthropic.messages.stream(
63
  model=model,
64
- max_tokens=2000,
65
- system=system_prompt,
66
- messages=[
67
- {"role": "user", "content": prompt},
68
- ],
69
  ) as stream:
70
- for message in stream:
71
- if isinstance(message, TextEvent):
72
- yield message.text
73
 
74
 
75
  def format_image(image: bytes, media_type: Literal["image/png", "image/jpeg"]):
@@ -84,13 +87,14 @@ def format_image(image: bytes, media_type: Literal["image/png", "image/jpeg"]):
84
  }
85
 
86
 
87
- def stream_with_visual_feedback(prompt, history: list[tuple[str, bytes]]):
88
  """
89
  history is a list of tuples of (content, image) corresponding to iterations of generation and rendering
90
  """
91
- print(f"History has {len(history)} images")
92
-
93
- messages = [
 
94
  {"role": "user", "content": prompt},
95
  *[
96
  item
@@ -118,27 +122,57 @@ def stream_with_visual_feedback(prompt, history: list[tuple[str, bytes]]):
118
  ],
119
  ]
120
 
121
- with anthropic.messages.stream(
122
- model=model,
123
- max_tokens=2000,
124
- system=system_prompt,
125
- messages=messages,
126
- ) as stream:
127
- for message in stream:
128
- if isinstance(message, TextEvent):
129
- yield message.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
 
132
  async def render_html(page: Page, content: str):
133
- start_time = t()
134
  await page.set_content(content)
135
  # weird, can we set scale to 2.0 directly instead of "device", ie whatever server this is running on?
136
  image_bytes = await page.screenshot(type="png", scale="device", full_page=True)
137
- return image_bytes, t() - start_time
138
-
139
-
140
- def t():
141
- return time.perf_counter()
142
 
143
 
144
  def apply_template(content, template):
@@ -151,48 +185,78 @@ def to_pil(image_bytes: bytes):
151
  return Image.open(BytesIO(image_bytes))
152
 
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  async def generate_with_visual_feedback(
155
  prompt,
156
  template,
157
  resolution: str = "512",
158
  num_iterations: int = 1,
159
  ):
160
- render_every = 0.25
161
- resolution = {"512": (512, 512), "1024": (1024, 1024)}[resolution]
162
- async with async_playwright() as p:
163
- browser = await p.chromium.launch()
164
- page = await browser.new_page(
165
- viewport={"width": resolution[0], "height": resolution[1]}
166
- )
167
- last_yield = t()
168
  history = []
169
  for i in range(num_iterations):
170
- stream = (
171
- stream_initial(prompt)
172
  if i == 0
173
- else stream_with_visual_feedback(prompt, history)
174
  )
175
  content = ""
176
- for chunk in stream:
177
- content = content + chunk
178
- current_time = t()
179
- if current_time - last_yield >= render_every:
180
- image_bytes, render_time = await render_html(
181
- page, apply_template(content, template)
182
- )
183
- yield to_pil(image_bytes), content, render_time
184
- last_yield = t()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  # always render the final image of each iteration
186
  image_bytes, render_time = await render_html(
187
  page, apply_template(content, template)
188
  )
189
  history.append((content, image_bytes))
190
- yield to_pil(image_bytes), content, render_time
191
- # cleanup
192
- await browser.close()
193
 
194
 
195
- demo = gr.Interface(
196
  generate_with_visual_feedback,
197
  inputs=[
198
  gr.Textbox(
@@ -212,6 +276,26 @@ demo = gr.Interface(
212
  ],
213
  )
214
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  if __name__ == "__main__":
217
  prepare_playwright_if_needed()
 
2
  from playwright.async_api import async_playwright, Page
3
  from PIL import Image
4
  from io import BytesIO
5
+ from anthropic import Anthropic
6
  from dotenv import load_dotenv
7
  import os
8
  from typing import Literal
9
  import time
10
  from base64 import b64encode
11
+ from contextlib import asynccontextmanager
12
 
13
  load_dotenv()
14
  # check for ANTHROPIC_API_KEY
 
54
  {apply_tailwind("your html here")}
55
  """
56
 
 
 
 
57
 
58
+ def messages_text_to_web(prompt):
59
+ return [
60
+ {"role": "user", "content": prompt},
61
+ ]
62
 
63
+
64
+ # returns the full text of the response each time
65
+ def stream_claude(messages, system=system_prompt, max_tokens=2000):
66
+ text = ""
67
  with anthropic.messages.stream(
68
  model=model,
69
+ max_tokens=max_tokens,
70
+ system=system,
71
+ messages=messages,
 
 
72
  ) as stream:
73
+ for chunk in stream.text_stream:
74
+ text += chunk
75
+ yield text
76
 
77
 
78
  def format_image(image: bytes, media_type: Literal["image/png", "image/jpeg"]):
 
87
  }
88
 
89
 
90
+ def visual_feedback_messages(prompt, history: list[tuple[str, bytes]]):
91
  """
92
  history is a list of tuples of (content, image) corresponding to iterations of generation and rendering
93
  """
94
+ improve_prompt = """
95
+ Given the current draft of the webpage you generated for me as HTML and the screenshot of it rendered, improve the HTML to look nicer.
96
+ """
97
+ return [
98
  {"role": "user", "content": prompt},
99
  *[
100
  item
 
122
  ],
123
  ]
124
 
125
+
126
+ def match_image_messages(image_bytes: bytes, history: list[tuple[bytes, bytes]]):
127
+ improve_prompt = """
128
+ Given the current draft of the webpage you generated for me as HTML and the original screenshot, improve the HTML to match closer to the original screenshot.
129
+ """
130
+
131
+ return [
132
+ {
133
+ "role": "user",
134
+ "content": [
135
+ {
136
+ "type": "text",
137
+ "text": "Please generate a webpage that matches the image below as closely as possible:",
138
+ },
139
+ format_image(image_bytes, "image/png"),
140
+ ],
141
+ },
142
+ *[
143
+ item
144
+ for content, image_bytes in history
145
+ for item in [
146
+ {
147
+ "role": "assistant",
148
+ "content": content,
149
+ },
150
+ {
151
+ "role": "user",
152
+ "content": [
153
+ {
154
+ "type": "text",
155
+ "text": "Here is a screenshot of the above HTML code rendered in a browser:",
156
+ },
157
+ format_image(image_bytes, "image/png"),
158
+ {
159
+ "type": "text",
160
+ "text": improve_prompt,
161
+ },
162
+ ],
163
+ },
164
+ ]
165
+ ],
166
+ ]
167
 
168
 
169
  async def render_html(page: Page, content: str):
170
+ start_time = time.perf_counter()
171
  await page.set_content(content)
172
  # weird, can we set scale to 2.0 directly instead of "device", ie whatever server this is running on?
173
  image_bytes = await page.screenshot(type="png", scale="device", full_page=True)
174
+ dt = time.perf_counter() - start_time
175
+ return image_bytes, dt
 
 
 
176
 
177
 
178
  def apply_template(content, template):
 
185
  return Image.open(BytesIO(image_bytes))
186
 
187
 
188
+ @asynccontextmanager
189
+ async def browser(width, height):
190
+ async with async_playwright() as p:
191
+ browser = await p.chromium.launch()
192
+ page = await browser.new_page(viewport={"width": width, "height": height})
193
+ try:
194
+ yield page
195
+ finally:
196
+ await browser.close()
197
+
198
+
199
+ async def throttle(generator, every=0.25):
200
+ last_emit_time = 0
201
+ for item in generator:
202
+ current_time = time.perf_counter()
203
+ if current_time - last_emit_time >= every:
204
+ yield item
205
+ last_emit_time = current_time
206
+ # always emit the last item
207
+ yield item
208
+
209
+
210
  async def generate_with_visual_feedback(
211
  prompt,
212
  template,
213
  resolution: str = "512",
214
  num_iterations: int = 1,
215
  ):
216
+ width = {"512": 512, "1024": 1024}[resolution]
217
+ async with browser(width, width) as page:
 
 
 
 
 
 
218
  history = []
219
  for i in range(num_iterations):
220
+ messages = (
221
+ messages_text_to_web(prompt)
222
  if i == 0
223
+ else visual_feedback_messages(prompt, history)
224
  )
225
  content = ""
226
+ async for content in throttle(stream_claude(messages), every=0.25):
227
+ image_bytes, render_time = await render_html(
228
+ page, apply_template(content, template)
229
+ )
230
+ yield to_pil(image_bytes), content, render_time
231
+ history.append((content, image_bytes))
232
+
233
+
234
+ def to_image_bytes(image: Image.Image) -> bytes:
235
+ buffer = BytesIO()
236
+ image.save(buffer, format="PNG")
237
+ return buffer.getvalue()
238
+
239
+
240
+ async def match_image_with_visual_feedback(image, template, resolution, num_iterations):
241
+ width = {"512": 512, "1024": 1024}[resolution]
242
+ async with browser(width, width) as page:
243
+ history = []
244
+ for i in range(num_iterations):
245
+ image.thumbnail((width, width), Image.Resampling.LANCZOS)
246
+ messages = match_image_messages(to_image_bytes(image), history)
247
+ async for content in throttle(stream_claude(messages), 0.25):
248
+ image_bytes, render_time = await render_html(
249
+ page, apply_template(content, template)
250
+ )
251
+ yield to_pil(image_bytes), content, render_time
252
  # always render the final image of each iteration
253
  image_bytes, render_time = await render_html(
254
  page, apply_template(content, template)
255
  )
256
  history.append((content, image_bytes))
 
 
 
257
 
258
 
259
+ demo_generate = gr.Interface(
260
  generate_with_visual_feedback,
261
  inputs=[
262
  gr.Textbox(
 
276
  ],
277
  )
278
 
279
+ demo_match_image = gr.Interface(
280
+ match_image_with_visual_feedback,
281
+ inputs=[
282
+ gr.Image(type="pil", label="Original Image", image_mode="RGB", format="png"),
283
+ gr.Dropdown(choices=["tailwind"], label="Template", value="tailwind"),
284
+ gr.Dropdown(choices=["512", "1024"], label="Page Width", value="512"),
285
+ gr.Slider(1, 10, 3, step=1, label="Iterations"),
286
+ ],
287
+ outputs=[
288
+ gr.Image(type="pil", label="Rendered HTML", image_mode="RGB", format="png"),
289
+ gr.Textbox(lines=5, label="Code"),
290
+ gr.Number(label="Render Time", precision=2),
291
+ ],
292
+ )
293
+
294
+ demo = gr.TabbedInterface(
295
+ [demo_match_image, demo_generate],
296
+ ["Match Image", "Generate"],
297
+ )
298
+
299
 
300
  if __name__ == "__main__":
301
  prepare_playwright_if_needed()