bkhmsi commited on
Commit
4435e47
·
1 Parent(s): 6cee4ab

bug fix + cleaned losses

Browse files
Files changed (1) hide show
  1. code/losses.py +0 -11
code/losses.py CHANGED
@@ -9,9 +9,6 @@ from shapely.geometry import Point
9
  from shapely.geometry.polygon import Polygon
10
  from torchvision import transforms
11
  from PIL import Image
12
- from transformers import CLIPProcessor, CLIPModel
13
-
14
- from diffusers import StableDiffusionPipeline
15
 
16
  class SDSLoss(nn.Module):
17
  def __init__(self, cfg, device, model):
@@ -20,11 +17,6 @@ class SDSLoss(nn.Module):
20
  self.device = device
21
  self.pipe = model
22
 
23
- # self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").to(self.device)
24
- # self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
25
-
26
- # default scheduler: PNDMScheduler(beta_start=0.00085, beta_end=0.012,
27
- # beta_schedule="scaled_linear", num_train_timesteps=1000)
28
  self.alphas = self.pipe.scheduler.alphas_cumprod.to(self.device)
29
  self.sigmas = (1 - self.pipe.scheduler.alphas_cumprod).to(self.device)
30
 
@@ -46,9 +38,6 @@ class SDSLoss(nn.Module):
46
 
47
  self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
48
  self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)
49
- del self.pipe.tokenizer
50
- del self.pipe.text_encoder
51
-
52
 
53
  def forward(self, x_aug):
54
  sds_loss = 0
 
9
  from shapely.geometry.polygon import Polygon
10
  from torchvision import transforms
11
  from PIL import Image
 
 
 
12
 
13
  class SDSLoss(nn.Module):
14
  def __init__(self, cfg, device, model):
 
17
  self.device = device
18
  self.pipe = model
19
 
 
 
 
 
 
20
  self.alphas = self.pipe.scheduler.alphas_cumprod.to(self.device)
21
  self.sigmas = (1 - self.pipe.scheduler.alphas_cumprod).to(self.device)
22
 
 
38
 
39
  self.text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
40
  self.text_embeddings = self.text_embeddings.repeat_interleave(self.cfg.batch_size, 0)
 
 
 
41
 
42
  def forward(self, x_aug):
43
  sds_loss = 0