Spaces:
Runtime error
Runtime error
Updated to use arbitrary model paths
Browse files
app.py
CHANGED
@@ -31,25 +31,22 @@ from generate_videos import generate_frames, video_from_interpolations, vid_to_g
|
|
31 |
model_dir = "models"
|
32 |
os.makedirs(model_dir, exist_ok=True)
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
|
38 |
def get_models():
|
39 |
os.makedirs(model_dir, exist_ok=True)
|
40 |
|
41 |
-
|
42 |
-
hf_hub_download(repo_id=repo_id, filename=file_path)
|
43 |
-
if not "akhaliq" in repo_id:
|
44 |
-
shutil.move(file_path, os.path.join(model_dir, file_path))
|
45 |
-
elif "stylegan2" in file_path:
|
46 |
-
shutil.move(file_path, os.path.join(model_dir, "base.pt"))
|
47 |
|
48 |
-
|
|
|
|
|
49 |
|
50 |
-
return
|
51 |
|
52 |
-
|
53 |
|
54 |
class ImageEditor(object):
|
55 |
def __init__(self):
|
@@ -62,18 +59,20 @@ class ImageEditor(object):
|
|
62 |
|
63 |
self.generators = {}
|
64 |
|
65 |
-
for
|
|
|
|
|
66 |
g_ema = Generator(
|
67 |
model_size, latent_size, n_mlp, channel_multiplier=channel_mult
|
68 |
).to(self.device)
|
69 |
|
70 |
-
checkpoint = torch.load(
|
71 |
|
72 |
g_ema.load_state_dict(checkpoint['g_ema'])
|
73 |
|
74 |
self.generators[model] = g_ema
|
75 |
|
76 |
-
self.experiment_args = {"model_path": "
|
77 |
self.experiment_args["transform"] = transforms.Compose(
|
78 |
[
|
79 |
transforms.Resize((256, 256)),
|
@@ -96,7 +95,7 @@ class ImageEditor(object):
|
|
96 |
self.e4e_net.cuda()
|
97 |
|
98 |
self.shape_predictor = dlib.shape_predictor(
|
99 |
-
|
100 |
)
|
101 |
|
102 |
print("setup complete")
|
@@ -120,11 +119,11 @@ class ImageEditor(object):
|
|
120 |
):
|
121 |
|
122 |
if output_style == 'all':
|
123 |
-
styles = model_list
|
124 |
elif output_style == 'list - enter below':
|
125 |
styles = style_list.split(",")
|
126 |
for style in styles:
|
127 |
-
if style not in model_list:
|
128 |
raise ValueError(f"Encountered style '{style}' in the style_list which is not an available option.")
|
129 |
else:
|
130 |
styles = [output_style]
|
|
|
31 |
model_dir = "models"
|
32 |
os.makedirs(model_dir, exist_ok=True)
|
33 |
|
34 |
+
model_repos = {"e4e": ("akhaliq/JoJoGAN_e4e_ffhq_encode", "e4e_ffhq_encode.pt"),
|
35 |
+
"dlib": ("akhaliq/jojogan_dlib", "shape_predictor_68_face_landmarks.dat"),
|
36 |
+
"base": ("akhaliq/jojogan-stylegan2-ffhq-config-f", "stylegan2-ffhq-config-f.pt")}
|
37 |
|
38 |
def get_models():
|
39 |
os.makedirs(model_dir, exist_ok=True)
|
40 |
|
41 |
+
model_paths = {}
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
for model_name, repo_details in model_repos.items():
|
44 |
+
download_path = hf_hub_download(repo_id=repo_details[0], filename=repo_details[1])
|
45 |
+
model_paths[model_name] = download_path
|
46 |
|
47 |
+
return model_paths
|
48 |
|
49 |
+
model_paths = get_models()
|
50 |
|
51 |
class ImageEditor(object):
|
52 |
def __init__(self):
|
|
|
59 |
|
60 |
self.generators = {}
|
61 |
|
62 |
+
self.model_list = [name for name in model_paths.keys() if name not in ["e4e", "dlib"]]
|
63 |
+
|
64 |
+
for model in self.model_list:
|
65 |
g_ema = Generator(
|
66 |
model_size, latent_size, n_mlp, channel_multiplier=channel_mult
|
67 |
).to(self.device)
|
68 |
|
69 |
+
checkpoint = torch.load(model_paths[model])
|
70 |
|
71 |
g_ema.load_state_dict(checkpoint['g_ema'])
|
72 |
|
73 |
self.generators[model] = g_ema
|
74 |
|
75 |
+
self.experiment_args = {"model_path": model_paths["e4e"]}
|
76 |
self.experiment_args["transform"] = transforms.Compose(
|
77 |
[
|
78 |
transforms.Resize((256, 256)),
|
|
|
95 |
self.e4e_net.cuda()
|
96 |
|
97 |
self.shape_predictor = dlib.shape_predictor(
|
98 |
+
model_paths["dlib"]
|
99 |
)
|
100 |
|
101 |
print("setup complete")
|
|
|
119 |
):
|
120 |
|
121 |
if output_style == 'all':
|
122 |
+
styles = self.model_list
|
123 |
elif output_style == 'list - enter below':
|
124 |
styles = style_list.split(",")
|
125 |
for style in styles:
|
126 |
+
if style not in self.model_list:
|
127 |
raise ValueError(f"Encountered style '{style}' in the style_list which is not an available option.")
|
128 |
else:
|
129 |
styles = [output_style]
|