azizinaghsh commited on
Commit
0d8e18a
·
1 Parent(s): 9962869

add seed for CCD

Browse files
Files changed (2) hide show
  1. CCD/src/main.py +12 -3
  2. app.py +4 -4
CCD/src/main.py CHANGED
@@ -306,8 +306,17 @@ def train():
306
  torch.save(ddpm.state_dict(), save_dir + f"model_{ep}.pth")
307
  print('saved model at ' + save_dir + f"model_{ep}.pth")
308
 
 
 
 
 
 
 
309
 
310
- def gen(text: str):
 
 
 
311
  script_dir = os.path.dirname(os.path.abspath(__file__))
312
 
313
  mean_std_path = os.path.join(script_dir, "..", "checkpoints", "Mean_Std.npy")
@@ -373,8 +382,8 @@ def gen(text: str):
373
  # f.write(txt+"\n")
374
 
375
 
376
- def generate_CCD_sample(text: str):
377
- return gen(text)
378
 
379
  if __name__ == "__main__":
380
  import sys
 
306
  torch.save(ddpm.state_dict(), save_dir + f"model_{ep}.pth")
307
  print('saved model at ' + save_dir + f"model_{ep}.pth")
308
 
309
+ def set_seed(seed: int):
310
+ random.seed(seed)
311
+ np.random.seed(seed)
312
+ torch.manual_seed(seed)
313
+ if torch.cuda.is_available():
314
+ torch.cuda.manual_seed_all(seed)
315
 
316
+
317
+ def gen(text: str, seed: int):
318
+ set_seed(seed)
319
+
320
  script_dir = os.path.dirname(os.path.abspath(__file__))
321
 
322
  mean_std_path = os.path.join(script_dir, "..", "checkpoints", "Mean_Std.npy")
 
382
  # f.write(txt+"\n")
383
 
384
 
385
+ def generate_CCD_sample(text: str, seed : int):
386
+ return gen(text, seed)
387
 
388
  if __name__ == "__main__":
389
  import sys
app.py CHANGED
@@ -111,7 +111,7 @@ def generate_ccd(
111
  character_position: list,
112
  ) -> Dict[str, Any]:
113
 
114
- results = generate_CCD_sample(prompt)
115
 
116
  rr.init(f"{3}")
117
  rr.save(".tmp_gr.rrd")
@@ -198,7 +198,7 @@ def launch_app(gen_fn_et: Callable, gen_fn_ccd: Callable):
198
  show_label=True,
199
  label="Character Position (3D vector)",
200
  value="[0.0, 0.0, 0.0]",
201
- interactive=True, # Ensure this is set to True
202
 
203
  )
204
  text = gr.Textbox(
@@ -206,13 +206,13 @@ def launch_app(gen_fn_et: Callable, gen_fn_ccd: Callable):
206
  show_label=True,
207
  label="Text prompt",
208
  value=DEFAULT_TEXT[0],
209
- interactive=True, # Ensure this is set to True
210
 
211
  )
212
  seed = gr.Number(value=33, label="Seed")
213
  guidance = gr.Slider(0, 10, value=1.4, label="Guidance", step=0.1)
214
 
215
- # Add a dropdown menu for selecting the generation model
216
  model_selector = gr.Dropdown(
217
  choices=list(model_options.keys()),
218
  value=list(model_options.keys())[0],
 
111
  character_position: list,
112
  ) -> Dict[str, Any]:
113
 
114
+ results = generate_CCD_sample(prompt, seed)
115
 
116
  rr.init(f"{3}")
117
  rr.save(".tmp_gr.rrd")
 
198
  show_label=True,
199
  label="Character Position (3D vector)",
200
  value="[0.0, 0.0, 0.0]",
201
+ interactive=True,
202
 
203
  )
204
  text = gr.Textbox(
 
206
  show_label=True,
207
  label="Text prompt",
208
  value=DEFAULT_TEXT[0],
209
+ interactive=True,
210
 
211
  )
212
  seed = gr.Number(value=33, label="Seed")
213
  guidance = gr.Slider(0, 10, value=1.4, label="Guidance", step=0.1)
214
 
215
+
216
  model_selector = gr.Dropdown(
217
  choices=list(model_options.keys()),
218
  value=list(model_options.keys())[0],