RohitGandikota commited on
Commit
ec60320
·
1 Parent(s): b0461af

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +4 -4
train.py CHANGED
@@ -7,13 +7,13 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
7
 
8
  nsteps = 50
9
 
10
- diffuser = StableDiffuser(scheduler='DDIM').to('cuda:1')
11
  diffuser.train()
12
 
13
 
14
 
15
 
16
- finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules).to('cuda:0')
17
 
18
  optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
19
  criteria = torch.nn.MSELoss()
@@ -58,8 +58,8 @@ def train(prompt, modules, freeze_modules, iterations, negative_guidance, lr, sa
58
 
59
  iteration = int(iteration / nsteps * 1000)
60
 
61
- positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1).to('cuda:0')
62
- neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1).to('cuda:0')
63
 
64
  with finetuner:
65
  negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
 
7
 
8
  nsteps = 50
9
 
10
+ diffuser = StableDiffuser(scheduler='DDIM').to('cuda')
11
  diffuser.train()
12
 
13
 
14
 
15
 
16
+ finetuner = FineTunedModel(diffuser, modules, frozen_modules=freeze_modules)
17
 
18
  optimizer = torch.optim.Adam(finetuner.parameters(), lr=lr)
19
  criteria = torch.nn.MSELoss()
 
58
 
59
  iteration = int(iteration / nsteps * 1000)
60
 
61
+ positive_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)
62
+ neutral_latents = diffuser.predict_noise(iteration, latents_steps[0], neutral_text_embeddings, guidance_scale=1)
63
 
64
  with finetuner:
65
  negative_latents = diffuser.predict_noise(iteration, latents_steps[0], positive_text_embeddings, guidance_scale=1)