mazpie commited on
Commit
e6f6d44
1 Parent(s): 4943924

Update demo/t2v.py

Browse files
Files changed (1) hide show
  1. demo/t2v.py +3 -5
demo/t2v.py CHANGED
@@ -52,11 +52,12 @@ class Text2Video():
52
  if not os.path.exists(self.result_dir):
53
  os.mkdir(self.result_dir)
54
 
 
 
 
55
  @spaces.GPU
56
  def get_prompt(self, prompt, duration):
57
  torch.cuda.empty_cache()
58
- self.agent.to('cuda')
59
- self.clip.to('cuda')
60
 
61
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
62
  start = time.time()
@@ -93,9 +94,6 @@ class Text2Video():
93
 
94
  save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
95
  print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
96
- # Offload GPU
97
- self.agent.to('cpu')
98
- self.clip.to('cpu')
99
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
100
 
101
  def download_model(self, model_folder, model_filename):
 
52
  if not os.path.exists(self.result_dir):
53
  os.mkdir(self.result_dir)
54
 
55
+ self.agent.to('cuda')
56
+ self.clip.to('cuda')
57
+
58
  @spaces.GPU
59
  def get_prompt(self, prompt, duration):
60
  torch.cuda.empty_cache()
 
 
61
 
62
  print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
63
  start = time.time()
 
94
 
95
  save_videos(prior_recon.unsqueeze(0), self.result_dir, filenames=[prompt_str], fps=15)
96
  print(f"Saved in {prompt_str}.mp4. Time used: {(time.time() - start):.2f} seconds")
 
 
 
97
  return os.path.join(self.result_dir, f"{prompt_str}.mp4")
98
 
99
  def download_model(self, model_folder, model_filename):