Curranj commited on
Commit
841be4d
1 Parent(s): 2ca6c75

Create new file

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import warnings
4
+
5
+ from IPython.display import display
6
+ from PIL import Image
7
+ from stability_sdk import client
8
+ import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
9
+
10
+ import gradio as gr
11
+ stability_api = client.StabilityInference(
12
+ key=os.environ["Secret"],
13
+ verbose=True,
14
+ )
15
+
16
+
17
+ def infer(prompt):
18
+ # the object returned is a python generator
19
+ answers = stability_api.generate(
20
+ prompt=prompt
21
+ )
22
+
23
+ # iterating over the generator produces the api response
24
+ for resp in answers:
25
+ for artifact in resp.artifacts:
26
+ if artifact.finish_reason == generation.FILTER:
27
+ warnings.warn(
28
+ "Your request activated the API's safety filters and could not be processed."
29
+ "Please modify the prompt and try again.")
30
+ if artifact.type == generation.ARTIFACT_IMAGE:
31
+ img = Image.open(io.BytesIO(artifact.binary))
32
+ return img
33
+
34
+
35
+ block = gr.Blocks(css=".container { max-width: 600px; margin: auto; }")
36
+
37
+ num_samples = 1
38
+
39
+
40
+
41
+ with block as demo:
42
+ gr.Markdown("<h1><center>Stable Diffusion</center></h1>")
43
+ gr.Markdown(
44
+ "Get an image for any prompt you provide!"
45
+ )
46
+ with gr.Group():
47
+ with gr.Box():
48
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
49
+
50
+ text = gr.Textbox(
51
+ label="Enter your prompt", show_label=False, max_lines=1
52
+ ).style(
53
+ border=(True, False, True, True),
54
+ rounded=(True, False, False, True),
55
+ container=False,
56
+ )
57
+ btn = gr.Button("Run").style(
58
+ margin=False,
59
+ rounded=(False, True, True, False),
60
+ )
61
+
62
+
63
+ gallery = gr.Image()
64
+ text.submit(infer, inputs=[text], outputs=gallery)
65
+ btn.click(infer, inputs=[text], outputs=gallery)
66
+
67
+
68
+
69
+
70
+
71
+ demo.launch()