Pedro Cuenca commited on
Commit
f68e37a
·
1 Parent(s): 810d65b

Gradio UI skeleton for experimentation.

Browse files

Former-commit-id: cf46af3085c3164460d5f5709e995052e1882fb0

app/sample_images/image_0.jpg ADDED
app/sample_images/image_1.jpg ADDED
app/sample_images/image_2.jpg ADDED
app/sample_images/image_3.jpg ADDED
app/sample_images/image_4.jpg ADDED
app/sample_images/image_5.jpg ADDED
app/sample_images/image_6.jpg ADDED
app/sample_images/image_7.jpg ADDED
app/sample_images/readme.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ These images were generated by one of our checkpoints, as responses to the prompt "snowy mountains by the sea".
app/ui_gradio.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ from PIL import Image
5
+ import gradio as gr
6
+
7
+ def compose_predictions(images, caption=None):
8
+ increased_h = 0 if caption is None else 48
9
+ w, h = images[0].size[0], images[0].size[1]
10
+ img = Image.new("RGB", (len(images)*w, h + increased_h))
11
+ for i, img_ in enumerate(images):
12
+ img.paste(img_, (i*w, increased_h))
13
+
14
+ if caption is not None:
15
+ draw = ImageDraw.Draw(img)
16
+ font = ImageFont.truetype("/usr/share/fonts/truetype/liberation2/LiberationMono-Bold.ttf", 40)
17
+ draw.text((20, 3), caption, (255,255,255), font=font)
18
+ return img
19
+
20
+ def compose_predictions_grid(images):
21
+ cols = 4
22
+ rows = len(images) // cols
23
+ w, h = images[0].size[0], images[0].size[1]
24
+ img = Image.new("RGB", (w * cols, h * rows))
25
+ for i, img_ in enumerate(images):
26
+ row = i // cols
27
+ col = i % cols
28
+ img.paste(img_, (w * col, h * row))
29
+ return img
30
+
31
+ def top_k_predictions_real(prompt, num_candidates=32, k=8):
32
+ images = hallucinate(prompt, num_images=num_candidates)
33
+ images = clip_top_k(prompt, images, k=num_preds)
34
+ return images
35
+
36
+ def top_k_predictions(prompt, num_candidates=32, k=8):
37
+ images = []
38
+ for i in range(k):
39
+ image = Image.open(f"sample_images/image_{i}.jpg")
40
+ images.append(image)
41
+ return images
42
+
43
+ def run_inference(prompt, num_images=32, num_preds=8):
44
+ images = top_k_predictions(prompt, num_candidates=num_images, k=num_preds)
45
+ predictions = compose_predictions(images)
46
+ output_title = 'This would be an html string to serve as title for the outputs.'
47
+ output_description = 'This is another random piece of html'
48
+ return (output_title, predictions, output_description)
49
+
50
+ outputs = [
51
+ gr.outputs.HTML(label=""), # To be used as title
52
+ gr.outputs.Image(label='Top predictions'),
53
+ gr.outputs.HTML(label=""), # Additional text that appears in the screenshot
54
+ ]
55
+
56
+ gr.Interface(run_inference,
57
+ inputs=[gr.inputs.Textbox(label='Prompt')], #, gr.inputs.Slider(1,64,1,8, label='Candidates to generate'), gr.inputs.Slider(1,8,1,1, label='Best predictions to show')],
58
+ outputs=outputs,
59
+ title='DALL·E mini',
60
+ description='This is a demo of the DALLE-mini model trained with Jax/Flax on TPU v3-8s during the HuggingFace Community Week',
61
+ article="<p style='text-align: center'> DALLE·mini by Boris Dayma et al. | <a href='https://github.com/borisdayma/dalle-mini'>GitHub</a></p>",
62
+ layout='vertical',
63
+ theme='huggingface',
64
+ examples=[['an armchair in the shape of an avocado'], ['snowy mountains by the sea']],
65
+ allow_flagging=False,
66
+ live=False,
67
+ server_port=8999
68
+ ).launch(
69
+ share=True # Creates temporary public link if true
70
+ )