unknown commited on
Commit
3cb7b20
·
1 Parent(s): 72bd3bf

move to gpu

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -72,7 +72,6 @@ class FoleyController:
72
 
73
  self.load_model()
74
 
75
- @spaces.GPU
76
  def load_model(self):
77
  gr.Info("Start Load Models...")
78
  print("Start Load Models...")
@@ -93,15 +92,15 @@ class FoleyController:
93
  vocoder_config_path= "./models/auffusion"
94
  self.vocoder = Generator.from_pretrained(
95
  vocoder_config_path,
96
- subfolder="vocoder").to(self.device)
97
 
98
  # load time detector
99
  time_detector_ckpt = osp.join(osp.join(self.model_dir, 'timestamp_detector.pth.tar'))
100
  time_detector = VideoOnsetNet(False)
101
  self.time_detector, _ = torch_utils.load_model(time_detector_ckpt, time_detector, strict=True)
102
- self.time_detector = self.time_detector.to(self.device)
103
 
104
- self.pipeline = build_foleycrafter().to(self.device)
105
  ckpt = torch.load(temporal_ckpt_path)
106
 
107
  # load temporal adapter
@@ -117,7 +116,7 @@ class FoleyController:
117
  print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
118
 
119
  self.image_processor = CLIPImageProcessor()
120
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained('h94/IP-Adapter', subfolder='models/image_encoder').to(self.device)
121
 
122
  self.pipeline.load_ip_adapter(fc_ckpt, subfolder='semantic', weight_name='semantic_adapter.bin', image_encoder_folder=None)
123
 
@@ -140,7 +139,12 @@ class FoleyController:
140
  cfg_scale_slider,
141
  seed_textbox,
142
  ):
143
-
 
 
 
 
 
144
  vision_transform_list = [
145
  torchvision.transforms.Resize((128, 128)),
146
  torchvision.transforms.CenterCrop((112, 112)),
 
72
 
73
  self.load_model()
74
 
 
75
  def load_model(self):
76
  gr.Info("Start Load Models...")
77
  print("Start Load Models...")
 
92
  vocoder_config_path= "./models/auffusion"
93
  self.vocoder = Generator.from_pretrained(
94
  vocoder_config_path,
95
+ subfolder="vocoder")
96
 
97
  # load time detector
98
  time_detector_ckpt = osp.join(osp.join(self.model_dir, 'timestamp_detector.pth.tar'))
99
  time_detector = VideoOnsetNet(False)
100
  self.time_detector, _ = torch_utils.load_model(time_detector_ckpt, time_detector, strict=True)
101
+ self.time_detector = self.time_detector
102
 
103
+ self.pipeline = build_foleycrafter()
104
  ckpt = torch.load(temporal_ckpt_path)
105
 
106
  # load temporal adapter
 
116
  print(f"### Control Net missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
117
 
118
  self.image_processor = CLIPImageProcessor()
119
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained('h94/IP-Adapter', subfolder='models/image_encoder')
120
 
121
  self.pipeline.load_ip_adapter(fc_ckpt, subfolder='semantic', weight_name='semantic_adapter.bin', image_encoder_folder=None)
122
 
 
139
  cfg_scale_slider,
140
  seed_textbox,
141
  ):
142
+ # move to gpu
143
+ self.time_detector = self.time_detector.to(self.device)
144
+ self.pipeline = self.pipeline.to(self.device)
145
+ self.vocoder = self.vocoder.to(self.device)
146
+ self.image_encoder = self.image_encoder.to(self.device)
147
+
148
  vision_transform_list = [
149
  torchvision.transforms.Resize((128, 128)),
150
  torchvision.transforms.CenterCrop((112, 112)),