NIRVANALAN commited on
Commit
b6efe48
1 Parent(s): 44bb30d
nsr/lsgm/train_util_diffusion_lsgm_noD_joint.py CHANGED
@@ -154,8 +154,11 @@ class SDETrainLoopJoint(TrainLoopDiffusionWithRec):
154
  self.ddp_rec_model = functools.partial(self.model, model_name='rec')
155
  self.ddp_ddpm_model = functools.partial(self.model, model_name='ddpm')
156
 
157
- self.rec_model = self.ddp_model.module.rec_model
158
- self.ddpm_model = self.ddp_model.module.ddpm_model # compatability
 
 
 
159
 
160
  # TODO, required?
161
  # for param in self.ddp_rec_model.module.decoder.triplane_decoder.parameters( # type: ignore
 
154
  self.ddp_rec_model = functools.partial(self.model, model_name='rec')
155
  self.ddp_ddpm_model = functools.partial(self.model, model_name='ddpm')
156
 
157
+ # self.rec_model = self.ddp_model.module.rec_model
158
+ # self.ddpm_model = self.ddp_model.module.ddpm_model # compatability
159
+
160
+ self.rec_model = self.ddp_model.rec_model
161
+ self.ddpm_model = self.ddp_model.ddpm_model # compatability
162
 
163
  # TODO, required?
164
  # for param in self.ddp_rec_model.module.decoder.triplane_decoder.parameters( # type: ignore