Spaces:
Runtime error
Runtime error
Badr AlKhamissi
commited on
Commit
·
e8f6bdd
1
Parent(s):
9530fad
losses fix
Browse files- code/losses.py +2 -5
code/losses.py
CHANGED
@@ -14,12 +14,11 @@ from transformers import CLIPProcessor, CLIPModel
|
|
14 |
from diffusers import StableDiffusionPipeline
|
15 |
|
16 |
class SDSLoss(nn.Module):
|
17 |
-
def __init__(self, cfg, device):
|
18 |
super(SDSLoss, self).__init__()
|
19 |
self.cfg = cfg
|
20 |
self.device = device
|
21 |
-
self.pipe =
|
22 |
-
torch_dtype=torch.float16, use_auth_token=cfg.token)
|
23 |
self.pipe = self.pipe.to(self.device)
|
24 |
|
25 |
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
|
@@ -55,8 +54,6 @@ class SDSLoss(nn.Module):
|
|
55 |
text_embeddings = img_emb
|
56 |
uncond_embeddings = img_emb
|
57 |
|
58 |
-
print(text_embeddings.size())
|
59 |
-
print(uncond_embeddings.size())
|
60 |
self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
61 |
self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)
|
62 |
del self.pipe.tokenizer
|
|
|
14 |
from diffusers import StableDiffusionPipeline
|
15 |
|
16 |
class SDSLoss(nn.Module):
|
17 |
+
def __init__(self, cfg, device, model):
|
18 |
super(SDSLoss, self).__init__()
|
19 |
self.cfg = cfg
|
20 |
self.device = device
|
21 |
+
self.pipe = model
|
|
|
22 |
self.pipe = self.pipe.to(self.device)
|
23 |
|
24 |
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
|
|
|
54 |
text_embeddings = img_emb
|
55 |
uncond_embeddings = img_emb
|
56 |
|
|
|
|
|
57 |
self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
58 |
self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)
|
59 |
del self.pipe.tokenizer
|