keithhon commited on
Commit
ec7af4e
·
1 Parent(s): f7c951f

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Code inspired by https://huggingface.co/spaces/flax-community/dalle-mini
2
+ """
3
+ import base64
4
+ import os
5
+ import time
6
+ from io import BytesIO
7
+ from multiprocessing import Process
8
+
9
+ import streamlit as st
10
+ from PIL import Image
11
+
12
+ import requests
13
+ import logging
14
+
15
+
16
+ def start_server():
17
+ os.system("uvicorn server:app --port 8080 --host 0.0.0.0 --workers 1")
18
+
19
+
20
+ def load_models():
21
+ if not is_port_in_use(8080):
22
+ with st.spinner(text="Loading models, please wait..."):
23
+ proc = Process(target=start_server, args=(), daemon=True)
24
+ proc.start()
25
+ while not is_port_in_use(8080):
26
+ time.sleep(1)
27
+ st.success("Model server started.")
28
+ else:
29
+ st.success("Model server already running...")
30
+ st.session_state["models_loaded"] = True
31
+
32
+
33
+ def is_port_in_use(port):
34
+ import socket
35
+
36
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
37
+ return s.connect_ex(("0.0.0.0", port)) == 0
38
+
39
+
40
+ def generate(prompt):
41
+ correct_request = f"http://0.0.0.0:8080/correct?prompt={prompt}"
42
+ response = requests.get(correct_request)
43
+ images = response.json()["images"]
44
+ images = [Image.open(BytesIO(base64.b64decode(img))) for img in images]
45
+ return images
46
+
47
+
48
+ if "models_loaded" not in st.session_state:
49
+ st.session_state["models_loaded"] = False
50
+
51
+
52
+ st.header("Logo generator")
53
+ #st.subheader("Generate images from text")
54
+ st.write("Generate logos from text")
55
+
56
+ if not st.session_state["models_loaded"]:
57
+ load_models()
58
+
59
+ prompt = st.text_input("Your text prompt. Tip: start with 'a logo of...':")
60
+
61
+ DEBUG = False
62
+ # UI code taken from https://huggingface.co/spaces/flax-community/dalle-mini/blob/main/app/streamlit/app.py
63
+ if prompt != "":
64
+ container = st.empty()
65
+ container.markdown(
66
+ f"""
67
+ <style> p {{ margin:0 }} div {{ margin:0 }} </style>
68
+ <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
69
+ <div class="stAlert">
70
+ <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
71
+ <div class="st-b7">
72
+ <div class="css-whx05o e13vu3m50">
73
+ <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
74
+ <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
75
+ Generating predictions for: <b>{prompt}</b>
76
+ </div>
77
+ </div>
78
+ </div>
79
+ </div>
80
+ </div>
81
+ </div>
82
+ """,
83
+ unsafe_allow_html=True,
84
+ )
85
+
86
+ print(f"Getting selections: {prompt}")
87
+ selected = generate(prompt)
88
+
89
+ margin = 0.1 # for better position of zoom in arrow
90
+ n_columns = 3
91
+ cols = st.columns([1] + [margin, 1] * (n_columns - 1))
92
+ for i, img in enumerate(selected):
93
+ cols[(i % n_columns) * 2].image(img)
94
+ container.markdown(f"**{prompt}**")
95
+
96
+ st.button("Run again", key="again_button")
97
+
98
+