azizinaghsh
commited on
Commit
·
0d8e18a
1
Parent(s):
9962869
add seed for CCD
Browse files- CCD/src/main.py +12 -3
- app.py +4 -4
CCD/src/main.py
CHANGED
@@ -306,8 +306,17 @@ def train():
|
|
306 |
torch.save(ddpm.state_dict(), save_dir + f"model_{ep}.pth")
|
307 |
print('saved model at ' + save_dir + f"model_{ep}.pth")
|
308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
310 |
-
|
|
|
|
|
|
|
311 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
312 |
|
313 |
mean_std_path = os.path.join(script_dir, "..", "checkpoints", "Mean_Std.npy")
|
@@ -373,8 +382,8 @@ def gen(text: str):
|
|
373 |
# f.write(txt+"\n")
|
374 |
|
375 |
|
376 |
-
def generate_CCD_sample(text: str):
|
377 |
-
return gen(text)
|
378 |
|
379 |
if __name__ == "__main__":
|
380 |
import sys
|
|
|
306 |
torch.save(ddpm.state_dict(), save_dir + f"model_{ep}.pth")
|
307 |
print('saved model at ' + save_dir + f"model_{ep}.pth")
|
308 |
|
309 |
+
def set_seed(seed: int):
|
310 |
+
random.seed(seed)
|
311 |
+
np.random.seed(seed)
|
312 |
+
torch.manual_seed(seed)
|
313 |
+
if torch.cuda.is_available():
|
314 |
+
torch.cuda.manual_seed_all(seed)
|
315 |
|
316 |
+
|
317 |
+
def gen(text: str, seed: int):
|
318 |
+
set_seed(seed)
|
319 |
+
|
320 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
321 |
|
322 |
mean_std_path = os.path.join(script_dir, "..", "checkpoints", "Mean_Std.npy")
|
|
|
382 |
# f.write(txt+"\n")
|
383 |
|
384 |
|
385 |
+
def generate_CCD_sample(text: str, seed : int):
|
386 |
+
return gen(text, seed)
|
387 |
|
388 |
if __name__ == "__main__":
|
389 |
import sys
|
app.py
CHANGED
@@ -111,7 +111,7 @@ def generate_ccd(
|
|
111 |
character_position: list,
|
112 |
) -> Dict[str, Any]:
|
113 |
|
114 |
-
results = generate_CCD_sample(prompt)
|
115 |
|
116 |
rr.init(f"{3}")
|
117 |
rr.save(".tmp_gr.rrd")
|
@@ -198,7 +198,7 @@ def launch_app(gen_fn_et: Callable, gen_fn_ccd: Callable):
|
|
198 |
show_label=True,
|
199 |
label="Character Position (3D vector)",
|
200 |
value="[0.0, 0.0, 0.0]",
|
201 |
-
interactive=True,
|
202 |
|
203 |
)
|
204 |
text = gr.Textbox(
|
@@ -206,13 +206,13 @@ def launch_app(gen_fn_et: Callable, gen_fn_ccd: Callable):
|
|
206 |
show_label=True,
|
207 |
label="Text prompt",
|
208 |
value=DEFAULT_TEXT[0],
|
209 |
-
interactive=True,
|
210 |
|
211 |
)
|
212 |
seed = gr.Number(value=33, label="Seed")
|
213 |
guidance = gr.Slider(0, 10, value=1.4, label="Guidance", step=0.1)
|
214 |
|
215 |
-
|
216 |
model_selector = gr.Dropdown(
|
217 |
choices=list(model_options.keys()),
|
218 |
value=list(model_options.keys())[0],
|
|
|
111 |
character_position: list,
|
112 |
) -> Dict[str, Any]:
|
113 |
|
114 |
+
results = generate_CCD_sample(prompt, seed)
|
115 |
|
116 |
rr.init(f"{3}")
|
117 |
rr.save(".tmp_gr.rrd")
|
|
|
198 |
show_label=True,
|
199 |
label="Character Position (3D vector)",
|
200 |
value="[0.0, 0.0, 0.0]",
|
201 |
+
interactive=True,
|
202 |
|
203 |
)
|
204 |
text = gr.Textbox(
|
|
|
206 |
show_label=True,
|
207 |
label="Text prompt",
|
208 |
value=DEFAULT_TEXT[0],
|
209 |
+
interactive=True,
|
210 |
|
211 |
)
|
212 |
seed = gr.Number(value=33, label="Seed")
|
213 |
guidance = gr.Slider(0, 10, value=1.4, label="Guidance", step=0.1)
|
214 |
|
215 |
+
|
216 |
model_selector = gr.Dropdown(
|
217 |
choices=list(model_options.keys()),
|
218 |
value=list(model_options.keys())[0],
|