taesiri commited on
Commit
856b52a
·
1 Parent(s): 13d5d73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -8
app.py CHANGED
@@ -26,6 +26,8 @@ from huggingface_hub import (
26
  snapshot_download,
27
  )
28
  from PIL import Image
 
 
29
 
30
  cached_latest_posts_df = None
31
  cached_top_posts = None
@@ -33,6 +35,13 @@ last_fetched = None
33
  last_fetched_top = None
34
 
35
 
 
 
 
 
 
 
 
36
  def get_reddit_id(url):
37
  # Regular expression pattern for r/GamePhysics URLs and IDs
38
  pattern = r"https://www\.reddit\.com/r/GamePhysics/comments/([0-9a-zA-Z]+).*|([0-9a-zA-Z]+)"
@@ -106,6 +115,47 @@ def extract_frames_decord(video_path, num_frames=10):
106
  raise Exception(f"Error extracting frames from video: {e}")
107
 
108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def get_top_posts():
110
  global cached_top_posts
111
  global last_fetched_top
@@ -212,7 +262,7 @@ def load_video(url):
212
  and r1.status_code != 302
213
  ):
214
  raise gr.Error(
215
- f"Video is not in the repo, please try another post. - {r.status_code }"
216
  )
217
 
218
  if r1.status_code == 200 or r1.status_code == 302:
@@ -234,11 +284,11 @@ with gr.Blocks() as demo:
234
  video_player = gr.Video(interactive=False)
235
 
236
  with gr.Column():
237
- gr.Markdown("## Sampled Frames from Video")
238
  num_frames = gr.Slider(minimum=1, maximum=60, step=1, value=10)
239
- sample_decord_btn = gr.Button("Sample decord")
240
 
241
- sampled_frames = gr.Gallery()
242
 
243
  download_samples_btn = gr.Button("Download Samples")
244
  output_files = gr.File()
@@ -250,16 +300,16 @@ with gr.Blocks() as demo:
250
  )
251
 
252
  with gr.Column():
253
- gr.Markdown("## Latest Posts")
254
- with gr.Accordion("Latest Posts"):
255
  latest_post_dataframe = gr.Dataframe()
256
  latest_posts_btn = gr.Button("Refresh Latest Posts")
257
- with gr.Accordion("Top Posts"):
258
  top_posts_dataframe = gr.Dataframe()
259
  top_posts_btn = gr.Button("Refresh Top Posts")
260
 
261
  sample_decord_btn.click(
262
- extract_frames_decord,
263
  inputs=[video_player, num_frames],
264
  outputs=[sampled_frames],
265
  )
 
26
  snapshot_download,
27
  )
28
  from PIL import Image
29
+ import concurrent.futures
30
+
31
 
32
  cached_latest_posts_df = None
33
  cached_top_posts = None
 
35
  last_fetched_top = None
36
 
37
 
38
+ def resize_image(image):
39
+ width, height = image.size
40
+ new_width = width * 0.35
41
+ new_height = height * 0.35
42
+ return image.resize((int(new_width), int(new_height)), Image.BILINEAR)
43
+
44
+
45
  def get_reddit_id(url):
46
  # Regular expression pattern for r/GamePhysics URLs and IDs
47
  pattern = r"https://www\.reddit\.com/r/GamePhysics/comments/([0-9a-zA-Z]+).*|([0-9a-zA-Z]+)"
 
115
  raise Exception(f"Error extracting frames from video: {e}")
116
 
117
 
118
+ def extract_frames_decord_preview(video_path, num_frames=10):
119
+ try:
120
+ start_time = time.time()
121
+
122
+ print(f"Extracting {num_frames} frames from {video_path}")
123
+
124
+ # Load the video
125
+ vr = VideoReader(video_path, ctx=cpu(0))
126
+
127
+ # Calculate the indices for the frames to be extracted
128
+ total_frames = len(vr)
129
+ frame_indices = np.linspace(
130
+ 0, total_frames - 1, num_frames, dtype=int, endpoint=False
131
+ )
132
+
133
+ # Extract frames
134
+ batch_frames = vr.get_batch(frame_indices).asnumpy()
135
+
136
+ # Convert frames to PIL Images
137
+ frame_images = [
138
+ Image.fromarray(batch_frames[i]) for i in range(batch_frames.shape[0])
139
+ ]
140
+
141
+ end_time = time.time()
142
+ print(f"Decord extraction took {end_time - start_time} seconds")
143
+
144
+ # # resize images to save bandwidth, keep aspect ratio
145
+ # for i, image in enumerate(frame_images):
146
+ # width, height = image.size
147
+ # new_width = int(width * 0.35)
148
+ # new_height = int(height * 0.35)
149
+ # frame_images[i] = image.resize((new_width, new_height), Image.ANTIALIAS)
150
+
151
+ with concurrent.futures.ThreadPoolExecutor() as executor:
152
+ frame_images = list(executor.map(resize_image, frame_images))
153
+
154
+ return frame_images
155
+ except Exception as e:
156
+ raise Exception(f"Error extracting frames from video: {e}")
157
+
158
+
159
  def get_top_posts():
160
  global cached_top_posts
161
  global last_fetched_top
 
262
  and r1.status_code != 302
263
  ):
264
  raise gr.Error(
265
+ f"Video is not in the repo, please try another post. - {r1.status_code }"
266
  )
267
 
268
  if r1.status_code == 200 or r1.status_code == 302:
 
284
  video_player = gr.Video(interactive=False)
285
 
286
  with gr.Column():
287
+ gr.Markdown("## Sample frames")
288
  num_frames = gr.Slider(minimum=1, maximum=60, step=1, value=10)
289
+ sample_decord_btn = gr.Button("Sample frames")
290
 
291
+ sampled_frames = gr.Gallery(label="Sampled frames preview")
292
 
293
  download_samples_btn = gr.Button("Download Samples")
294
  output_files = gr.File()
 
300
  )
301
 
302
  with gr.Column():
303
+ gr.Markdown("## Reddits Posts")
304
+ with gr.Tab("Latest Posts"):
305
  latest_post_dataframe = gr.Dataframe()
306
  latest_posts_btn = gr.Button("Refresh Latest Posts")
307
+ with gr.Tab("Top Monthly Posts"):
308
  top_posts_dataframe = gr.Dataframe()
309
  top_posts_btn = gr.Button("Refresh Top Posts")
310
 
311
  sample_decord_btn.click(
312
+ extract_frames_decord_preview,
313
  inputs=[video_player, num_frames],
314
  outputs=[sampled_frames],
315
  )