Spaces:
Runtime error
Runtime error
Damian Stewart
commited on
Commit
·
209d166
1
Parent(s):
0002379
provide empty negative prompt when training
Browse files- StableDiffuser.py +3 -1
- train.py +0 -2
StableDiffuser.py
CHANGED
@@ -114,9 +114,11 @@ class StableDiffuser(torch.nn.Module):
|
|
114 |
latents = noise * self.scheduler.init_noise_sigma
|
115 |
return latents
|
116 |
|
117 |
-
def get_text_embeddings(self, prompts, negative_prompts, n_imgs):
|
118 |
text_tokens = self.text_tokenize(prompts)
|
119 |
text_embeddings = self.text_encode(text_tokens)
|
|
|
|
|
120 |
unconditional_tokens = self.text_tokenize(negative_prompts)
|
121 |
unconditional_embeddings = self.text_encode(unconditional_tokens)
|
122 |
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
|
|
|
114 |
latents = noise * self.scheduler.init_noise_sigma
|
115 |
return latents
|
116 |
|
117 |
+
def get_text_embeddings(self, prompts, negative_prompts=None, n_imgs=1):
|
118 |
text_tokens = self.text_tokenize(prompts)
|
119 |
text_embeddings = self.text_encode(text_tokens)
|
120 |
+
if negative_prompts is None:
|
121 |
+
negative_prompts = [""] * len(prompts)
|
122 |
unconditional_tokens = self.text_tokenize(negative_prompts)
|
123 |
unconditional_embeddings = self.text_encode(unconditional_tokens)
|
124 |
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings]).repeat_interleave(n_imgs, dim=0)
|
train.py
CHANGED
@@ -36,11 +36,9 @@ def train(repo_id_or_path, img_size, prompt, modules, freeze_modules, iterations
|
|
36 |
optimizer.zero_grad()
|
37 |
|
38 |
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
39 |
-
|
40 |
latents = diffuser.get_initial_latents(1, img_size, 1)
|
41 |
|
42 |
with finetuner:
|
43 |
-
|
44 |
latents_steps, _ = diffuser.diffusion(
|
45 |
latents,
|
46 |
positive_text_embeddings,
|
|
|
36 |
optimizer.zero_grad()
|
37 |
|
38 |
iteration = torch.randint(1, nsteps - 1, (1,)).item()
|
|
|
39 |
latents = diffuser.get_initial_latents(1, img_size, 1)
|
40 |
|
41 |
with finetuner:
|
|
|
42 |
latents_steps, _ = diffuser.diffusion(
|
43 |
latents,
|
44 |
positive_text_embeddings,
|