unknown commited on
Commit
1197691
·
1 Parent(s): 91d0a5f

move to gpu in foley

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -120,11 +120,7 @@ class FoleyController:
120
 
121
  self.pipeline.load_ip_adapter(fc_ckpt, subfolder='semantic', weight_name='semantic_adapter.bin', image_encoder_folder=None)
122
 
123
- # move to gpu
124
- self.time_detector.to(self.device)
125
- self.pipeline.to(self.device)
126
- self.vocoder.to(self.device)
127
- self.image_encoder.to(self.device)
128
 
129
  gr.Info("Load Finish!")
130
  print("Load Finish!")
@@ -145,14 +141,20 @@ class FoleyController:
145
  cfg_scale_slider,
146
  seed_textbox,
147
  ):
 
 
 
 
 
 
148
  vision_transform_list = [
149
  torchvision.transforms.Resize((128, 128)),
150
  torchvision.transforms.CenterCrop((112, 112)),
151
  torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
152
  ]
153
  video_transform = torchvision.transforms.Compose(vision_transform_list)
154
- # if not self.loaded:
155
- # raise gr.Error("Error with loading model")
156
  generator = torch.Generator()
157
  if seed_textbox != "":
158
  torch.manual_seed(int(seed_textbox))
 
120
 
121
  self.pipeline.load_ip_adapter(fc_ckpt, subfolder='semantic', weight_name='semantic_adapter.bin', image_encoder_folder=None)
122
 
123
+ self.move_to_device()
 
 
 
 
124
 
125
  gr.Info("Load Finish!")
126
  print("Load Finish!")
 
141
  cfg_scale_slider,
142
  seed_textbox,
143
  ):
144
+ # move to gpu
145
+ self.time_detector.to(self.device)
146
+ self.pipeline.to(self.device)
147
+ self.vocoder.to(self.device)
148
+ self.image_encoder.to(self.device)
149
+
150
  vision_transform_list = [
151
  torchvision.transforms.Resize((128, 128)),
152
  torchvision.transforms.CenterCrop((112, 112)),
153
  torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
154
  ]
155
  video_transform = torchvision.transforms.Compose(vision_transform_list)
156
+ if not self.loaded:
157
+ raise gr.Error("Error with loading model")
158
  generator = torch.Generator()
159
  if seed_textbox != "":
160
  torch.manual_seed(int(seed_textbox))