Update tune.py
Browse files
tune.py
CHANGED
@@ -24,9 +24,26 @@ hyperparameters.lpips_type = 'squeeze'
|
|
24 |
|
25 |
from scripts.run_pti import run_PTI
|
26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
def tune():
|
28 |
model_id = run_PTI(run_name='',use_wandb=False, use_multi_id_training=False)
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
#----------------------------------------------------------------------------
|
31 |
if __name__ == '__main__':
|
32 |
tune()
|
|
|
24 |
|
25 |
from scripts.run_pti import run_PTI
|
26 |
|
27 |
+
def load_generator(model_id):
|
28 |
+
with open(f'{paths_config.checkpoints_dir}/model_{model_id}_file.pt', 'rb') as f_new:
|
29 |
+
new_G = torch.load(f_new).cuda()
|
30 |
+
return new_G
|
31 |
+
|
32 |
+
def tensor_to_pil(img):
|
33 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0]
|
34 |
+
plt.axis('off')
|
35 |
+
resized_image = Image.fromarray(img,mode='RGB').resize((256,256))
|
36 |
+
return resized_image
|
37 |
+
|
38 |
def tune():
|
39 |
model_id = run_PTI(run_name='',use_wandb=False, use_multi_id_training=False)
|
40 |
+
w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}'
|
41 |
+
embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}'
|
42 |
+
w_pivot = torch.load(f'{embedding_dir}/0.pt')
|
43 |
+
new_G = load_generator(model_id)
|
44 |
+
new_image = new_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True)
|
45 |
+
tensor_to_pil(new_image).save("output/out.png")
|
46 |
+
|
47 |
#----------------------------------------------------------------------------
|
48 |
if __name__ == '__main__':
|
49 |
tune()
|