unknown commited on
Commit
12d8e68
1 Parent(s): e8a0fdf
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -154,7 +154,7 @@ class FoleyController:
154
  frames, duration = read_frames_with_moviepy(input_video, max_frame_nums=max_frame_nums)
155
  if duration >= 10:
156
  duration = 10
157
- time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2).to(self.device)
158
  time_frames = video_transform(time_frames)
159
  time_frames = {'frames': time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)}
160
  preds = self.time_detector(time_frames)
@@ -166,7 +166,7 @@ class FoleyController:
166
  # w -> b c h w
167
  time_condition = torch.FloatTensor(time_condition).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 256, 1)
168
 
169
- images = self.image_processor(images=frames, return_tensors="pt").to(self.device)
170
  image_embeddings = self.image_encoder(**images).image_embeds
171
  image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0)
172
  neg_image_embeddings = torch.zeros_like(image_embeddings)
 
154
  frames, duration = read_frames_with_moviepy(input_video, max_frame_nums=max_frame_nums)
155
  if duration >= 10:
156
  duration = 10
157
+ time_frames = torch.FloatTensor(frames).permute(0, 3, 1, 2).to('cuda')
158
  time_frames = video_transform(time_frames)
159
  time_frames = {'frames': time_frames.unsqueeze(0).permute(0, 2, 1, 3, 4)}
160
  preds = self.time_detector(time_frames)
 
166
  # w -> b c h w
167
  time_condition = torch.FloatTensor(time_condition).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(1, 1, 256, 1)
168
 
169
+ images = self.image_processor(images=frames, return_tensors="pt").to('cuda')
170
  image_embeddings = self.image_encoder(**images).image_embeds
171
  image_embeddings = torch.mean(image_embeddings, dim=0, keepdim=True).unsqueeze(0).unsqueeze(0)
172
  neg_image_embeddings = torch.zeros_like(image_embeddings)