jadechoghari commited on
Commit
f1d908c
·
verified ·
1 Parent(s): e662c48

Update diffloss.py

Browse files
Files changed (1) hide show
  1. diffloss.py +13 -3
diffloss.py CHANGED
@@ -91,10 +91,20 @@ class TimestepEmbedder(nn.Module):
91
  embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
92
  return embedding
93
 
 
 
 
 
94
  def forward(self, t):
95
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
96
- t_emb = self.mlp(t_freq)
97
- return t_emb
 
 
 
 
 
 
98
 
99
 
100
  class ResBlock(nn.Module):
 
91
  embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
92
  return embedding
93
 
94
+ # def forward(self, t):
95
+ # t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
96
+ # t_emb = self.mlp(t_freq)
97
+ # return t_emb
98
  def forward(self, t):
99
+ t = t.to(self.mlp.weight.device)
100
+
101
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
102
+
103
+ t_freq = t_freq.to(self.mlp.weight.device)
104
+
105
+ t_emb = self.mlp(t_freq)
106
+
107
+ return t_emb
108
 
109
 
110
  class ResBlock(nn.Module):