HikariDawn777 commited on
Commit
59b2a81
·
1 Parent(s): aa505db

feat: initial push

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +49 -0
  2. __assets__/0.jpg +0 -0
  3. __assets__/156.jpg +0 -0
  4. __assets__/274.jpg +0 -0
  5. __assets__/375.jpg +0 -0
  6. __assets__/551.jpg +0 -0
  7. __assets__/91.jpg +0 -0
  8. __assets__/ThisThat_logo.png +0 -0
  9. app.py +475 -4
  10. config/accelerate_config.json +18 -0
  11. config/flowformer_config.py +78 -0
  12. config/train_image2video.yaml +78 -0
  13. config/train_image2video_controlnet.yaml +101 -0
  14. curation_pipeline/add_lang_info.py +38 -0
  15. curation_pipeline/match_dataset_v1.py +117 -0
  16. curation_pipeline/match_dataset_v2.py +137 -0
  17. curation_pipeline/prepare_bridge_csv.py +69 -0
  18. curation_pipeline/prepare_bridge_jsonl.py +47 -0
  19. curation_pipeline/prepare_bridge_v1.py +132 -0
  20. curation_pipeline/prepare_bridge_v2.py +139 -0
  21. curation_pipeline/select_frame_with_this_that.py +421 -0
  22. curation_pipeline/tracking_by_keypoint.py +136 -0
  23. data_loader/video_dataset.py +323 -0
  24. data_loader/video_this_that_dataset.py +326 -0
  25. pretrained/PUT_YOUR_WEIGHT_HERE.md +0 -0
  26. requirements.txt +27 -0
  27. scripts/active_learning_select.py +27 -0
  28. scripts/add_point2img.py +51 -0
  29. scripts/check_video.py +19 -0
  30. scripts/clean_bridge_dataset.py +22 -0
  31. scripts/collect_lang.py +31 -0
  32. scripts/combine_results.py +85 -0
  33. scripts/compress_gif.py +52 -0
  34. scripts/compress_videos.py +55 -0
  35. scripts/crop_video_frames.py +22 -0
  36. scripts/extract_test_dataset.py +18 -0
  37. scripts/generate_noise.py +14 -0
  38. scripts/generate_sam.py +56 -0
  39. scripts/generate_sam_this_that.py +108 -0
  40. scripts/generate_traj.py +601 -0
  41. scripts/interpolate_by_repeat.py +55 -0
  42. scripts/length_stats.py +21 -0
  43. scripts/motion_stats.py +75 -0
  44. scripts/process_llama.py +74 -0
  45. scripts/process_sim.py +59 -0
  46. scripts/resize_img.py +17 -0
  47. scripts/resize_video_seq.py +33 -0
  48. scripts/train_test_split.py +23 -0
  49. scripts/visualize_thisthat_point.py +43 -0
  50. svd/diffusion_arch/transformer_temporal.py +381 -0
.gitignore ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .ipynb_checkpoints
2
+ .idea
3
+ __pycache__
4
+
5
+ datasets/
6
+ tmp_imgs
7
+ runs/
8
+ runs_last/
9
+ saved_models/
10
+ pre_trained/
11
+ save_log/
12
+ diffusers/
13
+ weights/
14
+ checkpoints/
15
+ validation_videos*
16
+ pretrained/*
17
+ .gradio/*
18
+
19
+ *.pyc
20
+ *.sh
21
+ *.pth
22
+ *.png
23
+ *.jpg
24
+ *.mp4
25
+ *.txt
26
+ *.json
27
+ *.jsonl
28
+ *.zip
29
+ *.mp4
30
+ *.csv
31
+ *.webp
32
+ *.bin
33
+ *.pkl
34
+ *.safetensors
35
+ *.pt
36
+ *.log
37
+ events.*
38
+ *.yml
39
+ *.gif
40
+ *.npy
41
+ *.out
42
+
43
+ !requirements.txt
44
+ !saved_models/*.md
45
+ !LICENSE.txt
46
+ !config/*
47
+ !__assets__/*
48
+ !__assets__/Bridge_example/*
49
+ !pretrained/PUT_YOUR_WEIGHT_HERE.md
__assets__/0.jpg ADDED
__assets__/156.jpg ADDED
__assets__/274.jpg ADDED
__assets__/375.jpg ADDED
__assets__/551.jpg ADDED
__assets__/91.jpg ADDED
__assets__/ThisThat_logo.png ADDED
app.py CHANGED
@@ -1,7 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ # *************************************************************************
2
+ # Copyright (2023) Bytedance Inc.
3
+ #
4
+ # Copyright (2023) DragDiffusion Authors
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ # *************************************************************************
18
+
19
+ import os, shutil, sys
20
+ import urllib.request
21
+ import argparse
22
+ import imageio
23
+ import math
24
+ import cv2
25
+ import collections
26
+ import numpy as np
27
  import gradio as gr
28
+ from PIL import Image
29
+
30
+ import torch
31
+ from pathlib import Path
32
+ from omegaconf import OmegaConf
33
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
34
+ from accelerate import Accelerator
35
+ from accelerate.utils import ProjectConfiguration
36
+ from diffusers import (
37
+ AutoencoderKLTemporalDecoder,
38
+ DDPMScheduler,
39
+ )
40
+ from diffusers.utils import check_min_version, is_wandb_available, load_image, export_to_video
41
+ from huggingface_hub import hf_hub_download
42
+ from transformers import AutoTokenizer, PretrainedConfig
43
+
44
+
45
+ # Import files from the local folder
46
+ root_path = os.path.abspath('.')
47
+ sys.path.append(root_path)
48
+ from train_code.train_svd import import_pretrained_text_encoder
49
+ from data_loader.video_dataset import tokenize_captions
50
+ from data_loader.video_this_that_dataset import get_thisthat_sam
51
+ from svd.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
52
+ from svd.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline
53
+ from svd.temporal_controlnet import ControlNetModel
54
+ from svd.pipeline_stable_video_diffusion_controlnet import StableVideoDiffusionControlNetPipeline
55
+ from utils.optical_flow_utils import bivariate_Gaussian
56
+
57
+
58
+ # For the 2D dilation
59
+ blur_kernel = bivariate_Gaussian(99, 10, 10, 0, grid = None, isotropic = True)
60
+
61
+
62
+ # Import
63
+ # LENGTH=480 # length of the square area displaying/editing images
64
+ HEIGHT = 256
65
+ WIDTH = 384
66
+
67
+
68
+ MARKDOWN = \
69
+ """
70
+ ## <p style='text-align: center'> This&That </p>
71
+
72
+ [GitHub](https://github.com/Kiteretsu77/This_and_That_VDM) | [Paper](http://arxiv.org/abs/2407.05530) | [Webpage](https://cfeng16.github.io/this-and-that/)
73
+ This&That is a Robotics scenario (Bridge-dataset-based for this repo) Language-Gesture-Image-conditioned Video Generation Model for Robot Planning.
74
+
75
+ This Demo is on the Video Diffusion Model part.
76
+ Only GestureNet is provided in this Gradio Demo, you can check the full test code for all pretrained weight available.
77
+
78
+ ### Note: The index we put the gesture point by default here is [4, 10] for two gesture points or [4] for one gesture point.
79
+ ### Note: The result now only support is 256x384
80
+ ### Note: Click "Clear All" to restart everything; Click "Undo Point" to cancel the point you put
81
+
82
+ If This&That is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/This_and_That_VDM). Thanks!
83
+ """
84
+
85
+
86
+ def store_img(img):
87
+
88
+ # when new image is uploaded, `selected_points` should be empty
89
+ return img, []
90
+
91
+
92
+
93
+ def clear_all():
94
+ return None, \
95
+ gr.Image(value=None, height=HEIGHT, width=WIDTH, interactive=False), \
96
+ None, [] # selected points
97
+
98
+
99
+ def undo_points(original_image):
100
+ img = original_image.copy()
101
+ return img, []
102
+
103
+
104
+ # User click the image to get points, and show the points on the image [From https://github.com/Yujun-Shi/DragDiffusion]
105
+ def get_points(img, original_image, sel_pix, evt: gr.SelectData):
106
+
107
+ # collect the selected point
108
+ sel_pix.append(evt.index)
109
+
110
+ if len(sel_pix) > 2:
111
+ raise gr.Error("We only at most support two points")
112
+
113
+ if original_image is None:
114
+ original_image = img.copy()
115
+
116
+ # draw points
117
+ points = []
118
+ for idx, point in enumerate(sel_pix):
119
+ if idx % 2 == 0:
120
+ # draw a red circle at the handle point
121
+ cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
122
+ else:
123
+ # draw a blue circle at the handle point
124
+ cv2.circle(img, tuple(point), 10, (0, 255, 0), -1)
125
+ points.append(tuple(point))
126
+ # draw an arrow from handle point to target point
127
+ # if len(points) == 2:
128
+ # cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
129
+ # points = []
130
+
131
+ return [img if isinstance(img, np.ndarray) else np.array(img), original_image]
132
+
133
+
134
+ def gesturenet_inference(ref_image, prompt, selected_points):
135
+
136
+ # Check some paramter, must have prompt and selected points
137
+ if prompt == "" or prompt is None:
138
+ raise gr.Error("Please input text prompt")
139
+ if selected_points == []:
140
+ raise gr.Error("Please click one/two points in the Image")
141
+
142
+ # Prepare the setting
143
+ frame_idxs = [4, 10]
144
+ use_ambiguous_prompt = False
145
+ model_type = "GestureNet"
146
+ huggingface_pretrained_path = "HikariDawn/This-and-That-1.1"
147
+
148
+ print("Text prompt is ", prompt)
149
+
150
+ # Prepare tmp folder
151
+ store_folder_name = "tmp"
152
+ if os.path.exists(store_folder_name):
153
+ shutil.rmtree(store_folder_name)
154
+ os.makedirs(store_folder_name)
155
+
156
+
157
+ # Read the yaml setting files (Very important for loading hyperparamters needed)
158
+ if not os.path.exists(huggingface_pretrained_path):
159
+ yaml_download_path = hf_hub_download(repo_id=huggingface_pretrained_path, subfolder="unet", filename="train_image2video.yaml")
160
+ if model_type == "GestureNet":
161
+ yaml_download_path = hf_hub_download(repo_id=huggingface_pretrained_path, subfolder="gesturenet", filename="train_image2video_gesturenet.yaml")
162
+ else: # If the path is a local path we can concatenate it here
163
+ yaml_download_path = os.path.join(huggingface_pretrained_path, "unet", "train_image2video.yaml")
164
+ if model_type == "GestureNet":
165
+ yaml_download_path = os.path.join(huggingface_pretrained_path, "gesturenet", "train_image2video_gesturenet.yaml")
166
+
167
+ # Load the config
168
+ assert(os.path.exists(yaml_download_path))
169
+ config = OmegaConf.load(yaml_download_path)
170
+
171
+
172
+ ################################################ Prepare vae, unet, image_encoder Same as before #################################################################
173
+ print("Prepare the pretrained model")
174
+ accelerator = Accelerator(
175
+ gradient_accumulation_steps = config["gradient_accumulation_steps"],
176
+ mixed_precision = config["mixed_precision"],
177
+ log_with = config["report_to"],
178
+ project_config = ProjectConfiguration(project_dir=config["output_dir"], logging_dir=Path(config["output_dir"], config["logging_name"])),
179
+ )
180
+ feature_extractor = CLIPImageProcessor.from_pretrained(
181
+ config["pretrained_model_name_or_path"], subfolder="feature_extractor", revision=None
182
+ ) # This instance has now weight, they are just seeting file
183
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained(
184
+ config["pretrained_model_name_or_path"], subfolder="image_encoder", revision=None, variant="fp16"
185
+ )
186
+ vae = AutoencoderKLTemporalDecoder.from_pretrained(
187
+ config["pretrained_model_name_or_path"], subfolder="vae", revision=None, variant="fp16"
188
+ )
189
+ unet = UNetSpatioTemporalConditionModel.from_pretrained(
190
+ huggingface_pretrained_path,
191
+ subfolder = "unet",
192
+ low_cpu_mem_usage = True,
193
+ # variant = "fp16",
194
+ )
195
+
196
+
197
+ # For text ..............................................
198
+ tokenizer = AutoTokenizer.from_pretrained(
199
+ config["pretrained_tokenizer_name_or_path"],
200
+ subfolder = "tokenizer",
201
+ revision = None,
202
+ use_fast = False,
203
+ )
204
+ # Clip Text Encoder
205
+ text_encoder_cls = import_pretrained_text_encoder(config["pretrained_tokenizer_name_or_path"], revision=None)
206
+ text_encoder = text_encoder_cls.from_pretrained(config["pretrained_tokenizer_name_or_path"], subfolder = "text_encoder", revision = None, variant = None)
207
+
208
+
209
+ weight_dtype = torch.float32
210
+ if accelerator.mixed_precision == "fp16":
211
+ weight_dtype = torch.float16
212
+ elif accelerator.mixed_precision == "bf16":
213
+ weight_dtype = torch.bfloat16
214
+
215
+ # Move vae + image_encoder to gpu and cast to weight_dtype
216
+ vae.requires_grad_(False)
217
+ image_encoder.requires_grad_(False)
218
+ unet.requires_grad_(False) # Will switch back at the end
219
+ text_encoder.requires_grad_(False)
220
+
221
+ # Move to accelerator
222
+ vae.to(accelerator.device, dtype=weight_dtype)
223
+ image_encoder.to(accelerator.device, dtype=weight_dtype)
224
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
225
+
226
+ # For GestureNet
227
+ if model_type == "GestureNet":
228
+ unet.to(accelerator.device, dtype=weight_dtype) # There is no need to cast unet in unet training, only needed in controlnet one
229
+
230
+ # Handle the Controlnet first from UNet
231
+ gesturenet = ControlNetModel.from_pretrained(
232
+ huggingface_pretrained_path,
233
+ subfolder = "gesturenet",
234
+ low_cpu_mem_usage = True,
235
+ variant = None,
236
+ )
237
+
238
+ gesturenet.requires_grad_(False)
239
+ gesturenet.to(accelerator.device)
240
+ ##############################################################################################################################################################
241
+
242
+
243
+
244
+
245
+ # Init the pipeline
246
+ pipeline = StableVideoDiffusionControlNetPipeline.from_pretrained(
247
+ config["pretrained_model_name_or_path"], # Still based on regular SVD config
248
+ vae = vae,
249
+ image_encoder = image_encoder,
250
+ unet = unet,
251
+ revision = None, # Set None directly now
252
+ torch_dtype = weight_dtype,
253
+ )
254
+ pipeline = pipeline.to(accelerator.device)
255
+ pipeline.set_progress_bar_config(disable=True)
256
+
257
+
258
+
259
+ ############################## Prepare and Process the condition here ##############################
260
+ org_height, org_width, _ = ref_image.shape
261
+ ref_image_pil = Image.fromarray(ref_image)
262
+ ref_image_pil = ref_image_pil.resize((config["width"], config["height"]))
263
+
264
+
265
+ # Initial the optical flow format we want
266
+ gesture_condition_img = np.zeros((config["video_seq_length"], config["conditioning_channels"], config["height"], config["width"]), dtype=np.float32) # The last image should be empty
267
+
268
+ # Handle the selected points to the condition we want
269
+ for point_idx, point in enumerate(selected_points):
270
+
271
+ frame_idx = frame_idxs[point_idx]
272
+ horizontal, vertical = point
273
+
274
+ # Init the base image
275
+ base_img = np.zeros((org_height, org_width, 3)).astype(np.float32) # Use the original image size
276
+ base_img.fill(255)
277
+
278
+ # Draw square around the target position
279
+ dot_range = 10 # Diameter
280
+ for i in range(-1*dot_range, dot_range+1):
281
+ for j in range(-1*dot_range, dot_range+1):
282
+ dil_vertical, dil_horizontal = vertical + i, horizontal + j
283
+ if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]):
284
+ if point_idx == 0:
285
+ base_img[dil_vertical][dil_horizontal] = [0, 0, 255] # The first point should be red
286
+ else:
287
+ base_img[dil_vertical][dil_horizontal] = [0, 255, 0] # The second point should be green to distinguish the first point
288
+
289
+ # Dilate
290
+ if config["dilate"]:
291
+ base_img = cv2.filter2D(base_img, -1, blur_kernel)
292
+
293
+
294
+ ##############################################################################################################################
295
+ ### The core pipeline of processing is: Dilate -> Resize -> Range Shift -> Transpose Shape -> Store
296
+
297
+ # Resize frames Don't use negative and don't resize in [0,1]
298
+ base_img = cv2.resize(base_img, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC)
299
+
300
+ # Channel Transform and Range Shift
301
+ if config["conditioning_channels"] == 3:
302
+ # Map to [0, 1] range
303
+ base_img = base_img / 255.0
304
+
305
+ else:
306
+ raise NotImplementedError()
307
+
308
+ # ReOrganize shape
309
+ base_img = base_img.transpose(2, 0, 1) # hwc -> chw
310
+
311
+ # Write base img based on frame_idx
312
+ gesture_condition_img[frame_idx] = base_img # Only the first frame, the rest is 0 initialized
313
+
314
+
315
+ ####################################################################################################
316
+
317
+ # Use the same tokenize process as the dataset preparation stage
318
+ tokenized_prompt = tokenize_captions(prompt, tokenizer, config, is_train=False).unsqueeze(0).to(accelerator.device) # Use unsqueeze to expand dim
319
+
320
+
321
+
322
+ # Call the pipeline
323
+ with torch.autocast("cuda"):
324
+ frames = pipeline(
325
+ image = ref_image_pil,
326
+ condition_img = gesture_condition_img, # numpy [0,1] range
327
+ controlnet = accelerator.unwrap_model(gesturenet),
328
+ prompt = tokenized_prompt,
329
+ use_text = config["use_text"],
330
+ text_encoder = text_encoder,
331
+ height = config["height"],
332
+ width = config["width"],
333
+ num_frames = config["video_seq_length"],
334
+ decode_chunk_size = 8,
335
+ motion_bucket_id = 200,
336
+ # controlnet_image_index = controlnet_image_index,
337
+ # coordinate_values = coordinate_values,
338
+ num_inference_steps = config["num_inference_steps"],
339
+ max_guidance_scale = config["inference_max_guidance_scale"],
340
+ fps = 7,
341
+ use_instructpix2pix = config["use_instructpix2pix"],
342
+ noise_aug_strength = config["inference_noise_aug_strength"],
343
+ controlnet_conditioning_scale = config["outer_conditioning_scale"],
344
+ inner_conditioning_scale = config["inner_conditioning_scale"],
345
+ guess_mode = config["inference_guess_mode"], # False in inference
346
+ image_guidance_scale = config["image_guidance_scale"],
347
+ ).frames[0]
348
+
349
+ # Save frames
350
+ video_file_path = os.path.join(store_folder_name, "tmp.mp4")
351
+ writer = imageio.get_writer(video_file_path, fps=4)
352
+ for idx, frame in enumerate(frames):
353
+ frame.save(os.path.join(store_folder_name, str(idx)+".png"))
354
+ writer.append_data(cv2.cvtColor(cv2.imread(os.path.join(store_folder_name, str(idx)+".png")), cv2.COLOR_BGR2RGB))
355
+ writer.close()
356
+
357
+
358
+
359
+ # Cleaning process
360
+ del pipeline
361
+ torch.cuda.empty_cache()
362
+
363
+ return gr.update(value=video_file_path, width=config["width"], height=config["height"]) # Return resuly based on the need
364
+
365
+
366
+
367
+ if __name__ == '__main__':
368
+
369
+
370
+ # Gradio demo part
371
+ with gr.Blocks() as demo:
372
+ # layout definition
373
+ with gr.Row():
374
+ gr.Markdown(MARKDOWN)
375
+
376
+ # UI components for editing real images
377
+ with gr.Row(elem_classes=["container"]):
378
+ selected_points = gr.State([]) # store points
379
+ original_image = gr.State(value=None) # store original input image
380
+ with gr.Row():
381
+ with gr.Column():
382
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Click two Points</p>""")
383
+ input_image = gr.Image(label="Input Image", height=HEIGHT, width=WIDTH, interactive=False, elem_id="input_img")
384
+ # gr.Image(type="numpy", label="Click Points", height=HEIGHT, width=WIDTH, interactive=False) # for points clicking
385
+ undo_button = gr.Button("Undo point")
386
+
387
+ # Text prompt
388
+ with gr.Row():
389
+ prompt = gr.Textbox(label="Text Prompt")
390
+
391
+
392
+ with gr.Column():
393
+ gr.Markdown("""<p style="text-align: center; font-size: 20px">Results</p>""")
394
+ frames = gr.Video(value=None, label="Generate Video", show_label=True, height=HEIGHT, width=WIDTH)
395
+ with gr.Row():
396
+ run_button = gr.Button("Run")
397
+ clear_all_button = gr.Button("Clear All")
398
+
399
+
400
+
401
+
402
+ # with gr.Tab("Base Model Config"):
403
+ # with gr.Row():
404
+ # local_models_dir = 'local_pretrained_models'
405
+ # local_models_choice = \
406
+ # [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
407
+ # model_path = gr.Dropdown(value="runwayml/stable-diffusion-v1-5",
408
+ # label="Diffusion Model Path",
409
+ # choices=[
410
+ # "runwayml/stable-diffusion-v1-5",
411
+ # "gsdf/Counterfeit-V2.5",
412
+ # "stablediffusionapi/anything-v5",
413
+ # "SG161222/Realistic_Vision_V2.0",
414
+ # ] + local_models_choice
415
+ # )
416
+ # vae_path = gr.Dropdown(value="default",
417
+ # label="VAE choice",
418
+ # choices=["default",
419
+ # "stabilityai/sd-vae-ft-mse"] + local_models_choice
420
+ # )
421
+
422
+ # Examples
423
+ with gr.Row(elem_classes=["container"]):
424
+ gr.Examples(
425
+ [
426
+ ["__assets__/Bridge_example/Task1_v1_511/im_0.jpg", "take this to there"],
427
+ ["__assets__/Bridge_example/Task2_v2_164/im_0.jpg", "put this to there"],
428
+ ["__assets__/Bridge_example/Task3_v2_490/im_0.jpg", "fold this"],
429
+ ["__assets__/Bridge_example/Task4_v2_119/im_0.jpg", "open this"],
430
+
431
+ # ["__assets__/0.jpg", "take this to there"],
432
+ ["__assets__/91.jpg", "take this to there"],
433
+ ["__assets__/156.jpg", "take this to there"],
434
+ # ["__assets__/274.jpg", "take this to there"],
435
+ ["__assets__/375.jpg", "take this to there"],
436
+ # ["__assets__/551.jpg", "take this to there"],
437
+ ],
438
+ [input_image, prompt, selected_points],
439
+ )
440
+
441
+
442
+
443
+
444
+ ####################################### Event Definition #######################################
445
+
446
+ # Draw the points
447
+ input_image.select(
448
+ get_points,
449
+ [input_image, original_image, selected_points],
450
+ [input_image, original_image],
451
+ )
452
+
453
+ # Clean the points
454
+ undo_button.click(
455
+ undo_points,
456
+ [original_image],
457
+ [input_image, selected_points],
458
+ )
459
+
460
+ run_button.click(
461
+ gesturenet_inference,
462
+ inputs = [
463
+ # vae, unet, gesturenet, image_encoder, text_encoder, tokenizer,
464
+ original_image, prompt, selected_points,
465
+ # frame_idxs,
466
+ # config, accelerator, weight_dtype
467
+ ],
468
+ outputs = [frames]
469
+ )
470
+
471
+ clear_all_button.click(
472
+ clear_all,
473
+ [],
474
+ outputs = [original_image, input_image, prompt, selected_points],
475
+ )
476
 
 
 
477
 
478
+ demo.queue().launch(share=True, debug=True)
 
config/accelerate_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compute_environment": "LOCAL_MACHINE",
3
+ "debug": false,
4
+ "distributed_type": "MULTI_GPU",
5
+ "downcast_bf16": "no",
6
+ "gpu_ids": "all",
7
+ "machine_rank": 0,
8
+ "main_training_function": "main",
9
+ "mixed_precision": "fp16",
10
+ "num_machines": 1,
11
+ "num_processes": 8,
12
+ "rdzv_backend": "static",
13
+ "same_network": true,
14
+ "tpu_env": [],
15
+ "tpu_use_cluster": false,
16
+ "tpu_use_sudo": false,
17
+ "use_cpu": false
18
+ }
config/flowformer_config.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+ _CN = CN()
3
+
4
+ _CN.name = 'default'
5
+ _CN.suffix ='sintel'
6
+ _CN.gamma = 0.75
7
+ _CN.max_flow = 400
8
+ _CN.batch_size = 6
9
+ _CN.sum_freq = 100
10
+ _CN.val_freq = 100000000
11
+ _CN.image_size = [432, 960]
12
+ _CN.add_noise = False
13
+ _CN.use_smoothl1 = False
14
+ _CN.critical_params = []
15
+
16
+ _CN.transformer = 'percostformer3'
17
+
18
+ ### change the path here
19
+ _CN.model = "pretrained/sintel.pth"
20
+
21
+ _CN.percostformer3 = CN()
22
+ _CN.percostformer3.pe = 'linear'
23
+ _CN.percostformer3.dropout = 0.0
24
+ _CN.percostformer3.droppath = 0.0
25
+ _CN.percostformer3.encoder_latent_dim = 256 # in twins, this is 256
26
+ _CN.percostformer3.query_latent_dim = 64
27
+ _CN.percostformer3.cost_latent_input_dim = 64
28
+ _CN.percostformer3.cost_latent_token_num = 8
29
+ _CN.percostformer3.cost_latent_dim = 128
30
+ _CN.percostformer3.cost_heads_num = 1
31
+ # encoder
32
+ _CN.percostformer3.pretrain = True
33
+ _CN.percostformer3.use_convertor = False
34
+ _CN.percostformer3.del_layers = True
35
+ _CN.percostformer3.encoder_depth = 3
36
+ _CN.percostformer3.expand_factor = 4
37
+ _CN.percostformer3.vertical_encoder_attn = "twins"
38
+ _CN.percostformer3.attn_dim = 128
39
+ _CN.percostformer3.patch_size = 8
40
+ _CN.percostformer3.patch_embed = 'single'
41
+ _CN.percostformer3.cross_attn = "all"
42
+ _CN.percostformer3.gma = "GMA"
43
+ _CN.percostformer3.vert_c_dim = 64
44
+ _CN.percostformer3.cost_encoder_res = True
45
+ _CN.percostformer3.cnet = 'twins'
46
+ _CN.percostformer3.fnet = 'twins'
47
+ _CN.percostformer3.flow_or_pe = "and"
48
+ _CN.percostformer3.use_patch = False # use cost patch rather than local cost as query
49
+ _CN.percostformer3.use_rpe = False
50
+ _CN.percostformer3.detach_local = False
51
+ _CN.percostformer3.no_sc = False
52
+ _CN.percostformer3.r_16 =-1
53
+ _CN.percostformer3.quater_refine = False
54
+ # pretrain config
55
+ _CN.percostformer3.pretrain_mode = False
56
+ _CN.percostformer3.pic_size = [368, 496, 368, 496]
57
+ _CN.percostformer3.mask_ratio = 0.5
58
+ _CN.percostformer3.query_num = 30
59
+ _CN.percostformer3.no_border = True
60
+ _CN.percostformer3.gt_r = 15
61
+ _CN.percostformer3.fix_pe = False
62
+ # decoder
63
+ _CN.percostformer3.decoder_depth = 12
64
+ _CN.percostformer3.critical_params = ['vert_c_dim', 'encoder_depth', 'vertical_encoder_attn', "use_patch", "flow_or_pe", "use_rpe", "dropout", "detach_local", "expand_factor"]
65
+
66
+
67
+ ### TRAINER
68
+ _CN.trainer = CN()
69
+ _CN.trainer.scheduler = 'OneCycleLR'
70
+ _CN.trainer.optimizer = 'adamw'
71
+ _CN.trainer.canonical_lr = 12.5e-5
72
+ _CN.trainer.adamw_decay = 1e-5
73
+ _CN.trainer.clip = 1.0
74
+ _CN.trainer.num_steps = 120000
75
+ _CN.trainer.epsilon = 1e-8
76
+ _CN.trainer.anneal_strategy = 'linear'
77
+ def get_cfg():
78
+ return _CN.clone()
config/train_image2video.yaml ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Model Setting
3
+ pretrained_model_name_or_path: stabilityai/stable-video-diffusion-img2vid # -xt is for 25 frames version
4
+ load_unet_path: # This is usally used to load pretrained UNet; e.g., you may want to start one of your checkpoints trained before
5
+ video_seq_length: 14 # Standardized to 14
6
+ process_fps: 7
7
+ train_noise_aug_strength: 0.1
8
+ scheduler: EDM
9
+ conditioning_dropout_prob: 0.1
10
+
11
+
12
+ # Dataset Setting
13
+ dataset_name: Bridge # WebVid / Bridge
14
+ dataset_path: [../sanity_check/bridge_v1_raw, ../sanity_check/bridge_v2_raw]
15
+ output_dir: checkpoints/img2video
16
+ height: 256 # Ratio that is functional: 256:384 576:1024 320:512 320:576
17
+ width: 384 # It is said that the height and width should be a scale of 64
18
+ dataloader_num_workers: 4 # Don't set this too large; usually, Video diffusion are slow processing, so don't need that many workers to do early loading
19
+ flip_aug_prob: 0.45 # Whether we flip the GT and cond vertically
20
+ acceleration_tolerance: 4 # Recommened setting
21
+
22
+
23
+ # Text setting
24
+ use_text: True # If this is True, we will use text value
25
+ pretrained_tokenizer_name_or_path: stabilityai/stable-diffusion-2-1-base # Use SD 2.1
26
+ empty_prompts_proportion: 0.0 # Useless now, we already have CFG in training
27
+ mix_ambiguous: False # Whether we mix ambiguous prompt for "this" and "that"
28
+
29
+
30
+ # Motion setting Useless right now...
31
+ motion_bucket_id: 200 # Set it for exact value; If this is none, we will use below setting
32
+ dataset_motion_mean: 35.3 # For 14 fps, it is N(35.3, 18.5)
33
+ dataset_motion_std: 18.5 # For 25 fps, it is N(?, ?)
34
+ svd_motion_mean: 165
35
+ svd_motion_std: 22.5
36
+
37
+
38
+ # Training setting
39
+ resume_from_checkpoint: False # latest/False
40
+ num_train_iters: 100000 # Will automatically choose the checkpoints at 99K
41
+ partial_finetune: False # Whether we just tune some params to speed up
42
+ train_batch_size: 1 # This is the batch size per GPU
43
+ checkpointing_steps: 3000
44
+ validation_step: 300
45
+ logging_name: logging
46
+ seed: 42
47
+ validation_img_folder: # Prepare your own validation dataset
48
+ validation_store_folder: validation_results
49
+ checkpoints_total_limit: 15
50
+
51
+ # Noise Strength
52
+ noise_mean: 0.5 # Regular Img2Video: (0.7, 1.6); Text2Video: (0.5, 1.4)
53
+ noise_std: 1.4
54
+
55
+
56
+ # Inference
57
+ num_inference_steps: 25
58
+ inference_noise_aug_strength: 0.1
59
+ inference_max_guidance_scale: 3.0 # Take training and testing at different scenario
60
+
61
+
62
+ # Learning Rate and Optimizer
63
+ learning_rate: 1e-5 # Usually this is ok
64
+ scale_lr: False # TODO: Is it needed to scale the learning rate?
65
+ adam_beta1: 0.9
66
+ adam_beta2: 0.999
67
+ use_8bit_adam: True # Need this to save more memory
68
+ adam_weight_decay: 1e-2
69
+ adam_epsilon: 1e-08
70
+ lr_warmup_steps: 500
71
+ lr_decay_scale: 0.5
72
+
73
+
74
+ # Other Setting
75
+ mixed_precision: fp16
76
+ gradient_accumulation_steps: 1
77
+ gradient_checkpointing: 1
78
+ report_to: tensorboard
config/train_image2video_controlnet.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Model Setting
3
+ pretrained_model_name_or_path: stabilityai/stable-video-diffusion-img2vid # stabilityai/pretrained
4
+ load_unet_path: ../saved_weights/v4_VL_paper/checkpoint-99000 # None/specific path This is for pretrained-UNet path
5
+ load_controlnet_path: # None/specific path For checkpoint loaded from pretrained-Controlnet Path
6
+ video_seq_length: 14
7
+ process_fps: 7
8
+ train_noise_aug_strength: 0.1
9
+ scheduler: EDM
10
+ conditioning_dropout_prob: 0.1
11
+
12
+
13
+ # Dataset Setting
14
+ data_loader_type: thisthat # thisthat
15
+ dataset_name: Bridge # Bridge
16
+ dataset_path: [../sanity_check/bridge_v1_TT14, ../sanity_check/bridge_v2_TT14] # ../Bridge_filter_flow, ../Bridge_v2_filter_flow/]
17
+ output_dir: checkpoints/img2video
18
+ height: 256 # Ratio that is functional: 256:384 576:1024 320:448 320:576 512:640 448:640
19
+ width: 384 # It is said that the height and width should be a scale of 64
20
+ dataloader_num_workers: 4 # For Debug, it only needs 1
21
+ flip_aug_prob: 0.45 # Whether we flip the GT and cond vertically
22
+ # No acceleration_tolerance, since TT dataset already filter those out
23
+
24
+
25
+ # Text setting
26
+ use_text: True # If this is True, we will use text value
27
+ pretrained_tokenizer_name_or_path: stabilityai/stable-diffusion-2-1-base # Use SD 2.1
28
+ empty_prompts_proportion: 0.0
29
+ mix_ambiguous: False # Whether we mix ambiguous prompt for "this" and "that"
30
+
31
+
32
+ # Mask setting
33
+ mask_unet_vae: False # Whether we use mask to map latents to be zero padding
34
+ mask_controlnet_vae: False
35
+ mask_proportion: 0.0
36
+
37
+
38
+ # Condition Setting
39
+ conditioning_channels: 3 # Usually it is 3
40
+ num_points_left: # 1 # For flow: You can only choose one between flow_select_rate and num_points_left; num_points_left should be higher priority
41
+ flow_select_rate: 0.99 # For flow
42
+ threshold_factor: 0.2 # For flow
43
+ dilate: True # Traj must be True for dilate
44
+ inner_conditioning_scale: 1.0 # Conditioning scale for the internal value, defauly is starting from 1.0
45
+ outer_conditioning_scale: 1.0 # Outer Conditioning Scale for whole conditioning trainable copy 这里有点意思,直接不小心设定成2.0了
46
+
47
+
48
+ # Motion setting
49
+ motion_bucket_id: 200
50
+ dataset_motion_mean: 25 # For 14 fps, it is N(25, 10)
51
+ dataset_motion_std: 10 # For 25 fps, it is N(18, 7)
52
+ svd_motion_mean: 180
53
+ svd_motion_std: 30
54
+
55
+
56
+
57
+ # Training setting
58
+ resume_from_checkpoint: False # latest/False
59
+ num_train_iters: 30100 # Will automatically choose the checkpoints
60
+ partial_finetune: False # Whether we just tune some params to speed up
61
+ train_batch_size: 1 # This is the batch size per GPU
62
+ checkpointing_steps: 3000
63
+ validation_step: 300
64
+ logging_name: logging
65
+ seed: 42
66
+ validation_img_folder: datasets/validation_TT14
67
+ validation_store_folder: validation_videos
68
+ checkpoints_total_limit: 15
69
+
70
+
71
+ # Noise Strength
72
+ noise_mean: 0.5 # Regular Img2Video: (0.7, 1.6); Text2Video: (0.5, 1.4)
73
+ noise_std: 1.4
74
+
75
+
76
+ # Inference
77
+ num_inference_steps: 25
78
+ use_instructpix2pix: False # Whether we will use the instructPix2Pix mode, which involves 3 inputs; it may needs tuning to have better result at the end.
79
+ inference_noise_aug_strength: 0.1
80
+ inference_max_guidance_scale: 3.0 # Take training and testing at different scenario
81
+ inference_guess_mode: False # Whether we use guess mode in the contorlnet
82
+ image_guidance_scale: 2.5 # Empirically, 2.5 is the best value Seems not using this now
83
+
84
+
85
+ # Learning Rate and Optimizer
86
+ learning_rate: 5e-6 # 5e-6 is the LR we test that is just right
87
+ scale_lr: False # TODO: Is it needed to scale the learning rate?
88
+ adam_beta1: 0.9
89
+ adam_beta2: 0.999
90
+ use_8bit_adam: True # Need this to save more memory
91
+ adam_weight_decay: 1e-2
92
+ adam_epsilon: 1e-08
93
+ lr_warmup_steps: 500
94
+ lr_decay_scale: 0.5
95
+
96
+
97
+ # Other Setting
98
+ mixed_precision: fp16
99
+ gradient_accumulation_steps: 1 # ????
100
+ gradient_checkpointing: 1 # ????
101
+ report_to: tensorboard
curation_pipeline/add_lang_info.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Add the processed lang information
3
+ '''
4
+ import os, sys, shutil
5
+ import json
6
+
7
+
8
+ if __name__ == "__main__":
9
+
10
+ # Main config file path information
11
+ processed_json_file_path = "updated_bridge_v2.json"
12
+
13
+
14
+ # Read the json file
15
+ file = open(processed_json_file_path)
16
+ data = json.load(file)
17
+
18
+
19
+ # Iterate all the folders inside
20
+ start_idx = 0
21
+ for seq_instance in data:
22
+ target_path = seq_instance["images0"]
23
+ print("We are processing ", target_path)
24
+
25
+ processed_lang_txt_path = os.path.join(target_path, "processed_lang.txt")
26
+ if os.path.exists(processed_lang_txt_path):
27
+ os.remove(processed_lang_txt_path)
28
+
29
+ # Write the action + This + That into the sequence.
30
+ processed_lang_txt = open(processed_lang_txt_path, "a")
31
+ processed_lang_txt.write(str(seq_instance["action"])+"\n")
32
+ processed_lang_txt.write(str(seq_instance["this"])+"\n")
33
+ processed_lang_txt.write(str(seq_instance["that"])+"\n")
34
+
35
+
36
+ start_idx += 1
37
+
38
+ print("We have ", start_idx)
curation_pipeline/match_dataset_v1.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file is to match the selected frames with the bridge dataset
3
+ We need to use some tricks to select the item
4
+ '''
5
+ import os, sys, shutil
6
+ import cv2
7
+ import numpy as np
8
+
9
+
10
+
11
+
12
+ def compare_img(imageA, imageB):
13
+ # the 'Mean Squared Error' between the two images is the
14
+ # sum of the squared difference between the two images;
15
+ # NOTE: the two images must have the same dimension
16
+ err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
17
+ err /= float(imageA.shape[0] * imageA.shape[1])
18
+
19
+ # return the MSE, the lower the error, the more "similar"
20
+ # the two images are
21
+ return err
22
+
23
+
24
+
25
+ def search_path(dataset_path, target_path, store_txt_path):
26
+
27
+ # We only needs to care about Bridge v1 dataset area
28
+ target_img_path = os.path.join(target_path, "im_0.jpg")
29
+ target_img = cv2.imread(target_img_path)
30
+
31
+ # Iterate all the folders inside
32
+ for scene_name in sorted(os.listdir(dataset_path)):
33
+ # print("We are reading scene", scene_name)
34
+ scene_dir = os.path.join(dataset_path, scene_name)
35
+
36
+ for task_name in os.listdir(scene_dir):
37
+ task_dir = os.path.join(scene_dir, task_name)
38
+
39
+ for time_clock in os.listdir(task_dir):
40
+ if time_clock == "lmdb":
41
+ continue # Skip lmdb folder
42
+
43
+ time_dir = os.path.join(task_dir, time_clock, "raw", "traj_group0")
44
+ if not os.path.exists(time_dir):
45
+ continue
46
+
47
+ for traj_name in os.listdir(time_dir):
48
+ traj_path = os.path.join(time_dir, traj_name)
49
+ if not os.path.isdir(traj_path):
50
+ continue
51
+
52
+ # Directly move policy_out_file_path; just in case there is also valuable information there
53
+ policy_out_file_path = os.path.join(traj_path, "policy_out.pkl")
54
+ if not os.path.exists(policy_out_file_path):
55
+ continue
56
+
57
+ # Check the lang txt file
58
+ lang_txt_file_path = os.path.join(traj_path, "lang.txt")
59
+ if not os.path.exists(lang_txt_file_path):
60
+ continue
61
+
62
+
63
+ # Last thing to locate to the right path
64
+ for img_name in os.listdir(traj_path):
65
+ if img_name != "images0": # Only consider one camera angle
66
+ continue
67
+
68
+ img_folder_path = os.path.join(traj_path, img_name)
69
+ if not os.path.isdir(img_folder_path):
70
+ continue
71
+
72
+
73
+ # Compare two image
74
+ img_path = os.path.join(img_folder_path, "im_0.jpg")
75
+ # print("img_folder_path is ", img_path)
76
+ compare_sample_img = cv2.imread(img_path)
77
+ error = compare_img(target_img, compare_sample_img)
78
+
79
+ if error == 0:
80
+ # Continue to all the rest for at least 5 images
81
+ status = True
82
+ for idx in range (10):
83
+ idx_img_path = os.path.join(img_folder_path, "im_"+str(idx)+".jpg")
84
+ idx_target_img_path = os.path.join(target_path, "im_"+str(idx)+".jpg")
85
+ idx_compare_sample_img = cv2.imread(idx_img_path)
86
+ idx_target_img = cv2.imread(idx_target_img_path)
87
+ error = compare_img(idx_target_img, idx_compare_sample_img)
88
+
89
+ if error != 0:
90
+ status = False
91
+ break
92
+
93
+ if status:
94
+ print("We found one at ", img_path)
95
+ f = open(store_txt_path, "a")
96
+ f.write(target_path + " " + img_folder_path + "\n")
97
+ return True
98
+
99
+ return False
100
+
101
+
102
+ if __name__ == "__main__":
103
+ input_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/datasets_rob/Bridge_v1_test_raw"
104
+ dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v1/berkeley" # 直接从本地新unzip的获取,怕之前的被xuweiyi改动过
105
+ store_txt_path = "match_info.txt"
106
+
107
+ if os.path.exists(store_txt_path):
108
+ os.remove(store_txt_path)
109
+
110
+ for img_name in sorted(os.listdir(input_path)):
111
+ target_path = os.path.join(input_path, img_name)
112
+ print("We are finding for ", target_path)
113
+
114
+ status = search_path(dataset_path, target_path, store_txt_path)
115
+
116
+ if not status:
117
+ print("we cannot find one")
curation_pipeline/match_dataset_v2.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file is to match the selected frames with the bridge dataset
3
+ We need to use some tricks to select the item
4
+ '''
5
+ import os, sys, shutil
6
+ import cv2
7
+ import numpy as np
8
+
9
+
10
+
11
+
12
+ def compare_img(imageA, imageB):
13
+ # the 'Mean Squared Error' between the two images is the
14
+ # sum of the squared difference between the two images;
15
+ # NOTE: the two images must have the same dimension
16
+ err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
17
+ err /= float(imageA.shape[0] * imageA.shape[1])
18
+
19
+ # return the MSE, the lower the error, the more "similar"
20
+ # the two images are
21
+ return err
22
+
23
+
24
+
25
+ def search_path(dataset_path, target_path, store_txt_path):
26
+
27
+ # We only needs to care about Bridge v1 dataset area
28
+ target_img_path = os.path.join(target_path, "im_0.jpg")
29
+ if not os.path.exists(target_img_path):
30
+ print("The image we read is False")
31
+ return False
32
+ target_img = cv2.imread(target_img_path)
33
+
34
+ # Iterate all the folders inside
35
+ for scene_name in sorted(os.listdir(dataset_path)):
36
+ scene_dir = os.path.join(dataset_path, scene_name)
37
+
38
+ for task_name in sorted(os.listdir(scene_dir)):
39
+ task_dir = os.path.join(scene_dir, task_name)
40
+
41
+ for order_name in sorted(os.listdir(task_dir)):
42
+ order_dir = os.path.join(task_dir, order_name)
43
+
44
+ for time_clock in sorted(os.listdir(order_dir)):
45
+ if time_clock == "lmdb":
46
+ continue # Skip lmdb folder
47
+
48
+ time_dir = os.path.join(order_dir, time_clock, "raw", "traj_group0")
49
+ if not os.path.exists(time_dir):
50
+ continue
51
+
52
+ for traj_name in sorted(os.listdir(time_dir)):
53
+ traj_path = os.path.join(time_dir, traj_name)
54
+ if not os.path.isdir(traj_path):
55
+ continue
56
+
57
+ # Directly move policy_out_file_path; just in case there is also valuable information there
58
+ policy_out_file_path = os.path.join(traj_path, "policy_out.pkl")
59
+ if not os.path.exists(policy_out_file_path):
60
+ continue
61
+
62
+ # Check the lang txt file
63
+ lang_txt_file_path = os.path.join(traj_path, "lang.txt")
64
+ if not os.path.exists(lang_txt_file_path):
65
+ continue
66
+
67
+
68
+ for img_name in sorted(os.listdir(traj_path)):
69
+ if img_name != "images0": # Only consider one camera angle
70
+ continue
71
+
72
+ img_folder_path = os.path.join(traj_path, img_name)
73
+ if not os.path.isdir(img_folder_path):
74
+ continue
75
+
76
+
77
+ # Compare two image
78
+ img_path = os.path.join(img_folder_path, "im_0.jpg")
79
+ if not os.path.exists(img_path):
80
+ print(img_folder_path + " doesn't even have im_0.jpg")
81
+ continue
82
+ # print("img_folder_path is ", img_path)
83
+ compare_sample_img = cv2.imread(img_path)
84
+ # try:
85
+ # compare_sample_img.shape
86
+ # except Exception:
87
+ # print("The compare_sample_img cannot be red")
88
+ # continue
89
+ error = compare_img(target_img, compare_sample_img)
90
+
91
+ if error == 0:
92
+ # Continue to all the rest for at least 5 images
93
+ status = True
94
+ for idx in range (10):
95
+ idx_img_path = os.path.join(img_folder_path, "im_"+str(idx)+".jpg")
96
+ idx_target_img_path = os.path.join(target_path, "im_"+str(idx)+".jpg")
97
+ if not os.path.exists(idx_img_path):
98
+ print("The idx_img_path long idx we see only at ", idx)
99
+ continue
100
+ if not os.path.exists(idx_target_img_path):
101
+ print("The idx_target_img_path long idx we see only at ", idx)
102
+ continue
103
+ idx_compare_sample_img = cv2.imread(idx_img_path)
104
+ idx_target_img = cv2.imread(idx_target_img_path)
105
+ error = compare_img(idx_target_img, idx_compare_sample_img)
106
+
107
+ if error != 0:
108
+ status = False
109
+ break
110
+
111
+ if status:
112
+ print("We found one at ", img_path)
113
+ f = open(store_txt_path, "a")
114
+ f.write(target_path + " " + img_folder_path + "\n")
115
+ return True
116
+
117
+ return False
118
+
119
+
120
+ if __name__ == "__main__":
121
+ input_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/datasets_rob/Bridge_v2_test_raw"
122
+ dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v2" # 直接从本地新unzip的获取,怕之前的被xuweiyi改动过
123
+ store_txt_path = "match_info_v2_p1.txt"
124
+ start_idx = 0
125
+ end_idx = 500
126
+
127
+ if os.path.exists(store_txt_path):
128
+ os.remove(store_txt_path)
129
+
130
+ for img_name in sorted(os.listdir(input_path))[start_idx:end_idx]:
131
+ target_path = os.path.join(input_path, img_name)
132
+ print("We are finding for ", target_path)
133
+
134
+ status = search_path(dataset_path, target_path, store_txt_path)
135
+
136
+ if not status:
137
+ print("we cannot find one")
curation_pipeline/prepare_bridge_csv.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file is to prepare the dataset in csv file following the format required by Opne-SORA
3
+ '''
4
+
5
+ import os, sys, shutil
6
+ import json
7
+ import csv
8
+
9
+ # Import files from the local folder
10
+ root_path = os.path.abspath('.')
11
+ sys.path.append(root_path)
12
+ # from curation_pipeline.prepare_bridge_v1 import read_bridge_v1
13
+ # from curation_pipeline.prepare_bridge_v2 import read_bridge_v2
14
+
15
+
16
+
17
+ def iter_dataset(dataset_path):
18
+ lists = []
19
+ for sub_folder_name in os.listdir(dataset_path):
20
+ sub_folder_path = os.path.join(dataset_path, sub_folder_name)
21
+
22
+ # Check number of frames
23
+ max_length = len(os.listdir(sub_folder_path))
24
+ for check_idx in range(max_length):
25
+ if not os.path.exists(os.path.join(sub_folder_path, 'im_' + str(check_idx) + '.jpg')): # Should be sequentially exists
26
+ break
27
+ num_frames = check_idx
28
+
29
+ # Read the text
30
+ txt_path = os.path.join(sub_folder_path, "lang.txt")
31
+ f = open(txt_path, "r")
32
+ lang_prompt = f.readline()
33
+
34
+ lists.append([sub_folder_path, lang_prompt, num_frames, 480, 640])
35
+ # break
36
+ return lists
37
+
38
+
39
+
40
+ if __name__ == "__main__":
41
+ v1_dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/sanity_check/bridge_v1_raw"
42
+ v2_dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/sanity_check/bridge_v2_raw"
43
+ store_name = "Bridge_raw.csv"
44
+
45
+ if os.path.exists(store_name):
46
+ os.remove(store_name)
47
+
48
+
49
+ # Execute
50
+ full_lists = [["path", "text", "num_frames", "height", "width"]]
51
+
52
+ v1_lists = iter_dataset(v1_dataset_path)
53
+ full_lists.extend(v1_lists)
54
+ v2_lists = iter_dataset(v2_dataset_path)
55
+ full_lists.extend(v2_lists)
56
+ print("Full length is ", len(full_lists))
57
+
58
+
59
+ # Store as csv file
60
+ with open(store_name, 'w') as outfile:
61
+ write = csv.writer(outfile)
62
+ write.writerows(full_lists)
63
+
64
+
65
+
66
+ # with open('output.jsonl', 'w') as outfile:
67
+ # for entry in JSON_file:
68
+ # json.dump(entry, outfile)
69
+ # outfile.write('\n')
curation_pipeline/prepare_bridge_jsonl.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file is to prepare the dataset in jsonl file
3
+ '''
4
+
5
+ import os, sys, shutil
6
+ import json
7
+
8
+ # Import files from the local folder
9
+ root_path = os.path.abspath('.')
10
+ sys.path.append(root_path)
11
+ from curation_pipeline.prepare_bridge_v1 import read_bridge_v1
12
+ from curation_pipeline.prepare_bridge_v2 import read_bridge_v2
13
+
14
+
15
+ if __name__ == "__main__":
16
+ v1_dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v1/berkeley"
17
+ v2_dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v2"
18
+ store_name = "store.jsonl"
19
+
20
+ if os.path.exists(store_name):
21
+ os.remove(store_name)
22
+
23
+
24
+ # Execute
25
+ full_lists = []
26
+
27
+ v1_lists = read_bridge_v1(v1_dataset_path, "", copyfile=False)
28
+ full_lists.extend(v1_lists)
29
+ v2_lists = read_bridge_v2(v2_dataset_path, "", copyfile=False)
30
+ full_lists.extend(v2_lists)
31
+ print("Full length is ", len(full_lists))
32
+
33
+
34
+ with open(store_name, 'w') as outfile:
35
+ for list_name in full_lists:
36
+ instance = dict()
37
+ instance["file_path"] = list_name
38
+
39
+ json.dump(instance, outfile)
40
+ outfile.write('\n')
41
+
42
+
43
+
44
+ # with open('output.jsonl', 'w') as outfile:
45
+ # for entry in JSON_file:
46
+ # json.dump(entry, outfile)
47
+ # outfile.write('\n')
curation_pipeline/prepare_bridge_v1.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This repository is used to prepare Bridge dataset
3
+ '''
4
+ import os, sys, shutil
5
+
6
+
7
+ def read_bridge_v1(dataset_path, train_store_path, test_store_path, test_dataset_lists, copyfile=True):
8
+ # copyfile is True when we need to copy the file to the target destination
9
+
10
+ start_idx = 0
11
+ target_lists = []
12
+ prefix_len = len(dataset_path) + 1
13
+
14
+ # Iterate all the folders inside
15
+ for scene_name in sorted(os.listdir(dataset_path)):
16
+ print("We are reading scene ", scene_name)
17
+ scene_dir = os.path.join(dataset_path, scene_name)
18
+ for task_name in sorted(os.listdir(scene_dir)):
19
+ task_dir = os.path.join(scene_dir, task_name)
20
+
21
+ for time_clock in sorted(os.listdir(task_dir)):
22
+ if time_clock == "lmdb":
23
+ continue # Skip lmdb folder
24
+
25
+ time_dir = os.path.join(task_dir, time_clock, "raw", "traj_group0")
26
+ if not os.path.exists(time_dir):
27
+ continue
28
+
29
+ for traj_name in sorted(os.listdir(time_dir)):
30
+ traj_path = os.path.join(time_dir, traj_name)
31
+ if not os.path.isdir(traj_path):
32
+ continue
33
+
34
+ # Directly move policy_out_file_path; just in case there is also valuable information there
35
+ policy_out_file_path = os.path.join(traj_path, "policy_out.pkl")
36
+ if not os.path.exists(policy_out_file_path):
37
+ continue
38
+
39
+ # Check the lang txt file
40
+ lang_txt_file_path = os.path.join(traj_path, "lang.txt")
41
+ if not os.path.exists(lang_txt_file_path):
42
+ continue
43
+
44
+
45
+ for img_name in sorted(os.listdir(traj_path)):
46
+ if img_name != "images0": # Only consider one camera angle
47
+ continue
48
+
49
+ img_folder_path = os.path.join(traj_path, img_name)
50
+ if not os.path.isdir(img_folder_path):
51
+ continue
52
+
53
+ ############################################ Main Process ####################################################
54
+
55
+ # # First Sanity check (Make sure the input source is jpg good)
56
+ # length = len(os.listdir(img_folder_path))
57
+ # status = True
58
+ # for check_idx in range(length):
59
+ # if not os.path.exists(os.path.join(img_folder_path, 'im_' + str(check_idx) + '.jpg')): # Should be sequentially exists
60
+ # status = False
61
+ # break
62
+
63
+ # Now we can copy the folder to our destination
64
+ target_lists.append(img_folder_path)
65
+ if copyfile:
66
+ print("img_folder_path[prefix_len:] is ", img_folder_path[prefix_len:])
67
+ if img_folder_path[prefix_len:] in test_dataset_lists:
68
+ # Store to test set
69
+ target_dir = os.path.join(test_store_path, str(start_idx))
70
+ else:
71
+ # This is training set
72
+ target_dir = os.path.join(train_store_path, str(start_idx))
73
+
74
+ print("Copy " + str(img_folder_path) + " to " + str(target_dir))
75
+ shutil.copytree(img_folder_path, target_dir)
76
+
77
+
78
+ # Sanity check
79
+ length = len(os.listdir(target_dir))
80
+ status = True
81
+ for check_idx in range(length):
82
+ if not os.path.exists(os.path.join(target_dir, 'im_' + str(check_idx) + '.jpg')): # Should be sequentially exists
83
+ status = False
84
+ break
85
+
86
+ if not status:
87
+ # If they didn't have sequential files we need, we will remove and begin again without updating start_idx
88
+ print("This file cannot pass the sanity check. We will remove it!")
89
+ shutil.rmtree(target_dir)
90
+ continue
91
+
92
+ # Move other auxiliary files
93
+ shutil.copy(policy_out_file_path, os.path.join(target_dir, "policy_out.pkl"))
94
+ shutil.copy(lang_txt_file_path, os.path.join(target_dir, "lang.txt"))
95
+
96
+ ################################################################################################################
97
+
98
+ # Update the idx
99
+ start_idx += 1
100
+
101
+ print("We have ", start_idx, " number of cases")
102
+
103
+ # Return a list of file path
104
+ return target_lists
105
+
106
+
107
+
108
+ if __name__ == "__main__":
109
+ dataset_path = "/Path/to/Bridge/raw/bridge_data_v1/berkeley" # Until Bridge v1 - berkeley section
110
+ train_store_path = "/Path/to/Bridge/train/bridge_v1_raw"
111
+ test_store_path = "/Path/to/Bridge/train/bridge_v1_test_raw"
112
+ test_dataset_predefined_path = "test_path.txt" # This will be providede by us
113
+
114
+
115
+ # Make dir if needed
116
+ if os.path.exists(train_store_path):
117
+ shutil.rmtree(train_store_path)
118
+ os.makedirs(train_store_path)
119
+ if os.path.exists(test_store_path):
120
+ shutil.rmtree(test_store_path)
121
+ os.makedirs(test_store_path)
122
+
123
+
124
+ # Read Test dataset path
125
+ test_dataset_lists = []
126
+ read_file = open(test_dataset_predefined_path, "r")
127
+ for line in read_file.readlines():
128
+ test_dataset_lists.append(line[:-1])
129
+ print("test_dataset_lists is ", test_dataset_lists)
130
+
131
+
132
+ read_bridge_v1(dataset_path, train_store_path, test_store_path, test_dataset_lists)
curation_pipeline/prepare_bridge_v2.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This repository is used to prepare Bridge dataset
3
+ '''
4
+ import os, sys, shutil
5
+
6
+
7
+
8
+ def read_bridge_v2(dataset_path, train_store_path, test_store_path, test_dataset_lists, copyfile=True):
9
+ # copyfile is True most of the time
10
+
11
+ start_idx = 0
12
+ target_lists = []
13
+ prefix_len = len(dataset_path) + 1
14
+
15
+ # Iterate all the folders inside
16
+ for scene_name in sorted(os.listdir(dataset_path)):
17
+ print("We are reading scene ", scene_name)
18
+ scene_dir = os.path.join(dataset_path, scene_name)
19
+
20
+ for task_name in sorted(os.listdir(scene_dir)):
21
+ task_dir = os.path.join(scene_dir, task_name)
22
+
23
+ for order_name in sorted(os.listdir(task_dir)):
24
+ order_dir = os.path.join(task_dir, order_name)
25
+
26
+ for time_clock in sorted(os.listdir(order_dir)):
27
+ if time_clock == "lmdb":
28
+ continue # Skip lmdb folder
29
+
30
+ time_dir = os.path.join(order_dir, time_clock, "raw", "traj_group0")
31
+ if not os.path.exists(time_dir):
32
+ print("time_dir does not exist for ", time_dir)
33
+ continue
34
+
35
+ for traj_name in sorted(os.listdir(time_dir)):
36
+ traj_path = os.path.join(time_dir, traj_name)
37
+ if not os.path.isdir(traj_path):
38
+ print("traj_path does not exist for ", traj_path)
39
+ continue
40
+
41
+ # Directly move policy_out_file_path; just in case there is also valuable information there
42
+ policy_out_file_path = os.path.join(traj_path, "policy_out.pkl")
43
+ if not os.path.exists(policy_out_file_path):
44
+ continue
45
+
46
+ # Check the lang txt file
47
+ lang_txt_file_path = os.path.join(traj_path, "lang.txt")
48
+ if not os.path.exists(lang_txt_file_path):
49
+ continue
50
+
51
+
52
+ for img_name in sorted(os.listdir(traj_path)):
53
+ if img_name != "images0": # Only consider one camera angle
54
+ continue
55
+
56
+ img_folder_path = os.path.join(traj_path, img_name)
57
+ if not os.path.isdir(img_folder_path):
58
+ print("img_folder_path does not exist for ", img_folder_path)
59
+ continue
60
+
61
+ ############################################ Main Process ####################################################
62
+
63
+ # # First Sanity check (Make sure the input source is jpg good)
64
+ # length = len(os.listdir(img_folder_path))
65
+ # status = True
66
+ # for check_idx in range(length):
67
+ # if not os.path.exists(os.path.join(img_folder_path, 'im_' + str(check_idx) + '.jpg')): # Should be sequentially exists
68
+ # status = False
69
+ # break
70
+
71
+ # Now we can copy the folder to our destination
72
+ target_lists.append(img_folder_path)
73
+ if copyfile:
74
+ print("img_folder_path[prefix_len:] is ", img_folder_path[prefix_len:])
75
+ if img_folder_path[prefix_len:] in test_dataset_lists:
76
+ # Store to test set
77
+ target_dir = os.path.join(test_store_path, str(start_idx))
78
+ else:
79
+ # This is training set
80
+ target_dir = os.path.join(train_store_path, str(start_idx))
81
+
82
+ # Now we can copy the folder to our destination
83
+ print("Copy " + str(img_folder_path) + " to " + str(os.path.join(train_store_path, str(start_idx))))
84
+ shutil.copytree(img_folder_path, target_dir)
85
+
86
+ # Sanity check
87
+ length = len(os.listdir(target_dir))
88
+ status = True
89
+ for check_idx in range(length):
90
+ if not os.path.exists(os.path.join(target_dir, 'im_' + str(check_idx) + '.jpg' )): # Should be sequentially exists
91
+ status = False
92
+ break
93
+
94
+ if not status:
95
+ # If they didn't have sequential files we need, we will remove and begin again without updating start_idx
96
+ print("This file cannot pass the sanity check. We will remove it!")
97
+ shutil.rmtree(target_dir)
98
+ continue
99
+
100
+ # Move other auxilary files
101
+ shutil.copy(policy_out_file_path, os.path.join(target_dir, "policy_out.pkl"))
102
+ shutil.copy(lang_txt_file_path, os.path.join(target_dir, "lang.txt"))
103
+
104
+ # Update the idx
105
+ start_idx += 1
106
+
107
+ print("We have ", start_idx)
108
+
109
+ # Return a list of file path
110
+ return target_lists
111
+
112
+
113
+
114
+ if __name__ == "__main__":
115
+ dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v2"
116
+ train_store_path = "../sanity_check/bridge_v2_raw"
117
+ test_store_path = "../sanity_check/bridge_v2_test_raw"
118
+ test_dataset_predefined_path = "test_path_v2.txt"
119
+
120
+
121
+ # Make dir if needed
122
+ if os.path.exists(train_store_path):
123
+ shutil.rmtree(train_store_path)
124
+ os.makedirs(train_store_path)
125
+ if os.path.exists(test_store_path):
126
+ shutil.rmtree(test_store_path)
127
+ os.makedirs(test_store_path)
128
+
129
+ # Read Test dataset path
130
+ test_dataset_lists = []
131
+ read_file = open(test_dataset_predefined_path, "r")
132
+ for line in read_file.readlines():
133
+ test_dataset_lists.append(line[:-1])
134
+ print("test_dataset_lists is ", test_dataset_lists)
135
+
136
+
137
+ read_bridge_v2(dataset_path, train_store_path, test_store_path, test_dataset_lists)
138
+
139
+
curation_pipeline/select_frame_with_this_that.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This repository is used to prepare Bridge dataset with this that conditioning
3
+ '''
4
+ import os, sys, shutil
5
+ import pickle
6
+ from ultralytics import YOLO
7
+ from PIL import Image, ImageDraw
8
+ import numpy as np
9
+ import cv2
10
+ import math
11
+ import collections
12
+ from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
13
+
14
+
15
+ def show_mask(mask, random_color=False):
16
+ if random_color:
17
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
18
+ else:
19
+ color = np.array([30/255, 144/255, 255/255, 0.6])
20
+ h, w = mask.shape[-2:]
21
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
22
+
23
+ return mask_image * 255
24
+
25
+
26
+ def read_center_point(model, img_path, do_visualization, store_path):
27
+
28
+ action_img = Image.open(img_path)
29
+ prediction = model.predict(source=action_img, save=False)[0] # Only 1 frame
30
+
31
+ if not hasattr(prediction, "boxes"):
32
+ print("Detection Fail: We cannot have boxes attribute")
33
+ return None, None # -1 means NAN and pass this case
34
+
35
+ # save at the temp_places for visualizaiton
36
+ if do_visualization:
37
+ prediction.save(filename=store_path)
38
+
39
+
40
+ bounding_boxes = prediction.boxes.xywh
41
+ num, dim = bounding_boxes.shape
42
+ assert(dim == 4)
43
+
44
+ # Catch up all center point of all bounding boxes
45
+ edge_point_cord = []
46
+ center_points = []
47
+ for idx in range(num):
48
+ x, y, w, h = bounding_boxes[idx].detach().cpu().numpy()
49
+ center_point = [x, y] # TODO: y+(h/4) 根据经验,往下飘逸25%的高度,一般来说比较有帮助
50
+
51
+ edge_point_cord.extend([ (x+w//2, y+h//2), (x-w//2, y+h//2), (x-w//2, y-h//2), (x+w//2, y-h//2) ])
52
+
53
+
54
+ if w <= 15 or h <= 15: # If a bounding box is too small, we will disregard this case
55
+ return None, None
56
+
57
+ # Calculate the distance between current one and previous points for sanity check
58
+ for point in center_points: # Check all previous points
59
+ give_up_threshold = 90
60
+ if center_point[0] - point[0] >= give_up_threshold:
61
+ print("Two points are too far away and neglect the case")
62
+ return None, None
63
+ if center_point[1] - point[1] >= give_up_threshold:
64
+ print("Two points are too far away and neglect the case")
65
+ return None, None
66
+
67
+ # Append to the list
68
+ center_points.append(center_point)
69
+
70
+
71
+ if len(center_points) == 0 or len(center_points) > 2:
72
+ print("Detection Fail: We cannot detect bounding boxes")
73
+ return None, None
74
+
75
+ # Calculating the average distance among center_points
76
+ if len(center_points) == 2:
77
+ first_box, second_box = center_points
78
+
79
+ center_x = (first_box[0] + second_box[0]) / 2
80
+ center_y = (first_box[1] + second_box[1]) / 2
81
+
82
+ distance = math.sqrt(abs(first_box[0] - second_box[0])**2 + abs(first_box[1] - second_box[1])**2)
83
+
84
+ return [center_x, center_y, distance], edge_point_cord
85
+
86
+ return [*center_points[0], 100], edge_point_cord # if len(center_points) == 1, distance is 0; however, to avoid 2-1-2 box detection in sequential, we set it as a higher value
87
+
88
+
89
+
90
+ def detect_gripper(gripper_detection_model, input_dir, action_start, action_end, do_visualization, store_dir, sample_failure_collect_folder=None):
91
+
92
+ # 先处理第一个point的(这个比较重要,所以要重复3次);然后再快速处理最后一个point
93
+
94
+ # Process the first action frame by iterating next three frames and choose the closest one
95
+ first_center_points = []
96
+ edge_point_cords = []
97
+ for idx in range(3): # Repeat 3 times
98
+ action_start_path = os.path.join(input_dir, "im_"+str(action_start + idx)+".jpg")
99
+ first_center_point, edge_point_cord = read_center_point(gripper_detection_model, action_start_path, do_visualization, os.path.join(store_dir, "contact_first"+str(idx)+".jpg")) # The first frame
100
+
101
+ if idx == 0 and first_center_point is None:
102
+ message = "Cannot find the first contact point!"
103
+
104
+ print("The contact point we cannot detect is at ", action_start_path)
105
+ if sample_failure_collect_folder != "":
106
+ shutil.copyfile(action_start_path, os.path.join(sample_failure_collect_folder, str(len(os.listdir(sample_failure_collect_folder)))+".jpg") )
107
+
108
+ return (None, None, message)
109
+
110
+ if first_center_point is not None:
111
+ first_center_points.append([action_start + idx, first_center_point])
112
+
113
+ # Add edge points
114
+ print(edge_point_cord)
115
+ edge_point_cords.extend(edge_point_cord) # 我有点担心所有point就这么extend会对一些的edge case不是那么robust
116
+
117
+
118
+ # Select the closest point between two
119
+ first_center_points.sort(key=lambda x: x[1][2])
120
+ first_center_point = first_center_points[0][1][:2]
121
+ start_idx = first_center_points[0][0]
122
+ print("first_center_point is " + str(first_center_point) + " with idx " + str(start_idx))
123
+ order_idx = [start_idx, action_end]
124
+
125
+
126
+ # Find the xmin, ymin, xmax, ymax for based all three points as the bounding box for the SAM
127
+ edge_point_cords.sort(key=lambda x: x[0])
128
+ xmin = int(edge_point_cords[0][0])
129
+ xmax = int(edge_point_cords[-1][0])
130
+
131
+ edge_point_cords.sort(key=lambda x: x[1])
132
+ ymin = int(edge_point_cords[0][1])
133
+ ymax = int(edge_point_cords[-1][1])
134
+
135
+ bbox_info = (xmin, xmax, ymin, ymax)
136
+
137
+
138
+ # Process the last action frame
139
+ action_end_path = os.path.join(input_dir, "im_"+str(action_end)+".jpg")
140
+ last_center_point, edge_point_cord = read_center_point(gripper_detection_model, action_end_path, do_visualization, os.path.join(store_dir, "contact_last.jpg")) # The last frame
141
+ if last_center_point is None:
142
+ message = "Cannot find the last contact point!"
143
+
144
+ print("The contact point we cannot detect is at ", action_start_path)
145
+ if sample_failure_collect_folder != "":
146
+ store_name = str(len(os.listdir(sample_failure_collect_folder))) + ".jpg"
147
+ shutil.copyfile(action_start_path, os.path.join(sample_failure_collect_folder, store_name) )
148
+
149
+ return (None, bbox_info, message)
150
+ last_center_point = last_center_point[:2]
151
+
152
+
153
+ # Check if two center points is too close, if they are too close, we will merge to one point
154
+ merge_threshold = 30
155
+ if math.sqrt((first_center_point[0] - last_center_point[0])**2 + (first_center_point[1] - last_center_point[1])**2) <= merge_threshold:
156
+ print("Merge two points to one!")
157
+ message = "Success!"
158
+ return ([[first_center_point], order_idx], bbox_info, message)
159
+
160
+
161
+ # Return needed information
162
+ message = "Success!"
163
+ return ([[first_center_point, last_center_point], order_idx], bbox_info, message)
164
+
165
+
166
+
167
+
168
+ def visualize_this_that(base_img, bbox_info, this_that_points):
169
+
170
+ # Draw a green dot only for the start point
171
+ for point in this_that_points:
172
+ print("point is ", point)
173
+ target_horizontal, target_vertical = point
174
+ target_horizontal, target_vertical = int(target_horizontal), int(target_vertical)
175
+
176
+ dot_range = 3
177
+ for i in range(-1*dot_range, dot_range+1):
178
+ for j in range(-1*dot_range, dot_range+1):
179
+ dil_vertical, dil_horizontal = target_vertical + i, target_horizontal + j
180
+ if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]):
181
+ base_img[dil_vertical, dil_horizontal, :] = [0, 128, 0]
182
+ # else:
183
+ # # print("The traj is out of boundary!!!!!!!!!!!!!!!!!!!!! and we won't consider it") # 现在
184
+ # return (False, base_img)
185
+
186
+ # Draw the bounding box
187
+ xmin, xmax, ymin, ymax = bbox_info
188
+ base_img = cv2.rectangle(base_img, (xmin, ymin), (xmax, ymax), color=(0,0,255), thickness=2)
189
+
190
+ return (True, base_img)
191
+
192
+
193
+
194
+ def manage_seq_range(input_dir, store_dir, sample_failure_collect_folder, total_frames_needed,
195
+ max_original_input_tolerate, gripper_detection_model, sam_predictor, do_visualization):
196
+
197
+ # Find valid image lists
198
+ num_frames_input = 0
199
+ for file_name in os.listdir(input_dir):
200
+ if file_name.startswith("im_"):
201
+ num_frames_input += 1
202
+ for idx in range(num_frames_input):
203
+ target_path = os.path.join(input_dir, "im_"+str(idx)+".jpg")
204
+ if not os.path.exists(target_path):
205
+ print("We don't have ", target_path)
206
+ message = "Invalid error" # Make sure that every file in this order is existed, this is quite important
207
+ return (False, message)
208
+
209
+
210
+ if num_frames_input > max_original_input_tolerate:
211
+ message = "The number of frames is too long for constructing the sequence length needed"
212
+ return (False, message)
213
+
214
+ if num_frames_input < total_frames_needed:
215
+ message = "The number of frames is too short for constructing the sequence length needed"
216
+ return (False, message)
217
+
218
+
219
+
220
+ # Prepare this and that based on policy_out.pkl
221
+ policy_out_file_path = os.path.join(input_dir, "policy_out.pkl")
222
+ with open(policy_out_file_path, "rb") as f:
223
+ policy = pickle.load(f)
224
+
225
+ actions_codes = []
226
+ action_start, action_end = None, None
227
+ for idx, item in enumerate(policy):
228
+ action_value = item["actions"][-1]
229
+ if action_start is None and action_value == 0.0:
230
+ action_start = idx
231
+
232
+ if (action_start is not None) and (action_end is None) and (action_value == 1.0):
233
+ action_end = idx # Until record the first 1.0 exists after the first 0.0 appears
234
+ actions_codes.append(action_value)
235
+
236
+ if action_start is None or action_end is None:
237
+ message = "We cannot read an action_start or action_end code!"
238
+ return (False, message) # Requires to have both start and end actions (Usually, they are a pair)
239
+
240
+ print("actions_codes is ", actions_codes)
241
+ print("the start end idx we read is ", action_start, action_end)
242
+
243
+
244
+ # Detect the gripper (should return a list with exactly two x,y coordinate points)
245
+ detection_retrun_info, bbox_info, detect_message = detect_gripper(
246
+ gripper_detection_model,
247
+ input_dir,
248
+ action_start,
249
+ action_end,
250
+ do_visualization = do_visualization,
251
+ store_dir = store_dir,
252
+ sample_failure_collect_folder = sample_failure_collect_folder,
253
+ )
254
+ if detection_retrun_info is None:
255
+ return (False, detect_message)
256
+
257
+ detected_point, old_seq_idx = detection_retrun_info
258
+ print("detected_point is ", detected_point)
259
+
260
+
261
+ # Visualize if needed
262
+ base_img = cv2.imread(os.path.join(input_dir, "im_0.jpg"))
263
+ if do_visualization:
264
+ status, visual_img = visualize_this_that(base_img, bbox_info, detected_point)
265
+ if status:
266
+ cv2.imwrite(os.path.join(store_dir, "visualization.png"), visual_img)
267
+
268
+
269
+
270
+ # SAM process based on bbox_info
271
+ xmin, xmax, ymin, ymax = bbox_info
272
+ sam_predictor.set_image(np.uint8(base_img))
273
+ positive_point_cords = np.array([[ int(detected_point[0][0]), int(detected_point[0][1]) ]])
274
+ positive_point_cords = np.array(positive_point_cords)
275
+ positive_point_labels = np.ones(len(positive_point_cords))
276
+
277
+ # Predict the mask based on the point and bounding box designed
278
+ masks, scores, logits = sam_predictor.predict(
279
+ point_coords = positive_point_cords,
280
+ point_labels = positive_point_labels,
281
+ box = np.array([xmin, ymin, xmax, ymax])[None, :],
282
+ multimask_output = False,
283
+ )
284
+ print(scores)
285
+ for mask_idx, mask in enumerate(masks):
286
+ mask_img = show_mask(mask)
287
+ cv2.imwrite(os.path.join(store_dir, "mask_" + str(mask_idx) + ".png"), mask_img)
288
+
289
+
290
+
291
+ ################################ Move the img ######################################
292
+ # Calculate needed parameters
293
+ division_factor = num_frames_input // total_frames_needed
294
+ remain_frames = (num_frames_input % total_frames_needed) - 1 # -1 for adaptation
295
+
296
+ # Define the gap
297
+ gaps = [division_factor for _ in range(total_frames_needed-1)]
298
+ for idx in range(remain_frames):
299
+ if idx % 2 == 0:
300
+ gaps[idx//2] += 1 # Start to end order
301
+ else:
302
+ gaps[-1*(1+(idx//2))] += 1 # End to start order
303
+
304
+ # Map the gap to the specific orders
305
+ idx_orders = [1] # 从1还是shift一下问题应该不大
306
+ for global_idx, gap in enumerate(gaps):
307
+ idx_orders.append(idx_orders[-1] + gap)
308
+ if idx_orders[-1] >= num_frames_input:
309
+ message = "Invalid error"
310
+ return (False, message)
311
+ # assert(idx_orders[-1] < num_frames_input)
312
+ assert(len(idx_orders) == total_frames_needed)
313
+
314
+
315
+ # Copy the essential files first
316
+ for global_idx, cur_idx in enumerate(idx_orders):
317
+ source_path = os.path.join(input_dir, "im_"+str(cur_idx)+".jpg")
318
+ destination_path = os.path.join(store_dir, "im_"+str(global_idx)+".jpg")
319
+
320
+ if not os.path.exists(source_path): # Theoretically, source_path must exists
321
+ message = "We couldn't find the source path. Theoretically, source_path must exists!" # 有一种可能就是我们丢失了一些地方,在cp或者本来就没有,记得统计数量
322
+ return (False, message)
323
+
324
+ shutil.copyfile(source_path, destination_path)
325
+
326
+ # Map order_idx to the cropped version
327
+ mapped_seq_idx = []
328
+ for old_idx in old_seq_idx:
329
+ tmp = []
330
+ for tmp_idx, new_idx in enumerate(range(len(idx_orders))):
331
+ tmp.append((tmp_idx, abs(old_idx - idx_orders[new_idx])))
332
+ # Sort the smallest fistance
333
+ tmp.sort(key=lambda x: x[1])
334
+ mapped_seq_idx.append(tmp[0][0])
335
+
336
+ print("Before the idx is ", old_seq_idx)
337
+ print("mapped idx is ", mapped_seq_idx)
338
+
339
+
340
+ # Write the information to new destination
341
+ f = open(os.path.join(store_dir, "data.txt"), "a")
342
+ f.write(str(mapped_seq_idx[0]) + " " + str(detected_point[0][0]) + " " + str(detected_point[0][1]) + "\n")
343
+ if len(detected_point) == 2: # Two points excluding the last idx
344
+ f.write(str(mapped_seq_idx[1]) + " " + str(detected_point[1][0]) + " " + str(detected_point[1][1]) + "\n")
345
+ f.close()
346
+
347
+
348
+ # Move lang.txt file
349
+ shutil.copyfile(os.path.join(input_dir, 'lang.txt'), os.path.join(store_dir, 'lang.txt'))
350
+
351
+
352
+ message = "Success!"
353
+ return (True, message)
354
+
355
+
356
+
357
+
358
+ if __name__ == "__main__":
359
+
360
+ # General storage setting
361
+ dataset_path = "../datasets_rob/Bridge_v2_raw"
362
+ destination_path = "../sanity_check/bridge_v2_TT14_longer_tolerance"
363
+ sample_failure_collect_folder = "" # This is to collect cases that fail for active learning
364
+
365
+ total_frames_needed = 14
366
+ max_original_input_tolerate = 56 # 40 for 14 fps; 60 for 25fps;
367
+ do_visualization = True
368
+
369
+
370
+ # YOLO model init
371
+ yolo_pretarined_path = "pretrained/yolov8n_best.pt"
372
+ gripper_detection_model = YOLO("yolov8n.yaml") # build a new model from scratch
373
+ gripper_detection_model = YOLO(yolo_pretarined_path) # load a pretrained model (recommended for training)
374
+
375
+ # SAM model init
376
+ model_type = "vit_h"
377
+ sam_pretrained_path = "pretrained/sam_vit_h_4b8939.pth"
378
+ sam = sam_model_registry[model_type](checkpoint=sam_pretrained_path).to(device="cuda")
379
+ sam_predictor = SamPredictor(sam) # There is a lot of setting here
380
+
381
+
382
+ # Make dir if needed
383
+ if os.path.exists(destination_path):
384
+ shutil.rmtree(destination_path)
385
+ os.makedirs(destination_path)
386
+
387
+ # Prepare the folder to collect failure cases
388
+ if sample_failure_collect_folder != "":
389
+ if os.path.exists(sample_failure_collect_folder):
390
+ shutil.rmtree(sample_failure_collect_folder)
391
+ os.makedirs(sample_failure_collect_folder)
392
+
393
+
394
+
395
+ # Collect the message
396
+ message_dict = collections.defaultdict(int)
397
+
398
+
399
+ store_idx = 0
400
+ for folder_name in sorted(os.listdir(dataset_path)):
401
+ input_folder_path = os.path.join(dataset_path, folder_name)
402
+ store_folder_path = os.path.join(destination_path, "0"*(6-len(str(store_idx)))+str(store_idx))
403
+ print("We are processing ", input_folder_path)
404
+
405
+ # Prepare store_folder_path folder
406
+ os.makedirs(store_folder_path)
407
+
408
+ status, message = manage_seq_range(input_folder_path, store_folder_path, sample_failure_collect_folder, total_frames_needed, max_original_input_tolerate, gripper_detection_model, sam_predictor, do_visualization)
409
+ if status: # We will only update the store_idx only when this file is successfully written
410
+ store_idx += 1
411
+ else:
412
+ print("This status failed! Message: " + message)
413
+ shutil.rmtree(store_folder_path)
414
+ # break # For debug
415
+
416
+ # Collect the infor to dict
417
+ message_dict[message] += 1
418
+
419
+ print("We have " + str(store_idx) + " valid dataset")
420
+ print("message_dict info is ", message_dict)
421
+
curation_pipeline/tracking_by_keypoint.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, shutil, sys
2
+ import argparse
3
+ import gdown
4
+ import cv2
5
+ import numpy as np
6
+ import os
7
+ import sys
8
+ import requests
9
+ import json
10
+ import torchvision
11
+ import torch
12
+ import psutil
13
+ import time
14
+ try:
15
+ from mmcv.cnn import ConvModule
16
+ except:
17
+ os.system("mim install mmcv")
18
+
19
+
20
+ # Import files from the local folder
21
+ root_path = os.path.abspath('.')
22
+ sys.path.append(root_path)
23
+ from track_anything_code.model import TrackingAnything
24
+ from track_anything_code.track_anything_module import get_frames_from_video, download_checkpoint, parse_augment, sam_refine, vos_tracking_video
25
+ from scripts.compress_videos import compress_video
26
+
27
+
28
+
29
+
30
+ if __name__ == "__main__":
31
+ dataset_path = "Bridge_v1_TT14"
32
+ video_name = "combined.mp4"
33
+ verbose = True # If this is verbose, you will continue to write the code
34
+
35
+
36
+ ################################################## Model setup ####################################################
37
+ # check and download checkpoints if needed
38
+ sam_checkpoint = "sam_vit_h_4b8939.pth"
39
+ sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
40
+ xmem_checkpoint = "XMem-s012.pth"
41
+ xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
42
+
43
+
44
+ folder ="./pretrained"
45
+ SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
46
+ xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
47
+
48
+ # argument
49
+ args = parse_augment()
50
+ args.device = "cuda" # Any GPU is ok
51
+
52
+ # Initialize the Track model
53
+ track_model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
54
+ ###################################################################################################################
55
+
56
+
57
+ # Iterate all files under the folder
58
+ for sub_folder_name in sorted(os.listdir(dataset_path)):
59
+
60
+ ################################################## Setting ####################################################
61
+ sub_folder_path = os.path.join(dataset_path, sub_folder_name)
62
+
63
+ click_state = [[],[]]
64
+ interactive_state = {
65
+ "inference_times": 0,
66
+ "negative_click_times" : 0,
67
+ "positive_click_times": 0,
68
+ "mask_save": args.mask_save,
69
+ "multi_mask": {
70
+ "mask_names": [],
71
+ "masks": []
72
+ },
73
+ "track_end_number": None,
74
+ "resize_ratio": 1
75
+ }
76
+ ###################################################################################################################
77
+
78
+
79
+ video_path = os.path.join(sub_folder_path, video_name)
80
+ if not os.path.exists(video_path):
81
+ print("We cannot find the path of the ", video_path, " and we will compress one")
82
+ status = compress_video(sub_folder_path, video_name)
83
+ if not status:
84
+ print("We still cannot generate a video")
85
+ continue
86
+
87
+ # Read video state
88
+ video_state = {
89
+ "user_name": "",
90
+ "video_name": "",
91
+ "origin_images": None,
92
+ "painted_images": None,
93
+ "masks": None,
94
+ "inpaint_masks": None,
95
+ "logits": None,
96
+ "select_frame_number": 0,
97
+ "fps": 30
98
+ }
99
+ video_state, template_frame = get_frames_from_video(video_path, video_state, track_model)
100
+
101
+
102
+
103
+ ########################################################## Get the sam point based on the data.txt ###########################################################
104
+ data_txt_path = os.path.join(sub_folder_path, "data.txt")
105
+ if not os.path.exists(data_txt_path):
106
+ print("We cannot find data.txt in this folder")
107
+ continue
108
+
109
+ data_file = open(data_txt_path, 'r')
110
+ lines = data_file.readlines()
111
+ frame_idx, horizontal, vertical = lines[0][:-2].split(' ') # Only read the first point
112
+ point_cord = [int(float(horizontal)), int(float(vertical))]
113
+
114
+ # Process by SAM
115
+ track_model.samcontroler.sam_controler.reset_image() # Reset the image to clean history
116
+ painted_image, video_state, interactive_state, operation_log = sam_refine(track_model, video_state, "Positive", click_state, interactive_state, point_cord)
117
+ ################################################################################################################################################################
118
+
119
+
120
+
121
+ ######################################################### Get the tracking output ########################################################################
122
+
123
+ # Track the video for processing
124
+ segment_output_path = os.path.join(sub_folder_path, "segment_output.gif")
125
+ video_state = vos_tracking_video(track_model, segment_output_path, video_state, interactive_state, mask_dropdown=[])[0] # mask_dropdown is empty now
126
+
127
+ # Extract the mask needed by us for further point calculating
128
+ masks = video_state["masks"] # In the range [0, 1]
129
+
130
+ if verbose:
131
+ for idx, mask in enumerate(masks):
132
+ cv2.imwrite(os.path.join(sub_folder_path, "mask"+str(idx)+".png"), mask*255)
133
+
134
+ ##############################################################################################################################################################
135
+
136
+
data_loader/video_dataset.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import json
3
+ import cv2
4
+ import math
5
+ import shutil
6
+ import numpy as np
7
+ import random
8
+ import collections
9
+ from PIL import Image
10
+ import torch
11
+ from torch.utils.data import Dataset
12
+
13
+ # Import files from the local folder
14
+ root_path = os.path.abspath('.')
15
+ sys.path.append(root_path)
16
+ from utils.img_utils import resize_with_antialiasing, numpy_to_pt
17
+
18
+
19
+
20
+ def get_video_frames(config, video_frame_path, flip = False):
21
+
22
+ video_seq_length = config["video_seq_length"]
23
+
24
+ # Calculate needed parameters
25
+ num_frames_input = 0
26
+ for file_name in os.listdir(video_frame_path):
27
+ if file_name.startswith("im_"):
28
+ num_frames_input += 1
29
+ total_frames_needed = video_seq_length
30
+ division_factor = num_frames_input // total_frames_needed
31
+ remain_frames = (num_frames_input % total_frames_needed) - 1 # -1 for adaptation
32
+
33
+
34
+ # Define the gap
35
+ gaps = [division_factor for _ in range(total_frames_needed-1)]
36
+ for idx in range(remain_frames):
37
+ if idx % 2 == 0:
38
+ gaps[idx//2] += 1 # Start to end order
39
+ else:
40
+ gaps[-1*(1+(idx//2))] += 1 # End to start order
41
+
42
+
43
+ # Find needed file
44
+ needed_img_path = []
45
+ cur_idx = 0
46
+ for gap in gaps:
47
+ img_path = os.path.join(video_frame_path, "im_" + str(cur_idx) + ".jpg")
48
+ needed_img_path.append(img_path)
49
+
50
+ # Update the idx
51
+ cur_idx += gap
52
+ # Append the last one
53
+ img_path = os.path.join(video_frame_path, "im_" + str(cur_idx) + ".jpg")
54
+ needed_img_path.append(img_path)
55
+
56
+
57
+ # Read all img_path based on the order
58
+ video_frames = []
59
+ for img_path in needed_img_path:
60
+ if not os.path.exists(img_path):
61
+ print("We don't have ", img_path)
62
+ frame = cv2.imread(img_path)
63
+
64
+ try:
65
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
66
+ except Exception:
67
+ print("The exception places is ", img_path)
68
+
69
+ # Resize frames
70
+ frame = cv2.resize(frame, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC)
71
+
72
+ # Flip aug
73
+ if flip:
74
+ frame = np.fliplr(frame)
75
+
76
+ # Collect frames
77
+ video_frames.append(np.expand_dims(frame, axis=0)) # The frame is already RGB, there is no need to convert here.
78
+
79
+
80
+ # Concatenate
81
+ video_frames = np.concatenate(video_frames, axis=0)
82
+ assert(len(video_frames) == video_seq_length)
83
+
84
+ return video_frames
85
+
86
+
87
+
88
+ def tokenize_captions(prompt, tokenizer, config, is_train=True):
89
+ '''
90
+ Tokenize text prompt be prepared tokenizer from SD2.1
91
+ '''
92
+
93
+ captions = []
94
+ if random.random() < config["empty_prompts_proportion"]:
95
+ captions.append("")
96
+ elif isinstance(prompt, str):
97
+ captions.append(prompt)
98
+ elif isinstance(prompt, (list, np.ndarray)):
99
+ # take a random caption if there are multiple
100
+ captions.append(random.choice(prompt) if is_train else prompt[0])
101
+ else:
102
+ raise ValueError(
103
+ f"Caption column should contain either strings or lists of strings."
104
+ )
105
+
106
+ inputs = tokenizer(
107
+ captions, max_length = tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
108
+ )
109
+ return inputs.input_ids[0]
110
+
111
+
112
+
113
+ class Video_Dataset(Dataset):
114
+ '''
115
+ Video Dataset to load sequential frames for training with needed pre-processing
116
+ '''
117
+
118
+ def __init__(self, config, device, normalize=True, tokenizer=None):
119
+
120
+ # Attribute variables
121
+ self.config = config
122
+ self.device = device
123
+ self.normalize = normalize
124
+ self.tokenizer = tokenizer
125
+
126
+ # Obtain values
127
+ self.video_seq_length = config["video_seq_length"]
128
+ self.height = config["height"]
129
+ self.width = config["width"]
130
+
131
+ # Process data
132
+ self.video_lists = []
133
+ stats_analysis = collections.defaultdict(int)
134
+ print("Process all files to check valid datasets....")
135
+ for dataset_path in config["dataset_path"]:
136
+ for video_name in sorted(os.listdir(dataset_path)):
137
+ video_path = os.path.join(dataset_path, video_name)
138
+ all_files = os.listdir(video_path)
139
+
140
+
141
+ valid = True
142
+ # Valid check 1: the number of files should be in sequential order
143
+ num_frames_input = 0
144
+ for file_name in os.listdir(video_path):
145
+ if file_name.startswith("im_"):
146
+ num_frames_input += 1
147
+ for idx in range(num_frames_input):
148
+ img_path = 'im_' + str(idx) + '.jpg'
149
+ if img_path not in all_files: # Should be sequential existing
150
+ valid = False
151
+ stats_analysis["incomplete_img"] += 1
152
+ break
153
+
154
+
155
+ # Valid check 1.5: the number of files must be longer than video_seq_length and less than self.config["acceleration_tolerance"]*self.config["video_seq_length"]
156
+ if num_frames_input < self.config["video_seq_length"]:
157
+ stats_analysis["too_little_frames"] += 1
158
+ valid = False
159
+ if num_frames_input > self.config["acceleration_tolerance"] * self.config["video_seq_length"]:
160
+ stats_analysis["too_many_frames"] += 1
161
+ valid = False
162
+
163
+ if not valid: # SpeedUp so set in the middle here
164
+ continue
165
+
166
+
167
+ # Valid check 2: language if needed
168
+ if config["use_text"] and not os.path.exists(os.path.join(dataset_path, video_name, "lang.txt")):
169
+ stats_analysis["no_lang_txt"] += 1
170
+ valid = False
171
+
172
+
173
+ # Valid check 3: motion if needed
174
+ if config["motion_bucket_id"] is None:
175
+ flow_path = os.path.join(dataset_path, video_name, "flow.txt")
176
+ if "flow.txt" not in all_files:
177
+ stats_analysis["no_flow_txt"] += 1
178
+ valid = False
179
+ else:
180
+ file = open(flow_path, 'r')
181
+ info = file.readlines()
182
+ if len(info) == 0:
183
+ stats_analysis["no_flow_txt"] += 1
184
+ valid = False
185
+
186
+
187
+ if valid:
188
+ self.video_lists.append(video_path)
189
+ print("stats_analysis is ", stats_analysis)
190
+ print("Valid dataset length is ", len(self.video_lists))
191
+
192
+
193
+ def __len__(self):
194
+ return len(self.video_lists)
195
+
196
+
197
+
198
+ def _get_motion_value(self, sub_folder_path):
199
+ ''' Read the motion value from the flow.txt file prepared; preprocess the flow to accelerate
200
+ '''
201
+
202
+ # Read the flow.txt
203
+ flow_path = os.path.join(sub_folder_path, 'flow.txt')
204
+ file = open(flow_path, 'r')
205
+ info = file.readlines()
206
+ per_video_movement = float(info[0][:-2])
207
+
208
+ # Map the raw reflected_motion_bucket_id to target range based on the number of images have
209
+ num_frames_input = 0
210
+ for file_name in os.listdir(sub_folder_path): # num_frames_input is the total number of files with name begin with im_
211
+ if file_name.startswith("im_"):
212
+ num_frames_input += 1
213
+
214
+ # Correct the value based on the number of frames relative to video_seq_length
215
+ per_video_movement_correct = per_video_movement * (num_frames_input/self.config["video_seq_length"])
216
+
217
+ # Map from one Normal Distribution to another Normal Distribution
218
+ z = (per_video_movement_correct - self.config["dataset_motion_mean"]) / (self.config["dataset_motion_std"] + 0.001)
219
+ reflected_motion_bucket_id = int((z * self.config["svd_motion_std"]) + self.config["svd_motion_mean"])
220
+
221
+
222
+ print("We map " + str(per_video_movement) + " to " + str(per_video_movement_correct) + " by length " + str(num_frames_input) + " to bucket_id of " + str(reflected_motion_bucket_id))
223
+ return reflected_motion_bucket_id
224
+
225
+
226
+
227
+ def __getitem__(self, idx):
228
+ ''' Get item by idx and pre-process by Resize and Normalize to [0, 1]
229
+ Args:
230
+ idx (int): The index to the file in the directory
231
+ Returns:
232
+ video_frames (torch.float32): The Pytorch tensor format of obtained frames (max: 1.0; min: 0.0)
233
+ reflected_motion_bucket_id (tensor): Motion value is there is optical flow provided, else they are fixed value from config
234
+ prompt (tensor): Tokenized text
235
+ '''
236
+
237
+ # Prepare the text if needed:
238
+ if self.config["use_text"]:
239
+ # Read the file
240
+ file_path = os.path.join(self.video_lists[idx], "lang.txt")
241
+ file = open(file_path, 'r')
242
+ prompt = file.readlines()[0] # Only read the first line
243
+
244
+ if self.config["mix_ambiguous"] and os.path.exists(os.path.join(self.video_lists[idx], "processed_text.txt")):
245
+ # If we don't have this txt file, we skip
246
+
247
+ ######################################################## Mix up prompt ########################################################
248
+
249
+ # Read the file
250
+ file_path = os.path.join(self.video_lists[idx], "processed_text.txt")
251
+ file = open(file_path, 'r')
252
+ prompts = [line for line in file.readlines()] # Only read the first line
253
+
254
+ # Get the componenet
255
+ action = prompts[0][:-1]
256
+ this = prompts[1][:-1]
257
+ there = prompts[2][:-1]
258
+
259
+
260
+ random_value = random.random()
261
+ # If less than 0.4, we don't care, just use the most concrete one
262
+ if random_value >= 0.4 and random_value < 0.6:
263
+ # Mask pick object to "This"
264
+ prompt = action + " this to " + there
265
+ elif random_value >= 0.6 and random_value < 0.8:
266
+ # Mask place position to "There"
267
+ prompt = action + " " + this + " to there"
268
+ elif random_value >= 0.8 and random_value < 1.0:
269
+ # Just be like "this to there"
270
+ prompt = action + " this to there"
271
+
272
+ # print("New prompt is ", prompt)
273
+ ###################################################################################################################################################
274
+
275
+ # else:
276
+ # print("We don't have llama processed prompt at ", self.video_lists[idx])
277
+
278
+ else:
279
+ prompt = ""
280
+
281
+ # Tokenize text prompt
282
+ tokenized_prompt = tokenize_captions(prompt, self.tokenizer, self.config)
283
+
284
+
285
+ # Dataset aug by chance (it is needed to check whether there is any object position words [left|right] in the prompt text)
286
+ flip = False
287
+ if random.random() < self.config["flip_aug_prob"]:
288
+ if self.config["use_text"]:
289
+ if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok)
290
+ flip = True
291
+ else:
292
+ flip = True
293
+
294
+
295
+ # Read frames for different datasets; Currently, we have WebVid / Bridge
296
+ if self.config["dataset_name"] == "Bridge":
297
+ video_frames = get_video_frames(self.config, self.video_lists[idx], flip=flip)
298
+ else:
299
+ raise NotImplementedError("We don't support this dataset loader")
300
+
301
+
302
+ # Scale [0, 255] -> [-1, 1]
303
+ if self.normalize:
304
+ video_frames = video_frames.astype(np.float32) / 127.5 - 1 # Be careful to cast to float32
305
+
306
+ # Transform to Pytorch Tensor in the range [-1, 1]
307
+ video_frames = numpy_to_pt(video_frames)
308
+ # print("length of input frames has ", len(video_frames))
309
+
310
+
311
+ # Get the motion value based on the optical flow
312
+ if self.config["motion_bucket_id"] is None:
313
+ reflected_motion_bucket_id = self._get_motion_value(self.video_lists[idx])
314
+ else:
315
+ reflected_motion_bucket_id = self.config["motion_bucket_id"]
316
+
317
+
318
+ # The tensor we returned is torch float32. We won't cast here for mixed precision training!
319
+ return {
320
+ "video_frames" : video_frames,
321
+ "reflected_motion_bucket_id" : reflected_motion_bucket_id,
322
+ "prompt": tokenized_prompt,
323
+ }
data_loader/video_this_that_dataset.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import json
3
+ import cv2
4
+ import math
5
+ import shutil
6
+ import numpy as np
7
+ import random
8
+ from PIL import Image
9
+ import torch.nn.functional as F
10
+ import torch
11
+ import os.path as osp
12
+ import time
13
+ from moviepy.editor import VideoFileClip
14
+ from torch.utils.data import Dataset
15
+
16
+ # Import files from the local folder
17
+ root_path = os.path.abspath('.')
18
+ sys.path.append(root_path)
19
+ from utils.img_utils import resize_with_antialiasing, numpy_to_pt
20
+ from utils.optical_flow_utils import flow_to_image, filter_uv, bivariate_Gaussian
21
+ from data_loader.video_dataset import tokenize_captions
22
+
23
+
24
+ # For the 2D dilation
25
+ blur_kernel = bivariate_Gaussian(99, 10, 10, 0, grid = None, isotropic = True)
26
+
27
+
28
+ def get_thisthat_sam(config, intput_dir, store_dir = None, flip = False, verbose=False):
29
+ '''
30
+ Args:
31
+ idx (int): The index to the folder we need to process
32
+ '''
33
+
34
+ # Read file
35
+ file_path = os.path.join(intput_dir, "data.txt")
36
+ file1 = open(file_path, 'r')
37
+ Lines = file1.readlines()
38
+
39
+
40
+ # Initial the optical flow format we want
41
+ thisthat_condition = np.zeros((config["video_seq_length"], config["conditioning_channels"], config["height"], config["width"]), dtype=np.float32) # The last image should be empty
42
+
43
+
44
+ # Init the image
45
+ sample_img = cv2.imread(os.path.join(intput_dir, "im_0.jpg"))
46
+ org_height, org_width, _ = sample_img.shape
47
+
48
+ # Prepare masking
49
+ controlnet_image_index = []
50
+ coordinate_values = []
51
+
52
+ # Iterate all points in the txt file
53
+ for idx in range(len(Lines)):
54
+
55
+ # Read points
56
+ frame_idx, horizontal, vertical = Lines[idx].split(' ')
57
+ frame_idx, vertical, horizontal = int(frame_idx), int(float(vertical)), int(float(horizontal))
58
+
59
+ # Read the mask frame idx
60
+ controlnet_image_index.append(frame_idx)
61
+ coordinate_values.append((vertical, horizontal))
62
+
63
+
64
+ # Init the base image
65
+ base_img = np.zeros((org_height, org_width, 3)).astype(np.float32) # Use the original image size
66
+ base_img.fill(255)
67
+
68
+ # Draw square around the target position
69
+ dot_range = 10 # Diameter
70
+ for i in range(-1*dot_range, dot_range+1):
71
+ for j in range(-1*dot_range, dot_range+1):
72
+ dil_vertical, dil_horizontal = vertical + i, horizontal + j
73
+ if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]):
74
+ if idx == 0:
75
+ base_img[dil_vertical][dil_horizontal] = [0, 0, 255] # The first point should be red
76
+ else:
77
+ base_img[dil_vertical][dil_horizontal] = [0, 255, 0] # The second point should be green to distinguish the first point
78
+
79
+ # Dilate
80
+ if config["dilate"]:
81
+ base_img = cv2.filter2D(base_img, -1, blur_kernel)
82
+
83
+
84
+ ##############################################################################################################################
85
+ ### The core pipeline of processing is: Dilate -> Resize -> Range Shift -> Transpose Shape -> Store
86
+
87
+ # Resize frames Don't use negative and don't resize in [0,1]
88
+ base_img = cv2.resize(base_img, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC)
89
+
90
+
91
+ # Flip the image for aug if needed
92
+ if flip:
93
+ base_img = np.fliplr(base_img)
94
+
95
+
96
+ # Channel Transform and Range Shift
97
+ if config["conditioning_channels"] == 3:
98
+ # Map to [0, 1] range
99
+ if store_dir is not None and verbose: # For the first frame condition visualization
100
+ cv2.imwrite(os.path.join(store_dir, "condition_TT"+str(idx)+".png"), base_img)
101
+ base_img = base_img / 255.0
102
+
103
+ else:
104
+ raise NotImplementedError()
105
+
106
+
107
+ # ReOrganize shape
108
+ base_img = base_img.transpose(2, 0, 1) # hwc -> chw
109
+
110
+
111
+ # Check the min max value range
112
+ # if verbose:
113
+ # print("{} min, max range value is {} - {}".format(intput_dir, np.min(base_img), np.max(base_img)))
114
+
115
+
116
+ # Write base img based on frame_idx
117
+ thisthat_condition[frame_idx] = base_img # Only the first frame, the rest is 0 initialized
118
+
119
+ ##############################################################################################################################
120
+
121
+
122
+ if config["motion_bucket_id"] is None:
123
+ # take the motion to stats collected before
124
+ reflected_motion_bucket_id = 200
125
+ else:
126
+ reflected_motion_bucket_id = config["motion_bucket_id"]
127
+
128
+
129
+ # print("Motion Bucket ID is ", reflected_motion_bucket_id)
130
+ return (thisthat_condition, reflected_motion_bucket_id, controlnet_image_index, coordinate_values)
131
+
132
+
133
+
134
+ class Video_ThisThat_Dataset(Dataset):
135
+ '''
136
+ Video Dataset to load sequential frames for training with needed pre-processing and process with optical flow
137
+ '''
138
+
139
+ def __init__(self, config, device, normalize=True, tokenizer=None):
140
+ # Attribute variables
141
+ self.config = config
142
+ self.device = device
143
+ self.normalize = normalize
144
+ self.tokenizer = tokenizer
145
+
146
+ # Obtain values
147
+ self.video_seq_length = config["video_seq_length"]
148
+ self.height = config["height"]
149
+ self.width = config["width"]
150
+
151
+ # Process data
152
+ self.video_lists = []
153
+ for dataset_path in config["dataset_path"]:
154
+ for video_name in sorted(os.listdir(dataset_path)):
155
+ if not os.path.exists(os.path.join(dataset_path, video_name, "data.txt")):
156
+ continue
157
+
158
+ self.video_lists.append(os.path.join(dataset_path, video_name))
159
+ print("length of the dataset is ", len(self.video_lists))
160
+
161
+
162
+
163
+
164
+ def __len__(self):
165
+ return len(self.video_lists)
166
+
167
+
168
+ def _extract_frame_bridge(self, idx, flip=False):
169
+ ''' Extract the frame in video based on the needed fps from already extracted frame
170
+ Args:
171
+ idx (int): The index to the file in the directory
172
+ flip (bool): Bool for whether we will flip
173
+ Returns:
174
+ video_frames (numpy): Extracted video frames in numpy format
175
+ '''
176
+
177
+ # Init the the Video Reader
178
+ # The naming of the Bridge dataset follow a pattern: im_x.jpg, so we need to
179
+ video_frame_path = self.video_lists[idx]
180
+
181
+
182
+ # Find needed file
183
+ needed_img_path = []
184
+ for idx in range(self.video_seq_length):
185
+ img_path = os.path.join(video_frame_path, "im_" + str(idx) + ".jpg")
186
+ needed_img_path.append(img_path)
187
+
188
+
189
+
190
+ # Read all img_path based on the order
191
+ video_frames = []
192
+ for img_path in needed_img_path:
193
+ if not os.path.exists(img_path):
194
+ print("We don't have ", img_path)
195
+ frame = cv2.imread(img_path)
196
+
197
+ try:
198
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
199
+ except Exception:
200
+ print("The exception place is ", img_path)
201
+ # Resize frames
202
+ frame = cv2.resize(frame, (self.width, self.height), interpolation = cv2.INTER_CUBIC)
203
+
204
+ # Flip aug
205
+ if flip:
206
+ frame = np.fliplr(frame)
207
+
208
+ # Collect frames
209
+ video_frames.append(np.expand_dims(frame, axis=0)) # The frame is already RGB, there is no need to convert here.
210
+
211
+
212
+ # Concatenate
213
+ video_frames = np.concatenate(video_frames, axis=0)
214
+ assert(len(video_frames) == self.video_seq_length)
215
+
216
+ # Returns
217
+ return video_frames
218
+
219
+
220
+
221
+
222
+ def __getitem__(self, idx):
223
+ ''' Get item by idx and pre-process by Resize and Normalize to [0, 1]
224
+ Args:
225
+ idx (int): The index to the file in the directory
226
+ Returns:
227
+ return_dict (dict): video_frames (torch.float32) [-1, 1] and controlnet_condition (torch.float32) [0, 1]
228
+ '''
229
+
230
+ # Prepare the text if needed:
231
+ if self.config["use_text"]:
232
+ # Read the file
233
+ file_path = os.path.join(self.video_lists[idx], "lang.txt")
234
+ file = open(file_path, 'r')
235
+ prompt = file.readlines()[0] # Only read the first line
236
+
237
+ if self.config["mix_ambiguous"] and os.path.exists(os.path.join(self.video_lists[idx], "processed_text.txt")):
238
+ # If we don't have this txt file, we skip
239
+
240
+ ######################################################## Mix up prompt ########################################################
241
+
242
+ # Read the file
243
+ file_path = os.path.join(self.video_lists[idx], "processed_text.txt")
244
+ file = open(file_path, 'r')
245
+ prompts = [line for line in file.readlines()] # Only read the first line
246
+
247
+ # Get the componenet
248
+ action = prompts[0][:-1]
249
+ this = prompts[1][:-1]
250
+ there = prompts[2][:-1]
251
+
252
+
253
+ random_value = random.random()
254
+ # If less than 0.4, we don't care, just use the most concrete one
255
+ if random_value >= 0.4 and random_value < 0.6:
256
+ # Mask pick object to "This"
257
+ prompt = action + " this to " + there
258
+ elif random_value >= 0.6 and random_value < 0.8:
259
+ # Mask place position to "There"
260
+ prompt = action + " " + this + " to there"
261
+ elif random_value >= 0.8 and random_value < 1.0:
262
+ # Just be like "this to there"
263
+ prompt = action + " this to there"
264
+
265
+ # print("New prompt is ", prompt)
266
+ ###################################################################################################################################################
267
+
268
+ # else:
269
+ # print("We don't have llama processed prompt at ", self.video_lists[idx])
270
+
271
+ else:
272
+ prompt = ""
273
+
274
+ # Tokenize text prompt
275
+ tokenized_prompt = tokenize_captions(prompt, self.tokenizer, self.config)
276
+
277
+
278
+
279
+ # Dataset aug by chance (it is needed to check whether there is any object position words [left|right] in the prompt text)
280
+ flip = False
281
+ if random.random() < self.config["flip_aug_prob"]:
282
+ if self.config["use_text"]:
283
+ if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok)
284
+ flip = True
285
+ else:
286
+ flip = True
287
+
288
+
289
+
290
+ # Read frames for different dataset; Currently, we have WebVid / Bridge
291
+ if self.config["dataset_name"] == "Bridge":
292
+ video_frames_raw = self._extract_frame_bridge(idx, flip=flip)
293
+ else:
294
+ raise NotImplementedError("We don't support this dataset loader")
295
+
296
+
297
+ # Scale [0, 255] -> [-1, 1] if needed
298
+ if self.normalize:
299
+ video_frames = video_frames_raw.astype(np.float32) / 127.5 - 1 # Be careful to cast to float32
300
+
301
+ # Transform to Pytorch Tensor in the range [-1, 1]
302
+ video_frames = numpy_to_pt(video_frames)
303
+
304
+
305
+ # Generate the pairs we need
306
+ intput_dir = self.video_lists[idx]
307
+
308
+ # Get the This That point information
309
+ controlnet_condition, reflected_motion_bucket_id, controlnet_image_index, coordinate_values = get_thisthat_sam(self.config, intput_dir, flip=flip)
310
+ controlnet_condition = torch.from_numpy(controlnet_condition)
311
+
312
+ # Cast other value to tensor
313
+ reflected_motion_bucket_id = torch.tensor(reflected_motion_bucket_id, dtype=torch.float32)
314
+ controlnet_image_index = torch.tensor(controlnet_image_index, dtype=torch.int32)
315
+ coordinate_values = torch.tensor(coordinate_values, dtype=torch.int32)
316
+
317
+
318
+ # The tensor we returned is torch float32. We won't cast here for mixed precision training!
319
+ return {"video_frames" : video_frames,
320
+ "controlnet_condition" : controlnet_condition,
321
+ "reflected_motion_bucket_id" : reflected_motion_bucket_id,
322
+ "controlnet_image_index": controlnet_image_index,
323
+ "prompt": tokenized_prompt,
324
+ "coordinate_values": coordinate_values, # Useless now, but I still passed back
325
+ }
326
+
pretrained/PUT_YOUR_WEIGHT_HERE.md ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Non-strict version lib
2
+ opencv-python
3
+ transformers
4
+ accelerate
5
+ requests
6
+ moviepy
7
+ omegaconf
8
+ # xformers
9
+ tensorboard
10
+ einops
11
+ yacs
12
+ loguru
13
+ imageio
14
+ pyparsing
15
+ ultralytics
16
+ lpips
17
+ matplotlib
18
+ gradio
19
+ torch==2.0.1
20
+ torchvision
21
+
22
+ # Strict version lib
23
+ bitsandbytes==0.43.0
24
+ diffusers==0.25.1
25
+ timm==0.4.12
26
+ scipy==1.9.3
27
+ pyiqa==0.1.7
scripts/active_learning_select.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, shutil
2
+ import random
3
+
4
+
5
+ if __name__ == "__main__":
6
+ start_idx = 950
7
+ end_idx = 1020
8
+ select_num = 70
9
+
10
+ label_start_idx = 632
11
+ input_parent_dir = "../Bridge"
12
+ store_dir = "../bridge_select3"
13
+
14
+ if os.path.exists(store_dir):
15
+ shutil.rmtree(store_dir)
16
+ os.makedirs(store_dir)
17
+
18
+ for idx in range(start_idx, end_idx):
19
+ folder_path = os.path.join(input_parent_dir, str(idx))
20
+ select_idx = random.randint(0, len(os.listdir(folder_path)))
21
+ for idx, img_name in enumerate(os.listdir(folder_path)):
22
+ if idx == select_idx and img_name != "policy_out.pkl":
23
+ img_path = os.path.join(folder_path, img_name)
24
+ target_path = os.path.join(store_dir, str(label_start_idx) + ".jpg")
25
+ label_start_idx += 1
26
+ shutil.copy(img_path, target_path)
27
+
scripts/add_point2img.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file is to add point to the first image
3
+ '''
4
+
5
+ import os, shutil, sys
6
+
7
+ if __name__ == "__main__":
8
+ input_folder_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/model_results/Human_Study/Input_Bridge_human_evaluation"
9
+ store_path = "point_highlighted"
10
+
11
+ if os.path.exists(input_folder_path):
12
+ shutil.rmtree(input_folder_path)
13
+ os.makedirs(input_folder_path)
14
+
15
+
16
+ for instance_name in os.listdir(input_folder_path):
17
+
18
+ sub_folder_dir = os.path.join(input_folder_path, instance_name)
19
+
20
+ # Read file
21
+ file_path = os.path.join(sub_folder_dir, "data.txt")
22
+ file1 = open(file_path, 'r')
23
+ Lines = file1.readlines()
24
+
25
+ # Read the first img
26
+ first_img_path = os.path.join(sub_folder_dir, "im_0.jpg")
27
+
28
+
29
+ # Init the image
30
+ base_img = cv2.imread(first_img_path).astype(np.float32) # Use the original image size
31
+
32
+ # Draw the point
33
+ for idx in range(len(Lines)):
34
+ # Read points
35
+ frame_idx, horizontal, vertical = Lines[idx].split(' ')
36
+ frame_idx, vertical, horizontal = int(frame_idx), int(float(vertical)), int(float(horizontal))
37
+
38
+ # Draw square around the target position
39
+ dot_range = 15 # Diameter
40
+ for i in range(-1*dot_range, dot_range+1):
41
+ for j in range(-1*dot_range, dot_range+1):
42
+ dil_vertical, dil_horizontal = vertical + i, horizontal + j
43
+ if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]):
44
+ if idx == 0:
45
+ base_img[dil_vertical][dil_horizontal] = [0, 0, 255] # The first point should be red
46
+ else:
47
+ base_img[dil_vertical][dil_horizontal] = [0, 255, 0] # The second point should be green to distinguish the first point
48
+
49
+
50
+
51
+
scripts/check_video.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file is to make sure that the video files is readeable by moviepy, such that the data loader can read these files.
3
+ '''
4
+ import os
5
+ from moviepy.editor import VideoFileClip
6
+
7
+ if __name__ == "__main__":
8
+ video_dir = "../webvid_sample"
9
+ delete_abnormal_video = True # Whether you want to delete these abnormal video directly
10
+
11
+ for video_name in sorted(os.listdir(video_dir)):
12
+ video_path = os.path.join(video_dir, video_name)
13
+ try:
14
+ objVideoreader = VideoFileClip(filename=video_path)
15
+ except Exception:
16
+ print("There is an exception of reading: ", video_path)
17
+ if delete_abnormal_video:
18
+ print("We will remove this abnormal video source")
19
+ os.remove(video_path)
scripts/clean_bridge_dataset.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Sometimes, Bridge dataset will contain strange downloads, we need to clean them
3
+ '''
4
+ import os, shutil
5
+
6
+ # TODO: 后面把这个直接merge 到prepare_bridge_dataset中
7
+ if __name__ == "__main__":
8
+ dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/Bridge"
9
+
10
+ for sub_folder in sorted(os.listdir(dataset_path)):
11
+ sub_folder_path = os.path.join(dataset_path, sub_folder)
12
+
13
+ img_lists = os.listdir(sub_folder_path)
14
+ if len(img_lists) < 14:
15
+ print("The folder is too short, we will remove them all")
16
+ shutil.rmtree(sub_folder_path)
17
+ continue
18
+ for img_name in img_lists:
19
+ img_path = os.path.join(sub_folder_path, img_name)
20
+ if not img_name.startswith("im_"):
21
+ print("We remove ", img_path)
22
+ os.remove(img_path)
scripts/collect_lang.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ THis file is to collect all lang.txt and move to a new directory, this is for the convenience to compress and scp the lang for post-processing
3
+ '''
4
+ import os, sys, shutil
5
+
6
+ if __name__ == "__main__":
7
+ parent_dir = "../datasets_rob"
8
+ dataset_paths = ["Bridge_v1_TT14", "Bridge_v2_TT14"]
9
+ store_folder = "../full_text_tmp"
10
+
11
+ # Manage the store folder
12
+ if os.path.exists(store_folder):
13
+ shutil.rmtree(store_folder)
14
+ os.makedirs(store_folder)
15
+
16
+
17
+ for dataset_name in dataset_paths:
18
+ store_path = os.path.join(store_folder, dataset_name)
19
+ if os.path.exists(store_path):
20
+ shutil.rmtree(store_path)
21
+ os.makedirs(store_path)
22
+
23
+ # Iterate all the files
24
+ for sub_folder_name in os.listdir(os.path.join(parent_dir, dataset_name)):
25
+ print("We are processing ", sub_folder_name)
26
+ lang_txt_path = os.path.join(parent_dir, dataset_name, sub_folder_name, "lang.txt")
27
+
28
+ # Store on the new address
29
+ store_file_path = os.path.join(store_path, sub_folder_name)
30
+ os.makedirs(store_file_path)
31
+ shutil.copyfile(lang_txt_path, os.path.join(store_file_path, "lang.txt"))
scripts/combine_results.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This repo is to combine multiple generated images with same index together
3
+ '''
4
+
5
+ import os, shutil, sys
6
+ import imageio
7
+ import math
8
+ import cv2
9
+ from PIL import Image
10
+ import collections
11
+ import numpy as np
12
+
13
+
14
+ if __name__ == "__main__":
15
+
16
+ # Basic setting
17
+ data_paths = [
18
+ "human_evaluation_v3_V_raw_prompt",
19
+ "human_evaluation_v3_VG_raw_prompt_no_sam",
20
+ "human_evaluation_v3_VL_ambiguous_prompt",
21
+
22
+ "../datasets_rob/Bridge_human_evaluation",
23
+
24
+ "human_evaluation_v3_VL_raw_prompt",
25
+ "human_evaluation_v3_VGL_raw_prompt_no_sam",
26
+ "human_evaluation_v3_VGL_ambiguous_prompt_no_sam",
27
+ ]
28
+ store_path = "combined_results_human_evaluation"
29
+ sample_data_path = data_paths[0]
30
+ gif_per_row = 4 # Number of GIF files per row
31
+
32
+
33
+ # Create folder
34
+ if os.path.exists(store_path):
35
+ shutil.rmtree(store_path)
36
+ os.makedirs(store_path)
37
+
38
+
39
+ # Iterate the sample
40
+ for instance_idx, sub_folder_name in enumerate(sorted(os.listdir(sample_data_path))):
41
+ print("we are processing ", sub_folder_name)
42
+
43
+ collected_gif_paths = []
44
+ for data_path in data_paths:
45
+ collected_gif_paths.append(os.path.join(data_path, sub_folder_name, 'combined.gif'))
46
+
47
+ # Merge frames together
48
+ rows = math.ceil(len(collected_gif_paths) / gif_per_row)
49
+ cols = gif_per_row
50
+
51
+ # Read all input GIFs and find maximum dimensions
52
+ gifs = []
53
+ max_width, max_height = 0, 0
54
+ for path in collected_gif_paths:
55
+ gif = imageio.mimread(path)
56
+ max_width = max(max_width, gif[0].shape[1])
57
+ max_height = max(max_height, gif[0].shape[0])
58
+ gifs.append(gif)
59
+
60
+ # Create blank canvas for concatenated GIF
61
+ frames_length = len(gifs[0])
62
+ canvas_width = max_width * cols
63
+ canvas_height = max_height * rows
64
+ canvas = np.zeros((frames_length, canvas_height, canvas_width, 3), dtype=np.uint8)
65
+
66
+
67
+ # push each frame into the canvas placeholder
68
+ gif_index = 0
69
+ for row in range(rows):
70
+ for col in range(cols):
71
+ gif = gifs[gif_index]
72
+ gif_height, gif_width, _ = gif[0].shape
73
+ start_y = row * max_height
74
+ start_x = col * max_width
75
+ for i in range(frames_length):
76
+ canvas[i, start_y:start_y+gif_height, start_x:start_x+gif_width, :] = gif[i]
77
+
78
+ # Update index
79
+ gif_index += 1
80
+ if gif_index == len(collected_gif_paths):
81
+ break
82
+
83
+
84
+ # Write the concatenated GIF
85
+ imageio.mimsave(os.path.join(store_path, sub_folder_name + ".gif"), canvas, duration=0.05, quality=100)
scripts/compress_gif.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, shutil, sys
2
+ import cv2
3
+ import imageio
4
+ import numpy as np
5
+
6
+
7
+ def compress_gif(sub_folder_path):
8
+
9
+ # Check valid length
10
+ all_files = os.listdir(sub_folder_path)
11
+ num_frames_input = 0
12
+ valid = True
13
+ for file_name in os.listdir(sub_folder_path):
14
+ if file_name.startswith("im_"):
15
+ num_frames_input += 1
16
+ for idx in range(num_frames_input):
17
+ img_path = 'im_' + str(idx) + '.jpg'
18
+ if img_path not in all_files: # Should be sequential existing
19
+ valid = False
20
+ break
21
+ if not valid:
22
+ print("We cannot generate a video because the video is not sequential")
23
+ return False
24
+
25
+
26
+ if num_frames_input == 0:
27
+ print("We cannot generate a video because the input length is 0")
28
+ return False
29
+
30
+ img_lists = []
31
+ for idx in range(num_frames_input):
32
+ img_path = os.path.join(sub_folder_path, "im_" + str(idx) + ".jpg")
33
+ img_lists.append(cv2.resize(cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB), (384, 256)))
34
+
35
+ imageio.mimsave(os.path.join(sub_folder_path, 'combined.gif'), np.array(img_lists), duration=0.05, quality=100)
36
+
37
+ return True
38
+
39
+
40
+ if __name__ == "__main__":
41
+ dataset_path = "../datasets_rob/Bridge_human_evaluation" # ../datasets_rob/Bridge_v1_raw
42
+
43
+ for sub_folder_name in sorted(os.listdir(dataset_path)):
44
+ print("We are processing ", sub_folder_name)
45
+ sub_folder_path = os.path.join(dataset_path, sub_folder_name)
46
+
47
+ status = compress_gif(sub_folder_path)
48
+
49
+
50
+
51
+
52
+
scripts/compress_videos.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, shutil, sys
2
+ from moviepy.editor import ImageSequenceClip
3
+
4
+
5
+ def compress_video(sub_folder_path, video_name):
6
+ store_path = os.path.join(sub_folder_path, video_name)
7
+
8
+ if os.path.exists(store_path):
9
+ os.remove(store_path)
10
+
11
+
12
+ # Check valid length
13
+ all_files = os.listdir(sub_folder_path)
14
+ num_frames_input = 0
15
+ valid = True
16
+ for file_name in os.listdir(sub_folder_path):
17
+ if file_name.startswith("im_"):
18
+ num_frames_input += 1
19
+ for idx in range(num_frames_input):
20
+ img_path = 'im_' + str(idx) + '.jpg'
21
+ if img_path not in all_files: # Should be sequential existing
22
+ valid = False
23
+ break
24
+ if not valid:
25
+ print("We cannot generate a video because the video is not sequential")
26
+ return False
27
+
28
+
29
+ if num_frames_input == 0:
30
+ print("We cannot generate a video because the input length is 0")
31
+ return False
32
+
33
+ img_lists = []
34
+ for idx in range(num_frames_input):
35
+ img_path = os.path.join(sub_folder_path, "im_" + str(idx) + ".jpg")
36
+ img_lists.append(img_path)
37
+
38
+ clip = ImageSequenceClip(img_lists, fps=4)
39
+ clip.write_videofile(store_path)
40
+
41
+ return True
42
+
43
+
44
+ if __name__ == "__main__":
45
+ dataset_path = "../datasets_rob/Bridge_v2_raw" # ../datasets_rob/Bridge_v1_raw
46
+
47
+ for sub_folder_name in sorted(os.listdir(dataset_path)):
48
+ sub_folder_path = os.path.join(dataset_path, sub_folder_name)
49
+
50
+ status = compress_video(sub_folder_path)
51
+
52
+
53
+
54
+
55
+
scripts/crop_video_frames.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file is to split the video sources in a folder to folder with images, for the mass evaluation
3
+ '''
4
+ import os, shutil, sys
5
+ import cv2
6
+
7
+
8
+ if __name__ == "__main__":
9
+ input_folder = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/StreamingT2V_results"
10
+ needed_frame_length = 14
11
+
12
+ idx = 0
13
+ for file_name in sorted(os.listdir(input_folder)):
14
+ print("We are processing ", file_name)
15
+ sub_folder_path = os.path.join(input_folder, file_name)
16
+
17
+ for idx in range(len(os.listdir(sub_folder_path))):
18
+ if idx >= needed_frame_length:
19
+ target_path = os.path.join(sub_folder_path, str(idx)+".png")
20
+ os.remove(target_path)
21
+
22
+
scripts/extract_test_dataset.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Extract the test dataset from the txt file
3
+ '''
4
+
5
+ if __name__ == "__main__":
6
+ txt_path = "match_info_v2.txt"
7
+ store_path = "test_path_v2.txt"
8
+ start_idx = len("/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v2/")
9
+
10
+ read_file = open(txt_path, "r")
11
+ write_file = open(store_path, "w")
12
+ for line in read_file.readlines():
13
+ test_dataset_path = line.split(' ')[1]
14
+ test_instance = test_dataset_path[start_idx:]
15
+
16
+ write_file.write(test_instance)
17
+
18
+
scripts/generate_noise.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+
5
+ # Set the dimensions of the image
6
+ height = 256
7
+ width = 256
8
+
9
+ # Generate random pixel values
10
+ noise = np.random.rand(height, width, 3) * 255 # Scale to 255 for grayscale image
11
+
12
+
13
+ for idx in range (4):
14
+ cv2.imwrite("noise"+str(idx)+".png", noise)
scripts/generate_sam.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, shutil
2
+ import cv2
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
6
+
7
+
8
+ def show_anns(anns):
9
+ if len(anns) == 0:
10
+ return
11
+
12
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
13
+ ax = plt.gca()
14
+ ax.set_autoscale_on(True)
15
+
16
+ img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 3))
17
+ # img[:,:,3] = 0
18
+ for ann in sorted_anns:
19
+ m = ann['segmentation']
20
+ color_mask = np.concatenate([np.random.random(3)])
21
+ img[m] = color_mask
22
+
23
+ return img*255
24
+
25
+
26
+
27
+
28
+ if __name__ == "__main__":
29
+ input_parent_folder = "../Bridge_filter_flow"
30
+
31
+
32
+ # Init SAM for segmentation task
33
+ model_type = "vit_h"
34
+ weight_path = "pretrained/sam_vit_h_4b8939.pth"
35
+
36
+
37
+
38
+ sam = sam_model_registry[model_type](checkpoint=weight_path).to(device="cuda")
39
+ mask_generator = SamAutomaticMaskGenerator(sam) # There is a lot of setting here
40
+
41
+
42
+ for sub_dir_name in sorted(os.listdir(input_parent_folder)):
43
+ print("We are processing ", sub_dir_name)
44
+ ref_img_path = os.path.join(input_parent_folder, sub_dir_name, 'im_0.jpg')
45
+ store_path = os.path.join(input_parent_folder, sub_dir_name, 'sam.png')
46
+
47
+ image = cv2.imread(ref_img_path)
48
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
49
+
50
+ mask = mask_generator.generate(image)
51
+ mask_img = show_anns(mask)
52
+
53
+ cv2.imwrite(store_path, mask_img)
54
+
55
+
56
+
scripts/generate_sam_this_that.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, shutil
2
+ import cv2
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
6
+
7
+
8
+ def show_anns(anns):
9
+ if len(anns) == 0:
10
+ return
11
+
12
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
13
+ ax = plt.gca()
14
+ ax.set_autoscale_on(True)
15
+
16
+ img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 3))
17
+ # img[:,:,3] = 0
18
+ for ann in sorted_anns:
19
+ m = ann['segmentation']
20
+ color_mask = np.concatenate([np.random.random(3)])
21
+ img[m] = color_mask
22
+
23
+ return img*255
24
+
25
+
26
+ def show_mask(mask, random_color=False):
27
+ if random_color:
28
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
29
+ else:
30
+ color = np.array([30/255, 144/255, 255/255, 0.6])
31
+ h, w = mask.shape[-2:]
32
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
33
+
34
+ return mask_image * 255
35
+
36
+
37
+ def show_points(coords, labels, ax, marker_size=375):
38
+ pos_points = coords[labels==1]
39
+ neg_points = coords[labels==0]
40
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
41
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1)
42
+
43
+
44
+ if __name__ == "__main__":
45
+ input_parent_folder = "validation_tmp"
46
+
47
+
48
+ # Init SAM for segmentation task
49
+ model_type = "vit_h"
50
+ weight_path = "pretrained/sam_vit_h_4b8939.pth"
51
+
52
+
53
+
54
+ sam = sam_model_registry[model_type](checkpoint=weight_path).to(device="cuda")
55
+ sam_predictor = SamPredictor(sam)
56
+ mask_generator = SamAutomaticMaskGenerator(sam)
57
+
58
+
59
+ # Iterate the folder
60
+ for sub_dir_name in sorted(os.listdir(input_parent_folder)):
61
+ print("We are processing ", sub_dir_name)
62
+ ref_img_path = os.path.join(input_parent_folder, sub_dir_name, 'im_0.jpg')
63
+ data_txt_path = os.path.join(input_parent_folder, sub_dir_name, 'data.txt')
64
+
65
+
66
+ # Read the image and process
67
+ image = cv2.imread(ref_img_path)
68
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
69
+
70
+
71
+ # Read the positive point
72
+ data_file = open(data_txt_path, 'r')
73
+ lines = data_file.readlines()
74
+ for idx in range(len(lines)):
75
+ frame_idx, horizontal, vertical = lines[idx].split(' ')
76
+ vertical, horizontal = int(float(vertical)), int(float(horizontal))
77
+ positive_point_cords = [[horizontal, vertical]]
78
+
79
+ positive_point_cords = np.array(positive_point_cords)
80
+ positive_point_labels = np.ones(len(positive_point_cords))
81
+ print(positive_point_cords)
82
+
83
+
84
+
85
+ # Set the SAM predictor
86
+ sam_predictor.set_image(np.uint8(image))
87
+ masks, scores, logits = sam_predictor.predict(
88
+ point_coords = positive_point_cords, # Only positive points here
89
+ point_labels = positive_point_labels,
90
+ multimask_output = False,
91
+ )
92
+ # print("Detected mask length is ", len(masks))
93
+
94
+ # Visualize
95
+ mask_img = show_mask(masks[0])
96
+ cv2.imwrite(os.path.join(input_parent_folder, sub_dir_name, "first_contact0.png"), mask_img)
97
+
98
+ break
99
+
100
+
101
+ # SAM all
102
+ sam_all = mask_generator.generate(image)
103
+ all_sam_imgs = show_anns(sam_all)
104
+ cv2.imwrite("sam_all.png", all_sam_imgs)
105
+
106
+
107
+
108
+
scripts/generate_traj.py ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import copy
4
+ import os, shutil
5
+ import imageio
6
+ import cv2
7
+ from PIL import Image, ImageDraw
8
+ import os.path as osp
9
+ import random
10
+ import numpy as np
11
+ import torch.multiprocessing as mp
12
+ from multiprocessing import set_start_method
13
+ import math, time, gc
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import matplotlib.pyplot as plt
17
+ from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
18
+
19
+
20
+ # Import files from the local path
21
+ root_path = os.path.abspath('.')
22
+ sys.path.append(root_path)
23
+ from config.flowformer_config import get_cfg
24
+ from flowformer_code.utils import flow_viz, frame_utils
25
+ from flowformer_code.utils.utils import InputPadder
26
+ from flowformer_code.FlowFormer import build_flowformer
27
+
28
+
29
+
30
+
31
+ TRAIN_SIZE = [432, 960]
32
+
33
+ def show_anns(anns):
34
+ if len(anns) == 0:
35
+ return
36
+
37
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
38
+ ax = plt.gca()
39
+ ax.set_autoscale_on(True)
40
+
41
+ img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
42
+ img[:,:,3] = 0
43
+ for ann in sorted_anns:
44
+ m = ann['segmentation']
45
+ color_mask = np.concatenate([np.random.random(3), [0.35]])
46
+ img[m] = color_mask
47
+
48
+ return img*255
49
+
50
+
51
+ def show_mask(mask, random_color=False):
52
+ if random_color:
53
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
54
+ else:
55
+ color = np.array([30/255, 144/255, 255/255, 0.6])
56
+ h, w = mask.shape[-2:]
57
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
58
+
59
+ return mask_image * 255
60
+
61
+
62
+ def compute_grid_indices(image_shape, patch_size=TRAIN_SIZE, min_overlap=20):
63
+ if min_overlap >= TRAIN_SIZE[0] or min_overlap >= TRAIN_SIZE[1]:
64
+ raise ValueError(
65
+ f"Overlap should be less than size of patch (got {min_overlap}"
66
+ f"for patch size {patch_size}).")
67
+ if image_shape[0] == TRAIN_SIZE[0]:
68
+ hs = list(range(0, image_shape[0], TRAIN_SIZE[0]))
69
+ else:
70
+ hs = list(range(0, image_shape[0], TRAIN_SIZE[0] - min_overlap))
71
+ if image_shape[1] == TRAIN_SIZE[1]:
72
+ ws = list(range(0, image_shape[1], TRAIN_SIZE[1]))
73
+ else:
74
+ ws = list(range(0, image_shape[1], TRAIN_SIZE[1] - min_overlap))
75
+
76
+ # Make sure the final patch is flush with the image boundary
77
+ hs[-1] = image_shape[0] - patch_size[0]
78
+ ws[-1] = image_shape[1] - patch_size[1]
79
+ return [(h, w) for h in hs for w in ws]
80
+
81
+
82
+
83
+ def compute_flow(model, image1, image2, weights=None):
84
+ print(f"computing flow...")
85
+
86
+ image_size = image1.shape[1:]
87
+
88
+ image1, image2 = image1[None].cuda(), image2[None].cuda()
89
+
90
+ hws = compute_grid_indices(image_size)
91
+ if weights is None: # no tile
92
+ padder = InputPadder(image1.shape)
93
+ image1, image2 = padder.pad(image1, image2)
94
+
95
+ flow_pre, _ = model(image1, image2)
96
+
97
+ flow_pre = padder.unpad(flow_pre)
98
+ flow = flow_pre[0].permute(1, 2, 0).cpu().numpy()
99
+ else: # tile
100
+ flows = 0
101
+ flow_count = 0
102
+
103
+ for idx, (h, w) in enumerate(hws):
104
+ image1_tile = image1[:, :, h:h+TRAIN_SIZE[0], w:w+TRAIN_SIZE[1]]
105
+ image2_tile = image2[:, :, h:h+TRAIN_SIZE[0], w:w+TRAIN_SIZE[1]]
106
+ flow_pre, _ = model(image1_tile, image2_tile)
107
+ padding = (w, image_size[1]-w-TRAIN_SIZE[1], h, image_size[0]-h-TRAIN_SIZE[0], 0, 0)
108
+ flows += F.pad(flow_pre * weights[idx], padding)
109
+ flow_count += F.pad(weights[idx], padding)
110
+
111
+ flow_pre = flows / flow_count
112
+ flow = flow_pre[0].permute(1, 2, 0).cpu().numpy()
113
+
114
+ return flow
115
+
116
+
117
+ def compute_adaptive_image_size(image_size):
118
+ target_size = TRAIN_SIZE
119
+ scale0 = target_size[0] / image_size[0]
120
+ scale1 = target_size[1] / image_size[1]
121
+
122
+ if scale0 > scale1:
123
+ scale = scale0
124
+ else:
125
+ scale = scale1
126
+
127
+ image_size = (int(image_size[1] * scale), int(image_size[0] * scale))
128
+
129
+ return image_size
130
+
131
+
132
+ def prepare_image(viz_root_dir, fn1, fn2, keep_size):
133
+ print(f"preparing image...")
134
+
135
+ image1 = frame_utils.read_gen(fn1)
136
+ image2 = frame_utils.read_gen(fn2)
137
+ image1 = np.array(image1).astype(np.uint8)[..., :3]
138
+ image2 = np.array(image2).astype(np.uint8)[..., :3]
139
+ if not keep_size:
140
+ dsize = compute_adaptive_image_size(image1.shape[0:2])
141
+ image1 = cv2.resize(image1, dsize=dsize, interpolation=cv2.INTER_CUBIC)
142
+ image2 = cv2.resize(image2, dsize=dsize, interpolation=cv2.INTER_CUBIC)
143
+ image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
144
+ image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
145
+
146
+
147
+ dirname = osp.dirname(fn1)
148
+ filename = osp.splitext(osp.basename(fn1))[0]
149
+
150
+ viz_dir = osp.join(viz_root_dir, dirname)
151
+ # if not osp.exists(viz_dir):
152
+ # os.makedirs(viz_dir)
153
+
154
+ viz_fn = osp.join(viz_dir, filename + '.png')
155
+
156
+ return image1, image2, viz_fn
157
+
158
+
159
+ def build_model():
160
+ print(f"building model...")
161
+ cfg = get_cfg()
162
+ model = torch.nn.DataParallel(build_flowformer(cfg))
163
+ model.load_state_dict(torch.load(cfg.model))
164
+
165
+ model.cuda()
166
+ model.eval()
167
+
168
+ return model
169
+
170
+
171
+ def filter_uv(flow, threshold_factor = 0.2):
172
+ u = flow[:,:,0]
173
+ v = flow[:,:,1]
174
+
175
+ rad = np.sqrt(np.square(u) + np.square(v))
176
+ rad_max = np.max(rad)
177
+
178
+ threshold = threshold_factor * rad_max
179
+ flow[:,:,0][rad < threshold] = 0
180
+ flow[:,:,1][rad < threshold] = 0
181
+
182
+ return flow
183
+
184
+
185
+ def visualize_traj(base_img, traj_path, connect_points = True):
186
+ target_vertical, target_horizontal = traj_path[-1]
187
+
188
+ if connect_points and len(traj_path) > 1:
189
+ # Draw a line to connect two point to show motion direction
190
+ start_coordinate = (traj_path[-2][1], traj_path[-2][0])
191
+ end_coordinate = (traj_path[-1][1], traj_path[-1][0])
192
+ pil_img = Image.fromarray(base_img)
193
+
194
+ # Draw the line
195
+ color = 'red'
196
+ draw = ImageDraw.Draw(pil_img)
197
+ draw.line([start_coordinate, end_coordinate], fill = color, width = 3)
198
+
199
+ base_img = np.array(pil_img)
200
+
201
+
202
+ # Draw a green dot only for the start point
203
+ if len(traj_path) == 1:
204
+ dot_range = 3
205
+ for i in range(-1*dot_range, dot_range+1):
206
+ for j in range(-1*dot_range, dot_range+1):
207
+ dil_vertical, dil_horizontal = target_vertical + i, target_horizontal + j
208
+ if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]):
209
+ base_img[dil_vertical][dil_horizontal] = [0, 128, 0]
210
+ else:
211
+ print("The traj is out of boundary!!!!!!!!!!!!!!!!!!!!! and we won't consider it") # 现在
212
+ return (False, base_img)
213
+
214
+ return (True, base_img)
215
+
216
+
217
+
218
+ def calculate_flow(viz_root_dir, store_dir, img_pairs, optical_flow_model, sam_predictor, SAM_positive_sample_num, SAM_negative_sample_num, mask_generator, traj_visualization, keep_size, verbose=False):
219
+
220
+ # Trajectory prepare
221
+ traj_path = [] # It collects all points traversed in a temporal order
222
+ is_hard_to_track = False # If this is True, it means that, we have a time in tracking hard to find dx and dy movement. Under this circumstance, we are not very recommended to use it
223
+ hard_track_idxs = set()
224
+ traj_image_lists = []
225
+
226
+
227
+ # Iterate all image pairs
228
+ for idx, img_pair in enumerate(img_pairs):
229
+
230
+ fn1, fn2 = img_pair
231
+ print(f"processing {fn1}, {fn2}...")
232
+
233
+ image1, image2, viz_fn = prepare_image(viz_root_dir, fn1, fn2, keep_size) # Be very careful, image1 and image2 may be different resolution shape if keep_size is False
234
+ # Generate the optical flow and filter those that is small motion
235
+ flow_uv = filter_uv(compute_flow(optical_flow_model, image1, image2, None))
236
+
237
+ # if verbose:
238
+ # Store the visualization of flow_uv
239
+ # flow_img = flow_viz.flow_to_image(flow_uv)
240
+ # cv2.imwrite("optical_flow_" + str(idx+1) + ".png", flow_img[:, :, [2,1,0]])
241
+
242
+ if idx == 0:
243
+ # We will store the first image to memory for further visualization purpose
244
+
245
+ # Base img
246
+ # base_img = np.uint8(np.transpose(image1.numpy(), (1,2,0)))
247
+
248
+ # SAM figure
249
+ # sam_all = mask_generator.generate(image1)
250
+ # base_img = show_anns(sam_all)
251
+ # base_img = np.transpose(base_img, (1,2,0))
252
+
253
+ # Plain white image
254
+ base_img = np.zeros(np.transpose(image1.numpy(), (1,2,0)).shape, dtype=np.uint8)
255
+ base_img.fill(255)
256
+
257
+
258
+
259
+
260
+ # Extract moving points (positive point)
261
+ positive_point_cords = []
262
+ nonzeros = np.nonzero(flow_uv) # [(vertical), (horizontal)]
263
+ if len(nonzeros[0]) < SAM_positive_sample_num:
264
+ # We require the number of points to be more than SAM_positive_sample_num
265
+ return False
266
+ positive_orders = np.random.choice(len(nonzeros[0]), SAM_positive_sample_num, replace=False) # we have randomly select instead of use all in the sam_predictor prediction
267
+ for i in range(len(nonzeros[0])):
268
+ if i in positive_orders:
269
+ positive_point_cords.append([nonzeros[1][i], nonzeros[0][i]]) # 根据document来看,这个就应该是先horizontal再vertical,也就是这个顺序
270
+ positive_point_cords = np.array(positive_point_cords)
271
+ positive_point_labels = np.ones(len(positive_point_cords))
272
+
273
+
274
+ # Define negative sample (outside the optical flow choice)
275
+ if SAM_negative_sample_num != 0:
276
+ skip_prob = 2 * SAM_negative_sample_num / (flow_uv.shape[0]*flow_uv.shape[1] - len(nonzeros[0]))
277
+ negative_point_cords = []
278
+ for i in range(flow_uv.shape[0]):
279
+ for j in range(flow_uv.shape[1]):
280
+ if flow_uv[i][j][0] == 0 and flow_uv[i][j][1] == 0: # 0 means the no motion zone and we have already filter low motion as zero before
281
+ if random.random() < skip_prob:
282
+ negative_point_cords.append([j, i]) # 根据document来看,这个就应该是先horizontal再vertical,也就是这个顺序
283
+ negative_point_cords = np.array(negative_point_cords) # [:SAM_negative_sample_num]
284
+ negative_point_labels = np.zeros(len(negative_point_cords)) # Make sure that it is less than / equals to SAM_negative_sample_num quantity
285
+
286
+
287
+
288
+ ################## Use SAM to filter out what we need (& use negative points) ##################
289
+ if idx == 0: # Only consider the first frame now.
290
+ # With sample coordinate
291
+ sam_predictor.set_image(np.uint8(np.transpose(image1.numpy(), (1,2,0))))
292
+ if SAM_negative_sample_num != 0 and len(negative_point_cords) != 0:
293
+ all_point_cords = np.concatenate((positive_point_cords, negative_point_cords), axis=0)
294
+ all_point_labels = np.concatenate((positive_point_labels, negative_point_labels), axis=0)
295
+ else:
296
+ all_point_cords = positive_point_cords
297
+ all_point_labels = positive_point_labels
298
+
299
+ masks, scores, logits = sam_predictor.predict(
300
+ point_coords=all_point_cords,
301
+ point_labels=all_point_labels,
302
+ multimask_output=False,
303
+ )
304
+ mask = masks[0] # TODO: 一定要确定我们这里选择了最大的mask,而没有考虑的第二大和其他的, 这里可能有bug,我们默认了第一个就是最大的mask
305
+ # if verbose:
306
+ # cv2.imwrite("mask_"+str(idx+1)+".png", (np.uint8(mask)*255))
307
+ # annotated_img = show_mask(mask)
308
+ # cv2.imwrite("annotated.png", annotated_img)
309
+
310
+
311
+ ################## Choose the one we need as the reference for the future tracking ##################
312
+ # Choose a random point in the mask
313
+ target_zone = np.nonzero(mask) # [(vertical), (horizontal)]
314
+ target_zone = [(target_zone[0][i], target_zone[1][i]) for i in range(len(target_zone[0]))] # Now, the sturcture is [(vertical, horizontal), ...]
315
+
316
+ repeat_time = 0
317
+ loop2find = True
318
+ while loop2find:
319
+ loop2find = False
320
+ start_point = target_zone[np.random.choice(len(target_zone), 1, replace=False)[0]]
321
+ start_vertical, start_horizontal = start_point
322
+
323
+ repeat_time += 1
324
+ if repeat_time == 100:
325
+ # In some minor case, it may have infinite loop, so we need to manually break if it is looping
326
+ print("We are still hard to find a optimal first point, but we cannot let it loop")
327
+ break
328
+
329
+ # Try to choose a start_point that is more centralized (Not close to the border)
330
+ fast_break = False
331
+ for i in range(-15, 15):
332
+ for j in range(-15, 15):
333
+ dil_vertical, dil_horizontal = start_vertical + i, start_horizontal + j
334
+ if (0 <= dil_vertical and dil_vertical < mask.shape[0]) and (0 <= dil_horizontal and dil_horizontal < mask.shape[1]):
335
+ if mask[dil_vertical][dil_horizontal] == 0:
336
+ print("We need to change to a new position for the start p Since this one is close to the border of the object...........")
337
+ loop2find = True
338
+ fast_break = True
339
+ break
340
+ else:
341
+ # We won't want to consider those that is close to the boundary
342
+ print("We need to change to a new position Since this one is close to the border of the image...........")
343
+ loop2find = True
344
+ fast_break = True
345
+ break
346
+ if fast_break:
347
+ break
348
+ traj_path.append(start_point)
349
+
350
+ status, base_img = visualize_traj(base_img, traj_path)
351
+ if status == False: # If the traj is False, we won't consider it anymore.
352
+ file = open("log.txt", "a")
353
+ file.write("Invalid start point\n")
354
+ return False
355
+
356
+ # Read from the last one in traj
357
+ ref_vertical, ref_horizontal = traj_path[-1][0], traj_path[-1][1]
358
+
359
+
360
+ # Get the average motion vector for point surrounding (8+1 directions) the ref_point; This is because this is the most accurate statistics
361
+ horizon_lists, vertical_lists = [], []
362
+ start_range, end_range = -5, 5
363
+
364
+ # Calculate the average motion based on surrounding motion
365
+ search_times = 0
366
+ while len(horizon_lists) == 0: # If we cannot find a direction, we use average value inside this mask, but we will flag it.
367
+ search_times += 1
368
+
369
+ if search_times > 1:
370
+ print("This is hard to track!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! and we have tracked " + str(search_times) + " times")
371
+ # TODO: 如果out of boundary那种,search times到了8-10次的就砍掉那后面frame吧,这种非常inaccurate了, 你也可以retrack一个新的点,但是没有什么意义,看整体数量来定吧
372
+ is_hard_to_track = True
373
+ hard_track_idxs.add(idx)
374
+
375
+ if abs(start_range) >= flow_uv.shape[0]//2:
376
+ file = open("log.txt", "a")
377
+ file.write("This folder has search all space but didn't find any place to track optical flow\n")
378
+ return False # If we have already search for the whole graph but didn't find anything to track, we discard this sample
379
+
380
+ # Search for a larger space which is nearby 我觉得扩大搜索范围应该是最稳定的选择吧
381
+ for i in range(start_range, end_range):
382
+ for j in range(start_range, end_range):
383
+ target_vertical, target_horizontal = ref_vertical + i, ref_horizontal + j
384
+ if 0 <= target_vertical and target_vertical < flow_uv.shape[0] and 0 <= target_horizontal and target_horizontal < flow_uv.shape[1]:
385
+ if flow_uv[target_vertical, target_horizontal, 0] == 0 or flow_uv[target_vertical, target_horizontal, 1] == 0:
386
+ continue # Ignore zero vector to ensure only calculate moving position
387
+ horizon_lists.append(flow_uv[target_vertical, target_horizontal, 0]) # Horizontal motion strength
388
+ vertical_lists.append(flow_uv[target_vertical, target_horizontal, 1]) # Vertical motion strength
389
+
390
+ # If there isn't any to search, we kepp on a larger space
391
+ start_range -= 10
392
+ end_range += 10
393
+
394
+ average_dx = sum(horizon_lists)/len(horizon_lists)
395
+ average_dy = sum(vertical_lists)/len(vertical_lists)
396
+ print("average movement is ", (average_dx, average_dy))
397
+ traj_path.append(( int(traj_path[-1][0] + average_dy), int(traj_path[-1][1] + average_dx))) # Append the motion in independent order
398
+
399
+ print(traj_path)
400
+
401
+
402
+ ##################### Visualize the trajectory path (Debug Purpose) #####################
403
+ status, base_img = visualize_traj(base_img, traj_path)
404
+ if status == False: # If the traj is False, we won't consider it anymore.
405
+ return False
406
+
407
+ cv2.imwrite(os.path.join(store_dir, "traj_path.png"), cv2.cvtColor(base_img, cv2.COLOR_BGR2RGB))
408
+
409
+ if traj_visualization:
410
+ status, single_traj_img = visualize_traj(np.uint8(np.transpose(image1.numpy(), (1,2,0))), traj_path[:-1], connect_points=False)
411
+ if status == False: # If the traj is False, we won't consider it anymore.
412
+ return False
413
+
414
+ traj_write_path = os.path.join(store_dir, "traj_"+str(idx)+".png")
415
+ # cv2.imwrite(traj_write_path, cv2.cvtColor(single_traj_img, cv2.COLOR_BGR2RGB))
416
+ traj_image_lists.append(traj_write_path)
417
+
418
+
419
+ # if traj_visualization:
420
+ # images = []
421
+ # for filename in traj_image_lists:
422
+ # images.append(imageio.imread(filename))
423
+ # # os.remove(filename) # Remove when used
424
+ # imageio.mimsave(os.path.join(store_dir, 'traj_motion.gif'), images, duration=0.05)
425
+
426
+
427
+ # TODO: 可以如果hard to track,就aggressivly多试即便,我们根据这个hard_track_idxs的长度来粗略判断哪个最好,三次里面选最好的
428
+ if is_hard_to_track:
429
+ if len(hard_track_idxs) >= len(img_pairs)//3: # If more than half of the traj is hard to track, we need to consider discard this one
430
+ file = open("log.txt", "a")
431
+ file.write("we have a lot of times hard to find dx and dy movement. Under this circumstance, we are not very recommended to use the track\n")
432
+ return False
433
+
434
+
435
+ # Write a file store all position for further utilization
436
+ txt_path = os.path.join(store_dir, "traj_data.txt")
437
+ if os.path.exists(txt_path):
438
+ os.remove(txt_path)
439
+ file = open(txt_path, "a")
440
+ for traj in traj_path:
441
+ file.write(str(traj[0]) + " " + str(traj[1]) + "\n")
442
+ # Save in numpy information
443
+ # with open(os.path.join(store_dir, 'traj_data.npy'), 'wb') as f:
444
+ # np.save(f, flow_uv)
445
+ print("We write ", traj_path)
446
+ return True
447
+
448
+
449
+
450
+ def manage_seq_range(input_dir, store_dir, total_frame_needed):
451
+
452
+ lists = os.listdir(input_dir)
453
+ lists = lists[2:-2]
454
+ num_frames_input = len(lists)
455
+
456
+ if num_frames_input < total_frame_needed:
457
+ print("The number of frames is too short for constructing the sequnece length needed")
458
+ return False
459
+
460
+
461
+ division_factor = num_frames_input // total_frame_needed
462
+ remain_frame = num_frames_input % total_frame_needed
463
+
464
+ gaps = [division_factor for _ in range(total_frame_needed)]
465
+ for idx in range(remain_frame):
466
+ gaps[idx] += 1
467
+
468
+
469
+ cur_idx = 2
470
+ for global_idx, gap in enumerate(gaps):
471
+ source_path = os.path.join(input_dir, "im_"+str(cur_idx)+".jpg")
472
+ destination_path = os.path.join(store_dir, "im_"+str(global_idx)+".jpg")
473
+
474
+ shutil.copyfile(source_path, destination_path)
475
+ cur_idx += gap
476
+
477
+ return True
478
+
479
+
480
+ def generate_pairs(dirname, start_idx, end_idx):
481
+ img_pairs = []
482
+ for idx in range(start_idx, end_idx):
483
+ img1 = osp.join(dirname, f'im_{idx}.jpg')
484
+ img2 = osp.join(dirname, f'im_{idx+1}.jpg')
485
+ # img1 = f'{idx:06}.png'
486
+ # img2 = f'{idx+1:06}.png'
487
+ img_pairs.append((img1, img2))
488
+
489
+ return img_pairs
490
+
491
+
492
+ def process_partial_request(request_list, num_frames, traj_visualization, viz_root_dir):
493
+
494
+
495
+ # Init the optical flow model
496
+ optical_flow_model = build_model()
497
+
498
+ # Init SAM for segmentation task
499
+ model_type = "vit_h"
500
+ weight_path = "pretrained/sam_vit_h_4b8939.pth"
501
+ SAM_positive_sample_num = 20 # How many points we use for the positive sample num ()
502
+ SAM_negative_sample_num = 0 # How many points we use for the negative sample num
503
+
504
+ print("In multi processing, we will build an instance of mask_generator independently")
505
+ sam = sam_model_registry[model_type](checkpoint=weight_path).to(device="cuda")
506
+ mask_generator = SamAutomaticMaskGenerator(sam)
507
+ print("In multi processing, we will build an instance of sam_predictor independently")
508
+ sam_predictor = SamPredictor(sam)
509
+
510
+
511
+ counter = 0
512
+ while True:
513
+ counter += 1
514
+ if counter == 10:
515
+ counter = 0
516
+ gc.collect()
517
+ print("We will sleep here to clear memory")
518
+ time.sleep(5)
519
+ info = request_list[0]
520
+ request_list = request_list[1:]
521
+ if info == None:
522
+ print("This queue ends")
523
+ break
524
+
525
+
526
+ # Process each sub_input_dir and store the information there
527
+ sub_input_dir = info
528
+
529
+
530
+ img_pairs = generate_pairs(sub_input_dir, 0, num_frames-1)
531
+ print(img_pairs)
532
+
533
+ with torch.no_grad():
534
+
535
+ # Calculate the optical flow and return a status to say whther this generated flow is usable
536
+ status = calculate_flow(viz_root_dir, sub_input_dir, img_pairs, optical_flow_model, sam_predictor, SAM_positive_sample_num, SAM_negative_sample_num,
537
+ mask_generator, traj_visualization, keep_size = True)
538
+
539
+ # file = open("log.txt", "a")
540
+ print("The status for folder " + sub_input_dir + " is " + str(status) + "\n")
541
+
542
+ if status == False:
543
+ # If the status is failed, we will remove it afterwords
544
+ print("The status is Failed, so we won't store this one as one promising data")
545
+ else:
546
+ print("We have successfully process one!")
547
+
548
+
549
+ if __name__ == '__main__':
550
+
551
+ # Manage the paramter
552
+ parser = argparse.ArgumentParser()
553
+ parser.add_argument('--input_dir', default = '../validation_flow14/')
554
+ parser.add_argument('--num_workers', type = int, default = 1) # starting index of the image sequence
555
+ parser.add_argument('--viz_root_dir', default = 'viz_results')
556
+ parser.add_argument('--traj_visualization', default = True) # If this is True,
557
+
558
+ # list_start = 0
559
+ # list_end = 25000
560
+ num_frames = 14
561
+
562
+ args = parser.parse_args()
563
+ input_dir = args.input_dir
564
+ num_workers = args.num_workers
565
+ viz_root_dir = args.viz_root_dir
566
+ traj_visualization = args.traj_visualization
567
+
568
+
569
+
570
+ store_idx = 0
571
+ dir_list = []
572
+ for sub_input_name in sorted(os.listdir(input_dir)):
573
+ sub_input_dir = os.path.join(input_dir, sub_input_name)
574
+ # sub_store_dir = os.path.join(store_dir, "0"*(7-len(str(store_idx)))+str(store_idx))
575
+ store_idx += 1
576
+ dir_list.append(sub_input_dir)
577
+
578
+ # Truncate the list to the target
579
+ # dir_list = dir_list[list_start:]
580
+
581
+
582
+ # Use multiprocessing to handle to speed up
583
+ num = math.ceil(len(dir_list) / num_workers)
584
+ for idx in range(num_workers):
585
+ # set_start_method('spawn', force=True)
586
+
587
+ request_list = dir_list[:num]
588
+ request_list.append(None)
589
+ dir_list = dir_list[num:]
590
+
591
+
592
+ process_partial_request(request_list, num_frames, traj_visualization, viz_root_dir) # This is for debug purpose
593
+ # p = mp.Process(target=process_partial_request, args=(request_list, num_frames, traj_visualization, viz_root_dir, ))
594
+ # p.start()
595
+
596
+ print("Submitted all jobs!")
597
+ # p.join() # 好像不加这个multiprocess就莫名自己结束了
598
+ print("All task finished!")
599
+
600
+
601
+
scripts/interpolate_by_repeat.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file is trying to repeat the frames such the it reaches target frames needed
3
+ '''
4
+ import os, shutil, sys
5
+
6
+ if __name__ == "__main__":
7
+ input_path = "/nfs/turbo/coe-jjparkcv/boyangwa/AVDC/AVDC_results"
8
+ store_path = "/nfs/turbo/coe-jjparkcv/boyangwa/AVDC/AVDC_results_interpolated"
9
+ total_frames_needed = 14
10
+
11
+ # Handle the file folder management
12
+ if os.path.exists(store_path):
13
+ shutil.rmtree(store_path)
14
+ os.makedirs(store_path)
15
+
16
+ for video_name in sorted(os.listdir(input_path)):
17
+ sub_input_path = os.path.join(input_path, video_name)
18
+ sub_store_path = os.path.join(store_path, video_name)
19
+
20
+ # Create the store place
21
+ os.makedirs(sub_store_path)
22
+
23
+ # Find valid image lists
24
+ num_frames_input = 0
25
+ for file_name in os.listdir(sub_input_path):
26
+ if file_name.endswith("png"):
27
+ num_frames_input += 1
28
+ print("num_frames_input is ", num_frames_input)
29
+
30
+ # Calculate needed parameters
31
+ division_factor = total_frames_needed // num_frames_input
32
+ remain_frames = (total_frames_needed % num_frames_input) - 1 # -1 for adaptation
33
+
34
+ # Define the gap
35
+ gaps = [division_factor for _ in range(num_frames_input)]
36
+ for idx in range(remain_frames):
37
+ if idx % 2 == 0:
38
+ gaps[idx//2] += 1 # Start to end order
39
+ else:
40
+ gaps[-1*(1+(idx//2))] += 1 # End to start order
41
+
42
+ print("gaps is ", gaps)
43
+
44
+
45
+ # Write to the new folder
46
+ store_idx = 0
47
+ for frame_idx, gap in enumerate(gaps):
48
+ for tmp in range(gap): # Repeat copy gap num of times
49
+ img_path = os.path.join(sub_input_path, str(frame_idx)+".png")
50
+ shutil.copyfile(img_path, os.path.join(sub_store_path, str(store_idx)+".png"))
51
+ store_idx += 1
52
+
53
+
54
+
55
+
scripts/length_stats.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, shutil
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+
5
+
6
+ if __name__ == "__main__":
7
+ input_folder_path = "../Bridge_v2"
8
+
9
+ average_length = []
10
+
11
+ # Iterate each file
12
+ for sub_folder_name in sorted(os.listdir(input_folder_path)):
13
+ sub_folder_path = os.path.join(input_folder_path, sub_folder_name)
14
+
15
+ average_length.append(len(os.listdir(sub_folder_path))) # Have more than one than expected, but we keep this
16
+ print("average length of {} is {}".format(sub_folder_name, average_length[-1]))
17
+
18
+ print("average_movement_list is ", average_length)
19
+ n, bins, patches = plt.hist(average_length, bins=100)
20
+ plt.savefig("dataset_length2.png")
21
+
scripts/motion_stats.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, shutil
2
+ import numpy as np
3
+ import math
4
+ from statistics import mean
5
+ import matplotlib.pyplot as plt
6
+
7
+
8
+ if __name__ == "__main__":
9
+ input_folder_paths = ["../datasets_rob/Bridge_v1_raw", "../datasets_rob/Bridge_v2_raw"] # "../datasets_rob/Bridge_v1_raw", "../datasets_rob/Bridge_v2_raw"
10
+ num_frames = 14
11
+ store_name = "movement.png"
12
+
13
+
14
+ average_movement_list = []
15
+ not_valid_num = 0
16
+ not_exists_num = 0
17
+ # Iterate each file
18
+ for input_folder_path in input_folder_paths:
19
+ for sub_folder_name in sorted(os.listdir(input_folder_path)):
20
+ sub_folder_path = os.path.join(input_folder_path, sub_folder_name)
21
+ flow_path = os.path.join(sub_folder_path, 'flow.txt')
22
+
23
+ if not os.path.exists(flow_path):
24
+ not_exists_num += 1
25
+ continue
26
+
27
+
28
+ # Read the movement
29
+ file = open(flow_path, 'r')
30
+ info = file.readlines()
31
+ print(info)
32
+ if len(info) == 0:
33
+ not_valid_num += 1
34
+ continue
35
+ info = info[0][:-2]
36
+ per_video_movement = float(info)
37
+
38
+
39
+ # Calculate the number of frames in this video
40
+ num_frames_input = 0
41
+ valid = True
42
+ for file_name in os.listdir(sub_folder_path): # num_frames_input is the total number of files with name begin with im_
43
+ if file_name.startswith("im_"):
44
+ num_frames_input += 1
45
+ for idx in range(num_frames_input): # Ensure that this number is concurrent
46
+ img_path = os.path.join(sub_folder_path, 'im_' + str(idx) + '.jpg')
47
+ if not os.path.exists(img_path): # Should be sequential existing
48
+ valid = False
49
+ break
50
+ if num_frames_input < 2:
51
+ valid = False
52
+ if not valid:
53
+ not_valid_num += 1
54
+ print("This is not valid path")
55
+ continue
56
+
57
+ average_movement_list.append(per_video_movement * (num_frames_input/num_frames)) # Have more than one than expected, but we keep this
58
+ print("average movement of {} is {}".format(sub_folder_name, average_movement_list[-1]))
59
+
60
+ print("not_exists_num is ", not_exists_num)
61
+ print("not_valid_num is ", not_valid_num)
62
+ print("average_movement_list length is ", len(average_movement_list))
63
+
64
+ # Get mean and variance data
65
+ mean_value = mean(average_movement_list)
66
+ std_value = math.sqrt(np.var(average_movement_list))
67
+ print("Mean is ", mean_value)
68
+ print("std_value is ", std_value)
69
+
70
+ # Plot the figure
71
+ n, bins, patches = plt.hist(average_movement_list, bins=100)
72
+ plt.title("Mean" + str(mean_value) + "_STD"+str(std_value))
73
+ plt.savefig(store_name)
74
+
75
+
scripts/process_llama.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Process the llama file for the next step
3
+ '''
4
+ import os, shutil, sys
5
+ import json
6
+ import pandas as pd
7
+ import collections
8
+
9
+
10
+ if __name__ == "__main__":
11
+
12
+ # Define important path
13
+ json_path = "../SVD1/v1.jsonl"
14
+ folder_path = "/home/kiteret/Desktop/StableVideoDiffusion/full_text_tmp/"
15
+
16
+
17
+ # Read the json file
18
+ with open(json_path, 'r') as json_file:
19
+ json_list = list(json_file)
20
+
21
+ # Iterate all the json files
22
+ length_stats = collections.defaultdict(int)
23
+ for json_info in json_list:
24
+ json_info = json.loads(json_info)
25
+
26
+
27
+ # Define the path to write
28
+ key_start = len("/home/chfeng/llama3/full_text_tmp/")
29
+ key_end = len("lang.txt")
30
+ sub_path = json_info["file_path"][key_start:int(-1*key_end)]
31
+ new_text_path = os.path.join(folder_path, sub_path, "processed_text.txt")
32
+ if os.path.exists(new_text_path):
33
+ os.remove(new_text_path)
34
+
35
+
36
+ # Sanity check for the case where input is missed
37
+ if json_info["input"] == "":
38
+ print("It is weird for the input is empty in the LLM process for ", sub_path)
39
+ continue
40
+
41
+
42
+ # Re-Define the content
43
+ outputs = json_info["output"]
44
+ if outputs.find("action:") != 0:
45
+ print("It is weird for no actions: keyword in the outputs for ", sub_path, " with prompt ", outputs)
46
+ continue
47
+
48
+ # Prepare write file
49
+ contents = outputs.split('\n')
50
+ f = open(new_text_path, "a")
51
+
52
+ # Itearte
53
+ effective_length = 0
54
+ for idx, content in enumerate(contents):
55
+ key_word = content.split(":")[1][1:]
56
+ if key_word != "":
57
+ effective_length += 1
58
+ else:
59
+ if idx == 1:
60
+ print("It is abnormal for the this content to be empty ", sub_path, " with prompt ", outputs)
61
+ f.write(key_word + "\n")
62
+ # if effective_length == 2:
63
+ # print("short prompt case is ", sub_path, " with prompt ", outputs)
64
+ if effective_length < 2: # For those only 1 or zero, we won't consider them
65
+ print("The prompt is too short for ", sub_path, " with prompt ", outputs)
66
+ os.remove(new_text_path)
67
+
68
+ length_stats[effective_length] += 1
69
+
70
+ print("length_stats is ", length_stats)
71
+
72
+
73
+
74
+
scripts/process_sim.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This is a script to processs Mark's data.
3
+ '''
4
+ import os, sys, shutil
5
+
6
+ if __name__ == "__main__":
7
+ file_path = "/nfs/turbo/coe-jjparkcv/datasets/isaac-gym-pick-place/full/dataset_v3_proc"
8
+ store_path = "../datasets_rob/sim_raw"
9
+ most_descriptive_prompt_idx = 6 # Start from the 0
10
+
11
+
12
+ # Folder management
13
+ if os.path.exists(store_path):
14
+ shutil.rmtree(store_path)
15
+ os.makedirs(store_path)
16
+
17
+ # Check length
18
+ file_names = os.listdir(file_path)
19
+ target_length = len(file_names) // 10 # 10 files as a cycle
20
+
21
+
22
+ for idx in range(target_length):
23
+ sub_folder_path = os.path.join(file_path, "run_"+str(10*idx))
24
+ if not os.path.exists(sub_folder_path):
25
+ continue
26
+
27
+ # Prepare the target position
28
+ sub_store_path = os.path.join(store_path, str(idx))
29
+ os.makedirs(sub_store_path)
30
+
31
+ # Find the key prompt to read it
32
+ prompt_content = []
33
+ for tmp_idx in range(10):
34
+ tmp_text_path = os.path.join(file_path, "run_"+str(10*idx + tmp_idx), "lang.txt") # Usually, the 6th is the most concrete version
35
+ if not os.path.exists(tmp_text_path):
36
+ continue
37
+ file = open(tmp_text_path, 'r')
38
+ prompt_content.append(file.readlines()[0])
39
+ file.close()
40
+ print("prompt_content we have num ", len(prompt_content))
41
+
42
+
43
+
44
+ # Copy the image into the target position and copy the data.txt
45
+ for file_name in os.listdir(sub_folder_path):
46
+ if file_name == "lang.txt":
47
+ continue
48
+ shutil.copyfile(os.path.join(sub_folder_path, file_name), os.path.join(sub_store_path, file_name))
49
+
50
+ # Handle the lang.txt
51
+ target_lang_txt_path = os.path.join(sub_store_path, "lang.txt")
52
+ f = open(target_lang_txt_path, "a")
53
+ f.write(prompt_content[most_descriptive_prompt_idx]+"\n")
54
+ for tmp_idx in range(10):
55
+ if tmp_idx == most_descriptive_prompt_idx:
56
+ continue
57
+ f.write(prompt_content[tmp_idx]+"\n")
58
+ f.close()
59
+
scripts/resize_img.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, shutil
2
+ import cv2
3
+
4
+ if __name__ == "__main__":
5
+ input_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/resize"
6
+ output_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/resize_resized"
7
+
8
+ if os.path.exists(output_path):
9
+ shutil.rmtree(output_path)
10
+ os.makedirs(output_path)
11
+
12
+ for img_name in os.listdir(input_path):
13
+ img_path = os.path.join(input_path, img_name)
14
+ img = cv2.imread(img_path)
15
+ img = cv2.resize(img, (384, 256))
16
+ store_path = os.path.join(output_path, img_name)
17
+ cv2.imwrite(store_path, img)
scripts/resize_video_seq.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This file is designed to resize the video sequence to the target resolution
3
+ '''
4
+ import os, sys, shutil
5
+ import cv2
6
+
7
+ if __name__ == "__main__":
8
+ input_folder = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/model_results/SVD_results"
9
+ store_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/model_results/SVD_results_resized"
10
+ target_height, target_width = 256, 384
11
+
12
+ if os.path.exists(store_path):
13
+ shutil.rmtree(store_path)
14
+ os.makedirs(store_path)
15
+
16
+ for video_name in sorted(os.listdir(input_folder)):
17
+ print("We are processing ", video_name)
18
+ sub_video_folder = os.path.join(input_folder, video_name)
19
+ sub_store_folder = os.path.join(store_path, video_name)
20
+ os.makedirs(sub_store_folder)
21
+
22
+ for img_name in os.listdir(sub_video_folder):
23
+ if not img_name.endswith("jpg") and not img_name.endswith("png"):
24
+ continue
25
+
26
+ img_path = os.path.join(sub_video_folder, img_name)
27
+ store_img_path = os.path.join(sub_store_folder, img_name)
28
+ img = cv2.imread(img_path)
29
+
30
+ # Resize
31
+ img = cv2.resize(img, (target_width, target_height))
32
+ cv2.imwrite(store_img_path, img)
33
+
scripts/train_test_split.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, shutil
2
+ import random
3
+
4
+
5
+ if __name__ == "__main__":
6
+ base_dataset_path = "../datasets_rob/Bridge_v1_raw"
7
+ test_store_path = "../datasets_rob/Bridge_v1_test_raw"
8
+ split_ratio = 0.1 # [0, 1] range
9
+
10
+ # Prepare the folder
11
+ if os.path.exists(test_store_path):
12
+ shutil.rmtree(test_store_path)
13
+ os.makedirs(test_store_path)
14
+
15
+ full_img_lists = os.listdir(base_dataset_path)
16
+ random.shuffle(full_img_lists)
17
+ target_test_length = int(len(full_img_lists) * split_ratio)
18
+ test_img_lists = full_img_lists[-1 * target_test_length : ]
19
+
20
+ # Move the lists based on test_img_lists
21
+ for test_img_name in test_img_lists:
22
+ shutil.move(os.path.join(base_dataset_path, test_img_name), os.path.join(test_store_path, test_img_name))
23
+
scripts/visualize_thisthat_point.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ This repo is provided to change the destination area.
3
+ '''
4
+
5
+ import os, cv2
6
+
7
+
8
+ def draw_dot(ref_img, new_h, new_w):
9
+ # Draw the dot
10
+ dot_range = 3
11
+ for i in range(-1*dot_range, dot_range+1):
12
+ for j in range(-1*dot_range, dot_range+1):
13
+ dil_vertical, dil_horizontal = new_h + i, new_w + j
14
+ if (0 <= dil_vertical and dil_vertical < ref_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < ref_img.shape[1]):
15
+ ref_img[dil_vertical, dil_horizontal, :] = [0, 128, 0]
16
+
17
+ return ref_img
18
+
19
+
20
+ if __name__ == "__main__":
21
+ instance_path = "datasets/validation_thisthat14/000049/"
22
+ new_w, new_h = 385, 310
23
+ # 256.1850280761719 241.71287155151367
24
+
25
+ # Read the items
26
+ data_path = os.path.join(instance_path, "data.txt")
27
+ ref_img_path = os.path.join(instance_path, "im_0.jpg")
28
+ ref_img = cv2.imread(ref_img_path)
29
+
30
+
31
+ # Read the first point
32
+ file1 = open(data_path, 'r')
33
+ Lines = file1.readlines()
34
+ frame_idx, horizontal, vertical = Lines[0].split(' ')
35
+ ref_img = draw_dot(ref_img, int(float(vertical)), int(float(horizontal)))
36
+
37
+ # Second dot
38
+ ref_img = draw_dot(ref_img, new_h, new_w)
39
+
40
+
41
+
42
+ # Store the image
43
+ cv2.imwrite("visual.png", ref_img)
svd/diffusion_arch/transformer_temporal.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, Optional
16
+ import torch
17
+ from torch import nn
18
+
19
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
20
+ from diffusers.utils import BaseOutput
21
+ from diffusers.models.attention import BasicTransformerBlock, TemporalBasicTransformerBlock
22
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+ from diffusers.models.resnet import AlphaBlender
25
+
26
+
27
+ @dataclass
28
+ class TransformerTemporalModelOutput(BaseOutput):
29
+ """
30
+ The output of [`TransformerTemporalModel`].
31
+
32
+ Args:
33
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
34
+ The hidden states output conditioned on `encoder_hidden_states` input.
35
+ """
36
+
37
+ sample: torch.FloatTensor
38
+
39
+
40
+ class TransformerTemporalModel(ModelMixin, ConfigMixin):
41
+ """
42
+ A Transformer model for video-like data.
43
+
44
+ Parameters:
45
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
46
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
47
+ in_channels (`int`, *optional*):
48
+ The number of channels in the input and output (specify if the input is **continuous**).
49
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
50
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
51
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
52
+ attention_bias (`bool`, *optional*):
53
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
54
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
55
+ This is fixed during training since it is used to learn a number of position embeddings.
56
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
57
+ Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
58
+ activation functions.
59
+ norm_elementwise_affine (`bool`, *optional*):
60
+ Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
61
+ double_self_attention (`bool`, *optional*):
62
+ Configure if each `TransformerBlock` should contain two self-attention layers.
63
+ positional_embeddings: (`str`, *optional*):
64
+ The type of positional embeddings to apply to the sequence input before passing use.
65
+ num_positional_embeddings: (`int`, *optional*):
66
+ The maximum length of the sequence over which to apply positional embeddings.
67
+ """
68
+
69
+ @register_to_config
70
+ def __init__(
71
+ self,
72
+ num_attention_heads: int = 16,
73
+ attention_head_dim: int = 88,
74
+ in_channels: Optional[int] = None,
75
+ out_channels: Optional[int] = None,
76
+ num_layers: int = 1,
77
+ dropout: float = 0.0,
78
+ norm_num_groups: int = 32,
79
+ cross_attention_dim: Optional[int] = None,
80
+ attention_bias: bool = False,
81
+ sample_size: Optional[int] = None,
82
+ activation_fn: str = "geglu",
83
+ norm_elementwise_affine: bool = True,
84
+ double_self_attention: bool = True,
85
+ positional_embeddings: Optional[str] = None,
86
+ num_positional_embeddings: Optional[int] = None,
87
+ ):
88
+ super().__init__()
89
+ self.num_attention_heads = num_attention_heads
90
+ self.attention_head_dim = attention_head_dim
91
+ inner_dim = num_attention_heads * attention_head_dim
92
+
93
+ self.in_channels = in_channels
94
+
95
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
96
+ self.proj_in = nn.Linear(in_channels, inner_dim)
97
+
98
+ # 3. Define transformers blocks
99
+ self.transformer_blocks = nn.ModuleList(
100
+ [
101
+ BasicTransformerBlock(
102
+ inner_dim,
103
+ num_attention_heads,
104
+ attention_head_dim,
105
+ dropout=dropout,
106
+ cross_attention_dim=cross_attention_dim,
107
+ activation_fn=activation_fn,
108
+ attention_bias=attention_bias,
109
+ double_self_attention=double_self_attention,
110
+ norm_elementwise_affine=norm_elementwise_affine,
111
+ positional_embeddings=positional_embeddings,
112
+ num_positional_embeddings=num_positional_embeddings,
113
+ )
114
+ for d in range(num_layers)
115
+ ]
116
+ )
117
+
118
+ self.proj_out = nn.Linear(inner_dim, in_channels)
119
+
120
+ def forward(
121
+ self,
122
+ hidden_states: torch.FloatTensor,
123
+ encoder_hidden_states: Optional[torch.LongTensor] = None,
124
+ timestep: Optional[torch.LongTensor] = None,
125
+ class_labels: torch.LongTensor = None,
126
+ num_frames: int = 1,
127
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
128
+ return_dict: bool = True,
129
+ ) -> TransformerTemporalModelOutput:
130
+ """
131
+ The [`TransformerTemporal`] forward method.
132
+
133
+ Args:
134
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
135
+ Input hidden_states.
136
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
137
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
138
+ self-attention.
139
+ timestep ( `torch.LongTensor`, *optional*):
140
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
141
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
142
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
143
+ `AdaLayerZeroNorm`.
144
+ num_frames (`int`, *optional*, defaults to 1):
145
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
146
+ cross_attention_kwargs (`dict`, *optional*):
147
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
148
+ `self.processor` in
149
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
150
+ return_dict (`bool`, *optional*, defaults to `True`):
151
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
152
+ tuple.
153
+
154
+ Returns:
155
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
156
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
157
+ returned, otherwise a `tuple` where the first element is the sample tensor.
158
+ """
159
+ # 1. Input
160
+ batch_frames, channel, height, width = hidden_states.shape
161
+ batch_size = batch_frames // num_frames
162
+
163
+ residual = hidden_states
164
+
165
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
166
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
167
+
168
+ hidden_states = self.norm(hidden_states)
169
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
170
+
171
+ hidden_states = self.proj_in(hidden_states)
172
+
173
+ # 2. Blocks
174
+ for block in self.transformer_blocks:
175
+ hidden_states = block(
176
+ hidden_states,
177
+ encoder_hidden_states=encoder_hidden_states,
178
+ timestep=timestep,
179
+ cross_attention_kwargs=cross_attention_kwargs,
180
+ class_labels=class_labels,
181
+ )
182
+
183
+ # 3. Output
184
+ hidden_states = self.proj_out(hidden_states)
185
+ hidden_states = (
186
+ hidden_states[None, None, :]
187
+ .reshape(batch_size, height, width, num_frames, channel)
188
+ .permute(0, 3, 4, 1, 2)
189
+ .contiguous()
190
+ )
191
+ hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
192
+
193
+ output = hidden_states + residual
194
+
195
+ if not return_dict:
196
+ return (output,)
197
+
198
+ return TransformerTemporalModelOutput(sample=output)
199
+
200
+
201
+ class TransformerSpatioTemporalModel(nn.Module):
202
+ """
203
+ A Transformer model for video-like data.
204
+
205
+ Parameters:
206
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
207
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
208
+ in_channels (`int`, *optional*):
209
+ The number of channels in the input and output (specify if the input is **continuous**).
210
+ out_channels (`int`, *optional*):
211
+ The number of channels in the output (specify if the input is **continuous**).
212
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
213
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ num_attention_heads: int = 16,
219
+ attention_head_dim: int = 88,
220
+ in_channels: int = 320,
221
+ out_channels: Optional[int] = None,
222
+ num_layers: int = 1,
223
+ cross_attention_dim: Optional[int] = None,
224
+ ):
225
+ super().__init__()
226
+ self.num_attention_heads = num_attention_heads
227
+ self.attention_head_dim = attention_head_dim
228
+
229
+ inner_dim = num_attention_heads * attention_head_dim
230
+ self.inner_dim = inner_dim
231
+
232
+ # 2. Define input layers
233
+ self.in_channels = in_channels
234
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
235
+ self.proj_in = nn.Linear(in_channels, inner_dim)
236
+
237
+ # 3. Define transformers blocks
238
+ self.transformer_blocks = nn.ModuleList(
239
+ [
240
+ BasicTransformerBlock(
241
+ inner_dim,
242
+ num_attention_heads,
243
+ attention_head_dim,
244
+ cross_attention_dim=cross_attention_dim,
245
+ )
246
+ for d in range(num_layers)
247
+ ]
248
+ )
249
+
250
+ time_mix_inner_dim = inner_dim
251
+ self.temporal_transformer_blocks = nn.ModuleList(
252
+ [
253
+ TemporalBasicTransformerBlock(
254
+ inner_dim,
255
+ time_mix_inner_dim,
256
+ num_attention_heads,
257
+ attention_head_dim,
258
+ cross_attention_dim=cross_attention_dim,
259
+ )
260
+ for _ in range(num_layers)
261
+ ]
262
+ )
263
+
264
+ time_embed_dim = in_channels * 4
265
+ self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
266
+ self.time_proj = Timesteps(in_channels, True, 0)
267
+ self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
268
+
269
+ # 4. Define output layers
270
+ self.out_channels = in_channels if out_channels is None else out_channels
271
+ # TODO: should use out_channels for continuous projections
272
+ self.proj_out = nn.Linear(inner_dim, in_channels)
273
+
274
+ self.gradient_checkpointing = False
275
+
276
+ def forward(
277
+ self,
278
+ hidden_states: torch.Tensor,
279
+ encoder_hidden_states: Optional[torch.Tensor] = None,
280
+ image_only_indicator: Optional[torch.Tensor] = None,
281
+ return_dict: bool = True,
282
+ ):
283
+ """
284
+ Args:
285
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
286
+ Input hidden_states.
287
+ num_frames (`int`):
288
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
289
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
290
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
291
+ self-attention.
292
+ image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
293
+ A tensor indicating whether the input contains only images. 1 indicates that the input contains only
294
+ images, 0 indicates that the input contains video frames.
295
+ return_dict (`bool`, *optional*, defaults to `True`):
296
+ Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
297
+ tuple.
298
+
299
+ Returns:
300
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
301
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
302
+ returned, otherwise a `tuple` where the first element is the sample tensor.
303
+ """
304
+ # 1. Input
305
+ batch_frames, _, height, width = hidden_states.shape
306
+ num_frames = image_only_indicator.shape[-1]
307
+ batch_size = batch_frames // num_frames
308
+
309
+ time_context = encoder_hidden_states
310
+ time_context_first_timestep = time_context[None, :].reshape(
311
+ batch_size, num_frames, -1, time_context.shape[-1]
312
+ )[:, 0] # This part means that the cross attn section for the temporal blocks only consider ths first frames
313
+
314
+
315
+ encoder_hidden_states_dim = time_context_first_timestep.shape[1]
316
+ time_context = time_context_first_timestep[None, :].broadcast_to(
317
+ height * width, batch_size, encoder_hidden_states_dim, time_context.shape[-1]
318
+ )
319
+ time_context = time_context.reshape(height * width * batch_size, encoder_hidden_states_dim, time_context.shape[-1])
320
+
321
+ residual = hidden_states
322
+
323
+ hidden_states = self.norm(hidden_states)
324
+ inner_dim = hidden_states.shape[1]
325
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
326
+ hidden_states = self.proj_in(hidden_states)
327
+
328
+ num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
329
+ num_frames_emb = num_frames_emb.repeat(batch_size, 1)
330
+ num_frames_emb = num_frames_emb.reshape(-1)
331
+ t_emb = self.time_proj(num_frames_emb)
332
+
333
+ # `Timesteps` does not contain any weights and will always return f32 tensors
334
+ # but time_embedding might actually be running in fp16. so we need to cast here.
335
+ # there might be better ways to encapsulate this.
336
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
337
+
338
+ emb = self.time_pos_embed(t_emb)
339
+ emb = emb[:, None, :]
340
+
341
+ # 2. Blocks
342
+ for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
343
+ if self.training and self.gradient_checkpointing:
344
+ hidden_states = torch.utils.checkpoint.checkpoint(
345
+ block,
346
+ hidden_states,
347
+ None,
348
+ encoder_hidden_states,
349
+ None,
350
+ use_reentrant=False,
351
+ )
352
+ else:
353
+ hidden_states = block(
354
+ hidden_states,
355
+ encoder_hidden_states=encoder_hidden_states,
356
+ )
357
+
358
+ hidden_states_mix = hidden_states
359
+ hidden_states_mix = hidden_states_mix + emb
360
+
361
+ hidden_states_mix = temporal_block(
362
+ hidden_states_mix,
363
+ num_frames=num_frames,
364
+ encoder_hidden_states=time_context,
365
+ )
366
+ hidden_states = self.time_mixer(
367
+ x_spatial=hidden_states,
368
+ x_temporal=hidden_states_mix,
369
+ image_only_indicator=image_only_indicator,
370
+ )
371
+
372
+ # 3. Output
373
+ hidden_states = self.proj_out(hidden_states)
374
+ hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
375
+
376
+ output = hidden_states + residual
377
+
378
+ if not return_dict:
379
+ return (output,)
380
+
381
+ return TransformerTemporalModelOutput(sample=output)