|
import glob |
|
import os |
|
from configs import global_config, paths_config, hyperparameters |
|
from scripts.latent_creators.sg2_plus_latent_creator import SG2PlusLatentCreator |
|
from scripts.latent_creators.e4e_latent_creator import E4ELatentCreator |
|
from scripts.run_pti import run_PTI |
|
import pickle |
|
import torch |
|
from utils.models_utils import toogle_grad, load_old_G |
|
|
|
|
|
class ExperimentRunner: |
|
|
|
def __init__(self, run_id=''): |
|
self.images_paths = glob.glob(f'{paths_config.input_data_path}/*') |
|
self.target_paths = glob.glob(f'{paths_config.input_data_path}/*') |
|
self.run_id = run_id |
|
self.sampled_ws = None |
|
|
|
self.old_G = load_old_G() |
|
|
|
toogle_grad(self.old_G, False) |
|
|
|
def run_experiment(self, run_pt, create_other_latents, use_multi_id_training, use_wandb=False): |
|
if run_pt: |
|
self.run_id = run_PTI(self.run_id, use_wandb=use_wandb, use_multi_id_training=use_multi_id_training) |
|
if create_other_latents: |
|
sg2_plus_latent_creator = SG2PlusLatentCreator(use_wandb=use_wandb) |
|
sg2_plus_latent_creator.create_latents() |
|
e4e_latent_creator = E4ELatentCreator(use_wandb=use_wandb) |
|
e4e_latent_creator.create_latents() |
|
|
|
torch.cuda.empty_cache() |
|
|
|
return self.run_id |
|
|
|
|
|
if __name__ == '__main__': |
|
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' |
|
os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices |
|
|
|
runner = ExperimentRunner() |
|
runner.run_experiment(True, False, False) |
|
|