eliphatfs commited on
Commit
cd542fa
·
1 Parent(s): 471a386

Better UX: no refresh inside form.

Browse files
Files changed (1) hide show
  1. app.py +82 -71
app.py CHANGED
@@ -5,10 +5,12 @@ from huggingface_hub import HfFolder, snapshot_download
5
 
6
  @st.cache_data
7
  def load_support():
8
- HfFolder().save_token(st.secrets['etoken'])
 
9
  sys.path.append(snapshot_download("OpenShape/openshape-demo-support"))
10
 
11
 
 
12
  load_support()
13
 
14
 
@@ -43,13 +45,15 @@ torch.set_grad_enabled(False)
43
 
44
  from openshape.demo import misc_utils, classification, caption, sd_pc2img, retrieval
45
 
 
46
  st.title("OpenShape Demo")
 
47
  prog = st.progress(0.0, "Idle")
48
- tab_cls, tab_text, tab_img, tab_pc, tab_sd, tab_cap = st.tabs([
49
  "Classification",
50
- "Retrieval from Text",
51
- "Retrieval from Image",
52
- "Retrieval from 3D Shape",
53
  "Image Generation",
54
  "Captioning",
55
  ])
@@ -62,7 +66,9 @@ def demo_classification():
62
  if len(cats) > 64:
63
  st.error('Maximum 64 custom categories supported in the demo')
64
  return
65
- if st.button("Run Classification on LVIS Categories"):
 
 
66
  pc = load_data(prog)
67
  col2 = misc_utils.render_pc(pc)
68
  prog.progress(0.5, "Running Classification")
@@ -72,7 +78,7 @@ def demo_classification():
72
  st.text(cat)
73
  st.caption("Similarity %.4f" % sim)
74
  prog.progress(1.0, "Idle")
75
- if st.button("Run Classification on Custom Categories"):
76
  pc = load_data(prog)
77
  col2 = misc_utils.render_pc(pc)
78
  prog.progress(0.5, "Computing Category Embeddings")
@@ -89,40 +95,42 @@ def demo_classification():
89
 
90
 
91
  def demo_captioning():
92
- load_data = misc_utils.input_3d_shape('cap')
93
- cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0)
94
- if st.button("Generate a Caption"):
95
- pc = load_data(prog)
96
- col2 = misc_utils.render_pc(pc)
97
- prog.progress(0.5, "Running Generation")
98
- cap = caption.pc_caption(model_b32, pc, cond_scale)
99
- st.text(cap)
100
- prog.progress(1.0, "Idle")
 
101
 
102
 
103
  def demo_pc2img():
104
- load_data = misc_utils.input_3d_shape('sd')
105
- prompt = st.text_input("Prompt (Optional)")
106
- noise_scale = st.slider('Variation Level', 0, 5, 1)
107
- cfg_scale = st.slider('Guidance Scale', 0.0, 30.0, 10.0)
108
- steps = st.slider('Diffusion Steps', 8, 50, 25)
109
- width = 640 # st.slider('Width', 480, 640, step=32)
110
- height = 640 # st.slider('Height', 480, 640, step=32)
111
- if st.button("Generate"):
112
- pc = load_data(prog)
113
- col2 = misc_utils.render_pc(pc)
114
- prog.progress(0.49, "Running Generation")
115
- if torch.cuda.is_available():
116
- clip_model.cpu()
117
- img = sd_pc2img.pc_to_image(
118
- model_l14, pc, prompt, noise_scale, width, height, cfg_scale, steps,
119
- lambda i, t, _: prog.progress(0.49 + i / (steps + 1) / 2, "Running Diffusion Step %d" % i)
120
- )
121
- if torch.cuda.is_available():
122
- clip_model.cuda()
123
- with col2:
124
- st.image(img)
125
- prog.progress(1.0, "Idle")
 
126
 
127
 
128
  def retrieval_results(results):
@@ -144,43 +152,46 @@ def retrieval_results(results):
144
 
145
  def demo_retrieval():
146
  with tab_text:
147
- k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rtext')
148
- text = st.text_input("Input Text")
149
- if st.button("Run with Text"):
150
- prog.progress(0.49, "Computing Embeddings")
151
- device = clip_model.device
152
- tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device)
153
- enc = clip_model.get_text_features(**tn).float().cpu()
154
- prog.progress(0.7, "Running Retrieval")
155
- retrieval_results(retrieval.retrieve(enc, k))
156
- prog.progress(1.0, "Idle")
 
157
 
158
  with tab_img:
159
- k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rimage')
160
- pic = st.file_uploader("Upload an Image")
161
- if st.button("Run with Image"):
162
- img = Image.open(pic)
163
- st.image(img)
164
- prog.progress(0.49, "Computing Embeddings")
165
- device = clip_model.device
166
- tn = clip_prep(images=[img], return_tensors="pt").to(device)
167
- enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
168
- prog.progress(0.7, "Running Retrieval")
169
- retrieval_results(retrieval.retrieve(enc, k))
170
- prog.progress(1.0, "Idle")
 
171
 
172
  with tab_pc:
173
- k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rpc')
174
- load_data = misc_utils.input_3d_shape('retpc')
175
- if st.button("Run with Shape"):
176
- pc = load_data(prog)
177
- col2 = misc_utils.render_pc(pc)
178
- prog.progress(0.49, "Computing Embeddings")
179
- ref_dev = next(model_g14.parameters()).device
180
- enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
181
- prog.progress(0.7, "Running Retrieval")
182
- retrieval_results(retrieval.retrieve(enc, k))
183
- prog.progress(1.0, "Idle")
 
184
 
185
 
186
  try:
 
5
 
6
  @st.cache_data
7
  def load_support():
8
+ if st.secrets.has_key('etoken'):
9
+ HfFolder().save_token(st.secrets['etoken'])
10
  sys.path.append(snapshot_download("OpenShape/openshape-demo-support"))
11
 
12
 
13
+ # st.set_page_config(layout='wide')
14
  load_support()
15
 
16
 
 
45
 
46
  from openshape.demo import misc_utils, classification, caption, sd_pc2img, retrieval
47
 
48
+
49
  st.title("OpenShape Demo")
50
+ st.caption("For faster inference without waiting in queue, you may clone the space and run it yourself.")
51
  prog = st.progress(0.0, "Idle")
52
+ tab_cls, tab_img, tab_text, tab_pc, tab_sd, tab_cap = st.tabs([
53
  "Classification",
54
+ "Retrieval w/ Image",
55
+ "Retrieval w/ Text",
56
+ "Retrieval w/ 3D",
57
  "Image Generation",
58
  "Captioning",
59
  ])
 
66
  if len(cats) > 64:
67
  st.error('Maximum 64 custom categories supported in the demo')
68
  return
69
+ lvis_run = st.button("Run Classification on LVIS Categories")
70
+ custom_run = st.button("Run Classification on Custom Categories")
71
+ if lvis_run:
72
  pc = load_data(prog)
73
  col2 = misc_utils.render_pc(pc)
74
  prog.progress(0.5, "Running Classification")
 
78
  st.text(cat)
79
  st.caption("Similarity %.4f" % sim)
80
  prog.progress(1.0, "Idle")
81
+ if custom_run:
82
  pc = load_data(prog)
83
  col2 = misc_utils.render_pc(pc)
84
  prog.progress(0.5, "Computing Category Embeddings")
 
95
 
96
 
97
  def demo_captioning():
98
+ with st.form("capform"):
99
+ load_data = misc_utils.input_3d_shape('cap')
100
+ cond_scale = st.slider('Conditioning Scale', 0.0, 4.0, 2.0)
101
+ if st.form_submit_button("Generate a Caption"):
102
+ pc = load_data(prog)
103
+ col2 = misc_utils.render_pc(pc)
104
+ prog.progress(0.5, "Running Generation")
105
+ cap = caption.pc_caption(model_b32, pc, cond_scale)
106
+ st.text(cap)
107
+ prog.progress(1.0, "Idle")
108
 
109
 
110
  def demo_pc2img():
111
+ with st.form("sdform"):
112
+ load_data = misc_utils.input_3d_shape('sd')
113
+ prompt = st.text_input("Prompt (Optional)")
114
+ noise_scale = st.slider('Variation Level', 0, 5, 1)
115
+ cfg_scale = st.slider('Guidance Scale', 0.0, 30.0, 10.0)
116
+ steps = st.slider('Diffusion Steps', 8, 50, 25)
117
+ width = 640 # st.slider('Width', 480, 640, step=32)
118
+ height = 640 # st.slider('Height', 480, 640, step=32)
119
+ if st.form_submit_button("Generate"):
120
+ pc = load_data(prog)
121
+ col2 = misc_utils.render_pc(pc)
122
+ prog.progress(0.49, "Running Generation")
123
+ if torch.cuda.is_available():
124
+ clip_model.cpu()
125
+ img = sd_pc2img.pc_to_image(
126
+ model_l14, pc, prompt, noise_scale, width, height, cfg_scale, steps,
127
+ lambda i, t, _: prog.progress(0.49 + i / (steps + 1) / 2, "Running Diffusion Step %d" % i)
128
+ )
129
+ if torch.cuda.is_available():
130
+ clip_model.cuda()
131
+ with col2:
132
+ st.image(img)
133
+ prog.progress(1.0, "Idle")
134
 
135
 
136
  def retrieval_results(results):
 
152
 
153
  def demo_retrieval():
154
  with tab_text:
155
+ with st.form("rtextform"):
156
+ k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rtext')
157
+ text = st.text_input("Input Text")
158
+ if st.form_submit_button("Run with Text"):
159
+ prog.progress(0.49, "Computing Embeddings")
160
+ device = clip_model.device
161
+ tn = clip_prep(text=[text], return_tensors='pt', truncation=True, max_length=76).to(device)
162
+ enc = clip_model.get_text_features(**tn).float().cpu()
163
+ prog.progress(0.7, "Running Retrieval")
164
+ retrieval_results(retrieval.retrieve(enc, k))
165
+ prog.progress(1.0, "Idle")
166
 
167
  with tab_img:
168
+ with st.form("rimgform"):
169
+ k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rimage')
170
+ pic = st.file_uploader("Upload an Image")
171
+ if st.form_submit_button("Run with Image"):
172
+ img = Image.open(pic)
173
+ st.image(img)
174
+ prog.progress(0.49, "Computing Embeddings")
175
+ device = clip_model.device
176
+ tn = clip_prep(images=[img], return_tensors="pt").to(device)
177
+ enc = clip_model.get_image_features(pixel_values=tn['pixel_values'].type(half)).float().cpu()
178
+ prog.progress(0.7, "Running Retrieval")
179
+ retrieval_results(retrieval.retrieve(enc, k))
180
+ prog.progress(1.0, "Idle")
181
 
182
  with tab_pc:
183
+ with st.form("rpcform"):
184
+ k = st.slider("# Shapes to Retrieve", 1, 100, 16, key='rpc')
185
+ load_data = misc_utils.input_3d_shape('retpc')
186
+ if st.form_submit_button("Run with Shape"):
187
+ pc = load_data(prog)
188
+ col2 = misc_utils.render_pc(pc)
189
+ prog.progress(0.49, "Computing Embeddings")
190
+ ref_dev = next(model_g14.parameters()).device
191
+ enc = model_g14(torch.tensor(pc[:, [0, 2, 1, 3, 4, 5]].T[None], device=ref_dev)).cpu()
192
+ prog.progress(0.7, "Running Retrieval")
193
+ retrieval_results(retrieval.retrieve(enc, k))
194
+ prog.progress(1.0, "Idle")
195
 
196
 
197
  try: