ucalyptus commited on
Commit
98df4c6
1 Parent(s): f2514bc

Update tune.py

Browse files
Files changed (1) hide show
  1. tune.py +18 -1
tune.py CHANGED
@@ -24,9 +24,26 @@ hyperparameters.lpips_type = 'squeeze'
24
 
25
  from scripts.run_pti import run_PTI
26
 
 
 
 
 
 
 
 
 
 
 
 
27
  def tune():
28
  model_id = run_PTI(run_name='',use_wandb=False, use_multi_id_training=False)
29
-
 
 
 
 
 
 
30
  #----------------------------------------------------------------------------
31
  if __name__ == '__main__':
32
  tune()
 
24
 
25
  from scripts.run_pti import run_PTI
26
 
27
+ def load_generator(model_id):
28
+ with open(f'{paths_config.checkpoints_dir}/model_{model_id}_file.pt', 'rb') as f_new:
29
+ new_G = torch.load(f_new).cuda()
30
+ return new_G
31
+
32
+ def tensor_to_pil(img):
33
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).detach().cpu().numpy()[0]
34
+ plt.axis('off')
35
+ resized_image = Image.fromarray(img,mode='RGB').resize((256,256))
36
+ return resized_image
37
+
38
  def tune():
39
  model_id = run_PTI(run_name='',use_wandb=False, use_multi_id_training=False)
40
+ w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}'
41
+ embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}'
42
+ w_pivot = torch.load(f'{embedding_dir}/0.pt')
43
+ new_G = load_generator(model_id)
44
+ new_image = new_G.synthesis(w_pivot, noise_mode='const', force_fp32 = True)
45
+ tensor_to_pil(new_image).save("output/out.png")
46
+
47
  #----------------------------------------------------------------------------
48
  if __name__ == '__main__':
49
  tune()