Spaces:
Sleeping
Sleeping
Commit
·
59b2a81
1
Parent(s):
aa505db
feat: initial push
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +49 -0
- __assets__/0.jpg +0 -0
- __assets__/156.jpg +0 -0
- __assets__/274.jpg +0 -0
- __assets__/375.jpg +0 -0
- __assets__/551.jpg +0 -0
- __assets__/91.jpg +0 -0
- __assets__/ThisThat_logo.png +0 -0
- app.py +475 -4
- config/accelerate_config.json +18 -0
- config/flowformer_config.py +78 -0
- config/train_image2video.yaml +78 -0
- config/train_image2video_controlnet.yaml +101 -0
- curation_pipeline/add_lang_info.py +38 -0
- curation_pipeline/match_dataset_v1.py +117 -0
- curation_pipeline/match_dataset_v2.py +137 -0
- curation_pipeline/prepare_bridge_csv.py +69 -0
- curation_pipeline/prepare_bridge_jsonl.py +47 -0
- curation_pipeline/prepare_bridge_v1.py +132 -0
- curation_pipeline/prepare_bridge_v2.py +139 -0
- curation_pipeline/select_frame_with_this_that.py +421 -0
- curation_pipeline/tracking_by_keypoint.py +136 -0
- data_loader/video_dataset.py +323 -0
- data_loader/video_this_that_dataset.py +326 -0
- pretrained/PUT_YOUR_WEIGHT_HERE.md +0 -0
- requirements.txt +27 -0
- scripts/active_learning_select.py +27 -0
- scripts/add_point2img.py +51 -0
- scripts/check_video.py +19 -0
- scripts/clean_bridge_dataset.py +22 -0
- scripts/collect_lang.py +31 -0
- scripts/combine_results.py +85 -0
- scripts/compress_gif.py +52 -0
- scripts/compress_videos.py +55 -0
- scripts/crop_video_frames.py +22 -0
- scripts/extract_test_dataset.py +18 -0
- scripts/generate_noise.py +14 -0
- scripts/generate_sam.py +56 -0
- scripts/generate_sam_this_that.py +108 -0
- scripts/generate_traj.py +601 -0
- scripts/interpolate_by_repeat.py +55 -0
- scripts/length_stats.py +21 -0
- scripts/motion_stats.py +75 -0
- scripts/process_llama.py +74 -0
- scripts/process_sim.py +59 -0
- scripts/resize_img.py +17 -0
- scripts/resize_video_seq.py +33 -0
- scripts/train_test_split.py +23 -0
- scripts/visualize_thisthat_point.py +43 -0
- svd/diffusion_arch/transformer_temporal.py +381 -0
.gitignore
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.ipynb_checkpoints
|
2 |
+
.idea
|
3 |
+
__pycache__
|
4 |
+
|
5 |
+
datasets/
|
6 |
+
tmp_imgs
|
7 |
+
runs/
|
8 |
+
runs_last/
|
9 |
+
saved_models/
|
10 |
+
pre_trained/
|
11 |
+
save_log/
|
12 |
+
diffusers/
|
13 |
+
weights/
|
14 |
+
checkpoints/
|
15 |
+
validation_videos*
|
16 |
+
pretrained/*
|
17 |
+
.gradio/*
|
18 |
+
|
19 |
+
*.pyc
|
20 |
+
*.sh
|
21 |
+
*.pth
|
22 |
+
*.png
|
23 |
+
*.jpg
|
24 |
+
*.mp4
|
25 |
+
*.txt
|
26 |
+
*.json
|
27 |
+
*.jsonl
|
28 |
+
*.zip
|
29 |
+
*.mp4
|
30 |
+
*.csv
|
31 |
+
*.webp
|
32 |
+
*.bin
|
33 |
+
*.pkl
|
34 |
+
*.safetensors
|
35 |
+
*.pt
|
36 |
+
*.log
|
37 |
+
events.*
|
38 |
+
*.yml
|
39 |
+
*.gif
|
40 |
+
*.npy
|
41 |
+
*.out
|
42 |
+
|
43 |
+
!requirements.txt
|
44 |
+
!saved_models/*.md
|
45 |
+
!LICENSE.txt
|
46 |
+
!config/*
|
47 |
+
!__assets__/*
|
48 |
+
!__assets__/Bridge_example/*
|
49 |
+
!pretrained/PUT_YOUR_WEIGHT_HERE.md
|
__assets__/0.jpg
ADDED
![]() |
__assets__/156.jpg
ADDED
![]() |
__assets__/274.jpg
ADDED
![]() |
__assets__/375.jpg
ADDED
![]() |
__assets__/551.jpg
ADDED
![]() |
__assets__/91.jpg
ADDED
![]() |
__assets__/ThisThat_logo.png
ADDED
![]() |
app.py
CHANGED
@@ -1,7 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
def greet(name):
|
4 |
-
return "Hello " + name + "!!"
|
5 |
|
6 |
-
demo
|
7 |
-
demo.launch()
|
|
|
1 |
+
# *************************************************************************
|
2 |
+
# Copyright (2023) Bytedance Inc.
|
3 |
+
#
|
4 |
+
# Copyright (2023) DragDiffusion Authors
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
# *************************************************************************
|
18 |
+
|
19 |
+
import os, shutil, sys
|
20 |
+
import urllib.request
|
21 |
+
import argparse
|
22 |
+
import imageio
|
23 |
+
import math
|
24 |
+
import cv2
|
25 |
+
import collections
|
26 |
+
import numpy as np
|
27 |
import gradio as gr
|
28 |
+
from PIL import Image
|
29 |
+
|
30 |
+
import torch
|
31 |
+
from pathlib import Path
|
32 |
+
from omegaconf import OmegaConf
|
33 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
34 |
+
from accelerate import Accelerator
|
35 |
+
from accelerate.utils import ProjectConfiguration
|
36 |
+
from diffusers import (
|
37 |
+
AutoencoderKLTemporalDecoder,
|
38 |
+
DDPMScheduler,
|
39 |
+
)
|
40 |
+
from diffusers.utils import check_min_version, is_wandb_available, load_image, export_to_video
|
41 |
+
from huggingface_hub import hf_hub_download
|
42 |
+
from transformers import AutoTokenizer, PretrainedConfig
|
43 |
+
|
44 |
+
|
45 |
+
# Import files from the local folder
|
46 |
+
root_path = os.path.abspath('.')
|
47 |
+
sys.path.append(root_path)
|
48 |
+
from train_code.train_svd import import_pretrained_text_encoder
|
49 |
+
from data_loader.video_dataset import tokenize_captions
|
50 |
+
from data_loader.video_this_that_dataset import get_thisthat_sam
|
51 |
+
from svd.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
52 |
+
from svd.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline
|
53 |
+
from svd.temporal_controlnet import ControlNetModel
|
54 |
+
from svd.pipeline_stable_video_diffusion_controlnet import StableVideoDiffusionControlNetPipeline
|
55 |
+
from utils.optical_flow_utils import bivariate_Gaussian
|
56 |
+
|
57 |
+
|
58 |
+
# For the 2D dilation
|
59 |
+
blur_kernel = bivariate_Gaussian(99, 10, 10, 0, grid = None, isotropic = True)
|
60 |
+
|
61 |
+
|
62 |
+
# Import
|
63 |
+
# LENGTH=480 # length of the square area displaying/editing images
|
64 |
+
HEIGHT = 256
|
65 |
+
WIDTH = 384
|
66 |
+
|
67 |
+
|
68 |
+
MARKDOWN = \
|
69 |
+
"""
|
70 |
+
## <p style='text-align: center'> This&That </p>
|
71 |
+
|
72 |
+
[GitHub](https://github.com/Kiteretsu77/This_and_That_VDM) | [Paper](http://arxiv.org/abs/2407.05530) | [Webpage](https://cfeng16.github.io/this-and-that/)
|
73 |
+
This&That is a Robotics scenario (Bridge-dataset-based for this repo) Language-Gesture-Image-conditioned Video Generation Model for Robot Planning.
|
74 |
+
|
75 |
+
This Demo is on the Video Diffusion Model part.
|
76 |
+
Only GestureNet is provided in this Gradio Demo, you can check the full test code for all pretrained weight available.
|
77 |
+
|
78 |
+
### Note: The index we put the gesture point by default here is [4, 10] for two gesture points or [4] for one gesture point.
|
79 |
+
### Note: The result now only support is 256x384
|
80 |
+
### Note: Click "Clear All" to restart everything; Click "Undo Point" to cancel the point you put
|
81 |
+
|
82 |
+
If This&That is helpful, please help star the [GitHub Repo](https://github.com/Kiteretsu77/This_and_That_VDM). Thanks!
|
83 |
+
"""
|
84 |
+
|
85 |
+
|
86 |
+
def store_img(img):
|
87 |
+
|
88 |
+
# when new image is uploaded, `selected_points` should be empty
|
89 |
+
return img, []
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
def clear_all():
|
94 |
+
return None, \
|
95 |
+
gr.Image(value=None, height=HEIGHT, width=WIDTH, interactive=False), \
|
96 |
+
None, [] # selected points
|
97 |
+
|
98 |
+
|
99 |
+
def undo_points(original_image):
|
100 |
+
img = original_image.copy()
|
101 |
+
return img, []
|
102 |
+
|
103 |
+
|
104 |
+
# User click the image to get points, and show the points on the image [From https://github.com/Yujun-Shi/DragDiffusion]
|
105 |
+
def get_points(img, original_image, sel_pix, evt: gr.SelectData):
|
106 |
+
|
107 |
+
# collect the selected point
|
108 |
+
sel_pix.append(evt.index)
|
109 |
+
|
110 |
+
if len(sel_pix) > 2:
|
111 |
+
raise gr.Error("We only at most support two points")
|
112 |
+
|
113 |
+
if original_image is None:
|
114 |
+
original_image = img.copy()
|
115 |
+
|
116 |
+
# draw points
|
117 |
+
points = []
|
118 |
+
for idx, point in enumerate(sel_pix):
|
119 |
+
if idx % 2 == 0:
|
120 |
+
# draw a red circle at the handle point
|
121 |
+
cv2.circle(img, tuple(point), 10, (255, 0, 0), -1)
|
122 |
+
else:
|
123 |
+
# draw a blue circle at the handle point
|
124 |
+
cv2.circle(img, tuple(point), 10, (0, 255, 0), -1)
|
125 |
+
points.append(tuple(point))
|
126 |
+
# draw an arrow from handle point to target point
|
127 |
+
# if len(points) == 2:
|
128 |
+
# cv2.arrowedLine(img, points[0], points[1], (255, 255, 255), 4, tipLength=0.5)
|
129 |
+
# points = []
|
130 |
+
|
131 |
+
return [img if isinstance(img, np.ndarray) else np.array(img), original_image]
|
132 |
+
|
133 |
+
|
134 |
+
def gesturenet_inference(ref_image, prompt, selected_points):
|
135 |
+
|
136 |
+
# Check some paramter, must have prompt and selected points
|
137 |
+
if prompt == "" or prompt is None:
|
138 |
+
raise gr.Error("Please input text prompt")
|
139 |
+
if selected_points == []:
|
140 |
+
raise gr.Error("Please click one/two points in the Image")
|
141 |
+
|
142 |
+
# Prepare the setting
|
143 |
+
frame_idxs = [4, 10]
|
144 |
+
use_ambiguous_prompt = False
|
145 |
+
model_type = "GestureNet"
|
146 |
+
huggingface_pretrained_path = "HikariDawn/This-and-That-1.1"
|
147 |
+
|
148 |
+
print("Text prompt is ", prompt)
|
149 |
+
|
150 |
+
# Prepare tmp folder
|
151 |
+
store_folder_name = "tmp"
|
152 |
+
if os.path.exists(store_folder_name):
|
153 |
+
shutil.rmtree(store_folder_name)
|
154 |
+
os.makedirs(store_folder_name)
|
155 |
+
|
156 |
+
|
157 |
+
# Read the yaml setting files (Very important for loading hyperparamters needed)
|
158 |
+
if not os.path.exists(huggingface_pretrained_path):
|
159 |
+
yaml_download_path = hf_hub_download(repo_id=huggingface_pretrained_path, subfolder="unet", filename="train_image2video.yaml")
|
160 |
+
if model_type == "GestureNet":
|
161 |
+
yaml_download_path = hf_hub_download(repo_id=huggingface_pretrained_path, subfolder="gesturenet", filename="train_image2video_gesturenet.yaml")
|
162 |
+
else: # If the path is a local path we can concatenate it here
|
163 |
+
yaml_download_path = os.path.join(huggingface_pretrained_path, "unet", "train_image2video.yaml")
|
164 |
+
if model_type == "GestureNet":
|
165 |
+
yaml_download_path = os.path.join(huggingface_pretrained_path, "gesturenet", "train_image2video_gesturenet.yaml")
|
166 |
+
|
167 |
+
# Load the config
|
168 |
+
assert(os.path.exists(yaml_download_path))
|
169 |
+
config = OmegaConf.load(yaml_download_path)
|
170 |
+
|
171 |
+
|
172 |
+
################################################ Prepare vae, unet, image_encoder Same as before #################################################################
|
173 |
+
print("Prepare the pretrained model")
|
174 |
+
accelerator = Accelerator(
|
175 |
+
gradient_accumulation_steps = config["gradient_accumulation_steps"],
|
176 |
+
mixed_precision = config["mixed_precision"],
|
177 |
+
log_with = config["report_to"],
|
178 |
+
project_config = ProjectConfiguration(project_dir=config["output_dir"], logging_dir=Path(config["output_dir"], config["logging_name"])),
|
179 |
+
)
|
180 |
+
feature_extractor = CLIPImageProcessor.from_pretrained(
|
181 |
+
config["pretrained_model_name_or_path"], subfolder="feature_extractor", revision=None
|
182 |
+
) # This instance has now weight, they are just seeting file
|
183 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
184 |
+
config["pretrained_model_name_or_path"], subfolder="image_encoder", revision=None, variant="fp16"
|
185 |
+
)
|
186 |
+
vae = AutoencoderKLTemporalDecoder.from_pretrained(
|
187 |
+
config["pretrained_model_name_or_path"], subfolder="vae", revision=None, variant="fp16"
|
188 |
+
)
|
189 |
+
unet = UNetSpatioTemporalConditionModel.from_pretrained(
|
190 |
+
huggingface_pretrained_path,
|
191 |
+
subfolder = "unet",
|
192 |
+
low_cpu_mem_usage = True,
|
193 |
+
# variant = "fp16",
|
194 |
+
)
|
195 |
+
|
196 |
+
|
197 |
+
# For text ..............................................
|
198 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
199 |
+
config["pretrained_tokenizer_name_or_path"],
|
200 |
+
subfolder = "tokenizer",
|
201 |
+
revision = None,
|
202 |
+
use_fast = False,
|
203 |
+
)
|
204 |
+
# Clip Text Encoder
|
205 |
+
text_encoder_cls = import_pretrained_text_encoder(config["pretrained_tokenizer_name_or_path"], revision=None)
|
206 |
+
text_encoder = text_encoder_cls.from_pretrained(config["pretrained_tokenizer_name_or_path"], subfolder = "text_encoder", revision = None, variant = None)
|
207 |
+
|
208 |
+
|
209 |
+
weight_dtype = torch.float32
|
210 |
+
if accelerator.mixed_precision == "fp16":
|
211 |
+
weight_dtype = torch.float16
|
212 |
+
elif accelerator.mixed_precision == "bf16":
|
213 |
+
weight_dtype = torch.bfloat16
|
214 |
+
|
215 |
+
# Move vae + image_encoder to gpu and cast to weight_dtype
|
216 |
+
vae.requires_grad_(False)
|
217 |
+
image_encoder.requires_grad_(False)
|
218 |
+
unet.requires_grad_(False) # Will switch back at the end
|
219 |
+
text_encoder.requires_grad_(False)
|
220 |
+
|
221 |
+
# Move to accelerator
|
222 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
223 |
+
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
224 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
225 |
+
|
226 |
+
# For GestureNet
|
227 |
+
if model_type == "GestureNet":
|
228 |
+
unet.to(accelerator.device, dtype=weight_dtype) # There is no need to cast unet in unet training, only needed in controlnet one
|
229 |
+
|
230 |
+
# Handle the Controlnet first from UNet
|
231 |
+
gesturenet = ControlNetModel.from_pretrained(
|
232 |
+
huggingface_pretrained_path,
|
233 |
+
subfolder = "gesturenet",
|
234 |
+
low_cpu_mem_usage = True,
|
235 |
+
variant = None,
|
236 |
+
)
|
237 |
+
|
238 |
+
gesturenet.requires_grad_(False)
|
239 |
+
gesturenet.to(accelerator.device)
|
240 |
+
##############################################################################################################################################################
|
241 |
+
|
242 |
+
|
243 |
+
|
244 |
+
|
245 |
+
# Init the pipeline
|
246 |
+
pipeline = StableVideoDiffusionControlNetPipeline.from_pretrained(
|
247 |
+
config["pretrained_model_name_or_path"], # Still based on regular SVD config
|
248 |
+
vae = vae,
|
249 |
+
image_encoder = image_encoder,
|
250 |
+
unet = unet,
|
251 |
+
revision = None, # Set None directly now
|
252 |
+
torch_dtype = weight_dtype,
|
253 |
+
)
|
254 |
+
pipeline = pipeline.to(accelerator.device)
|
255 |
+
pipeline.set_progress_bar_config(disable=True)
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
############################## Prepare and Process the condition here ##############################
|
260 |
+
org_height, org_width, _ = ref_image.shape
|
261 |
+
ref_image_pil = Image.fromarray(ref_image)
|
262 |
+
ref_image_pil = ref_image_pil.resize((config["width"], config["height"]))
|
263 |
+
|
264 |
+
|
265 |
+
# Initial the optical flow format we want
|
266 |
+
gesture_condition_img = np.zeros((config["video_seq_length"], config["conditioning_channels"], config["height"], config["width"]), dtype=np.float32) # The last image should be empty
|
267 |
+
|
268 |
+
# Handle the selected points to the condition we want
|
269 |
+
for point_idx, point in enumerate(selected_points):
|
270 |
+
|
271 |
+
frame_idx = frame_idxs[point_idx]
|
272 |
+
horizontal, vertical = point
|
273 |
+
|
274 |
+
# Init the base image
|
275 |
+
base_img = np.zeros((org_height, org_width, 3)).astype(np.float32) # Use the original image size
|
276 |
+
base_img.fill(255)
|
277 |
+
|
278 |
+
# Draw square around the target position
|
279 |
+
dot_range = 10 # Diameter
|
280 |
+
for i in range(-1*dot_range, dot_range+1):
|
281 |
+
for j in range(-1*dot_range, dot_range+1):
|
282 |
+
dil_vertical, dil_horizontal = vertical + i, horizontal + j
|
283 |
+
if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]):
|
284 |
+
if point_idx == 0:
|
285 |
+
base_img[dil_vertical][dil_horizontal] = [0, 0, 255] # The first point should be red
|
286 |
+
else:
|
287 |
+
base_img[dil_vertical][dil_horizontal] = [0, 255, 0] # The second point should be green to distinguish the first point
|
288 |
+
|
289 |
+
# Dilate
|
290 |
+
if config["dilate"]:
|
291 |
+
base_img = cv2.filter2D(base_img, -1, blur_kernel)
|
292 |
+
|
293 |
+
|
294 |
+
##############################################################################################################################
|
295 |
+
### The core pipeline of processing is: Dilate -> Resize -> Range Shift -> Transpose Shape -> Store
|
296 |
+
|
297 |
+
# Resize frames Don't use negative and don't resize in [0,1]
|
298 |
+
base_img = cv2.resize(base_img, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC)
|
299 |
+
|
300 |
+
# Channel Transform and Range Shift
|
301 |
+
if config["conditioning_channels"] == 3:
|
302 |
+
# Map to [0, 1] range
|
303 |
+
base_img = base_img / 255.0
|
304 |
+
|
305 |
+
else:
|
306 |
+
raise NotImplementedError()
|
307 |
+
|
308 |
+
# ReOrganize shape
|
309 |
+
base_img = base_img.transpose(2, 0, 1) # hwc -> chw
|
310 |
+
|
311 |
+
# Write base img based on frame_idx
|
312 |
+
gesture_condition_img[frame_idx] = base_img # Only the first frame, the rest is 0 initialized
|
313 |
+
|
314 |
+
|
315 |
+
####################################################################################################
|
316 |
+
|
317 |
+
# Use the same tokenize process as the dataset preparation stage
|
318 |
+
tokenized_prompt = tokenize_captions(prompt, tokenizer, config, is_train=False).unsqueeze(0).to(accelerator.device) # Use unsqueeze to expand dim
|
319 |
+
|
320 |
+
|
321 |
+
|
322 |
+
# Call the pipeline
|
323 |
+
with torch.autocast("cuda"):
|
324 |
+
frames = pipeline(
|
325 |
+
image = ref_image_pil,
|
326 |
+
condition_img = gesture_condition_img, # numpy [0,1] range
|
327 |
+
controlnet = accelerator.unwrap_model(gesturenet),
|
328 |
+
prompt = tokenized_prompt,
|
329 |
+
use_text = config["use_text"],
|
330 |
+
text_encoder = text_encoder,
|
331 |
+
height = config["height"],
|
332 |
+
width = config["width"],
|
333 |
+
num_frames = config["video_seq_length"],
|
334 |
+
decode_chunk_size = 8,
|
335 |
+
motion_bucket_id = 200,
|
336 |
+
# controlnet_image_index = controlnet_image_index,
|
337 |
+
# coordinate_values = coordinate_values,
|
338 |
+
num_inference_steps = config["num_inference_steps"],
|
339 |
+
max_guidance_scale = config["inference_max_guidance_scale"],
|
340 |
+
fps = 7,
|
341 |
+
use_instructpix2pix = config["use_instructpix2pix"],
|
342 |
+
noise_aug_strength = config["inference_noise_aug_strength"],
|
343 |
+
controlnet_conditioning_scale = config["outer_conditioning_scale"],
|
344 |
+
inner_conditioning_scale = config["inner_conditioning_scale"],
|
345 |
+
guess_mode = config["inference_guess_mode"], # False in inference
|
346 |
+
image_guidance_scale = config["image_guidance_scale"],
|
347 |
+
).frames[0]
|
348 |
+
|
349 |
+
# Save frames
|
350 |
+
video_file_path = os.path.join(store_folder_name, "tmp.mp4")
|
351 |
+
writer = imageio.get_writer(video_file_path, fps=4)
|
352 |
+
for idx, frame in enumerate(frames):
|
353 |
+
frame.save(os.path.join(store_folder_name, str(idx)+".png"))
|
354 |
+
writer.append_data(cv2.cvtColor(cv2.imread(os.path.join(store_folder_name, str(idx)+".png")), cv2.COLOR_BGR2RGB))
|
355 |
+
writer.close()
|
356 |
+
|
357 |
+
|
358 |
+
|
359 |
+
# Cleaning process
|
360 |
+
del pipeline
|
361 |
+
torch.cuda.empty_cache()
|
362 |
+
|
363 |
+
return gr.update(value=video_file_path, width=config["width"], height=config["height"]) # Return resuly based on the need
|
364 |
+
|
365 |
+
|
366 |
+
|
367 |
+
if __name__ == '__main__':
|
368 |
+
|
369 |
+
|
370 |
+
# Gradio demo part
|
371 |
+
with gr.Blocks() as demo:
|
372 |
+
# layout definition
|
373 |
+
with gr.Row():
|
374 |
+
gr.Markdown(MARKDOWN)
|
375 |
+
|
376 |
+
# UI components for editing real images
|
377 |
+
with gr.Row(elem_classes=["container"]):
|
378 |
+
selected_points = gr.State([]) # store points
|
379 |
+
original_image = gr.State(value=None) # store original input image
|
380 |
+
with gr.Row():
|
381 |
+
with gr.Column():
|
382 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Click two Points</p>""")
|
383 |
+
input_image = gr.Image(label="Input Image", height=HEIGHT, width=WIDTH, interactive=False, elem_id="input_img")
|
384 |
+
# gr.Image(type="numpy", label="Click Points", height=HEIGHT, width=WIDTH, interactive=False) # for points clicking
|
385 |
+
undo_button = gr.Button("Undo point")
|
386 |
+
|
387 |
+
# Text prompt
|
388 |
+
with gr.Row():
|
389 |
+
prompt = gr.Textbox(label="Text Prompt")
|
390 |
+
|
391 |
+
|
392 |
+
with gr.Column():
|
393 |
+
gr.Markdown("""<p style="text-align: center; font-size: 20px">Results</p>""")
|
394 |
+
frames = gr.Video(value=None, label="Generate Video", show_label=True, height=HEIGHT, width=WIDTH)
|
395 |
+
with gr.Row():
|
396 |
+
run_button = gr.Button("Run")
|
397 |
+
clear_all_button = gr.Button("Clear All")
|
398 |
+
|
399 |
+
|
400 |
+
|
401 |
+
|
402 |
+
# with gr.Tab("Base Model Config"):
|
403 |
+
# with gr.Row():
|
404 |
+
# local_models_dir = 'local_pretrained_models'
|
405 |
+
# local_models_choice = \
|
406 |
+
# [os.path.join(local_models_dir,d) for d in os.listdir(local_models_dir) if os.path.isdir(os.path.join(local_models_dir,d))]
|
407 |
+
# model_path = gr.Dropdown(value="runwayml/stable-diffusion-v1-5",
|
408 |
+
# label="Diffusion Model Path",
|
409 |
+
# choices=[
|
410 |
+
# "runwayml/stable-diffusion-v1-5",
|
411 |
+
# "gsdf/Counterfeit-V2.5",
|
412 |
+
# "stablediffusionapi/anything-v5",
|
413 |
+
# "SG161222/Realistic_Vision_V2.0",
|
414 |
+
# ] + local_models_choice
|
415 |
+
# )
|
416 |
+
# vae_path = gr.Dropdown(value="default",
|
417 |
+
# label="VAE choice",
|
418 |
+
# choices=["default",
|
419 |
+
# "stabilityai/sd-vae-ft-mse"] + local_models_choice
|
420 |
+
# )
|
421 |
+
|
422 |
+
# Examples
|
423 |
+
with gr.Row(elem_classes=["container"]):
|
424 |
+
gr.Examples(
|
425 |
+
[
|
426 |
+
["__assets__/Bridge_example/Task1_v1_511/im_0.jpg", "take this to there"],
|
427 |
+
["__assets__/Bridge_example/Task2_v2_164/im_0.jpg", "put this to there"],
|
428 |
+
["__assets__/Bridge_example/Task3_v2_490/im_0.jpg", "fold this"],
|
429 |
+
["__assets__/Bridge_example/Task4_v2_119/im_0.jpg", "open this"],
|
430 |
+
|
431 |
+
# ["__assets__/0.jpg", "take this to there"],
|
432 |
+
["__assets__/91.jpg", "take this to there"],
|
433 |
+
["__assets__/156.jpg", "take this to there"],
|
434 |
+
# ["__assets__/274.jpg", "take this to there"],
|
435 |
+
["__assets__/375.jpg", "take this to there"],
|
436 |
+
# ["__assets__/551.jpg", "take this to there"],
|
437 |
+
],
|
438 |
+
[input_image, prompt, selected_points],
|
439 |
+
)
|
440 |
+
|
441 |
+
|
442 |
+
|
443 |
+
|
444 |
+
####################################### Event Definition #######################################
|
445 |
+
|
446 |
+
# Draw the points
|
447 |
+
input_image.select(
|
448 |
+
get_points,
|
449 |
+
[input_image, original_image, selected_points],
|
450 |
+
[input_image, original_image],
|
451 |
+
)
|
452 |
+
|
453 |
+
# Clean the points
|
454 |
+
undo_button.click(
|
455 |
+
undo_points,
|
456 |
+
[original_image],
|
457 |
+
[input_image, selected_points],
|
458 |
+
)
|
459 |
+
|
460 |
+
run_button.click(
|
461 |
+
gesturenet_inference,
|
462 |
+
inputs = [
|
463 |
+
# vae, unet, gesturenet, image_encoder, text_encoder, tokenizer,
|
464 |
+
original_image, prompt, selected_points,
|
465 |
+
# frame_idxs,
|
466 |
+
# config, accelerator, weight_dtype
|
467 |
+
],
|
468 |
+
outputs = [frames]
|
469 |
+
)
|
470 |
+
|
471 |
+
clear_all_button.click(
|
472 |
+
clear_all,
|
473 |
+
[],
|
474 |
+
outputs = [original_image, input_image, prompt, selected_points],
|
475 |
+
)
|
476 |
|
|
|
|
|
477 |
|
478 |
+
demo.queue().launch(share=True, debug=True)
|
|
config/accelerate_config.json
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"compute_environment": "LOCAL_MACHINE",
|
3 |
+
"debug": false,
|
4 |
+
"distributed_type": "MULTI_GPU",
|
5 |
+
"downcast_bf16": "no",
|
6 |
+
"gpu_ids": "all",
|
7 |
+
"machine_rank": 0,
|
8 |
+
"main_training_function": "main",
|
9 |
+
"mixed_precision": "fp16",
|
10 |
+
"num_machines": 1,
|
11 |
+
"num_processes": 8,
|
12 |
+
"rdzv_backend": "static",
|
13 |
+
"same_network": true,
|
14 |
+
"tpu_env": [],
|
15 |
+
"tpu_use_cluster": false,
|
16 |
+
"tpu_use_sudo": false,
|
17 |
+
"use_cpu": false
|
18 |
+
}
|
config/flowformer_config.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode as CN
|
2 |
+
_CN = CN()
|
3 |
+
|
4 |
+
_CN.name = 'default'
|
5 |
+
_CN.suffix ='sintel'
|
6 |
+
_CN.gamma = 0.75
|
7 |
+
_CN.max_flow = 400
|
8 |
+
_CN.batch_size = 6
|
9 |
+
_CN.sum_freq = 100
|
10 |
+
_CN.val_freq = 100000000
|
11 |
+
_CN.image_size = [432, 960]
|
12 |
+
_CN.add_noise = False
|
13 |
+
_CN.use_smoothl1 = False
|
14 |
+
_CN.critical_params = []
|
15 |
+
|
16 |
+
_CN.transformer = 'percostformer3'
|
17 |
+
|
18 |
+
### change the path here
|
19 |
+
_CN.model = "pretrained/sintel.pth"
|
20 |
+
|
21 |
+
_CN.percostformer3 = CN()
|
22 |
+
_CN.percostformer3.pe = 'linear'
|
23 |
+
_CN.percostformer3.dropout = 0.0
|
24 |
+
_CN.percostformer3.droppath = 0.0
|
25 |
+
_CN.percostformer3.encoder_latent_dim = 256 # in twins, this is 256
|
26 |
+
_CN.percostformer3.query_latent_dim = 64
|
27 |
+
_CN.percostformer3.cost_latent_input_dim = 64
|
28 |
+
_CN.percostformer3.cost_latent_token_num = 8
|
29 |
+
_CN.percostformer3.cost_latent_dim = 128
|
30 |
+
_CN.percostformer3.cost_heads_num = 1
|
31 |
+
# encoder
|
32 |
+
_CN.percostformer3.pretrain = True
|
33 |
+
_CN.percostformer3.use_convertor = False
|
34 |
+
_CN.percostformer3.del_layers = True
|
35 |
+
_CN.percostformer3.encoder_depth = 3
|
36 |
+
_CN.percostformer3.expand_factor = 4
|
37 |
+
_CN.percostformer3.vertical_encoder_attn = "twins"
|
38 |
+
_CN.percostformer3.attn_dim = 128
|
39 |
+
_CN.percostformer3.patch_size = 8
|
40 |
+
_CN.percostformer3.patch_embed = 'single'
|
41 |
+
_CN.percostformer3.cross_attn = "all"
|
42 |
+
_CN.percostformer3.gma = "GMA"
|
43 |
+
_CN.percostformer3.vert_c_dim = 64
|
44 |
+
_CN.percostformer3.cost_encoder_res = True
|
45 |
+
_CN.percostformer3.cnet = 'twins'
|
46 |
+
_CN.percostformer3.fnet = 'twins'
|
47 |
+
_CN.percostformer3.flow_or_pe = "and"
|
48 |
+
_CN.percostformer3.use_patch = False # use cost patch rather than local cost as query
|
49 |
+
_CN.percostformer3.use_rpe = False
|
50 |
+
_CN.percostformer3.detach_local = False
|
51 |
+
_CN.percostformer3.no_sc = False
|
52 |
+
_CN.percostformer3.r_16 =-1
|
53 |
+
_CN.percostformer3.quater_refine = False
|
54 |
+
# pretrain config
|
55 |
+
_CN.percostformer3.pretrain_mode = False
|
56 |
+
_CN.percostformer3.pic_size = [368, 496, 368, 496]
|
57 |
+
_CN.percostformer3.mask_ratio = 0.5
|
58 |
+
_CN.percostformer3.query_num = 30
|
59 |
+
_CN.percostformer3.no_border = True
|
60 |
+
_CN.percostformer3.gt_r = 15
|
61 |
+
_CN.percostformer3.fix_pe = False
|
62 |
+
# decoder
|
63 |
+
_CN.percostformer3.decoder_depth = 12
|
64 |
+
_CN.percostformer3.critical_params = ['vert_c_dim', 'encoder_depth', 'vertical_encoder_attn', "use_patch", "flow_or_pe", "use_rpe", "dropout", "detach_local", "expand_factor"]
|
65 |
+
|
66 |
+
|
67 |
+
### TRAINER
|
68 |
+
_CN.trainer = CN()
|
69 |
+
_CN.trainer.scheduler = 'OneCycleLR'
|
70 |
+
_CN.trainer.optimizer = 'adamw'
|
71 |
+
_CN.trainer.canonical_lr = 12.5e-5
|
72 |
+
_CN.trainer.adamw_decay = 1e-5
|
73 |
+
_CN.trainer.clip = 1.0
|
74 |
+
_CN.trainer.num_steps = 120000
|
75 |
+
_CN.trainer.epsilon = 1e-8
|
76 |
+
_CN.trainer.anneal_strategy = 'linear'
|
77 |
+
def get_cfg():
|
78 |
+
return _CN.clone()
|
config/train_image2video.yaml
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Model Setting
|
3 |
+
pretrained_model_name_or_path: stabilityai/stable-video-diffusion-img2vid # -xt is for 25 frames version
|
4 |
+
load_unet_path: # This is usally used to load pretrained UNet; e.g., you may want to start one of your checkpoints trained before
|
5 |
+
video_seq_length: 14 # Standardized to 14
|
6 |
+
process_fps: 7
|
7 |
+
train_noise_aug_strength: 0.1
|
8 |
+
scheduler: EDM
|
9 |
+
conditioning_dropout_prob: 0.1
|
10 |
+
|
11 |
+
|
12 |
+
# Dataset Setting
|
13 |
+
dataset_name: Bridge # WebVid / Bridge
|
14 |
+
dataset_path: [../sanity_check/bridge_v1_raw, ../sanity_check/bridge_v2_raw]
|
15 |
+
output_dir: checkpoints/img2video
|
16 |
+
height: 256 # Ratio that is functional: 256:384 576:1024 320:512 320:576
|
17 |
+
width: 384 # It is said that the height and width should be a scale of 64
|
18 |
+
dataloader_num_workers: 4 # Don't set this too large; usually, Video diffusion are slow processing, so don't need that many workers to do early loading
|
19 |
+
flip_aug_prob: 0.45 # Whether we flip the GT and cond vertically
|
20 |
+
acceleration_tolerance: 4 # Recommened setting
|
21 |
+
|
22 |
+
|
23 |
+
# Text setting
|
24 |
+
use_text: True # If this is True, we will use text value
|
25 |
+
pretrained_tokenizer_name_or_path: stabilityai/stable-diffusion-2-1-base # Use SD 2.1
|
26 |
+
empty_prompts_proportion: 0.0 # Useless now, we already have CFG in training
|
27 |
+
mix_ambiguous: False # Whether we mix ambiguous prompt for "this" and "that"
|
28 |
+
|
29 |
+
|
30 |
+
# Motion setting Useless right now...
|
31 |
+
motion_bucket_id: 200 # Set it for exact value; If this is none, we will use below setting
|
32 |
+
dataset_motion_mean: 35.3 # For 14 fps, it is N(35.3, 18.5)
|
33 |
+
dataset_motion_std: 18.5 # For 25 fps, it is N(?, ?)
|
34 |
+
svd_motion_mean: 165
|
35 |
+
svd_motion_std: 22.5
|
36 |
+
|
37 |
+
|
38 |
+
# Training setting
|
39 |
+
resume_from_checkpoint: False # latest/False
|
40 |
+
num_train_iters: 100000 # Will automatically choose the checkpoints at 99K
|
41 |
+
partial_finetune: False # Whether we just tune some params to speed up
|
42 |
+
train_batch_size: 1 # This is the batch size per GPU
|
43 |
+
checkpointing_steps: 3000
|
44 |
+
validation_step: 300
|
45 |
+
logging_name: logging
|
46 |
+
seed: 42
|
47 |
+
validation_img_folder: # Prepare your own validation dataset
|
48 |
+
validation_store_folder: validation_results
|
49 |
+
checkpoints_total_limit: 15
|
50 |
+
|
51 |
+
# Noise Strength
|
52 |
+
noise_mean: 0.5 # Regular Img2Video: (0.7, 1.6); Text2Video: (0.5, 1.4)
|
53 |
+
noise_std: 1.4
|
54 |
+
|
55 |
+
|
56 |
+
# Inference
|
57 |
+
num_inference_steps: 25
|
58 |
+
inference_noise_aug_strength: 0.1
|
59 |
+
inference_max_guidance_scale: 3.0 # Take training and testing at different scenario
|
60 |
+
|
61 |
+
|
62 |
+
# Learning Rate and Optimizer
|
63 |
+
learning_rate: 1e-5 # Usually this is ok
|
64 |
+
scale_lr: False # TODO: Is it needed to scale the learning rate?
|
65 |
+
adam_beta1: 0.9
|
66 |
+
adam_beta2: 0.999
|
67 |
+
use_8bit_adam: True # Need this to save more memory
|
68 |
+
adam_weight_decay: 1e-2
|
69 |
+
adam_epsilon: 1e-08
|
70 |
+
lr_warmup_steps: 500
|
71 |
+
lr_decay_scale: 0.5
|
72 |
+
|
73 |
+
|
74 |
+
# Other Setting
|
75 |
+
mixed_precision: fp16
|
76 |
+
gradient_accumulation_steps: 1
|
77 |
+
gradient_checkpointing: 1
|
78 |
+
report_to: tensorboard
|
config/train_image2video_controlnet.yaml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Model Setting
|
3 |
+
pretrained_model_name_or_path: stabilityai/stable-video-diffusion-img2vid # stabilityai/pretrained
|
4 |
+
load_unet_path: ../saved_weights/v4_VL_paper/checkpoint-99000 # None/specific path This is for pretrained-UNet path
|
5 |
+
load_controlnet_path: # None/specific path For checkpoint loaded from pretrained-Controlnet Path
|
6 |
+
video_seq_length: 14
|
7 |
+
process_fps: 7
|
8 |
+
train_noise_aug_strength: 0.1
|
9 |
+
scheduler: EDM
|
10 |
+
conditioning_dropout_prob: 0.1
|
11 |
+
|
12 |
+
|
13 |
+
# Dataset Setting
|
14 |
+
data_loader_type: thisthat # thisthat
|
15 |
+
dataset_name: Bridge # Bridge
|
16 |
+
dataset_path: [../sanity_check/bridge_v1_TT14, ../sanity_check/bridge_v2_TT14] # ../Bridge_filter_flow, ../Bridge_v2_filter_flow/]
|
17 |
+
output_dir: checkpoints/img2video
|
18 |
+
height: 256 # Ratio that is functional: 256:384 576:1024 320:448 320:576 512:640 448:640
|
19 |
+
width: 384 # It is said that the height and width should be a scale of 64
|
20 |
+
dataloader_num_workers: 4 # For Debug, it only needs 1
|
21 |
+
flip_aug_prob: 0.45 # Whether we flip the GT and cond vertically
|
22 |
+
# No acceleration_tolerance, since TT dataset already filter those out
|
23 |
+
|
24 |
+
|
25 |
+
# Text setting
|
26 |
+
use_text: True # If this is True, we will use text value
|
27 |
+
pretrained_tokenizer_name_or_path: stabilityai/stable-diffusion-2-1-base # Use SD 2.1
|
28 |
+
empty_prompts_proportion: 0.0
|
29 |
+
mix_ambiguous: False # Whether we mix ambiguous prompt for "this" and "that"
|
30 |
+
|
31 |
+
|
32 |
+
# Mask setting
|
33 |
+
mask_unet_vae: False # Whether we use mask to map latents to be zero padding
|
34 |
+
mask_controlnet_vae: False
|
35 |
+
mask_proportion: 0.0
|
36 |
+
|
37 |
+
|
38 |
+
# Condition Setting
|
39 |
+
conditioning_channels: 3 # Usually it is 3
|
40 |
+
num_points_left: # 1 # For flow: You can only choose one between flow_select_rate and num_points_left; num_points_left should be higher priority
|
41 |
+
flow_select_rate: 0.99 # For flow
|
42 |
+
threshold_factor: 0.2 # For flow
|
43 |
+
dilate: True # Traj must be True for dilate
|
44 |
+
inner_conditioning_scale: 1.0 # Conditioning scale for the internal value, defauly is starting from 1.0
|
45 |
+
outer_conditioning_scale: 1.0 # Outer Conditioning Scale for whole conditioning trainable copy 这里有点意思,直接不小心设定成2.0了
|
46 |
+
|
47 |
+
|
48 |
+
# Motion setting
|
49 |
+
motion_bucket_id: 200
|
50 |
+
dataset_motion_mean: 25 # For 14 fps, it is N(25, 10)
|
51 |
+
dataset_motion_std: 10 # For 25 fps, it is N(18, 7)
|
52 |
+
svd_motion_mean: 180
|
53 |
+
svd_motion_std: 30
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
# Training setting
|
58 |
+
resume_from_checkpoint: False # latest/False
|
59 |
+
num_train_iters: 30100 # Will automatically choose the checkpoints
|
60 |
+
partial_finetune: False # Whether we just tune some params to speed up
|
61 |
+
train_batch_size: 1 # This is the batch size per GPU
|
62 |
+
checkpointing_steps: 3000
|
63 |
+
validation_step: 300
|
64 |
+
logging_name: logging
|
65 |
+
seed: 42
|
66 |
+
validation_img_folder: datasets/validation_TT14
|
67 |
+
validation_store_folder: validation_videos
|
68 |
+
checkpoints_total_limit: 15
|
69 |
+
|
70 |
+
|
71 |
+
# Noise Strength
|
72 |
+
noise_mean: 0.5 # Regular Img2Video: (0.7, 1.6); Text2Video: (0.5, 1.4)
|
73 |
+
noise_std: 1.4
|
74 |
+
|
75 |
+
|
76 |
+
# Inference
|
77 |
+
num_inference_steps: 25
|
78 |
+
use_instructpix2pix: False # Whether we will use the instructPix2Pix mode, which involves 3 inputs; it may needs tuning to have better result at the end.
|
79 |
+
inference_noise_aug_strength: 0.1
|
80 |
+
inference_max_guidance_scale: 3.0 # Take training and testing at different scenario
|
81 |
+
inference_guess_mode: False # Whether we use guess mode in the contorlnet
|
82 |
+
image_guidance_scale: 2.5 # Empirically, 2.5 is the best value Seems not using this now
|
83 |
+
|
84 |
+
|
85 |
+
# Learning Rate and Optimizer
|
86 |
+
learning_rate: 5e-6 # 5e-6 is the LR we test that is just right
|
87 |
+
scale_lr: False # TODO: Is it needed to scale the learning rate?
|
88 |
+
adam_beta1: 0.9
|
89 |
+
adam_beta2: 0.999
|
90 |
+
use_8bit_adam: True # Need this to save more memory
|
91 |
+
adam_weight_decay: 1e-2
|
92 |
+
adam_epsilon: 1e-08
|
93 |
+
lr_warmup_steps: 500
|
94 |
+
lr_decay_scale: 0.5
|
95 |
+
|
96 |
+
|
97 |
+
# Other Setting
|
98 |
+
mixed_precision: fp16
|
99 |
+
gradient_accumulation_steps: 1 # ????
|
100 |
+
gradient_checkpointing: 1 # ????
|
101 |
+
report_to: tensorboard
|
curation_pipeline/add_lang_info.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Add the processed lang information
|
3 |
+
'''
|
4 |
+
import os, sys, shutil
|
5 |
+
import json
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
|
10 |
+
# Main config file path information
|
11 |
+
processed_json_file_path = "updated_bridge_v2.json"
|
12 |
+
|
13 |
+
|
14 |
+
# Read the json file
|
15 |
+
file = open(processed_json_file_path)
|
16 |
+
data = json.load(file)
|
17 |
+
|
18 |
+
|
19 |
+
# Iterate all the folders inside
|
20 |
+
start_idx = 0
|
21 |
+
for seq_instance in data:
|
22 |
+
target_path = seq_instance["images0"]
|
23 |
+
print("We are processing ", target_path)
|
24 |
+
|
25 |
+
processed_lang_txt_path = os.path.join(target_path, "processed_lang.txt")
|
26 |
+
if os.path.exists(processed_lang_txt_path):
|
27 |
+
os.remove(processed_lang_txt_path)
|
28 |
+
|
29 |
+
# Write the action + This + That into the sequence.
|
30 |
+
processed_lang_txt = open(processed_lang_txt_path, "a")
|
31 |
+
processed_lang_txt.write(str(seq_instance["action"])+"\n")
|
32 |
+
processed_lang_txt.write(str(seq_instance["this"])+"\n")
|
33 |
+
processed_lang_txt.write(str(seq_instance["that"])+"\n")
|
34 |
+
|
35 |
+
|
36 |
+
start_idx += 1
|
37 |
+
|
38 |
+
print("We have ", start_idx)
|
curation_pipeline/match_dataset_v1.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This file is to match the selected frames with the bridge dataset
|
3 |
+
We need to use some tricks to select the item
|
4 |
+
'''
|
5 |
+
import os, sys, shutil
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def compare_img(imageA, imageB):
|
13 |
+
# the 'Mean Squared Error' between the two images is the
|
14 |
+
# sum of the squared difference between the two images;
|
15 |
+
# NOTE: the two images must have the same dimension
|
16 |
+
err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
|
17 |
+
err /= float(imageA.shape[0] * imageA.shape[1])
|
18 |
+
|
19 |
+
# return the MSE, the lower the error, the more "similar"
|
20 |
+
# the two images are
|
21 |
+
return err
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def search_path(dataset_path, target_path, store_txt_path):
|
26 |
+
|
27 |
+
# We only needs to care about Bridge v1 dataset area
|
28 |
+
target_img_path = os.path.join(target_path, "im_0.jpg")
|
29 |
+
target_img = cv2.imread(target_img_path)
|
30 |
+
|
31 |
+
# Iterate all the folders inside
|
32 |
+
for scene_name in sorted(os.listdir(dataset_path)):
|
33 |
+
# print("We are reading scene", scene_name)
|
34 |
+
scene_dir = os.path.join(dataset_path, scene_name)
|
35 |
+
|
36 |
+
for task_name in os.listdir(scene_dir):
|
37 |
+
task_dir = os.path.join(scene_dir, task_name)
|
38 |
+
|
39 |
+
for time_clock in os.listdir(task_dir):
|
40 |
+
if time_clock == "lmdb":
|
41 |
+
continue # Skip lmdb folder
|
42 |
+
|
43 |
+
time_dir = os.path.join(task_dir, time_clock, "raw", "traj_group0")
|
44 |
+
if not os.path.exists(time_dir):
|
45 |
+
continue
|
46 |
+
|
47 |
+
for traj_name in os.listdir(time_dir):
|
48 |
+
traj_path = os.path.join(time_dir, traj_name)
|
49 |
+
if not os.path.isdir(traj_path):
|
50 |
+
continue
|
51 |
+
|
52 |
+
# Directly move policy_out_file_path; just in case there is also valuable information there
|
53 |
+
policy_out_file_path = os.path.join(traj_path, "policy_out.pkl")
|
54 |
+
if not os.path.exists(policy_out_file_path):
|
55 |
+
continue
|
56 |
+
|
57 |
+
# Check the lang txt file
|
58 |
+
lang_txt_file_path = os.path.join(traj_path, "lang.txt")
|
59 |
+
if not os.path.exists(lang_txt_file_path):
|
60 |
+
continue
|
61 |
+
|
62 |
+
|
63 |
+
# Last thing to locate to the right path
|
64 |
+
for img_name in os.listdir(traj_path):
|
65 |
+
if img_name != "images0": # Only consider one camera angle
|
66 |
+
continue
|
67 |
+
|
68 |
+
img_folder_path = os.path.join(traj_path, img_name)
|
69 |
+
if not os.path.isdir(img_folder_path):
|
70 |
+
continue
|
71 |
+
|
72 |
+
|
73 |
+
# Compare two image
|
74 |
+
img_path = os.path.join(img_folder_path, "im_0.jpg")
|
75 |
+
# print("img_folder_path is ", img_path)
|
76 |
+
compare_sample_img = cv2.imread(img_path)
|
77 |
+
error = compare_img(target_img, compare_sample_img)
|
78 |
+
|
79 |
+
if error == 0:
|
80 |
+
# Continue to all the rest for at least 5 images
|
81 |
+
status = True
|
82 |
+
for idx in range (10):
|
83 |
+
idx_img_path = os.path.join(img_folder_path, "im_"+str(idx)+".jpg")
|
84 |
+
idx_target_img_path = os.path.join(target_path, "im_"+str(idx)+".jpg")
|
85 |
+
idx_compare_sample_img = cv2.imread(idx_img_path)
|
86 |
+
idx_target_img = cv2.imread(idx_target_img_path)
|
87 |
+
error = compare_img(idx_target_img, idx_compare_sample_img)
|
88 |
+
|
89 |
+
if error != 0:
|
90 |
+
status = False
|
91 |
+
break
|
92 |
+
|
93 |
+
if status:
|
94 |
+
print("We found one at ", img_path)
|
95 |
+
f = open(store_txt_path, "a")
|
96 |
+
f.write(target_path + " " + img_folder_path + "\n")
|
97 |
+
return True
|
98 |
+
|
99 |
+
return False
|
100 |
+
|
101 |
+
|
102 |
+
if __name__ == "__main__":
|
103 |
+
input_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/datasets_rob/Bridge_v1_test_raw"
|
104 |
+
dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v1/berkeley" # 直接从本地新unzip的获取,怕之前的被xuweiyi改动过
|
105 |
+
store_txt_path = "match_info.txt"
|
106 |
+
|
107 |
+
if os.path.exists(store_txt_path):
|
108 |
+
os.remove(store_txt_path)
|
109 |
+
|
110 |
+
for img_name in sorted(os.listdir(input_path)):
|
111 |
+
target_path = os.path.join(input_path, img_name)
|
112 |
+
print("We are finding for ", target_path)
|
113 |
+
|
114 |
+
status = search_path(dataset_path, target_path, store_txt_path)
|
115 |
+
|
116 |
+
if not status:
|
117 |
+
print("we cannot find one")
|
curation_pipeline/match_dataset_v2.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This file is to match the selected frames with the bridge dataset
|
3 |
+
We need to use some tricks to select the item
|
4 |
+
'''
|
5 |
+
import os, sys, shutil
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
def compare_img(imageA, imageB):
|
13 |
+
# the 'Mean Squared Error' between the two images is the
|
14 |
+
# sum of the squared difference between the two images;
|
15 |
+
# NOTE: the two images must have the same dimension
|
16 |
+
err = np.sum((imageA.astype("float") - imageB.astype("float")) ** 2)
|
17 |
+
err /= float(imageA.shape[0] * imageA.shape[1])
|
18 |
+
|
19 |
+
# return the MSE, the lower the error, the more "similar"
|
20 |
+
# the two images are
|
21 |
+
return err
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
def search_path(dataset_path, target_path, store_txt_path):
|
26 |
+
|
27 |
+
# We only needs to care about Bridge v1 dataset area
|
28 |
+
target_img_path = os.path.join(target_path, "im_0.jpg")
|
29 |
+
if not os.path.exists(target_img_path):
|
30 |
+
print("The image we read is False")
|
31 |
+
return False
|
32 |
+
target_img = cv2.imread(target_img_path)
|
33 |
+
|
34 |
+
# Iterate all the folders inside
|
35 |
+
for scene_name in sorted(os.listdir(dataset_path)):
|
36 |
+
scene_dir = os.path.join(dataset_path, scene_name)
|
37 |
+
|
38 |
+
for task_name in sorted(os.listdir(scene_dir)):
|
39 |
+
task_dir = os.path.join(scene_dir, task_name)
|
40 |
+
|
41 |
+
for order_name in sorted(os.listdir(task_dir)):
|
42 |
+
order_dir = os.path.join(task_dir, order_name)
|
43 |
+
|
44 |
+
for time_clock in sorted(os.listdir(order_dir)):
|
45 |
+
if time_clock == "lmdb":
|
46 |
+
continue # Skip lmdb folder
|
47 |
+
|
48 |
+
time_dir = os.path.join(order_dir, time_clock, "raw", "traj_group0")
|
49 |
+
if not os.path.exists(time_dir):
|
50 |
+
continue
|
51 |
+
|
52 |
+
for traj_name in sorted(os.listdir(time_dir)):
|
53 |
+
traj_path = os.path.join(time_dir, traj_name)
|
54 |
+
if not os.path.isdir(traj_path):
|
55 |
+
continue
|
56 |
+
|
57 |
+
# Directly move policy_out_file_path; just in case there is also valuable information there
|
58 |
+
policy_out_file_path = os.path.join(traj_path, "policy_out.pkl")
|
59 |
+
if not os.path.exists(policy_out_file_path):
|
60 |
+
continue
|
61 |
+
|
62 |
+
# Check the lang txt file
|
63 |
+
lang_txt_file_path = os.path.join(traj_path, "lang.txt")
|
64 |
+
if not os.path.exists(lang_txt_file_path):
|
65 |
+
continue
|
66 |
+
|
67 |
+
|
68 |
+
for img_name in sorted(os.listdir(traj_path)):
|
69 |
+
if img_name != "images0": # Only consider one camera angle
|
70 |
+
continue
|
71 |
+
|
72 |
+
img_folder_path = os.path.join(traj_path, img_name)
|
73 |
+
if not os.path.isdir(img_folder_path):
|
74 |
+
continue
|
75 |
+
|
76 |
+
|
77 |
+
# Compare two image
|
78 |
+
img_path = os.path.join(img_folder_path, "im_0.jpg")
|
79 |
+
if not os.path.exists(img_path):
|
80 |
+
print(img_folder_path + " doesn't even have im_0.jpg")
|
81 |
+
continue
|
82 |
+
# print("img_folder_path is ", img_path)
|
83 |
+
compare_sample_img = cv2.imread(img_path)
|
84 |
+
# try:
|
85 |
+
# compare_sample_img.shape
|
86 |
+
# except Exception:
|
87 |
+
# print("The compare_sample_img cannot be red")
|
88 |
+
# continue
|
89 |
+
error = compare_img(target_img, compare_sample_img)
|
90 |
+
|
91 |
+
if error == 0:
|
92 |
+
# Continue to all the rest for at least 5 images
|
93 |
+
status = True
|
94 |
+
for idx in range (10):
|
95 |
+
idx_img_path = os.path.join(img_folder_path, "im_"+str(idx)+".jpg")
|
96 |
+
idx_target_img_path = os.path.join(target_path, "im_"+str(idx)+".jpg")
|
97 |
+
if not os.path.exists(idx_img_path):
|
98 |
+
print("The idx_img_path long idx we see only at ", idx)
|
99 |
+
continue
|
100 |
+
if not os.path.exists(idx_target_img_path):
|
101 |
+
print("The idx_target_img_path long idx we see only at ", idx)
|
102 |
+
continue
|
103 |
+
idx_compare_sample_img = cv2.imread(idx_img_path)
|
104 |
+
idx_target_img = cv2.imread(idx_target_img_path)
|
105 |
+
error = compare_img(idx_target_img, idx_compare_sample_img)
|
106 |
+
|
107 |
+
if error != 0:
|
108 |
+
status = False
|
109 |
+
break
|
110 |
+
|
111 |
+
if status:
|
112 |
+
print("We found one at ", img_path)
|
113 |
+
f = open(store_txt_path, "a")
|
114 |
+
f.write(target_path + " " + img_folder_path + "\n")
|
115 |
+
return True
|
116 |
+
|
117 |
+
return False
|
118 |
+
|
119 |
+
|
120 |
+
if __name__ == "__main__":
|
121 |
+
input_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/datasets_rob/Bridge_v2_test_raw"
|
122 |
+
dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v2" # 直接从本地新unzip的获取,怕之前的被xuweiyi改动过
|
123 |
+
store_txt_path = "match_info_v2_p1.txt"
|
124 |
+
start_idx = 0
|
125 |
+
end_idx = 500
|
126 |
+
|
127 |
+
if os.path.exists(store_txt_path):
|
128 |
+
os.remove(store_txt_path)
|
129 |
+
|
130 |
+
for img_name in sorted(os.listdir(input_path))[start_idx:end_idx]:
|
131 |
+
target_path = os.path.join(input_path, img_name)
|
132 |
+
print("We are finding for ", target_path)
|
133 |
+
|
134 |
+
status = search_path(dataset_path, target_path, store_txt_path)
|
135 |
+
|
136 |
+
if not status:
|
137 |
+
print("we cannot find one")
|
curation_pipeline/prepare_bridge_csv.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This file is to prepare the dataset in csv file following the format required by Opne-SORA
|
3 |
+
'''
|
4 |
+
|
5 |
+
import os, sys, shutil
|
6 |
+
import json
|
7 |
+
import csv
|
8 |
+
|
9 |
+
# Import files from the local folder
|
10 |
+
root_path = os.path.abspath('.')
|
11 |
+
sys.path.append(root_path)
|
12 |
+
# from curation_pipeline.prepare_bridge_v1 import read_bridge_v1
|
13 |
+
# from curation_pipeline.prepare_bridge_v2 import read_bridge_v2
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def iter_dataset(dataset_path):
|
18 |
+
lists = []
|
19 |
+
for sub_folder_name in os.listdir(dataset_path):
|
20 |
+
sub_folder_path = os.path.join(dataset_path, sub_folder_name)
|
21 |
+
|
22 |
+
# Check number of frames
|
23 |
+
max_length = len(os.listdir(sub_folder_path))
|
24 |
+
for check_idx in range(max_length):
|
25 |
+
if not os.path.exists(os.path.join(sub_folder_path, 'im_' + str(check_idx) + '.jpg')): # Should be sequentially exists
|
26 |
+
break
|
27 |
+
num_frames = check_idx
|
28 |
+
|
29 |
+
# Read the text
|
30 |
+
txt_path = os.path.join(sub_folder_path, "lang.txt")
|
31 |
+
f = open(txt_path, "r")
|
32 |
+
lang_prompt = f.readline()
|
33 |
+
|
34 |
+
lists.append([sub_folder_path, lang_prompt, num_frames, 480, 640])
|
35 |
+
# break
|
36 |
+
return lists
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
v1_dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/sanity_check/bridge_v1_raw"
|
42 |
+
v2_dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/sanity_check/bridge_v2_raw"
|
43 |
+
store_name = "Bridge_raw.csv"
|
44 |
+
|
45 |
+
if os.path.exists(store_name):
|
46 |
+
os.remove(store_name)
|
47 |
+
|
48 |
+
|
49 |
+
# Execute
|
50 |
+
full_lists = [["path", "text", "num_frames", "height", "width"]]
|
51 |
+
|
52 |
+
v1_lists = iter_dataset(v1_dataset_path)
|
53 |
+
full_lists.extend(v1_lists)
|
54 |
+
v2_lists = iter_dataset(v2_dataset_path)
|
55 |
+
full_lists.extend(v2_lists)
|
56 |
+
print("Full length is ", len(full_lists))
|
57 |
+
|
58 |
+
|
59 |
+
# Store as csv file
|
60 |
+
with open(store_name, 'w') as outfile:
|
61 |
+
write = csv.writer(outfile)
|
62 |
+
write.writerows(full_lists)
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
# with open('output.jsonl', 'w') as outfile:
|
67 |
+
# for entry in JSON_file:
|
68 |
+
# json.dump(entry, outfile)
|
69 |
+
# outfile.write('\n')
|
curation_pipeline/prepare_bridge_jsonl.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This file is to prepare the dataset in jsonl file
|
3 |
+
'''
|
4 |
+
|
5 |
+
import os, sys, shutil
|
6 |
+
import json
|
7 |
+
|
8 |
+
# Import files from the local folder
|
9 |
+
root_path = os.path.abspath('.')
|
10 |
+
sys.path.append(root_path)
|
11 |
+
from curation_pipeline.prepare_bridge_v1 import read_bridge_v1
|
12 |
+
from curation_pipeline.prepare_bridge_v2 import read_bridge_v2
|
13 |
+
|
14 |
+
|
15 |
+
if __name__ == "__main__":
|
16 |
+
v1_dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v1/berkeley"
|
17 |
+
v2_dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v2"
|
18 |
+
store_name = "store.jsonl"
|
19 |
+
|
20 |
+
if os.path.exists(store_name):
|
21 |
+
os.remove(store_name)
|
22 |
+
|
23 |
+
|
24 |
+
# Execute
|
25 |
+
full_lists = []
|
26 |
+
|
27 |
+
v1_lists = read_bridge_v1(v1_dataset_path, "", copyfile=False)
|
28 |
+
full_lists.extend(v1_lists)
|
29 |
+
v2_lists = read_bridge_v2(v2_dataset_path, "", copyfile=False)
|
30 |
+
full_lists.extend(v2_lists)
|
31 |
+
print("Full length is ", len(full_lists))
|
32 |
+
|
33 |
+
|
34 |
+
with open(store_name, 'w') as outfile:
|
35 |
+
for list_name in full_lists:
|
36 |
+
instance = dict()
|
37 |
+
instance["file_path"] = list_name
|
38 |
+
|
39 |
+
json.dump(instance, outfile)
|
40 |
+
outfile.write('\n')
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
# with open('output.jsonl', 'w') as outfile:
|
45 |
+
# for entry in JSON_file:
|
46 |
+
# json.dump(entry, outfile)
|
47 |
+
# outfile.write('\n')
|
curation_pipeline/prepare_bridge_v1.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This repository is used to prepare Bridge dataset
|
3 |
+
'''
|
4 |
+
import os, sys, shutil
|
5 |
+
|
6 |
+
|
7 |
+
def read_bridge_v1(dataset_path, train_store_path, test_store_path, test_dataset_lists, copyfile=True):
|
8 |
+
# copyfile is True when we need to copy the file to the target destination
|
9 |
+
|
10 |
+
start_idx = 0
|
11 |
+
target_lists = []
|
12 |
+
prefix_len = len(dataset_path) + 1
|
13 |
+
|
14 |
+
# Iterate all the folders inside
|
15 |
+
for scene_name in sorted(os.listdir(dataset_path)):
|
16 |
+
print("We are reading scene ", scene_name)
|
17 |
+
scene_dir = os.path.join(dataset_path, scene_name)
|
18 |
+
for task_name in sorted(os.listdir(scene_dir)):
|
19 |
+
task_dir = os.path.join(scene_dir, task_name)
|
20 |
+
|
21 |
+
for time_clock in sorted(os.listdir(task_dir)):
|
22 |
+
if time_clock == "lmdb":
|
23 |
+
continue # Skip lmdb folder
|
24 |
+
|
25 |
+
time_dir = os.path.join(task_dir, time_clock, "raw", "traj_group0")
|
26 |
+
if not os.path.exists(time_dir):
|
27 |
+
continue
|
28 |
+
|
29 |
+
for traj_name in sorted(os.listdir(time_dir)):
|
30 |
+
traj_path = os.path.join(time_dir, traj_name)
|
31 |
+
if not os.path.isdir(traj_path):
|
32 |
+
continue
|
33 |
+
|
34 |
+
# Directly move policy_out_file_path; just in case there is also valuable information there
|
35 |
+
policy_out_file_path = os.path.join(traj_path, "policy_out.pkl")
|
36 |
+
if not os.path.exists(policy_out_file_path):
|
37 |
+
continue
|
38 |
+
|
39 |
+
# Check the lang txt file
|
40 |
+
lang_txt_file_path = os.path.join(traj_path, "lang.txt")
|
41 |
+
if not os.path.exists(lang_txt_file_path):
|
42 |
+
continue
|
43 |
+
|
44 |
+
|
45 |
+
for img_name in sorted(os.listdir(traj_path)):
|
46 |
+
if img_name != "images0": # Only consider one camera angle
|
47 |
+
continue
|
48 |
+
|
49 |
+
img_folder_path = os.path.join(traj_path, img_name)
|
50 |
+
if not os.path.isdir(img_folder_path):
|
51 |
+
continue
|
52 |
+
|
53 |
+
############################################ Main Process ####################################################
|
54 |
+
|
55 |
+
# # First Sanity check (Make sure the input source is jpg good)
|
56 |
+
# length = len(os.listdir(img_folder_path))
|
57 |
+
# status = True
|
58 |
+
# for check_idx in range(length):
|
59 |
+
# if not os.path.exists(os.path.join(img_folder_path, 'im_' + str(check_idx) + '.jpg')): # Should be sequentially exists
|
60 |
+
# status = False
|
61 |
+
# break
|
62 |
+
|
63 |
+
# Now we can copy the folder to our destination
|
64 |
+
target_lists.append(img_folder_path)
|
65 |
+
if copyfile:
|
66 |
+
print("img_folder_path[prefix_len:] is ", img_folder_path[prefix_len:])
|
67 |
+
if img_folder_path[prefix_len:] in test_dataset_lists:
|
68 |
+
# Store to test set
|
69 |
+
target_dir = os.path.join(test_store_path, str(start_idx))
|
70 |
+
else:
|
71 |
+
# This is training set
|
72 |
+
target_dir = os.path.join(train_store_path, str(start_idx))
|
73 |
+
|
74 |
+
print("Copy " + str(img_folder_path) + " to " + str(target_dir))
|
75 |
+
shutil.copytree(img_folder_path, target_dir)
|
76 |
+
|
77 |
+
|
78 |
+
# Sanity check
|
79 |
+
length = len(os.listdir(target_dir))
|
80 |
+
status = True
|
81 |
+
for check_idx in range(length):
|
82 |
+
if not os.path.exists(os.path.join(target_dir, 'im_' + str(check_idx) + '.jpg')): # Should be sequentially exists
|
83 |
+
status = False
|
84 |
+
break
|
85 |
+
|
86 |
+
if not status:
|
87 |
+
# If they didn't have sequential files we need, we will remove and begin again without updating start_idx
|
88 |
+
print("This file cannot pass the sanity check. We will remove it!")
|
89 |
+
shutil.rmtree(target_dir)
|
90 |
+
continue
|
91 |
+
|
92 |
+
# Move other auxiliary files
|
93 |
+
shutil.copy(policy_out_file_path, os.path.join(target_dir, "policy_out.pkl"))
|
94 |
+
shutil.copy(lang_txt_file_path, os.path.join(target_dir, "lang.txt"))
|
95 |
+
|
96 |
+
################################################################################################################
|
97 |
+
|
98 |
+
# Update the idx
|
99 |
+
start_idx += 1
|
100 |
+
|
101 |
+
print("We have ", start_idx, " number of cases")
|
102 |
+
|
103 |
+
# Return a list of file path
|
104 |
+
return target_lists
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
dataset_path = "/Path/to/Bridge/raw/bridge_data_v1/berkeley" # Until Bridge v1 - berkeley section
|
110 |
+
train_store_path = "/Path/to/Bridge/train/bridge_v1_raw"
|
111 |
+
test_store_path = "/Path/to/Bridge/train/bridge_v1_test_raw"
|
112 |
+
test_dataset_predefined_path = "test_path.txt" # This will be providede by us
|
113 |
+
|
114 |
+
|
115 |
+
# Make dir if needed
|
116 |
+
if os.path.exists(train_store_path):
|
117 |
+
shutil.rmtree(train_store_path)
|
118 |
+
os.makedirs(train_store_path)
|
119 |
+
if os.path.exists(test_store_path):
|
120 |
+
shutil.rmtree(test_store_path)
|
121 |
+
os.makedirs(test_store_path)
|
122 |
+
|
123 |
+
|
124 |
+
# Read Test dataset path
|
125 |
+
test_dataset_lists = []
|
126 |
+
read_file = open(test_dataset_predefined_path, "r")
|
127 |
+
for line in read_file.readlines():
|
128 |
+
test_dataset_lists.append(line[:-1])
|
129 |
+
print("test_dataset_lists is ", test_dataset_lists)
|
130 |
+
|
131 |
+
|
132 |
+
read_bridge_v1(dataset_path, train_store_path, test_store_path, test_dataset_lists)
|
curation_pipeline/prepare_bridge_v2.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This repository is used to prepare Bridge dataset
|
3 |
+
'''
|
4 |
+
import os, sys, shutil
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
def read_bridge_v2(dataset_path, train_store_path, test_store_path, test_dataset_lists, copyfile=True):
|
9 |
+
# copyfile is True most of the time
|
10 |
+
|
11 |
+
start_idx = 0
|
12 |
+
target_lists = []
|
13 |
+
prefix_len = len(dataset_path) + 1
|
14 |
+
|
15 |
+
# Iterate all the folders inside
|
16 |
+
for scene_name in sorted(os.listdir(dataset_path)):
|
17 |
+
print("We are reading scene ", scene_name)
|
18 |
+
scene_dir = os.path.join(dataset_path, scene_name)
|
19 |
+
|
20 |
+
for task_name in sorted(os.listdir(scene_dir)):
|
21 |
+
task_dir = os.path.join(scene_dir, task_name)
|
22 |
+
|
23 |
+
for order_name in sorted(os.listdir(task_dir)):
|
24 |
+
order_dir = os.path.join(task_dir, order_name)
|
25 |
+
|
26 |
+
for time_clock in sorted(os.listdir(order_dir)):
|
27 |
+
if time_clock == "lmdb":
|
28 |
+
continue # Skip lmdb folder
|
29 |
+
|
30 |
+
time_dir = os.path.join(order_dir, time_clock, "raw", "traj_group0")
|
31 |
+
if not os.path.exists(time_dir):
|
32 |
+
print("time_dir does not exist for ", time_dir)
|
33 |
+
continue
|
34 |
+
|
35 |
+
for traj_name in sorted(os.listdir(time_dir)):
|
36 |
+
traj_path = os.path.join(time_dir, traj_name)
|
37 |
+
if not os.path.isdir(traj_path):
|
38 |
+
print("traj_path does not exist for ", traj_path)
|
39 |
+
continue
|
40 |
+
|
41 |
+
# Directly move policy_out_file_path; just in case there is also valuable information there
|
42 |
+
policy_out_file_path = os.path.join(traj_path, "policy_out.pkl")
|
43 |
+
if not os.path.exists(policy_out_file_path):
|
44 |
+
continue
|
45 |
+
|
46 |
+
# Check the lang txt file
|
47 |
+
lang_txt_file_path = os.path.join(traj_path, "lang.txt")
|
48 |
+
if not os.path.exists(lang_txt_file_path):
|
49 |
+
continue
|
50 |
+
|
51 |
+
|
52 |
+
for img_name in sorted(os.listdir(traj_path)):
|
53 |
+
if img_name != "images0": # Only consider one camera angle
|
54 |
+
continue
|
55 |
+
|
56 |
+
img_folder_path = os.path.join(traj_path, img_name)
|
57 |
+
if not os.path.isdir(img_folder_path):
|
58 |
+
print("img_folder_path does not exist for ", img_folder_path)
|
59 |
+
continue
|
60 |
+
|
61 |
+
############################################ Main Process ####################################################
|
62 |
+
|
63 |
+
# # First Sanity check (Make sure the input source is jpg good)
|
64 |
+
# length = len(os.listdir(img_folder_path))
|
65 |
+
# status = True
|
66 |
+
# for check_idx in range(length):
|
67 |
+
# if not os.path.exists(os.path.join(img_folder_path, 'im_' + str(check_idx) + '.jpg')): # Should be sequentially exists
|
68 |
+
# status = False
|
69 |
+
# break
|
70 |
+
|
71 |
+
# Now we can copy the folder to our destination
|
72 |
+
target_lists.append(img_folder_path)
|
73 |
+
if copyfile:
|
74 |
+
print("img_folder_path[prefix_len:] is ", img_folder_path[prefix_len:])
|
75 |
+
if img_folder_path[prefix_len:] in test_dataset_lists:
|
76 |
+
# Store to test set
|
77 |
+
target_dir = os.path.join(test_store_path, str(start_idx))
|
78 |
+
else:
|
79 |
+
# This is training set
|
80 |
+
target_dir = os.path.join(train_store_path, str(start_idx))
|
81 |
+
|
82 |
+
# Now we can copy the folder to our destination
|
83 |
+
print("Copy " + str(img_folder_path) + " to " + str(os.path.join(train_store_path, str(start_idx))))
|
84 |
+
shutil.copytree(img_folder_path, target_dir)
|
85 |
+
|
86 |
+
# Sanity check
|
87 |
+
length = len(os.listdir(target_dir))
|
88 |
+
status = True
|
89 |
+
for check_idx in range(length):
|
90 |
+
if not os.path.exists(os.path.join(target_dir, 'im_' + str(check_idx) + '.jpg' )): # Should be sequentially exists
|
91 |
+
status = False
|
92 |
+
break
|
93 |
+
|
94 |
+
if not status:
|
95 |
+
# If they didn't have sequential files we need, we will remove and begin again without updating start_idx
|
96 |
+
print("This file cannot pass the sanity check. We will remove it!")
|
97 |
+
shutil.rmtree(target_dir)
|
98 |
+
continue
|
99 |
+
|
100 |
+
# Move other auxilary files
|
101 |
+
shutil.copy(policy_out_file_path, os.path.join(target_dir, "policy_out.pkl"))
|
102 |
+
shutil.copy(lang_txt_file_path, os.path.join(target_dir, "lang.txt"))
|
103 |
+
|
104 |
+
# Update the idx
|
105 |
+
start_idx += 1
|
106 |
+
|
107 |
+
print("We have ", start_idx)
|
108 |
+
|
109 |
+
# Return a list of file path
|
110 |
+
return target_lists
|
111 |
+
|
112 |
+
|
113 |
+
|
114 |
+
if __name__ == "__main__":
|
115 |
+
dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v2"
|
116 |
+
train_store_path = "../sanity_check/bridge_v2_raw"
|
117 |
+
test_store_path = "../sanity_check/bridge_v2_test_raw"
|
118 |
+
test_dataset_predefined_path = "test_path_v2.txt"
|
119 |
+
|
120 |
+
|
121 |
+
# Make dir if needed
|
122 |
+
if os.path.exists(train_store_path):
|
123 |
+
shutil.rmtree(train_store_path)
|
124 |
+
os.makedirs(train_store_path)
|
125 |
+
if os.path.exists(test_store_path):
|
126 |
+
shutil.rmtree(test_store_path)
|
127 |
+
os.makedirs(test_store_path)
|
128 |
+
|
129 |
+
# Read Test dataset path
|
130 |
+
test_dataset_lists = []
|
131 |
+
read_file = open(test_dataset_predefined_path, "r")
|
132 |
+
for line in read_file.readlines():
|
133 |
+
test_dataset_lists.append(line[:-1])
|
134 |
+
print("test_dataset_lists is ", test_dataset_lists)
|
135 |
+
|
136 |
+
|
137 |
+
read_bridge_v2(dataset_path, train_store_path, test_store_path, test_dataset_lists)
|
138 |
+
|
139 |
+
|
curation_pipeline/select_frame_with_this_that.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This repository is used to prepare Bridge dataset with this that conditioning
|
3 |
+
'''
|
4 |
+
import os, sys, shutil
|
5 |
+
import pickle
|
6 |
+
from ultralytics import YOLO
|
7 |
+
from PIL import Image, ImageDraw
|
8 |
+
import numpy as np
|
9 |
+
import cv2
|
10 |
+
import math
|
11 |
+
import collections
|
12 |
+
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
13 |
+
|
14 |
+
|
15 |
+
def show_mask(mask, random_color=False):
|
16 |
+
if random_color:
|
17 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
18 |
+
else:
|
19 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
20 |
+
h, w = mask.shape[-2:]
|
21 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
22 |
+
|
23 |
+
return mask_image * 255
|
24 |
+
|
25 |
+
|
26 |
+
def read_center_point(model, img_path, do_visualization, store_path):
|
27 |
+
|
28 |
+
action_img = Image.open(img_path)
|
29 |
+
prediction = model.predict(source=action_img, save=False)[0] # Only 1 frame
|
30 |
+
|
31 |
+
if not hasattr(prediction, "boxes"):
|
32 |
+
print("Detection Fail: We cannot have boxes attribute")
|
33 |
+
return None, None # -1 means NAN and pass this case
|
34 |
+
|
35 |
+
# save at the temp_places for visualizaiton
|
36 |
+
if do_visualization:
|
37 |
+
prediction.save(filename=store_path)
|
38 |
+
|
39 |
+
|
40 |
+
bounding_boxes = prediction.boxes.xywh
|
41 |
+
num, dim = bounding_boxes.shape
|
42 |
+
assert(dim == 4)
|
43 |
+
|
44 |
+
# Catch up all center point of all bounding boxes
|
45 |
+
edge_point_cord = []
|
46 |
+
center_points = []
|
47 |
+
for idx in range(num):
|
48 |
+
x, y, w, h = bounding_boxes[idx].detach().cpu().numpy()
|
49 |
+
center_point = [x, y] # TODO: y+(h/4) 根据经验,往下飘逸25%的高度,一般来说比较有帮助
|
50 |
+
|
51 |
+
edge_point_cord.extend([ (x+w//2, y+h//2), (x-w//2, y+h//2), (x-w//2, y-h//2), (x+w//2, y-h//2) ])
|
52 |
+
|
53 |
+
|
54 |
+
if w <= 15 or h <= 15: # If a bounding box is too small, we will disregard this case
|
55 |
+
return None, None
|
56 |
+
|
57 |
+
# Calculate the distance between current one and previous points for sanity check
|
58 |
+
for point in center_points: # Check all previous points
|
59 |
+
give_up_threshold = 90
|
60 |
+
if center_point[0] - point[0] >= give_up_threshold:
|
61 |
+
print("Two points are too far away and neglect the case")
|
62 |
+
return None, None
|
63 |
+
if center_point[1] - point[1] >= give_up_threshold:
|
64 |
+
print("Two points are too far away and neglect the case")
|
65 |
+
return None, None
|
66 |
+
|
67 |
+
# Append to the list
|
68 |
+
center_points.append(center_point)
|
69 |
+
|
70 |
+
|
71 |
+
if len(center_points) == 0 or len(center_points) > 2:
|
72 |
+
print("Detection Fail: We cannot detect bounding boxes")
|
73 |
+
return None, None
|
74 |
+
|
75 |
+
# Calculating the average distance among center_points
|
76 |
+
if len(center_points) == 2:
|
77 |
+
first_box, second_box = center_points
|
78 |
+
|
79 |
+
center_x = (first_box[0] + second_box[0]) / 2
|
80 |
+
center_y = (first_box[1] + second_box[1]) / 2
|
81 |
+
|
82 |
+
distance = math.sqrt(abs(first_box[0] - second_box[0])**2 + abs(first_box[1] - second_box[1])**2)
|
83 |
+
|
84 |
+
return [center_x, center_y, distance], edge_point_cord
|
85 |
+
|
86 |
+
return [*center_points[0], 100], edge_point_cord # if len(center_points) == 1, distance is 0; however, to avoid 2-1-2 box detection in sequential, we set it as a higher value
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
def detect_gripper(gripper_detection_model, input_dir, action_start, action_end, do_visualization, store_dir, sample_failure_collect_folder=None):
|
91 |
+
|
92 |
+
# 先处理第一个point的(这个比较重要,所以要重复3次);然后再快速处理最后一个point
|
93 |
+
|
94 |
+
# Process the first action frame by iterating next three frames and choose the closest one
|
95 |
+
first_center_points = []
|
96 |
+
edge_point_cords = []
|
97 |
+
for idx in range(3): # Repeat 3 times
|
98 |
+
action_start_path = os.path.join(input_dir, "im_"+str(action_start + idx)+".jpg")
|
99 |
+
first_center_point, edge_point_cord = read_center_point(gripper_detection_model, action_start_path, do_visualization, os.path.join(store_dir, "contact_first"+str(idx)+".jpg")) # The first frame
|
100 |
+
|
101 |
+
if idx == 0 and first_center_point is None:
|
102 |
+
message = "Cannot find the first contact point!"
|
103 |
+
|
104 |
+
print("The contact point we cannot detect is at ", action_start_path)
|
105 |
+
if sample_failure_collect_folder != "":
|
106 |
+
shutil.copyfile(action_start_path, os.path.join(sample_failure_collect_folder, str(len(os.listdir(sample_failure_collect_folder)))+".jpg") )
|
107 |
+
|
108 |
+
return (None, None, message)
|
109 |
+
|
110 |
+
if first_center_point is not None:
|
111 |
+
first_center_points.append([action_start + idx, first_center_point])
|
112 |
+
|
113 |
+
# Add edge points
|
114 |
+
print(edge_point_cord)
|
115 |
+
edge_point_cords.extend(edge_point_cord) # 我有点担心所有point就这么extend会对一些的edge case不是那么robust
|
116 |
+
|
117 |
+
|
118 |
+
# Select the closest point between two
|
119 |
+
first_center_points.sort(key=lambda x: x[1][2])
|
120 |
+
first_center_point = first_center_points[0][1][:2]
|
121 |
+
start_idx = first_center_points[0][0]
|
122 |
+
print("first_center_point is " + str(first_center_point) + " with idx " + str(start_idx))
|
123 |
+
order_idx = [start_idx, action_end]
|
124 |
+
|
125 |
+
|
126 |
+
# Find the xmin, ymin, xmax, ymax for based all three points as the bounding box for the SAM
|
127 |
+
edge_point_cords.sort(key=lambda x: x[0])
|
128 |
+
xmin = int(edge_point_cords[0][0])
|
129 |
+
xmax = int(edge_point_cords[-1][0])
|
130 |
+
|
131 |
+
edge_point_cords.sort(key=lambda x: x[1])
|
132 |
+
ymin = int(edge_point_cords[0][1])
|
133 |
+
ymax = int(edge_point_cords[-1][1])
|
134 |
+
|
135 |
+
bbox_info = (xmin, xmax, ymin, ymax)
|
136 |
+
|
137 |
+
|
138 |
+
# Process the last action frame
|
139 |
+
action_end_path = os.path.join(input_dir, "im_"+str(action_end)+".jpg")
|
140 |
+
last_center_point, edge_point_cord = read_center_point(gripper_detection_model, action_end_path, do_visualization, os.path.join(store_dir, "contact_last.jpg")) # The last frame
|
141 |
+
if last_center_point is None:
|
142 |
+
message = "Cannot find the last contact point!"
|
143 |
+
|
144 |
+
print("The contact point we cannot detect is at ", action_start_path)
|
145 |
+
if sample_failure_collect_folder != "":
|
146 |
+
store_name = str(len(os.listdir(sample_failure_collect_folder))) + ".jpg"
|
147 |
+
shutil.copyfile(action_start_path, os.path.join(sample_failure_collect_folder, store_name) )
|
148 |
+
|
149 |
+
return (None, bbox_info, message)
|
150 |
+
last_center_point = last_center_point[:2]
|
151 |
+
|
152 |
+
|
153 |
+
# Check if two center points is too close, if they are too close, we will merge to one point
|
154 |
+
merge_threshold = 30
|
155 |
+
if math.sqrt((first_center_point[0] - last_center_point[0])**2 + (first_center_point[1] - last_center_point[1])**2) <= merge_threshold:
|
156 |
+
print("Merge two points to one!")
|
157 |
+
message = "Success!"
|
158 |
+
return ([[first_center_point], order_idx], bbox_info, message)
|
159 |
+
|
160 |
+
|
161 |
+
# Return needed information
|
162 |
+
message = "Success!"
|
163 |
+
return ([[first_center_point, last_center_point], order_idx], bbox_info, message)
|
164 |
+
|
165 |
+
|
166 |
+
|
167 |
+
|
168 |
+
def visualize_this_that(base_img, bbox_info, this_that_points):
|
169 |
+
|
170 |
+
# Draw a green dot only for the start point
|
171 |
+
for point in this_that_points:
|
172 |
+
print("point is ", point)
|
173 |
+
target_horizontal, target_vertical = point
|
174 |
+
target_horizontal, target_vertical = int(target_horizontal), int(target_vertical)
|
175 |
+
|
176 |
+
dot_range = 3
|
177 |
+
for i in range(-1*dot_range, dot_range+1):
|
178 |
+
for j in range(-1*dot_range, dot_range+1):
|
179 |
+
dil_vertical, dil_horizontal = target_vertical + i, target_horizontal + j
|
180 |
+
if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]):
|
181 |
+
base_img[dil_vertical, dil_horizontal, :] = [0, 128, 0]
|
182 |
+
# else:
|
183 |
+
# # print("The traj is out of boundary!!!!!!!!!!!!!!!!!!!!! and we won't consider it") # 现在
|
184 |
+
# return (False, base_img)
|
185 |
+
|
186 |
+
# Draw the bounding box
|
187 |
+
xmin, xmax, ymin, ymax = bbox_info
|
188 |
+
base_img = cv2.rectangle(base_img, (xmin, ymin), (xmax, ymax), color=(0,0,255), thickness=2)
|
189 |
+
|
190 |
+
return (True, base_img)
|
191 |
+
|
192 |
+
|
193 |
+
|
194 |
+
def manage_seq_range(input_dir, store_dir, sample_failure_collect_folder, total_frames_needed,
|
195 |
+
max_original_input_tolerate, gripper_detection_model, sam_predictor, do_visualization):
|
196 |
+
|
197 |
+
# Find valid image lists
|
198 |
+
num_frames_input = 0
|
199 |
+
for file_name in os.listdir(input_dir):
|
200 |
+
if file_name.startswith("im_"):
|
201 |
+
num_frames_input += 1
|
202 |
+
for idx in range(num_frames_input):
|
203 |
+
target_path = os.path.join(input_dir, "im_"+str(idx)+".jpg")
|
204 |
+
if not os.path.exists(target_path):
|
205 |
+
print("We don't have ", target_path)
|
206 |
+
message = "Invalid error" # Make sure that every file in this order is existed, this is quite important
|
207 |
+
return (False, message)
|
208 |
+
|
209 |
+
|
210 |
+
if num_frames_input > max_original_input_tolerate:
|
211 |
+
message = "The number of frames is too long for constructing the sequence length needed"
|
212 |
+
return (False, message)
|
213 |
+
|
214 |
+
if num_frames_input < total_frames_needed:
|
215 |
+
message = "The number of frames is too short for constructing the sequence length needed"
|
216 |
+
return (False, message)
|
217 |
+
|
218 |
+
|
219 |
+
|
220 |
+
# Prepare this and that based on policy_out.pkl
|
221 |
+
policy_out_file_path = os.path.join(input_dir, "policy_out.pkl")
|
222 |
+
with open(policy_out_file_path, "rb") as f:
|
223 |
+
policy = pickle.load(f)
|
224 |
+
|
225 |
+
actions_codes = []
|
226 |
+
action_start, action_end = None, None
|
227 |
+
for idx, item in enumerate(policy):
|
228 |
+
action_value = item["actions"][-1]
|
229 |
+
if action_start is None and action_value == 0.0:
|
230 |
+
action_start = idx
|
231 |
+
|
232 |
+
if (action_start is not None) and (action_end is None) and (action_value == 1.0):
|
233 |
+
action_end = idx # Until record the first 1.0 exists after the first 0.0 appears
|
234 |
+
actions_codes.append(action_value)
|
235 |
+
|
236 |
+
if action_start is None or action_end is None:
|
237 |
+
message = "We cannot read an action_start or action_end code!"
|
238 |
+
return (False, message) # Requires to have both start and end actions (Usually, they are a pair)
|
239 |
+
|
240 |
+
print("actions_codes is ", actions_codes)
|
241 |
+
print("the start end idx we read is ", action_start, action_end)
|
242 |
+
|
243 |
+
|
244 |
+
# Detect the gripper (should return a list with exactly two x,y coordinate points)
|
245 |
+
detection_retrun_info, bbox_info, detect_message = detect_gripper(
|
246 |
+
gripper_detection_model,
|
247 |
+
input_dir,
|
248 |
+
action_start,
|
249 |
+
action_end,
|
250 |
+
do_visualization = do_visualization,
|
251 |
+
store_dir = store_dir,
|
252 |
+
sample_failure_collect_folder = sample_failure_collect_folder,
|
253 |
+
)
|
254 |
+
if detection_retrun_info is None:
|
255 |
+
return (False, detect_message)
|
256 |
+
|
257 |
+
detected_point, old_seq_idx = detection_retrun_info
|
258 |
+
print("detected_point is ", detected_point)
|
259 |
+
|
260 |
+
|
261 |
+
# Visualize if needed
|
262 |
+
base_img = cv2.imread(os.path.join(input_dir, "im_0.jpg"))
|
263 |
+
if do_visualization:
|
264 |
+
status, visual_img = visualize_this_that(base_img, bbox_info, detected_point)
|
265 |
+
if status:
|
266 |
+
cv2.imwrite(os.path.join(store_dir, "visualization.png"), visual_img)
|
267 |
+
|
268 |
+
|
269 |
+
|
270 |
+
# SAM process based on bbox_info
|
271 |
+
xmin, xmax, ymin, ymax = bbox_info
|
272 |
+
sam_predictor.set_image(np.uint8(base_img))
|
273 |
+
positive_point_cords = np.array([[ int(detected_point[0][0]), int(detected_point[0][1]) ]])
|
274 |
+
positive_point_cords = np.array(positive_point_cords)
|
275 |
+
positive_point_labels = np.ones(len(positive_point_cords))
|
276 |
+
|
277 |
+
# Predict the mask based on the point and bounding box designed
|
278 |
+
masks, scores, logits = sam_predictor.predict(
|
279 |
+
point_coords = positive_point_cords,
|
280 |
+
point_labels = positive_point_labels,
|
281 |
+
box = np.array([xmin, ymin, xmax, ymax])[None, :],
|
282 |
+
multimask_output = False,
|
283 |
+
)
|
284 |
+
print(scores)
|
285 |
+
for mask_idx, mask in enumerate(masks):
|
286 |
+
mask_img = show_mask(mask)
|
287 |
+
cv2.imwrite(os.path.join(store_dir, "mask_" + str(mask_idx) + ".png"), mask_img)
|
288 |
+
|
289 |
+
|
290 |
+
|
291 |
+
################################ Move the img ######################################
|
292 |
+
# Calculate needed parameters
|
293 |
+
division_factor = num_frames_input // total_frames_needed
|
294 |
+
remain_frames = (num_frames_input % total_frames_needed) - 1 # -1 for adaptation
|
295 |
+
|
296 |
+
# Define the gap
|
297 |
+
gaps = [division_factor for _ in range(total_frames_needed-1)]
|
298 |
+
for idx in range(remain_frames):
|
299 |
+
if idx % 2 == 0:
|
300 |
+
gaps[idx//2] += 1 # Start to end order
|
301 |
+
else:
|
302 |
+
gaps[-1*(1+(idx//2))] += 1 # End to start order
|
303 |
+
|
304 |
+
# Map the gap to the specific orders
|
305 |
+
idx_orders = [1] # 从1还是shift一下问题应该不大
|
306 |
+
for global_idx, gap in enumerate(gaps):
|
307 |
+
idx_orders.append(idx_orders[-1] + gap)
|
308 |
+
if idx_orders[-1] >= num_frames_input:
|
309 |
+
message = "Invalid error"
|
310 |
+
return (False, message)
|
311 |
+
# assert(idx_orders[-1] < num_frames_input)
|
312 |
+
assert(len(idx_orders) == total_frames_needed)
|
313 |
+
|
314 |
+
|
315 |
+
# Copy the essential files first
|
316 |
+
for global_idx, cur_idx in enumerate(idx_orders):
|
317 |
+
source_path = os.path.join(input_dir, "im_"+str(cur_idx)+".jpg")
|
318 |
+
destination_path = os.path.join(store_dir, "im_"+str(global_idx)+".jpg")
|
319 |
+
|
320 |
+
if not os.path.exists(source_path): # Theoretically, source_path must exists
|
321 |
+
message = "We couldn't find the source path. Theoretically, source_path must exists!" # 有一种可能就是我们丢失了一些地方,在cp或者本来就没有,记得统计数量
|
322 |
+
return (False, message)
|
323 |
+
|
324 |
+
shutil.copyfile(source_path, destination_path)
|
325 |
+
|
326 |
+
# Map order_idx to the cropped version
|
327 |
+
mapped_seq_idx = []
|
328 |
+
for old_idx in old_seq_idx:
|
329 |
+
tmp = []
|
330 |
+
for tmp_idx, new_idx in enumerate(range(len(idx_orders))):
|
331 |
+
tmp.append((tmp_idx, abs(old_idx - idx_orders[new_idx])))
|
332 |
+
# Sort the smallest fistance
|
333 |
+
tmp.sort(key=lambda x: x[1])
|
334 |
+
mapped_seq_idx.append(tmp[0][0])
|
335 |
+
|
336 |
+
print("Before the idx is ", old_seq_idx)
|
337 |
+
print("mapped idx is ", mapped_seq_idx)
|
338 |
+
|
339 |
+
|
340 |
+
# Write the information to new destination
|
341 |
+
f = open(os.path.join(store_dir, "data.txt"), "a")
|
342 |
+
f.write(str(mapped_seq_idx[0]) + " " + str(detected_point[0][0]) + " " + str(detected_point[0][1]) + "\n")
|
343 |
+
if len(detected_point) == 2: # Two points excluding the last idx
|
344 |
+
f.write(str(mapped_seq_idx[1]) + " " + str(detected_point[1][0]) + " " + str(detected_point[1][1]) + "\n")
|
345 |
+
f.close()
|
346 |
+
|
347 |
+
|
348 |
+
# Move lang.txt file
|
349 |
+
shutil.copyfile(os.path.join(input_dir, 'lang.txt'), os.path.join(store_dir, 'lang.txt'))
|
350 |
+
|
351 |
+
|
352 |
+
message = "Success!"
|
353 |
+
return (True, message)
|
354 |
+
|
355 |
+
|
356 |
+
|
357 |
+
|
358 |
+
if __name__ == "__main__":
|
359 |
+
|
360 |
+
# General storage setting
|
361 |
+
dataset_path = "../datasets_rob/Bridge_v2_raw"
|
362 |
+
destination_path = "../sanity_check/bridge_v2_TT14_longer_tolerance"
|
363 |
+
sample_failure_collect_folder = "" # This is to collect cases that fail for active learning
|
364 |
+
|
365 |
+
total_frames_needed = 14
|
366 |
+
max_original_input_tolerate = 56 # 40 for 14 fps; 60 for 25fps;
|
367 |
+
do_visualization = True
|
368 |
+
|
369 |
+
|
370 |
+
# YOLO model init
|
371 |
+
yolo_pretarined_path = "pretrained/yolov8n_best.pt"
|
372 |
+
gripper_detection_model = YOLO("yolov8n.yaml") # build a new model from scratch
|
373 |
+
gripper_detection_model = YOLO(yolo_pretarined_path) # load a pretrained model (recommended for training)
|
374 |
+
|
375 |
+
# SAM model init
|
376 |
+
model_type = "vit_h"
|
377 |
+
sam_pretrained_path = "pretrained/sam_vit_h_4b8939.pth"
|
378 |
+
sam = sam_model_registry[model_type](checkpoint=sam_pretrained_path).to(device="cuda")
|
379 |
+
sam_predictor = SamPredictor(sam) # There is a lot of setting here
|
380 |
+
|
381 |
+
|
382 |
+
# Make dir if needed
|
383 |
+
if os.path.exists(destination_path):
|
384 |
+
shutil.rmtree(destination_path)
|
385 |
+
os.makedirs(destination_path)
|
386 |
+
|
387 |
+
# Prepare the folder to collect failure cases
|
388 |
+
if sample_failure_collect_folder != "":
|
389 |
+
if os.path.exists(sample_failure_collect_folder):
|
390 |
+
shutil.rmtree(sample_failure_collect_folder)
|
391 |
+
os.makedirs(sample_failure_collect_folder)
|
392 |
+
|
393 |
+
|
394 |
+
|
395 |
+
# Collect the message
|
396 |
+
message_dict = collections.defaultdict(int)
|
397 |
+
|
398 |
+
|
399 |
+
store_idx = 0
|
400 |
+
for folder_name in sorted(os.listdir(dataset_path)):
|
401 |
+
input_folder_path = os.path.join(dataset_path, folder_name)
|
402 |
+
store_folder_path = os.path.join(destination_path, "0"*(6-len(str(store_idx)))+str(store_idx))
|
403 |
+
print("We are processing ", input_folder_path)
|
404 |
+
|
405 |
+
# Prepare store_folder_path folder
|
406 |
+
os.makedirs(store_folder_path)
|
407 |
+
|
408 |
+
status, message = manage_seq_range(input_folder_path, store_folder_path, sample_failure_collect_folder, total_frames_needed, max_original_input_tolerate, gripper_detection_model, sam_predictor, do_visualization)
|
409 |
+
if status: # We will only update the store_idx only when this file is successfully written
|
410 |
+
store_idx += 1
|
411 |
+
else:
|
412 |
+
print("This status failed! Message: " + message)
|
413 |
+
shutil.rmtree(store_folder_path)
|
414 |
+
# break # For debug
|
415 |
+
|
416 |
+
# Collect the infor to dict
|
417 |
+
message_dict[message] += 1
|
418 |
+
|
419 |
+
print("We have " + str(store_idx) + " valid dataset")
|
420 |
+
print("message_dict info is ", message_dict)
|
421 |
+
|
curation_pipeline/tracking_by_keypoint.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, shutil, sys
|
2 |
+
import argparse
|
3 |
+
import gdown
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
import requests
|
9 |
+
import json
|
10 |
+
import torchvision
|
11 |
+
import torch
|
12 |
+
import psutil
|
13 |
+
import time
|
14 |
+
try:
|
15 |
+
from mmcv.cnn import ConvModule
|
16 |
+
except:
|
17 |
+
os.system("mim install mmcv")
|
18 |
+
|
19 |
+
|
20 |
+
# Import files from the local folder
|
21 |
+
root_path = os.path.abspath('.')
|
22 |
+
sys.path.append(root_path)
|
23 |
+
from track_anything_code.model import TrackingAnything
|
24 |
+
from track_anything_code.track_anything_module import get_frames_from_video, download_checkpoint, parse_augment, sam_refine, vos_tracking_video
|
25 |
+
from scripts.compress_videos import compress_video
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
if __name__ == "__main__":
|
31 |
+
dataset_path = "Bridge_v1_TT14"
|
32 |
+
video_name = "combined.mp4"
|
33 |
+
verbose = True # If this is verbose, you will continue to write the code
|
34 |
+
|
35 |
+
|
36 |
+
################################################## Model setup ####################################################
|
37 |
+
# check and download checkpoints if needed
|
38 |
+
sam_checkpoint = "sam_vit_h_4b8939.pth"
|
39 |
+
sam_checkpoint_url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
|
40 |
+
xmem_checkpoint = "XMem-s012.pth"
|
41 |
+
xmem_checkpoint_url = "https://github.com/hkchengrex/XMem/releases/download/v1.0/XMem-s012.pth"
|
42 |
+
|
43 |
+
|
44 |
+
folder ="./pretrained"
|
45 |
+
SAM_checkpoint = download_checkpoint(sam_checkpoint_url, folder, sam_checkpoint)
|
46 |
+
xmem_checkpoint = download_checkpoint(xmem_checkpoint_url, folder, xmem_checkpoint)
|
47 |
+
|
48 |
+
# argument
|
49 |
+
args = parse_augment()
|
50 |
+
args.device = "cuda" # Any GPU is ok
|
51 |
+
|
52 |
+
# Initialize the Track model
|
53 |
+
track_model = TrackingAnything(SAM_checkpoint, xmem_checkpoint, args)
|
54 |
+
###################################################################################################################
|
55 |
+
|
56 |
+
|
57 |
+
# Iterate all files under the folder
|
58 |
+
for sub_folder_name in sorted(os.listdir(dataset_path)):
|
59 |
+
|
60 |
+
################################################## Setting ####################################################
|
61 |
+
sub_folder_path = os.path.join(dataset_path, sub_folder_name)
|
62 |
+
|
63 |
+
click_state = [[],[]]
|
64 |
+
interactive_state = {
|
65 |
+
"inference_times": 0,
|
66 |
+
"negative_click_times" : 0,
|
67 |
+
"positive_click_times": 0,
|
68 |
+
"mask_save": args.mask_save,
|
69 |
+
"multi_mask": {
|
70 |
+
"mask_names": [],
|
71 |
+
"masks": []
|
72 |
+
},
|
73 |
+
"track_end_number": None,
|
74 |
+
"resize_ratio": 1
|
75 |
+
}
|
76 |
+
###################################################################################################################
|
77 |
+
|
78 |
+
|
79 |
+
video_path = os.path.join(sub_folder_path, video_name)
|
80 |
+
if not os.path.exists(video_path):
|
81 |
+
print("We cannot find the path of the ", video_path, " and we will compress one")
|
82 |
+
status = compress_video(sub_folder_path, video_name)
|
83 |
+
if not status:
|
84 |
+
print("We still cannot generate a video")
|
85 |
+
continue
|
86 |
+
|
87 |
+
# Read video state
|
88 |
+
video_state = {
|
89 |
+
"user_name": "",
|
90 |
+
"video_name": "",
|
91 |
+
"origin_images": None,
|
92 |
+
"painted_images": None,
|
93 |
+
"masks": None,
|
94 |
+
"inpaint_masks": None,
|
95 |
+
"logits": None,
|
96 |
+
"select_frame_number": 0,
|
97 |
+
"fps": 30
|
98 |
+
}
|
99 |
+
video_state, template_frame = get_frames_from_video(video_path, video_state, track_model)
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
########################################################## Get the sam point based on the data.txt ###########################################################
|
104 |
+
data_txt_path = os.path.join(sub_folder_path, "data.txt")
|
105 |
+
if not os.path.exists(data_txt_path):
|
106 |
+
print("We cannot find data.txt in this folder")
|
107 |
+
continue
|
108 |
+
|
109 |
+
data_file = open(data_txt_path, 'r')
|
110 |
+
lines = data_file.readlines()
|
111 |
+
frame_idx, horizontal, vertical = lines[0][:-2].split(' ') # Only read the first point
|
112 |
+
point_cord = [int(float(horizontal)), int(float(vertical))]
|
113 |
+
|
114 |
+
# Process by SAM
|
115 |
+
track_model.samcontroler.sam_controler.reset_image() # Reset the image to clean history
|
116 |
+
painted_image, video_state, interactive_state, operation_log = sam_refine(track_model, video_state, "Positive", click_state, interactive_state, point_cord)
|
117 |
+
################################################################################################################################################################
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
######################################################### Get the tracking output ########################################################################
|
122 |
+
|
123 |
+
# Track the video for processing
|
124 |
+
segment_output_path = os.path.join(sub_folder_path, "segment_output.gif")
|
125 |
+
video_state = vos_tracking_video(track_model, segment_output_path, video_state, interactive_state, mask_dropdown=[])[0] # mask_dropdown is empty now
|
126 |
+
|
127 |
+
# Extract the mask needed by us for further point calculating
|
128 |
+
masks = video_state["masks"] # In the range [0, 1]
|
129 |
+
|
130 |
+
if verbose:
|
131 |
+
for idx, mask in enumerate(masks):
|
132 |
+
cv2.imwrite(os.path.join(sub_folder_path, "mask"+str(idx)+".png"), mask*255)
|
133 |
+
|
134 |
+
##############################################################################################################################################################
|
135 |
+
|
136 |
+
|
data_loader/video_dataset.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import json
|
3 |
+
import cv2
|
4 |
+
import math
|
5 |
+
import shutil
|
6 |
+
import numpy as np
|
7 |
+
import random
|
8 |
+
import collections
|
9 |
+
from PIL import Image
|
10 |
+
import torch
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
|
13 |
+
# Import files from the local folder
|
14 |
+
root_path = os.path.abspath('.')
|
15 |
+
sys.path.append(root_path)
|
16 |
+
from utils.img_utils import resize_with_antialiasing, numpy_to_pt
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def get_video_frames(config, video_frame_path, flip = False):
|
21 |
+
|
22 |
+
video_seq_length = config["video_seq_length"]
|
23 |
+
|
24 |
+
# Calculate needed parameters
|
25 |
+
num_frames_input = 0
|
26 |
+
for file_name in os.listdir(video_frame_path):
|
27 |
+
if file_name.startswith("im_"):
|
28 |
+
num_frames_input += 1
|
29 |
+
total_frames_needed = video_seq_length
|
30 |
+
division_factor = num_frames_input // total_frames_needed
|
31 |
+
remain_frames = (num_frames_input % total_frames_needed) - 1 # -1 for adaptation
|
32 |
+
|
33 |
+
|
34 |
+
# Define the gap
|
35 |
+
gaps = [division_factor for _ in range(total_frames_needed-1)]
|
36 |
+
for idx in range(remain_frames):
|
37 |
+
if idx % 2 == 0:
|
38 |
+
gaps[idx//2] += 1 # Start to end order
|
39 |
+
else:
|
40 |
+
gaps[-1*(1+(idx//2))] += 1 # End to start order
|
41 |
+
|
42 |
+
|
43 |
+
# Find needed file
|
44 |
+
needed_img_path = []
|
45 |
+
cur_idx = 0
|
46 |
+
for gap in gaps:
|
47 |
+
img_path = os.path.join(video_frame_path, "im_" + str(cur_idx) + ".jpg")
|
48 |
+
needed_img_path.append(img_path)
|
49 |
+
|
50 |
+
# Update the idx
|
51 |
+
cur_idx += gap
|
52 |
+
# Append the last one
|
53 |
+
img_path = os.path.join(video_frame_path, "im_" + str(cur_idx) + ".jpg")
|
54 |
+
needed_img_path.append(img_path)
|
55 |
+
|
56 |
+
|
57 |
+
# Read all img_path based on the order
|
58 |
+
video_frames = []
|
59 |
+
for img_path in needed_img_path:
|
60 |
+
if not os.path.exists(img_path):
|
61 |
+
print("We don't have ", img_path)
|
62 |
+
frame = cv2.imread(img_path)
|
63 |
+
|
64 |
+
try:
|
65 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
66 |
+
except Exception:
|
67 |
+
print("The exception places is ", img_path)
|
68 |
+
|
69 |
+
# Resize frames
|
70 |
+
frame = cv2.resize(frame, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC)
|
71 |
+
|
72 |
+
# Flip aug
|
73 |
+
if flip:
|
74 |
+
frame = np.fliplr(frame)
|
75 |
+
|
76 |
+
# Collect frames
|
77 |
+
video_frames.append(np.expand_dims(frame, axis=0)) # The frame is already RGB, there is no need to convert here.
|
78 |
+
|
79 |
+
|
80 |
+
# Concatenate
|
81 |
+
video_frames = np.concatenate(video_frames, axis=0)
|
82 |
+
assert(len(video_frames) == video_seq_length)
|
83 |
+
|
84 |
+
return video_frames
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
def tokenize_captions(prompt, tokenizer, config, is_train=True):
|
89 |
+
'''
|
90 |
+
Tokenize text prompt be prepared tokenizer from SD2.1
|
91 |
+
'''
|
92 |
+
|
93 |
+
captions = []
|
94 |
+
if random.random() < config["empty_prompts_proportion"]:
|
95 |
+
captions.append("")
|
96 |
+
elif isinstance(prompt, str):
|
97 |
+
captions.append(prompt)
|
98 |
+
elif isinstance(prompt, (list, np.ndarray)):
|
99 |
+
# take a random caption if there are multiple
|
100 |
+
captions.append(random.choice(prompt) if is_train else prompt[0])
|
101 |
+
else:
|
102 |
+
raise ValueError(
|
103 |
+
f"Caption column should contain either strings or lists of strings."
|
104 |
+
)
|
105 |
+
|
106 |
+
inputs = tokenizer(
|
107 |
+
captions, max_length = tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
108 |
+
)
|
109 |
+
return inputs.input_ids[0]
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
class Video_Dataset(Dataset):
|
114 |
+
'''
|
115 |
+
Video Dataset to load sequential frames for training with needed pre-processing
|
116 |
+
'''
|
117 |
+
|
118 |
+
def __init__(self, config, device, normalize=True, tokenizer=None):
|
119 |
+
|
120 |
+
# Attribute variables
|
121 |
+
self.config = config
|
122 |
+
self.device = device
|
123 |
+
self.normalize = normalize
|
124 |
+
self.tokenizer = tokenizer
|
125 |
+
|
126 |
+
# Obtain values
|
127 |
+
self.video_seq_length = config["video_seq_length"]
|
128 |
+
self.height = config["height"]
|
129 |
+
self.width = config["width"]
|
130 |
+
|
131 |
+
# Process data
|
132 |
+
self.video_lists = []
|
133 |
+
stats_analysis = collections.defaultdict(int)
|
134 |
+
print("Process all files to check valid datasets....")
|
135 |
+
for dataset_path in config["dataset_path"]:
|
136 |
+
for video_name in sorted(os.listdir(dataset_path)):
|
137 |
+
video_path = os.path.join(dataset_path, video_name)
|
138 |
+
all_files = os.listdir(video_path)
|
139 |
+
|
140 |
+
|
141 |
+
valid = True
|
142 |
+
# Valid check 1: the number of files should be in sequential order
|
143 |
+
num_frames_input = 0
|
144 |
+
for file_name in os.listdir(video_path):
|
145 |
+
if file_name.startswith("im_"):
|
146 |
+
num_frames_input += 1
|
147 |
+
for idx in range(num_frames_input):
|
148 |
+
img_path = 'im_' + str(idx) + '.jpg'
|
149 |
+
if img_path not in all_files: # Should be sequential existing
|
150 |
+
valid = False
|
151 |
+
stats_analysis["incomplete_img"] += 1
|
152 |
+
break
|
153 |
+
|
154 |
+
|
155 |
+
# Valid check 1.5: the number of files must be longer than video_seq_length and less than self.config["acceleration_tolerance"]*self.config["video_seq_length"]
|
156 |
+
if num_frames_input < self.config["video_seq_length"]:
|
157 |
+
stats_analysis["too_little_frames"] += 1
|
158 |
+
valid = False
|
159 |
+
if num_frames_input > self.config["acceleration_tolerance"] * self.config["video_seq_length"]:
|
160 |
+
stats_analysis["too_many_frames"] += 1
|
161 |
+
valid = False
|
162 |
+
|
163 |
+
if not valid: # SpeedUp so set in the middle here
|
164 |
+
continue
|
165 |
+
|
166 |
+
|
167 |
+
# Valid check 2: language if needed
|
168 |
+
if config["use_text"] and not os.path.exists(os.path.join(dataset_path, video_name, "lang.txt")):
|
169 |
+
stats_analysis["no_lang_txt"] += 1
|
170 |
+
valid = False
|
171 |
+
|
172 |
+
|
173 |
+
# Valid check 3: motion if needed
|
174 |
+
if config["motion_bucket_id"] is None:
|
175 |
+
flow_path = os.path.join(dataset_path, video_name, "flow.txt")
|
176 |
+
if "flow.txt" not in all_files:
|
177 |
+
stats_analysis["no_flow_txt"] += 1
|
178 |
+
valid = False
|
179 |
+
else:
|
180 |
+
file = open(flow_path, 'r')
|
181 |
+
info = file.readlines()
|
182 |
+
if len(info) == 0:
|
183 |
+
stats_analysis["no_flow_txt"] += 1
|
184 |
+
valid = False
|
185 |
+
|
186 |
+
|
187 |
+
if valid:
|
188 |
+
self.video_lists.append(video_path)
|
189 |
+
print("stats_analysis is ", stats_analysis)
|
190 |
+
print("Valid dataset length is ", len(self.video_lists))
|
191 |
+
|
192 |
+
|
193 |
+
def __len__(self):
|
194 |
+
return len(self.video_lists)
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
def _get_motion_value(self, sub_folder_path):
|
199 |
+
''' Read the motion value from the flow.txt file prepared; preprocess the flow to accelerate
|
200 |
+
'''
|
201 |
+
|
202 |
+
# Read the flow.txt
|
203 |
+
flow_path = os.path.join(sub_folder_path, 'flow.txt')
|
204 |
+
file = open(flow_path, 'r')
|
205 |
+
info = file.readlines()
|
206 |
+
per_video_movement = float(info[0][:-2])
|
207 |
+
|
208 |
+
# Map the raw reflected_motion_bucket_id to target range based on the number of images have
|
209 |
+
num_frames_input = 0
|
210 |
+
for file_name in os.listdir(sub_folder_path): # num_frames_input is the total number of files with name begin with im_
|
211 |
+
if file_name.startswith("im_"):
|
212 |
+
num_frames_input += 1
|
213 |
+
|
214 |
+
# Correct the value based on the number of frames relative to video_seq_length
|
215 |
+
per_video_movement_correct = per_video_movement * (num_frames_input/self.config["video_seq_length"])
|
216 |
+
|
217 |
+
# Map from one Normal Distribution to another Normal Distribution
|
218 |
+
z = (per_video_movement_correct - self.config["dataset_motion_mean"]) / (self.config["dataset_motion_std"] + 0.001)
|
219 |
+
reflected_motion_bucket_id = int((z * self.config["svd_motion_std"]) + self.config["svd_motion_mean"])
|
220 |
+
|
221 |
+
|
222 |
+
print("We map " + str(per_video_movement) + " to " + str(per_video_movement_correct) + " by length " + str(num_frames_input) + " to bucket_id of " + str(reflected_motion_bucket_id))
|
223 |
+
return reflected_motion_bucket_id
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
def __getitem__(self, idx):
|
228 |
+
''' Get item by idx and pre-process by Resize and Normalize to [0, 1]
|
229 |
+
Args:
|
230 |
+
idx (int): The index to the file in the directory
|
231 |
+
Returns:
|
232 |
+
video_frames (torch.float32): The Pytorch tensor format of obtained frames (max: 1.0; min: 0.0)
|
233 |
+
reflected_motion_bucket_id (tensor): Motion value is there is optical flow provided, else they are fixed value from config
|
234 |
+
prompt (tensor): Tokenized text
|
235 |
+
'''
|
236 |
+
|
237 |
+
# Prepare the text if needed:
|
238 |
+
if self.config["use_text"]:
|
239 |
+
# Read the file
|
240 |
+
file_path = os.path.join(self.video_lists[idx], "lang.txt")
|
241 |
+
file = open(file_path, 'r')
|
242 |
+
prompt = file.readlines()[0] # Only read the first line
|
243 |
+
|
244 |
+
if self.config["mix_ambiguous"] and os.path.exists(os.path.join(self.video_lists[idx], "processed_text.txt")):
|
245 |
+
# If we don't have this txt file, we skip
|
246 |
+
|
247 |
+
######################################################## Mix up prompt ########################################################
|
248 |
+
|
249 |
+
# Read the file
|
250 |
+
file_path = os.path.join(self.video_lists[idx], "processed_text.txt")
|
251 |
+
file = open(file_path, 'r')
|
252 |
+
prompts = [line for line in file.readlines()] # Only read the first line
|
253 |
+
|
254 |
+
# Get the componenet
|
255 |
+
action = prompts[0][:-1]
|
256 |
+
this = prompts[1][:-1]
|
257 |
+
there = prompts[2][:-1]
|
258 |
+
|
259 |
+
|
260 |
+
random_value = random.random()
|
261 |
+
# If less than 0.4, we don't care, just use the most concrete one
|
262 |
+
if random_value >= 0.4 and random_value < 0.6:
|
263 |
+
# Mask pick object to "This"
|
264 |
+
prompt = action + " this to " + there
|
265 |
+
elif random_value >= 0.6 and random_value < 0.8:
|
266 |
+
# Mask place position to "There"
|
267 |
+
prompt = action + " " + this + " to there"
|
268 |
+
elif random_value >= 0.8 and random_value < 1.0:
|
269 |
+
# Just be like "this to there"
|
270 |
+
prompt = action + " this to there"
|
271 |
+
|
272 |
+
# print("New prompt is ", prompt)
|
273 |
+
###################################################################################################################################################
|
274 |
+
|
275 |
+
# else:
|
276 |
+
# print("We don't have llama processed prompt at ", self.video_lists[idx])
|
277 |
+
|
278 |
+
else:
|
279 |
+
prompt = ""
|
280 |
+
|
281 |
+
# Tokenize text prompt
|
282 |
+
tokenized_prompt = tokenize_captions(prompt, self.tokenizer, self.config)
|
283 |
+
|
284 |
+
|
285 |
+
# Dataset aug by chance (it is needed to check whether there is any object position words [left|right] in the prompt text)
|
286 |
+
flip = False
|
287 |
+
if random.random() < self.config["flip_aug_prob"]:
|
288 |
+
if self.config["use_text"]:
|
289 |
+
if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok)
|
290 |
+
flip = True
|
291 |
+
else:
|
292 |
+
flip = True
|
293 |
+
|
294 |
+
|
295 |
+
# Read frames for different datasets; Currently, we have WebVid / Bridge
|
296 |
+
if self.config["dataset_name"] == "Bridge":
|
297 |
+
video_frames = get_video_frames(self.config, self.video_lists[idx], flip=flip)
|
298 |
+
else:
|
299 |
+
raise NotImplementedError("We don't support this dataset loader")
|
300 |
+
|
301 |
+
|
302 |
+
# Scale [0, 255] -> [-1, 1]
|
303 |
+
if self.normalize:
|
304 |
+
video_frames = video_frames.astype(np.float32) / 127.5 - 1 # Be careful to cast to float32
|
305 |
+
|
306 |
+
# Transform to Pytorch Tensor in the range [-1, 1]
|
307 |
+
video_frames = numpy_to_pt(video_frames)
|
308 |
+
# print("length of input frames has ", len(video_frames))
|
309 |
+
|
310 |
+
|
311 |
+
# Get the motion value based on the optical flow
|
312 |
+
if self.config["motion_bucket_id"] is None:
|
313 |
+
reflected_motion_bucket_id = self._get_motion_value(self.video_lists[idx])
|
314 |
+
else:
|
315 |
+
reflected_motion_bucket_id = self.config["motion_bucket_id"]
|
316 |
+
|
317 |
+
|
318 |
+
# The tensor we returned is torch float32. We won't cast here for mixed precision training!
|
319 |
+
return {
|
320 |
+
"video_frames" : video_frames,
|
321 |
+
"reflected_motion_bucket_id" : reflected_motion_bucket_id,
|
322 |
+
"prompt": tokenized_prompt,
|
323 |
+
}
|
data_loader/video_this_that_dataset.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import json
|
3 |
+
import cv2
|
4 |
+
import math
|
5 |
+
import shutil
|
6 |
+
import numpy as np
|
7 |
+
import random
|
8 |
+
from PIL import Image
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import torch
|
11 |
+
import os.path as osp
|
12 |
+
import time
|
13 |
+
from moviepy.editor import VideoFileClip
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
|
16 |
+
# Import files from the local folder
|
17 |
+
root_path = os.path.abspath('.')
|
18 |
+
sys.path.append(root_path)
|
19 |
+
from utils.img_utils import resize_with_antialiasing, numpy_to_pt
|
20 |
+
from utils.optical_flow_utils import flow_to_image, filter_uv, bivariate_Gaussian
|
21 |
+
from data_loader.video_dataset import tokenize_captions
|
22 |
+
|
23 |
+
|
24 |
+
# For the 2D dilation
|
25 |
+
blur_kernel = bivariate_Gaussian(99, 10, 10, 0, grid = None, isotropic = True)
|
26 |
+
|
27 |
+
|
28 |
+
def get_thisthat_sam(config, intput_dir, store_dir = None, flip = False, verbose=False):
|
29 |
+
'''
|
30 |
+
Args:
|
31 |
+
idx (int): The index to the folder we need to process
|
32 |
+
'''
|
33 |
+
|
34 |
+
# Read file
|
35 |
+
file_path = os.path.join(intput_dir, "data.txt")
|
36 |
+
file1 = open(file_path, 'r')
|
37 |
+
Lines = file1.readlines()
|
38 |
+
|
39 |
+
|
40 |
+
# Initial the optical flow format we want
|
41 |
+
thisthat_condition = np.zeros((config["video_seq_length"], config["conditioning_channels"], config["height"], config["width"]), dtype=np.float32) # The last image should be empty
|
42 |
+
|
43 |
+
|
44 |
+
# Init the image
|
45 |
+
sample_img = cv2.imread(os.path.join(intput_dir, "im_0.jpg"))
|
46 |
+
org_height, org_width, _ = sample_img.shape
|
47 |
+
|
48 |
+
# Prepare masking
|
49 |
+
controlnet_image_index = []
|
50 |
+
coordinate_values = []
|
51 |
+
|
52 |
+
# Iterate all points in the txt file
|
53 |
+
for idx in range(len(Lines)):
|
54 |
+
|
55 |
+
# Read points
|
56 |
+
frame_idx, horizontal, vertical = Lines[idx].split(' ')
|
57 |
+
frame_idx, vertical, horizontal = int(frame_idx), int(float(vertical)), int(float(horizontal))
|
58 |
+
|
59 |
+
# Read the mask frame idx
|
60 |
+
controlnet_image_index.append(frame_idx)
|
61 |
+
coordinate_values.append((vertical, horizontal))
|
62 |
+
|
63 |
+
|
64 |
+
# Init the base image
|
65 |
+
base_img = np.zeros((org_height, org_width, 3)).astype(np.float32) # Use the original image size
|
66 |
+
base_img.fill(255)
|
67 |
+
|
68 |
+
# Draw square around the target position
|
69 |
+
dot_range = 10 # Diameter
|
70 |
+
for i in range(-1*dot_range, dot_range+1):
|
71 |
+
for j in range(-1*dot_range, dot_range+1):
|
72 |
+
dil_vertical, dil_horizontal = vertical + i, horizontal + j
|
73 |
+
if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]):
|
74 |
+
if idx == 0:
|
75 |
+
base_img[dil_vertical][dil_horizontal] = [0, 0, 255] # The first point should be red
|
76 |
+
else:
|
77 |
+
base_img[dil_vertical][dil_horizontal] = [0, 255, 0] # The second point should be green to distinguish the first point
|
78 |
+
|
79 |
+
# Dilate
|
80 |
+
if config["dilate"]:
|
81 |
+
base_img = cv2.filter2D(base_img, -1, blur_kernel)
|
82 |
+
|
83 |
+
|
84 |
+
##############################################################################################################################
|
85 |
+
### The core pipeline of processing is: Dilate -> Resize -> Range Shift -> Transpose Shape -> Store
|
86 |
+
|
87 |
+
# Resize frames Don't use negative and don't resize in [0,1]
|
88 |
+
base_img = cv2.resize(base_img, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC)
|
89 |
+
|
90 |
+
|
91 |
+
# Flip the image for aug if needed
|
92 |
+
if flip:
|
93 |
+
base_img = np.fliplr(base_img)
|
94 |
+
|
95 |
+
|
96 |
+
# Channel Transform and Range Shift
|
97 |
+
if config["conditioning_channels"] == 3:
|
98 |
+
# Map to [0, 1] range
|
99 |
+
if store_dir is not None and verbose: # For the first frame condition visualization
|
100 |
+
cv2.imwrite(os.path.join(store_dir, "condition_TT"+str(idx)+".png"), base_img)
|
101 |
+
base_img = base_img / 255.0
|
102 |
+
|
103 |
+
else:
|
104 |
+
raise NotImplementedError()
|
105 |
+
|
106 |
+
|
107 |
+
# ReOrganize shape
|
108 |
+
base_img = base_img.transpose(2, 0, 1) # hwc -> chw
|
109 |
+
|
110 |
+
|
111 |
+
# Check the min max value range
|
112 |
+
# if verbose:
|
113 |
+
# print("{} min, max range value is {} - {}".format(intput_dir, np.min(base_img), np.max(base_img)))
|
114 |
+
|
115 |
+
|
116 |
+
# Write base img based on frame_idx
|
117 |
+
thisthat_condition[frame_idx] = base_img # Only the first frame, the rest is 0 initialized
|
118 |
+
|
119 |
+
##############################################################################################################################
|
120 |
+
|
121 |
+
|
122 |
+
if config["motion_bucket_id"] is None:
|
123 |
+
# take the motion to stats collected before
|
124 |
+
reflected_motion_bucket_id = 200
|
125 |
+
else:
|
126 |
+
reflected_motion_bucket_id = config["motion_bucket_id"]
|
127 |
+
|
128 |
+
|
129 |
+
# print("Motion Bucket ID is ", reflected_motion_bucket_id)
|
130 |
+
return (thisthat_condition, reflected_motion_bucket_id, controlnet_image_index, coordinate_values)
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
class Video_ThisThat_Dataset(Dataset):
|
135 |
+
'''
|
136 |
+
Video Dataset to load sequential frames for training with needed pre-processing and process with optical flow
|
137 |
+
'''
|
138 |
+
|
139 |
+
def __init__(self, config, device, normalize=True, tokenizer=None):
|
140 |
+
# Attribute variables
|
141 |
+
self.config = config
|
142 |
+
self.device = device
|
143 |
+
self.normalize = normalize
|
144 |
+
self.tokenizer = tokenizer
|
145 |
+
|
146 |
+
# Obtain values
|
147 |
+
self.video_seq_length = config["video_seq_length"]
|
148 |
+
self.height = config["height"]
|
149 |
+
self.width = config["width"]
|
150 |
+
|
151 |
+
# Process data
|
152 |
+
self.video_lists = []
|
153 |
+
for dataset_path in config["dataset_path"]:
|
154 |
+
for video_name in sorted(os.listdir(dataset_path)):
|
155 |
+
if not os.path.exists(os.path.join(dataset_path, video_name, "data.txt")):
|
156 |
+
continue
|
157 |
+
|
158 |
+
self.video_lists.append(os.path.join(dataset_path, video_name))
|
159 |
+
print("length of the dataset is ", len(self.video_lists))
|
160 |
+
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
def __len__(self):
|
165 |
+
return len(self.video_lists)
|
166 |
+
|
167 |
+
|
168 |
+
def _extract_frame_bridge(self, idx, flip=False):
|
169 |
+
''' Extract the frame in video based on the needed fps from already extracted frame
|
170 |
+
Args:
|
171 |
+
idx (int): The index to the file in the directory
|
172 |
+
flip (bool): Bool for whether we will flip
|
173 |
+
Returns:
|
174 |
+
video_frames (numpy): Extracted video frames in numpy format
|
175 |
+
'''
|
176 |
+
|
177 |
+
# Init the the Video Reader
|
178 |
+
# The naming of the Bridge dataset follow a pattern: im_x.jpg, so we need to
|
179 |
+
video_frame_path = self.video_lists[idx]
|
180 |
+
|
181 |
+
|
182 |
+
# Find needed file
|
183 |
+
needed_img_path = []
|
184 |
+
for idx in range(self.video_seq_length):
|
185 |
+
img_path = os.path.join(video_frame_path, "im_" + str(idx) + ".jpg")
|
186 |
+
needed_img_path.append(img_path)
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
# Read all img_path based on the order
|
191 |
+
video_frames = []
|
192 |
+
for img_path in needed_img_path:
|
193 |
+
if not os.path.exists(img_path):
|
194 |
+
print("We don't have ", img_path)
|
195 |
+
frame = cv2.imread(img_path)
|
196 |
+
|
197 |
+
try:
|
198 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
199 |
+
except Exception:
|
200 |
+
print("The exception place is ", img_path)
|
201 |
+
# Resize frames
|
202 |
+
frame = cv2.resize(frame, (self.width, self.height), interpolation = cv2.INTER_CUBIC)
|
203 |
+
|
204 |
+
# Flip aug
|
205 |
+
if flip:
|
206 |
+
frame = np.fliplr(frame)
|
207 |
+
|
208 |
+
# Collect frames
|
209 |
+
video_frames.append(np.expand_dims(frame, axis=0)) # The frame is already RGB, there is no need to convert here.
|
210 |
+
|
211 |
+
|
212 |
+
# Concatenate
|
213 |
+
video_frames = np.concatenate(video_frames, axis=0)
|
214 |
+
assert(len(video_frames) == self.video_seq_length)
|
215 |
+
|
216 |
+
# Returns
|
217 |
+
return video_frames
|
218 |
+
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
def __getitem__(self, idx):
|
223 |
+
''' Get item by idx and pre-process by Resize and Normalize to [0, 1]
|
224 |
+
Args:
|
225 |
+
idx (int): The index to the file in the directory
|
226 |
+
Returns:
|
227 |
+
return_dict (dict): video_frames (torch.float32) [-1, 1] and controlnet_condition (torch.float32) [0, 1]
|
228 |
+
'''
|
229 |
+
|
230 |
+
# Prepare the text if needed:
|
231 |
+
if self.config["use_text"]:
|
232 |
+
# Read the file
|
233 |
+
file_path = os.path.join(self.video_lists[idx], "lang.txt")
|
234 |
+
file = open(file_path, 'r')
|
235 |
+
prompt = file.readlines()[0] # Only read the first line
|
236 |
+
|
237 |
+
if self.config["mix_ambiguous"] and os.path.exists(os.path.join(self.video_lists[idx], "processed_text.txt")):
|
238 |
+
# If we don't have this txt file, we skip
|
239 |
+
|
240 |
+
######################################################## Mix up prompt ########################################################
|
241 |
+
|
242 |
+
# Read the file
|
243 |
+
file_path = os.path.join(self.video_lists[idx], "processed_text.txt")
|
244 |
+
file = open(file_path, 'r')
|
245 |
+
prompts = [line for line in file.readlines()] # Only read the first line
|
246 |
+
|
247 |
+
# Get the componenet
|
248 |
+
action = prompts[0][:-1]
|
249 |
+
this = prompts[1][:-1]
|
250 |
+
there = prompts[2][:-1]
|
251 |
+
|
252 |
+
|
253 |
+
random_value = random.random()
|
254 |
+
# If less than 0.4, we don't care, just use the most concrete one
|
255 |
+
if random_value >= 0.4 and random_value < 0.6:
|
256 |
+
# Mask pick object to "This"
|
257 |
+
prompt = action + " this to " + there
|
258 |
+
elif random_value >= 0.6 and random_value < 0.8:
|
259 |
+
# Mask place position to "There"
|
260 |
+
prompt = action + " " + this + " to there"
|
261 |
+
elif random_value >= 0.8 and random_value < 1.0:
|
262 |
+
# Just be like "this to there"
|
263 |
+
prompt = action + " this to there"
|
264 |
+
|
265 |
+
# print("New prompt is ", prompt)
|
266 |
+
###################################################################################################################################################
|
267 |
+
|
268 |
+
# else:
|
269 |
+
# print("We don't have llama processed prompt at ", self.video_lists[idx])
|
270 |
+
|
271 |
+
else:
|
272 |
+
prompt = ""
|
273 |
+
|
274 |
+
# Tokenize text prompt
|
275 |
+
tokenized_prompt = tokenize_captions(prompt, self.tokenizer, self.config)
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
# Dataset aug by chance (it is needed to check whether there is any object position words [left|right] in the prompt text)
|
280 |
+
flip = False
|
281 |
+
if random.random() < self.config["flip_aug_prob"]:
|
282 |
+
if self.config["use_text"]:
|
283 |
+
if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok)
|
284 |
+
flip = True
|
285 |
+
else:
|
286 |
+
flip = True
|
287 |
+
|
288 |
+
|
289 |
+
|
290 |
+
# Read frames for different dataset; Currently, we have WebVid / Bridge
|
291 |
+
if self.config["dataset_name"] == "Bridge":
|
292 |
+
video_frames_raw = self._extract_frame_bridge(idx, flip=flip)
|
293 |
+
else:
|
294 |
+
raise NotImplementedError("We don't support this dataset loader")
|
295 |
+
|
296 |
+
|
297 |
+
# Scale [0, 255] -> [-1, 1] if needed
|
298 |
+
if self.normalize:
|
299 |
+
video_frames = video_frames_raw.astype(np.float32) / 127.5 - 1 # Be careful to cast to float32
|
300 |
+
|
301 |
+
# Transform to Pytorch Tensor in the range [-1, 1]
|
302 |
+
video_frames = numpy_to_pt(video_frames)
|
303 |
+
|
304 |
+
|
305 |
+
# Generate the pairs we need
|
306 |
+
intput_dir = self.video_lists[idx]
|
307 |
+
|
308 |
+
# Get the This That point information
|
309 |
+
controlnet_condition, reflected_motion_bucket_id, controlnet_image_index, coordinate_values = get_thisthat_sam(self.config, intput_dir, flip=flip)
|
310 |
+
controlnet_condition = torch.from_numpy(controlnet_condition)
|
311 |
+
|
312 |
+
# Cast other value to tensor
|
313 |
+
reflected_motion_bucket_id = torch.tensor(reflected_motion_bucket_id, dtype=torch.float32)
|
314 |
+
controlnet_image_index = torch.tensor(controlnet_image_index, dtype=torch.int32)
|
315 |
+
coordinate_values = torch.tensor(coordinate_values, dtype=torch.int32)
|
316 |
+
|
317 |
+
|
318 |
+
# The tensor we returned is torch float32. We won't cast here for mixed precision training!
|
319 |
+
return {"video_frames" : video_frames,
|
320 |
+
"controlnet_condition" : controlnet_condition,
|
321 |
+
"reflected_motion_bucket_id" : reflected_motion_bucket_id,
|
322 |
+
"controlnet_image_index": controlnet_image_index,
|
323 |
+
"prompt": tokenized_prompt,
|
324 |
+
"coordinate_values": coordinate_values, # Useless now, but I still passed back
|
325 |
+
}
|
326 |
+
|
pretrained/PUT_YOUR_WEIGHT_HERE.md
ADDED
File without changes
|
requirements.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Non-strict version lib
|
2 |
+
opencv-python
|
3 |
+
transformers
|
4 |
+
accelerate
|
5 |
+
requests
|
6 |
+
moviepy
|
7 |
+
omegaconf
|
8 |
+
# xformers
|
9 |
+
tensorboard
|
10 |
+
einops
|
11 |
+
yacs
|
12 |
+
loguru
|
13 |
+
imageio
|
14 |
+
pyparsing
|
15 |
+
ultralytics
|
16 |
+
lpips
|
17 |
+
matplotlib
|
18 |
+
gradio
|
19 |
+
torch==2.0.1
|
20 |
+
torchvision
|
21 |
+
|
22 |
+
# Strict version lib
|
23 |
+
bitsandbytes==0.43.0
|
24 |
+
diffusers==0.25.1
|
25 |
+
timm==0.4.12
|
26 |
+
scipy==1.9.3
|
27 |
+
pyiqa==0.1.7
|
scripts/active_learning_select.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, shutil
|
2 |
+
import random
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
start_idx = 950
|
7 |
+
end_idx = 1020
|
8 |
+
select_num = 70
|
9 |
+
|
10 |
+
label_start_idx = 632
|
11 |
+
input_parent_dir = "../Bridge"
|
12 |
+
store_dir = "../bridge_select3"
|
13 |
+
|
14 |
+
if os.path.exists(store_dir):
|
15 |
+
shutil.rmtree(store_dir)
|
16 |
+
os.makedirs(store_dir)
|
17 |
+
|
18 |
+
for idx in range(start_idx, end_idx):
|
19 |
+
folder_path = os.path.join(input_parent_dir, str(idx))
|
20 |
+
select_idx = random.randint(0, len(os.listdir(folder_path)))
|
21 |
+
for idx, img_name in enumerate(os.listdir(folder_path)):
|
22 |
+
if idx == select_idx and img_name != "policy_out.pkl":
|
23 |
+
img_path = os.path.join(folder_path, img_name)
|
24 |
+
target_path = os.path.join(store_dir, str(label_start_idx) + ".jpg")
|
25 |
+
label_start_idx += 1
|
26 |
+
shutil.copy(img_path, target_path)
|
27 |
+
|
scripts/add_point2img.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This file is to add point to the first image
|
3 |
+
'''
|
4 |
+
|
5 |
+
import os, shutil, sys
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
input_folder_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/model_results/Human_Study/Input_Bridge_human_evaluation"
|
9 |
+
store_path = "point_highlighted"
|
10 |
+
|
11 |
+
if os.path.exists(input_folder_path):
|
12 |
+
shutil.rmtree(input_folder_path)
|
13 |
+
os.makedirs(input_folder_path)
|
14 |
+
|
15 |
+
|
16 |
+
for instance_name in os.listdir(input_folder_path):
|
17 |
+
|
18 |
+
sub_folder_dir = os.path.join(input_folder_path, instance_name)
|
19 |
+
|
20 |
+
# Read file
|
21 |
+
file_path = os.path.join(sub_folder_dir, "data.txt")
|
22 |
+
file1 = open(file_path, 'r')
|
23 |
+
Lines = file1.readlines()
|
24 |
+
|
25 |
+
# Read the first img
|
26 |
+
first_img_path = os.path.join(sub_folder_dir, "im_0.jpg")
|
27 |
+
|
28 |
+
|
29 |
+
# Init the image
|
30 |
+
base_img = cv2.imread(first_img_path).astype(np.float32) # Use the original image size
|
31 |
+
|
32 |
+
# Draw the point
|
33 |
+
for idx in range(len(Lines)):
|
34 |
+
# Read points
|
35 |
+
frame_idx, horizontal, vertical = Lines[idx].split(' ')
|
36 |
+
frame_idx, vertical, horizontal = int(frame_idx), int(float(vertical)), int(float(horizontal))
|
37 |
+
|
38 |
+
# Draw square around the target position
|
39 |
+
dot_range = 15 # Diameter
|
40 |
+
for i in range(-1*dot_range, dot_range+1):
|
41 |
+
for j in range(-1*dot_range, dot_range+1):
|
42 |
+
dil_vertical, dil_horizontal = vertical + i, horizontal + j
|
43 |
+
if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]):
|
44 |
+
if idx == 0:
|
45 |
+
base_img[dil_vertical][dil_horizontal] = [0, 0, 255] # The first point should be red
|
46 |
+
else:
|
47 |
+
base_img[dil_vertical][dil_horizontal] = [0, 255, 0] # The second point should be green to distinguish the first point
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
scripts/check_video.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This file is to make sure that the video files is readeable by moviepy, such that the data loader can read these files.
|
3 |
+
'''
|
4 |
+
import os
|
5 |
+
from moviepy.editor import VideoFileClip
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
video_dir = "../webvid_sample"
|
9 |
+
delete_abnormal_video = True # Whether you want to delete these abnormal video directly
|
10 |
+
|
11 |
+
for video_name in sorted(os.listdir(video_dir)):
|
12 |
+
video_path = os.path.join(video_dir, video_name)
|
13 |
+
try:
|
14 |
+
objVideoreader = VideoFileClip(filename=video_path)
|
15 |
+
except Exception:
|
16 |
+
print("There is an exception of reading: ", video_path)
|
17 |
+
if delete_abnormal_video:
|
18 |
+
print("We will remove this abnormal video source")
|
19 |
+
os.remove(video_path)
|
scripts/clean_bridge_dataset.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Sometimes, Bridge dataset will contain strange downloads, we need to clean them
|
3 |
+
'''
|
4 |
+
import os, shutil
|
5 |
+
|
6 |
+
# TODO: 后面把这个直接merge 到prepare_bridge_dataset中
|
7 |
+
if __name__ == "__main__":
|
8 |
+
dataset_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/Bridge"
|
9 |
+
|
10 |
+
for sub_folder in sorted(os.listdir(dataset_path)):
|
11 |
+
sub_folder_path = os.path.join(dataset_path, sub_folder)
|
12 |
+
|
13 |
+
img_lists = os.listdir(sub_folder_path)
|
14 |
+
if len(img_lists) < 14:
|
15 |
+
print("The folder is too short, we will remove them all")
|
16 |
+
shutil.rmtree(sub_folder_path)
|
17 |
+
continue
|
18 |
+
for img_name in img_lists:
|
19 |
+
img_path = os.path.join(sub_folder_path, img_name)
|
20 |
+
if not img_name.startswith("im_"):
|
21 |
+
print("We remove ", img_path)
|
22 |
+
os.remove(img_path)
|
scripts/collect_lang.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
THis file is to collect all lang.txt and move to a new directory, this is for the convenience to compress and scp the lang for post-processing
|
3 |
+
'''
|
4 |
+
import os, sys, shutil
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
parent_dir = "../datasets_rob"
|
8 |
+
dataset_paths = ["Bridge_v1_TT14", "Bridge_v2_TT14"]
|
9 |
+
store_folder = "../full_text_tmp"
|
10 |
+
|
11 |
+
# Manage the store folder
|
12 |
+
if os.path.exists(store_folder):
|
13 |
+
shutil.rmtree(store_folder)
|
14 |
+
os.makedirs(store_folder)
|
15 |
+
|
16 |
+
|
17 |
+
for dataset_name in dataset_paths:
|
18 |
+
store_path = os.path.join(store_folder, dataset_name)
|
19 |
+
if os.path.exists(store_path):
|
20 |
+
shutil.rmtree(store_path)
|
21 |
+
os.makedirs(store_path)
|
22 |
+
|
23 |
+
# Iterate all the files
|
24 |
+
for sub_folder_name in os.listdir(os.path.join(parent_dir, dataset_name)):
|
25 |
+
print("We are processing ", sub_folder_name)
|
26 |
+
lang_txt_path = os.path.join(parent_dir, dataset_name, sub_folder_name, "lang.txt")
|
27 |
+
|
28 |
+
# Store on the new address
|
29 |
+
store_file_path = os.path.join(store_path, sub_folder_name)
|
30 |
+
os.makedirs(store_file_path)
|
31 |
+
shutil.copyfile(lang_txt_path, os.path.join(store_file_path, "lang.txt"))
|
scripts/combine_results.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This repo is to combine multiple generated images with same index together
|
3 |
+
'''
|
4 |
+
|
5 |
+
import os, shutil, sys
|
6 |
+
import imageio
|
7 |
+
import math
|
8 |
+
import cv2
|
9 |
+
from PIL import Image
|
10 |
+
import collections
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
if __name__ == "__main__":
|
15 |
+
|
16 |
+
# Basic setting
|
17 |
+
data_paths = [
|
18 |
+
"human_evaluation_v3_V_raw_prompt",
|
19 |
+
"human_evaluation_v3_VG_raw_prompt_no_sam",
|
20 |
+
"human_evaluation_v3_VL_ambiguous_prompt",
|
21 |
+
|
22 |
+
"../datasets_rob/Bridge_human_evaluation",
|
23 |
+
|
24 |
+
"human_evaluation_v3_VL_raw_prompt",
|
25 |
+
"human_evaluation_v3_VGL_raw_prompt_no_sam",
|
26 |
+
"human_evaluation_v3_VGL_ambiguous_prompt_no_sam",
|
27 |
+
]
|
28 |
+
store_path = "combined_results_human_evaluation"
|
29 |
+
sample_data_path = data_paths[0]
|
30 |
+
gif_per_row = 4 # Number of GIF files per row
|
31 |
+
|
32 |
+
|
33 |
+
# Create folder
|
34 |
+
if os.path.exists(store_path):
|
35 |
+
shutil.rmtree(store_path)
|
36 |
+
os.makedirs(store_path)
|
37 |
+
|
38 |
+
|
39 |
+
# Iterate the sample
|
40 |
+
for instance_idx, sub_folder_name in enumerate(sorted(os.listdir(sample_data_path))):
|
41 |
+
print("we are processing ", sub_folder_name)
|
42 |
+
|
43 |
+
collected_gif_paths = []
|
44 |
+
for data_path in data_paths:
|
45 |
+
collected_gif_paths.append(os.path.join(data_path, sub_folder_name, 'combined.gif'))
|
46 |
+
|
47 |
+
# Merge frames together
|
48 |
+
rows = math.ceil(len(collected_gif_paths) / gif_per_row)
|
49 |
+
cols = gif_per_row
|
50 |
+
|
51 |
+
# Read all input GIFs and find maximum dimensions
|
52 |
+
gifs = []
|
53 |
+
max_width, max_height = 0, 0
|
54 |
+
for path in collected_gif_paths:
|
55 |
+
gif = imageio.mimread(path)
|
56 |
+
max_width = max(max_width, gif[0].shape[1])
|
57 |
+
max_height = max(max_height, gif[0].shape[0])
|
58 |
+
gifs.append(gif)
|
59 |
+
|
60 |
+
# Create blank canvas for concatenated GIF
|
61 |
+
frames_length = len(gifs[0])
|
62 |
+
canvas_width = max_width * cols
|
63 |
+
canvas_height = max_height * rows
|
64 |
+
canvas = np.zeros((frames_length, canvas_height, canvas_width, 3), dtype=np.uint8)
|
65 |
+
|
66 |
+
|
67 |
+
# push each frame into the canvas placeholder
|
68 |
+
gif_index = 0
|
69 |
+
for row in range(rows):
|
70 |
+
for col in range(cols):
|
71 |
+
gif = gifs[gif_index]
|
72 |
+
gif_height, gif_width, _ = gif[0].shape
|
73 |
+
start_y = row * max_height
|
74 |
+
start_x = col * max_width
|
75 |
+
for i in range(frames_length):
|
76 |
+
canvas[i, start_y:start_y+gif_height, start_x:start_x+gif_width, :] = gif[i]
|
77 |
+
|
78 |
+
# Update index
|
79 |
+
gif_index += 1
|
80 |
+
if gif_index == len(collected_gif_paths):
|
81 |
+
break
|
82 |
+
|
83 |
+
|
84 |
+
# Write the concatenated GIF
|
85 |
+
imageio.mimsave(os.path.join(store_path, sub_folder_name + ".gif"), canvas, duration=0.05, quality=100)
|
scripts/compress_gif.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, shutil, sys
|
2 |
+
import cv2
|
3 |
+
import imageio
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def compress_gif(sub_folder_path):
|
8 |
+
|
9 |
+
# Check valid length
|
10 |
+
all_files = os.listdir(sub_folder_path)
|
11 |
+
num_frames_input = 0
|
12 |
+
valid = True
|
13 |
+
for file_name in os.listdir(sub_folder_path):
|
14 |
+
if file_name.startswith("im_"):
|
15 |
+
num_frames_input += 1
|
16 |
+
for idx in range(num_frames_input):
|
17 |
+
img_path = 'im_' + str(idx) + '.jpg'
|
18 |
+
if img_path not in all_files: # Should be sequential existing
|
19 |
+
valid = False
|
20 |
+
break
|
21 |
+
if not valid:
|
22 |
+
print("We cannot generate a video because the video is not sequential")
|
23 |
+
return False
|
24 |
+
|
25 |
+
|
26 |
+
if num_frames_input == 0:
|
27 |
+
print("We cannot generate a video because the input length is 0")
|
28 |
+
return False
|
29 |
+
|
30 |
+
img_lists = []
|
31 |
+
for idx in range(num_frames_input):
|
32 |
+
img_path = os.path.join(sub_folder_path, "im_" + str(idx) + ".jpg")
|
33 |
+
img_lists.append(cv2.resize(cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB), (384, 256)))
|
34 |
+
|
35 |
+
imageio.mimsave(os.path.join(sub_folder_path, 'combined.gif'), np.array(img_lists), duration=0.05, quality=100)
|
36 |
+
|
37 |
+
return True
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
dataset_path = "../datasets_rob/Bridge_human_evaluation" # ../datasets_rob/Bridge_v1_raw
|
42 |
+
|
43 |
+
for sub_folder_name in sorted(os.listdir(dataset_path)):
|
44 |
+
print("We are processing ", sub_folder_name)
|
45 |
+
sub_folder_path = os.path.join(dataset_path, sub_folder_name)
|
46 |
+
|
47 |
+
status = compress_gif(sub_folder_path)
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
scripts/compress_videos.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, shutil, sys
|
2 |
+
from moviepy.editor import ImageSequenceClip
|
3 |
+
|
4 |
+
|
5 |
+
def compress_video(sub_folder_path, video_name):
|
6 |
+
store_path = os.path.join(sub_folder_path, video_name)
|
7 |
+
|
8 |
+
if os.path.exists(store_path):
|
9 |
+
os.remove(store_path)
|
10 |
+
|
11 |
+
|
12 |
+
# Check valid length
|
13 |
+
all_files = os.listdir(sub_folder_path)
|
14 |
+
num_frames_input = 0
|
15 |
+
valid = True
|
16 |
+
for file_name in os.listdir(sub_folder_path):
|
17 |
+
if file_name.startswith("im_"):
|
18 |
+
num_frames_input += 1
|
19 |
+
for idx in range(num_frames_input):
|
20 |
+
img_path = 'im_' + str(idx) + '.jpg'
|
21 |
+
if img_path not in all_files: # Should be sequential existing
|
22 |
+
valid = False
|
23 |
+
break
|
24 |
+
if not valid:
|
25 |
+
print("We cannot generate a video because the video is not sequential")
|
26 |
+
return False
|
27 |
+
|
28 |
+
|
29 |
+
if num_frames_input == 0:
|
30 |
+
print("We cannot generate a video because the input length is 0")
|
31 |
+
return False
|
32 |
+
|
33 |
+
img_lists = []
|
34 |
+
for idx in range(num_frames_input):
|
35 |
+
img_path = os.path.join(sub_folder_path, "im_" + str(idx) + ".jpg")
|
36 |
+
img_lists.append(img_path)
|
37 |
+
|
38 |
+
clip = ImageSequenceClip(img_lists, fps=4)
|
39 |
+
clip.write_videofile(store_path)
|
40 |
+
|
41 |
+
return True
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
dataset_path = "../datasets_rob/Bridge_v2_raw" # ../datasets_rob/Bridge_v1_raw
|
46 |
+
|
47 |
+
for sub_folder_name in sorted(os.listdir(dataset_path)):
|
48 |
+
sub_folder_path = os.path.join(dataset_path, sub_folder_name)
|
49 |
+
|
50 |
+
status = compress_video(sub_folder_path)
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
scripts/crop_video_frames.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This file is to split the video sources in a folder to folder with images, for the mass evaluation
|
3 |
+
'''
|
4 |
+
import os, shutil, sys
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
input_folder = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/StreamingT2V_results"
|
10 |
+
needed_frame_length = 14
|
11 |
+
|
12 |
+
idx = 0
|
13 |
+
for file_name in sorted(os.listdir(input_folder)):
|
14 |
+
print("We are processing ", file_name)
|
15 |
+
sub_folder_path = os.path.join(input_folder, file_name)
|
16 |
+
|
17 |
+
for idx in range(len(os.listdir(sub_folder_path))):
|
18 |
+
if idx >= needed_frame_length:
|
19 |
+
target_path = os.path.join(sub_folder_path, str(idx)+".png")
|
20 |
+
os.remove(target_path)
|
21 |
+
|
22 |
+
|
scripts/extract_test_dataset.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Extract the test dataset from the txt file
|
3 |
+
'''
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
txt_path = "match_info_v2.txt"
|
7 |
+
store_path = "test_path_v2.txt"
|
8 |
+
start_idx = len("/nfs/turbo/jjparkcv-turbo-large/boyangwa/raw/bridge_data_v2/")
|
9 |
+
|
10 |
+
read_file = open(txt_path, "r")
|
11 |
+
write_file = open(store_path, "w")
|
12 |
+
for line in read_file.readlines():
|
13 |
+
test_dataset_path = line.split(' ')[1]
|
14 |
+
test_instance = test_dataset_path[start_idx:]
|
15 |
+
|
16 |
+
write_file.write(test_instance)
|
17 |
+
|
18 |
+
|
scripts/generate_noise.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
|
5 |
+
# Set the dimensions of the image
|
6 |
+
height = 256
|
7 |
+
width = 256
|
8 |
+
|
9 |
+
# Generate random pixel values
|
10 |
+
noise = np.random.rand(height, width, 3) * 255 # Scale to 255 for grayscale image
|
11 |
+
|
12 |
+
|
13 |
+
for idx in range (4):
|
14 |
+
cv2.imwrite("noise"+str(idx)+".png", noise)
|
scripts/generate_sam.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, shutil
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
6 |
+
|
7 |
+
|
8 |
+
def show_anns(anns):
|
9 |
+
if len(anns) == 0:
|
10 |
+
return
|
11 |
+
|
12 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
13 |
+
ax = plt.gca()
|
14 |
+
ax.set_autoscale_on(True)
|
15 |
+
|
16 |
+
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 3))
|
17 |
+
# img[:,:,3] = 0
|
18 |
+
for ann in sorted_anns:
|
19 |
+
m = ann['segmentation']
|
20 |
+
color_mask = np.concatenate([np.random.random(3)])
|
21 |
+
img[m] = color_mask
|
22 |
+
|
23 |
+
return img*255
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
input_parent_folder = "../Bridge_filter_flow"
|
30 |
+
|
31 |
+
|
32 |
+
# Init SAM for segmentation task
|
33 |
+
model_type = "vit_h"
|
34 |
+
weight_path = "pretrained/sam_vit_h_4b8939.pth"
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
sam = sam_model_registry[model_type](checkpoint=weight_path).to(device="cuda")
|
39 |
+
mask_generator = SamAutomaticMaskGenerator(sam) # There is a lot of setting here
|
40 |
+
|
41 |
+
|
42 |
+
for sub_dir_name in sorted(os.listdir(input_parent_folder)):
|
43 |
+
print("We are processing ", sub_dir_name)
|
44 |
+
ref_img_path = os.path.join(input_parent_folder, sub_dir_name, 'im_0.jpg')
|
45 |
+
store_path = os.path.join(input_parent_folder, sub_dir_name, 'sam.png')
|
46 |
+
|
47 |
+
image = cv2.imread(ref_img_path)
|
48 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
49 |
+
|
50 |
+
mask = mask_generator.generate(image)
|
51 |
+
mask_img = show_anns(mask)
|
52 |
+
|
53 |
+
cv2.imwrite(store_path, mask_img)
|
54 |
+
|
55 |
+
|
56 |
+
|
scripts/generate_sam_this_that.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, shutil
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
6 |
+
|
7 |
+
|
8 |
+
def show_anns(anns):
|
9 |
+
if len(anns) == 0:
|
10 |
+
return
|
11 |
+
|
12 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
13 |
+
ax = plt.gca()
|
14 |
+
ax.set_autoscale_on(True)
|
15 |
+
|
16 |
+
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 3))
|
17 |
+
# img[:,:,3] = 0
|
18 |
+
for ann in sorted_anns:
|
19 |
+
m = ann['segmentation']
|
20 |
+
color_mask = np.concatenate([np.random.random(3)])
|
21 |
+
img[m] = color_mask
|
22 |
+
|
23 |
+
return img*255
|
24 |
+
|
25 |
+
|
26 |
+
def show_mask(mask, random_color=False):
|
27 |
+
if random_color:
|
28 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
29 |
+
else:
|
30 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
31 |
+
h, w = mask.shape[-2:]
|
32 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
33 |
+
|
34 |
+
return mask_image * 255
|
35 |
+
|
36 |
+
|
37 |
+
def show_points(coords, labels, ax, marker_size=375):
|
38 |
+
pos_points = coords[labels==1]
|
39 |
+
neg_points = coords[labels==0]
|
40 |
+
ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
|
41 |
+
ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
input_parent_folder = "validation_tmp"
|
46 |
+
|
47 |
+
|
48 |
+
# Init SAM for segmentation task
|
49 |
+
model_type = "vit_h"
|
50 |
+
weight_path = "pretrained/sam_vit_h_4b8939.pth"
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
sam = sam_model_registry[model_type](checkpoint=weight_path).to(device="cuda")
|
55 |
+
sam_predictor = SamPredictor(sam)
|
56 |
+
mask_generator = SamAutomaticMaskGenerator(sam)
|
57 |
+
|
58 |
+
|
59 |
+
# Iterate the folder
|
60 |
+
for sub_dir_name in sorted(os.listdir(input_parent_folder)):
|
61 |
+
print("We are processing ", sub_dir_name)
|
62 |
+
ref_img_path = os.path.join(input_parent_folder, sub_dir_name, 'im_0.jpg')
|
63 |
+
data_txt_path = os.path.join(input_parent_folder, sub_dir_name, 'data.txt')
|
64 |
+
|
65 |
+
|
66 |
+
# Read the image and process
|
67 |
+
image = cv2.imread(ref_img_path)
|
68 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
69 |
+
|
70 |
+
|
71 |
+
# Read the positive point
|
72 |
+
data_file = open(data_txt_path, 'r')
|
73 |
+
lines = data_file.readlines()
|
74 |
+
for idx in range(len(lines)):
|
75 |
+
frame_idx, horizontal, vertical = lines[idx].split(' ')
|
76 |
+
vertical, horizontal = int(float(vertical)), int(float(horizontal))
|
77 |
+
positive_point_cords = [[horizontal, vertical]]
|
78 |
+
|
79 |
+
positive_point_cords = np.array(positive_point_cords)
|
80 |
+
positive_point_labels = np.ones(len(positive_point_cords))
|
81 |
+
print(positive_point_cords)
|
82 |
+
|
83 |
+
|
84 |
+
|
85 |
+
# Set the SAM predictor
|
86 |
+
sam_predictor.set_image(np.uint8(image))
|
87 |
+
masks, scores, logits = sam_predictor.predict(
|
88 |
+
point_coords = positive_point_cords, # Only positive points here
|
89 |
+
point_labels = positive_point_labels,
|
90 |
+
multimask_output = False,
|
91 |
+
)
|
92 |
+
# print("Detected mask length is ", len(masks))
|
93 |
+
|
94 |
+
# Visualize
|
95 |
+
mask_img = show_mask(masks[0])
|
96 |
+
cv2.imwrite(os.path.join(input_parent_folder, sub_dir_name, "first_contact0.png"), mask_img)
|
97 |
+
|
98 |
+
break
|
99 |
+
|
100 |
+
|
101 |
+
# SAM all
|
102 |
+
sam_all = mask_generator.generate(image)
|
103 |
+
all_sam_imgs = show_anns(sam_all)
|
104 |
+
cv2.imwrite("sam_all.png", all_sam_imgs)
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
|
scripts/generate_traj.py
ADDED
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
import copy
|
4 |
+
import os, shutil
|
5 |
+
import imageio
|
6 |
+
import cv2
|
7 |
+
from PIL import Image, ImageDraw
|
8 |
+
import os.path as osp
|
9 |
+
import random
|
10 |
+
import numpy as np
|
11 |
+
import torch.multiprocessing as mp
|
12 |
+
from multiprocessing import set_start_method
|
13 |
+
import math, time, gc
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
from segment_anything import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
|
18 |
+
|
19 |
+
|
20 |
+
# Import files from the local path
|
21 |
+
root_path = os.path.abspath('.')
|
22 |
+
sys.path.append(root_path)
|
23 |
+
from config.flowformer_config import get_cfg
|
24 |
+
from flowformer_code.utils import flow_viz, frame_utils
|
25 |
+
from flowformer_code.utils.utils import InputPadder
|
26 |
+
from flowformer_code.FlowFormer import build_flowformer
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
|
31 |
+
TRAIN_SIZE = [432, 960]
|
32 |
+
|
33 |
+
def show_anns(anns):
|
34 |
+
if len(anns) == 0:
|
35 |
+
return
|
36 |
+
|
37 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
38 |
+
ax = plt.gca()
|
39 |
+
ax.set_autoscale_on(True)
|
40 |
+
|
41 |
+
img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
|
42 |
+
img[:,:,3] = 0
|
43 |
+
for ann in sorted_anns:
|
44 |
+
m = ann['segmentation']
|
45 |
+
color_mask = np.concatenate([np.random.random(3), [0.35]])
|
46 |
+
img[m] = color_mask
|
47 |
+
|
48 |
+
return img*255
|
49 |
+
|
50 |
+
|
51 |
+
def show_mask(mask, random_color=False):
|
52 |
+
if random_color:
|
53 |
+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
|
54 |
+
else:
|
55 |
+
color = np.array([30/255, 144/255, 255/255, 0.6])
|
56 |
+
h, w = mask.shape[-2:]
|
57 |
+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
|
58 |
+
|
59 |
+
return mask_image * 255
|
60 |
+
|
61 |
+
|
62 |
+
def compute_grid_indices(image_shape, patch_size=TRAIN_SIZE, min_overlap=20):
|
63 |
+
if min_overlap >= TRAIN_SIZE[0] or min_overlap >= TRAIN_SIZE[1]:
|
64 |
+
raise ValueError(
|
65 |
+
f"Overlap should be less than size of patch (got {min_overlap}"
|
66 |
+
f"for patch size {patch_size}).")
|
67 |
+
if image_shape[0] == TRAIN_SIZE[0]:
|
68 |
+
hs = list(range(0, image_shape[0], TRAIN_SIZE[0]))
|
69 |
+
else:
|
70 |
+
hs = list(range(0, image_shape[0], TRAIN_SIZE[0] - min_overlap))
|
71 |
+
if image_shape[1] == TRAIN_SIZE[1]:
|
72 |
+
ws = list(range(0, image_shape[1], TRAIN_SIZE[1]))
|
73 |
+
else:
|
74 |
+
ws = list(range(0, image_shape[1], TRAIN_SIZE[1] - min_overlap))
|
75 |
+
|
76 |
+
# Make sure the final patch is flush with the image boundary
|
77 |
+
hs[-1] = image_shape[0] - patch_size[0]
|
78 |
+
ws[-1] = image_shape[1] - patch_size[1]
|
79 |
+
return [(h, w) for h in hs for w in ws]
|
80 |
+
|
81 |
+
|
82 |
+
|
83 |
+
def compute_flow(model, image1, image2, weights=None):
|
84 |
+
print(f"computing flow...")
|
85 |
+
|
86 |
+
image_size = image1.shape[1:]
|
87 |
+
|
88 |
+
image1, image2 = image1[None].cuda(), image2[None].cuda()
|
89 |
+
|
90 |
+
hws = compute_grid_indices(image_size)
|
91 |
+
if weights is None: # no tile
|
92 |
+
padder = InputPadder(image1.shape)
|
93 |
+
image1, image2 = padder.pad(image1, image2)
|
94 |
+
|
95 |
+
flow_pre, _ = model(image1, image2)
|
96 |
+
|
97 |
+
flow_pre = padder.unpad(flow_pre)
|
98 |
+
flow = flow_pre[0].permute(1, 2, 0).cpu().numpy()
|
99 |
+
else: # tile
|
100 |
+
flows = 0
|
101 |
+
flow_count = 0
|
102 |
+
|
103 |
+
for idx, (h, w) in enumerate(hws):
|
104 |
+
image1_tile = image1[:, :, h:h+TRAIN_SIZE[0], w:w+TRAIN_SIZE[1]]
|
105 |
+
image2_tile = image2[:, :, h:h+TRAIN_SIZE[0], w:w+TRAIN_SIZE[1]]
|
106 |
+
flow_pre, _ = model(image1_tile, image2_tile)
|
107 |
+
padding = (w, image_size[1]-w-TRAIN_SIZE[1], h, image_size[0]-h-TRAIN_SIZE[0], 0, 0)
|
108 |
+
flows += F.pad(flow_pre * weights[idx], padding)
|
109 |
+
flow_count += F.pad(weights[idx], padding)
|
110 |
+
|
111 |
+
flow_pre = flows / flow_count
|
112 |
+
flow = flow_pre[0].permute(1, 2, 0).cpu().numpy()
|
113 |
+
|
114 |
+
return flow
|
115 |
+
|
116 |
+
|
117 |
+
def compute_adaptive_image_size(image_size):
|
118 |
+
target_size = TRAIN_SIZE
|
119 |
+
scale0 = target_size[0] / image_size[0]
|
120 |
+
scale1 = target_size[1] / image_size[1]
|
121 |
+
|
122 |
+
if scale0 > scale1:
|
123 |
+
scale = scale0
|
124 |
+
else:
|
125 |
+
scale = scale1
|
126 |
+
|
127 |
+
image_size = (int(image_size[1] * scale), int(image_size[0] * scale))
|
128 |
+
|
129 |
+
return image_size
|
130 |
+
|
131 |
+
|
132 |
+
def prepare_image(viz_root_dir, fn1, fn2, keep_size):
|
133 |
+
print(f"preparing image...")
|
134 |
+
|
135 |
+
image1 = frame_utils.read_gen(fn1)
|
136 |
+
image2 = frame_utils.read_gen(fn2)
|
137 |
+
image1 = np.array(image1).astype(np.uint8)[..., :3]
|
138 |
+
image2 = np.array(image2).astype(np.uint8)[..., :3]
|
139 |
+
if not keep_size:
|
140 |
+
dsize = compute_adaptive_image_size(image1.shape[0:2])
|
141 |
+
image1 = cv2.resize(image1, dsize=dsize, interpolation=cv2.INTER_CUBIC)
|
142 |
+
image2 = cv2.resize(image2, dsize=dsize, interpolation=cv2.INTER_CUBIC)
|
143 |
+
image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
|
144 |
+
image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
|
145 |
+
|
146 |
+
|
147 |
+
dirname = osp.dirname(fn1)
|
148 |
+
filename = osp.splitext(osp.basename(fn1))[0]
|
149 |
+
|
150 |
+
viz_dir = osp.join(viz_root_dir, dirname)
|
151 |
+
# if not osp.exists(viz_dir):
|
152 |
+
# os.makedirs(viz_dir)
|
153 |
+
|
154 |
+
viz_fn = osp.join(viz_dir, filename + '.png')
|
155 |
+
|
156 |
+
return image1, image2, viz_fn
|
157 |
+
|
158 |
+
|
159 |
+
def build_model():
|
160 |
+
print(f"building model...")
|
161 |
+
cfg = get_cfg()
|
162 |
+
model = torch.nn.DataParallel(build_flowformer(cfg))
|
163 |
+
model.load_state_dict(torch.load(cfg.model))
|
164 |
+
|
165 |
+
model.cuda()
|
166 |
+
model.eval()
|
167 |
+
|
168 |
+
return model
|
169 |
+
|
170 |
+
|
171 |
+
def filter_uv(flow, threshold_factor = 0.2):
|
172 |
+
u = flow[:,:,0]
|
173 |
+
v = flow[:,:,1]
|
174 |
+
|
175 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
176 |
+
rad_max = np.max(rad)
|
177 |
+
|
178 |
+
threshold = threshold_factor * rad_max
|
179 |
+
flow[:,:,0][rad < threshold] = 0
|
180 |
+
flow[:,:,1][rad < threshold] = 0
|
181 |
+
|
182 |
+
return flow
|
183 |
+
|
184 |
+
|
185 |
+
def visualize_traj(base_img, traj_path, connect_points = True):
|
186 |
+
target_vertical, target_horizontal = traj_path[-1]
|
187 |
+
|
188 |
+
if connect_points and len(traj_path) > 1:
|
189 |
+
# Draw a line to connect two point to show motion direction
|
190 |
+
start_coordinate = (traj_path[-2][1], traj_path[-2][0])
|
191 |
+
end_coordinate = (traj_path[-1][1], traj_path[-1][0])
|
192 |
+
pil_img = Image.fromarray(base_img)
|
193 |
+
|
194 |
+
# Draw the line
|
195 |
+
color = 'red'
|
196 |
+
draw = ImageDraw.Draw(pil_img)
|
197 |
+
draw.line([start_coordinate, end_coordinate], fill = color, width = 3)
|
198 |
+
|
199 |
+
base_img = np.array(pil_img)
|
200 |
+
|
201 |
+
|
202 |
+
# Draw a green dot only for the start point
|
203 |
+
if len(traj_path) == 1:
|
204 |
+
dot_range = 3
|
205 |
+
for i in range(-1*dot_range, dot_range+1):
|
206 |
+
for j in range(-1*dot_range, dot_range+1):
|
207 |
+
dil_vertical, dil_horizontal = target_vertical + i, target_horizontal + j
|
208 |
+
if (0 <= dil_vertical and dil_vertical < base_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < base_img.shape[1]):
|
209 |
+
base_img[dil_vertical][dil_horizontal] = [0, 128, 0]
|
210 |
+
else:
|
211 |
+
print("The traj is out of boundary!!!!!!!!!!!!!!!!!!!!! and we won't consider it") # 现在
|
212 |
+
return (False, base_img)
|
213 |
+
|
214 |
+
return (True, base_img)
|
215 |
+
|
216 |
+
|
217 |
+
|
218 |
+
def calculate_flow(viz_root_dir, store_dir, img_pairs, optical_flow_model, sam_predictor, SAM_positive_sample_num, SAM_negative_sample_num, mask_generator, traj_visualization, keep_size, verbose=False):
|
219 |
+
|
220 |
+
# Trajectory prepare
|
221 |
+
traj_path = [] # It collects all points traversed in a temporal order
|
222 |
+
is_hard_to_track = False # If this is True, it means that, we have a time in tracking hard to find dx and dy movement. Under this circumstance, we are not very recommended to use it
|
223 |
+
hard_track_idxs = set()
|
224 |
+
traj_image_lists = []
|
225 |
+
|
226 |
+
|
227 |
+
# Iterate all image pairs
|
228 |
+
for idx, img_pair in enumerate(img_pairs):
|
229 |
+
|
230 |
+
fn1, fn2 = img_pair
|
231 |
+
print(f"processing {fn1}, {fn2}...")
|
232 |
+
|
233 |
+
image1, image2, viz_fn = prepare_image(viz_root_dir, fn1, fn2, keep_size) # Be very careful, image1 and image2 may be different resolution shape if keep_size is False
|
234 |
+
# Generate the optical flow and filter those that is small motion
|
235 |
+
flow_uv = filter_uv(compute_flow(optical_flow_model, image1, image2, None))
|
236 |
+
|
237 |
+
# if verbose:
|
238 |
+
# Store the visualization of flow_uv
|
239 |
+
# flow_img = flow_viz.flow_to_image(flow_uv)
|
240 |
+
# cv2.imwrite("optical_flow_" + str(idx+1) + ".png", flow_img[:, :, [2,1,0]])
|
241 |
+
|
242 |
+
if idx == 0:
|
243 |
+
# We will store the first image to memory for further visualization purpose
|
244 |
+
|
245 |
+
# Base img
|
246 |
+
# base_img = np.uint8(np.transpose(image1.numpy(), (1,2,0)))
|
247 |
+
|
248 |
+
# SAM figure
|
249 |
+
# sam_all = mask_generator.generate(image1)
|
250 |
+
# base_img = show_anns(sam_all)
|
251 |
+
# base_img = np.transpose(base_img, (1,2,0))
|
252 |
+
|
253 |
+
# Plain white image
|
254 |
+
base_img = np.zeros(np.transpose(image1.numpy(), (1,2,0)).shape, dtype=np.uint8)
|
255 |
+
base_img.fill(255)
|
256 |
+
|
257 |
+
|
258 |
+
|
259 |
+
|
260 |
+
# Extract moving points (positive point)
|
261 |
+
positive_point_cords = []
|
262 |
+
nonzeros = np.nonzero(flow_uv) # [(vertical), (horizontal)]
|
263 |
+
if len(nonzeros[0]) < SAM_positive_sample_num:
|
264 |
+
# We require the number of points to be more than SAM_positive_sample_num
|
265 |
+
return False
|
266 |
+
positive_orders = np.random.choice(len(nonzeros[0]), SAM_positive_sample_num, replace=False) # we have randomly select instead of use all in the sam_predictor prediction
|
267 |
+
for i in range(len(nonzeros[0])):
|
268 |
+
if i in positive_orders:
|
269 |
+
positive_point_cords.append([nonzeros[1][i], nonzeros[0][i]]) # 根据document来看,这个就应该是先horizontal再vertical,也就是这个顺序
|
270 |
+
positive_point_cords = np.array(positive_point_cords)
|
271 |
+
positive_point_labels = np.ones(len(positive_point_cords))
|
272 |
+
|
273 |
+
|
274 |
+
# Define negative sample (outside the optical flow choice)
|
275 |
+
if SAM_negative_sample_num != 0:
|
276 |
+
skip_prob = 2 * SAM_negative_sample_num / (flow_uv.shape[0]*flow_uv.shape[1] - len(nonzeros[0]))
|
277 |
+
negative_point_cords = []
|
278 |
+
for i in range(flow_uv.shape[0]):
|
279 |
+
for j in range(flow_uv.shape[1]):
|
280 |
+
if flow_uv[i][j][0] == 0 and flow_uv[i][j][1] == 0: # 0 means the no motion zone and we have already filter low motion as zero before
|
281 |
+
if random.random() < skip_prob:
|
282 |
+
negative_point_cords.append([j, i]) # 根据document来看,这个就应该是先horizontal再vertical,也就是这个顺序
|
283 |
+
negative_point_cords = np.array(negative_point_cords) # [:SAM_negative_sample_num]
|
284 |
+
negative_point_labels = np.zeros(len(negative_point_cords)) # Make sure that it is less than / equals to SAM_negative_sample_num quantity
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
################## Use SAM to filter out what we need (& use negative points) ##################
|
289 |
+
if idx == 0: # Only consider the first frame now.
|
290 |
+
# With sample coordinate
|
291 |
+
sam_predictor.set_image(np.uint8(np.transpose(image1.numpy(), (1,2,0))))
|
292 |
+
if SAM_negative_sample_num != 0 and len(negative_point_cords) != 0:
|
293 |
+
all_point_cords = np.concatenate((positive_point_cords, negative_point_cords), axis=0)
|
294 |
+
all_point_labels = np.concatenate((positive_point_labels, negative_point_labels), axis=0)
|
295 |
+
else:
|
296 |
+
all_point_cords = positive_point_cords
|
297 |
+
all_point_labels = positive_point_labels
|
298 |
+
|
299 |
+
masks, scores, logits = sam_predictor.predict(
|
300 |
+
point_coords=all_point_cords,
|
301 |
+
point_labels=all_point_labels,
|
302 |
+
multimask_output=False,
|
303 |
+
)
|
304 |
+
mask = masks[0] # TODO: 一定要确定我们这里选择了最大的mask,而没有考虑的第二大和其他的, 这里可能有bug,我们默认了第一个就是最大的mask
|
305 |
+
# if verbose:
|
306 |
+
# cv2.imwrite("mask_"+str(idx+1)+".png", (np.uint8(mask)*255))
|
307 |
+
# annotated_img = show_mask(mask)
|
308 |
+
# cv2.imwrite("annotated.png", annotated_img)
|
309 |
+
|
310 |
+
|
311 |
+
################## Choose the one we need as the reference for the future tracking ##################
|
312 |
+
# Choose a random point in the mask
|
313 |
+
target_zone = np.nonzero(mask) # [(vertical), (horizontal)]
|
314 |
+
target_zone = [(target_zone[0][i], target_zone[1][i]) for i in range(len(target_zone[0]))] # Now, the sturcture is [(vertical, horizontal), ...]
|
315 |
+
|
316 |
+
repeat_time = 0
|
317 |
+
loop2find = True
|
318 |
+
while loop2find:
|
319 |
+
loop2find = False
|
320 |
+
start_point = target_zone[np.random.choice(len(target_zone), 1, replace=False)[0]]
|
321 |
+
start_vertical, start_horizontal = start_point
|
322 |
+
|
323 |
+
repeat_time += 1
|
324 |
+
if repeat_time == 100:
|
325 |
+
# In some minor case, it may have infinite loop, so we need to manually break if it is looping
|
326 |
+
print("We are still hard to find a optimal first point, but we cannot let it loop")
|
327 |
+
break
|
328 |
+
|
329 |
+
# Try to choose a start_point that is more centralized (Not close to the border)
|
330 |
+
fast_break = False
|
331 |
+
for i in range(-15, 15):
|
332 |
+
for j in range(-15, 15):
|
333 |
+
dil_vertical, dil_horizontal = start_vertical + i, start_horizontal + j
|
334 |
+
if (0 <= dil_vertical and dil_vertical < mask.shape[0]) and (0 <= dil_horizontal and dil_horizontal < mask.shape[1]):
|
335 |
+
if mask[dil_vertical][dil_horizontal] == 0:
|
336 |
+
print("We need to change to a new position for the start p Since this one is close to the border of the object...........")
|
337 |
+
loop2find = True
|
338 |
+
fast_break = True
|
339 |
+
break
|
340 |
+
else:
|
341 |
+
# We won't want to consider those that is close to the boundary
|
342 |
+
print("We need to change to a new position Since this one is close to the border of the image...........")
|
343 |
+
loop2find = True
|
344 |
+
fast_break = True
|
345 |
+
break
|
346 |
+
if fast_break:
|
347 |
+
break
|
348 |
+
traj_path.append(start_point)
|
349 |
+
|
350 |
+
status, base_img = visualize_traj(base_img, traj_path)
|
351 |
+
if status == False: # If the traj is False, we won't consider it anymore.
|
352 |
+
file = open("log.txt", "a")
|
353 |
+
file.write("Invalid start point\n")
|
354 |
+
return False
|
355 |
+
|
356 |
+
# Read from the last one in traj
|
357 |
+
ref_vertical, ref_horizontal = traj_path[-1][0], traj_path[-1][1]
|
358 |
+
|
359 |
+
|
360 |
+
# Get the average motion vector for point surrounding (8+1 directions) the ref_point; This is because this is the most accurate statistics
|
361 |
+
horizon_lists, vertical_lists = [], []
|
362 |
+
start_range, end_range = -5, 5
|
363 |
+
|
364 |
+
# Calculate the average motion based on surrounding motion
|
365 |
+
search_times = 0
|
366 |
+
while len(horizon_lists) == 0: # If we cannot find a direction, we use average value inside this mask, but we will flag it.
|
367 |
+
search_times += 1
|
368 |
+
|
369 |
+
if search_times > 1:
|
370 |
+
print("This is hard to track!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! and we have tracked " + str(search_times) + " times")
|
371 |
+
# TODO: 如果out of boundary那种,search times到了8-10次的就砍掉那后面frame吧,这种非常inaccurate了, 你也可以retrack一个新的点,但是没有什么意义,看整体数量来定吧
|
372 |
+
is_hard_to_track = True
|
373 |
+
hard_track_idxs.add(idx)
|
374 |
+
|
375 |
+
if abs(start_range) >= flow_uv.shape[0]//2:
|
376 |
+
file = open("log.txt", "a")
|
377 |
+
file.write("This folder has search all space but didn't find any place to track optical flow\n")
|
378 |
+
return False # If we have already search for the whole graph but didn't find anything to track, we discard this sample
|
379 |
+
|
380 |
+
# Search for a larger space which is nearby 我觉得扩大搜索范围应该是最稳定的选择吧
|
381 |
+
for i in range(start_range, end_range):
|
382 |
+
for j in range(start_range, end_range):
|
383 |
+
target_vertical, target_horizontal = ref_vertical + i, ref_horizontal + j
|
384 |
+
if 0 <= target_vertical and target_vertical < flow_uv.shape[0] and 0 <= target_horizontal and target_horizontal < flow_uv.shape[1]:
|
385 |
+
if flow_uv[target_vertical, target_horizontal, 0] == 0 or flow_uv[target_vertical, target_horizontal, 1] == 0:
|
386 |
+
continue # Ignore zero vector to ensure only calculate moving position
|
387 |
+
horizon_lists.append(flow_uv[target_vertical, target_horizontal, 0]) # Horizontal motion strength
|
388 |
+
vertical_lists.append(flow_uv[target_vertical, target_horizontal, 1]) # Vertical motion strength
|
389 |
+
|
390 |
+
# If there isn't any to search, we kepp on a larger space
|
391 |
+
start_range -= 10
|
392 |
+
end_range += 10
|
393 |
+
|
394 |
+
average_dx = sum(horizon_lists)/len(horizon_lists)
|
395 |
+
average_dy = sum(vertical_lists)/len(vertical_lists)
|
396 |
+
print("average movement is ", (average_dx, average_dy))
|
397 |
+
traj_path.append(( int(traj_path[-1][0] + average_dy), int(traj_path[-1][1] + average_dx))) # Append the motion in independent order
|
398 |
+
|
399 |
+
print(traj_path)
|
400 |
+
|
401 |
+
|
402 |
+
##################### Visualize the trajectory path (Debug Purpose) #####################
|
403 |
+
status, base_img = visualize_traj(base_img, traj_path)
|
404 |
+
if status == False: # If the traj is False, we won't consider it anymore.
|
405 |
+
return False
|
406 |
+
|
407 |
+
cv2.imwrite(os.path.join(store_dir, "traj_path.png"), cv2.cvtColor(base_img, cv2.COLOR_BGR2RGB))
|
408 |
+
|
409 |
+
if traj_visualization:
|
410 |
+
status, single_traj_img = visualize_traj(np.uint8(np.transpose(image1.numpy(), (1,2,0))), traj_path[:-1], connect_points=False)
|
411 |
+
if status == False: # If the traj is False, we won't consider it anymore.
|
412 |
+
return False
|
413 |
+
|
414 |
+
traj_write_path = os.path.join(store_dir, "traj_"+str(idx)+".png")
|
415 |
+
# cv2.imwrite(traj_write_path, cv2.cvtColor(single_traj_img, cv2.COLOR_BGR2RGB))
|
416 |
+
traj_image_lists.append(traj_write_path)
|
417 |
+
|
418 |
+
|
419 |
+
# if traj_visualization:
|
420 |
+
# images = []
|
421 |
+
# for filename in traj_image_lists:
|
422 |
+
# images.append(imageio.imread(filename))
|
423 |
+
# # os.remove(filename) # Remove when used
|
424 |
+
# imageio.mimsave(os.path.join(store_dir, 'traj_motion.gif'), images, duration=0.05)
|
425 |
+
|
426 |
+
|
427 |
+
# TODO: 可以如果hard to track,就aggressivly多试即便,我们根据这个hard_track_idxs的长度来粗略判断哪个最好,三次里面选最好的
|
428 |
+
if is_hard_to_track:
|
429 |
+
if len(hard_track_idxs) >= len(img_pairs)//3: # If more than half of the traj is hard to track, we need to consider discard this one
|
430 |
+
file = open("log.txt", "a")
|
431 |
+
file.write("we have a lot of times hard to find dx and dy movement. Under this circumstance, we are not very recommended to use the track\n")
|
432 |
+
return False
|
433 |
+
|
434 |
+
|
435 |
+
# Write a file store all position for further utilization
|
436 |
+
txt_path = os.path.join(store_dir, "traj_data.txt")
|
437 |
+
if os.path.exists(txt_path):
|
438 |
+
os.remove(txt_path)
|
439 |
+
file = open(txt_path, "a")
|
440 |
+
for traj in traj_path:
|
441 |
+
file.write(str(traj[0]) + " " + str(traj[1]) + "\n")
|
442 |
+
# Save in numpy information
|
443 |
+
# with open(os.path.join(store_dir, 'traj_data.npy'), 'wb') as f:
|
444 |
+
# np.save(f, flow_uv)
|
445 |
+
print("We write ", traj_path)
|
446 |
+
return True
|
447 |
+
|
448 |
+
|
449 |
+
|
450 |
+
def manage_seq_range(input_dir, store_dir, total_frame_needed):
|
451 |
+
|
452 |
+
lists = os.listdir(input_dir)
|
453 |
+
lists = lists[2:-2]
|
454 |
+
num_frames_input = len(lists)
|
455 |
+
|
456 |
+
if num_frames_input < total_frame_needed:
|
457 |
+
print("The number of frames is too short for constructing the sequnece length needed")
|
458 |
+
return False
|
459 |
+
|
460 |
+
|
461 |
+
division_factor = num_frames_input // total_frame_needed
|
462 |
+
remain_frame = num_frames_input % total_frame_needed
|
463 |
+
|
464 |
+
gaps = [division_factor for _ in range(total_frame_needed)]
|
465 |
+
for idx in range(remain_frame):
|
466 |
+
gaps[idx] += 1
|
467 |
+
|
468 |
+
|
469 |
+
cur_idx = 2
|
470 |
+
for global_idx, gap in enumerate(gaps):
|
471 |
+
source_path = os.path.join(input_dir, "im_"+str(cur_idx)+".jpg")
|
472 |
+
destination_path = os.path.join(store_dir, "im_"+str(global_idx)+".jpg")
|
473 |
+
|
474 |
+
shutil.copyfile(source_path, destination_path)
|
475 |
+
cur_idx += gap
|
476 |
+
|
477 |
+
return True
|
478 |
+
|
479 |
+
|
480 |
+
def generate_pairs(dirname, start_idx, end_idx):
|
481 |
+
img_pairs = []
|
482 |
+
for idx in range(start_idx, end_idx):
|
483 |
+
img1 = osp.join(dirname, f'im_{idx}.jpg')
|
484 |
+
img2 = osp.join(dirname, f'im_{idx+1}.jpg')
|
485 |
+
# img1 = f'{idx:06}.png'
|
486 |
+
# img2 = f'{idx+1:06}.png'
|
487 |
+
img_pairs.append((img1, img2))
|
488 |
+
|
489 |
+
return img_pairs
|
490 |
+
|
491 |
+
|
492 |
+
def process_partial_request(request_list, num_frames, traj_visualization, viz_root_dir):
|
493 |
+
|
494 |
+
|
495 |
+
# Init the optical flow model
|
496 |
+
optical_flow_model = build_model()
|
497 |
+
|
498 |
+
# Init SAM for segmentation task
|
499 |
+
model_type = "vit_h"
|
500 |
+
weight_path = "pretrained/sam_vit_h_4b8939.pth"
|
501 |
+
SAM_positive_sample_num = 20 # How many points we use for the positive sample num ()
|
502 |
+
SAM_negative_sample_num = 0 # How many points we use for the negative sample num
|
503 |
+
|
504 |
+
print("In multi processing, we will build an instance of mask_generator independently")
|
505 |
+
sam = sam_model_registry[model_type](checkpoint=weight_path).to(device="cuda")
|
506 |
+
mask_generator = SamAutomaticMaskGenerator(sam)
|
507 |
+
print("In multi processing, we will build an instance of sam_predictor independently")
|
508 |
+
sam_predictor = SamPredictor(sam)
|
509 |
+
|
510 |
+
|
511 |
+
counter = 0
|
512 |
+
while True:
|
513 |
+
counter += 1
|
514 |
+
if counter == 10:
|
515 |
+
counter = 0
|
516 |
+
gc.collect()
|
517 |
+
print("We will sleep here to clear memory")
|
518 |
+
time.sleep(5)
|
519 |
+
info = request_list[0]
|
520 |
+
request_list = request_list[1:]
|
521 |
+
if info == None:
|
522 |
+
print("This queue ends")
|
523 |
+
break
|
524 |
+
|
525 |
+
|
526 |
+
# Process each sub_input_dir and store the information there
|
527 |
+
sub_input_dir = info
|
528 |
+
|
529 |
+
|
530 |
+
img_pairs = generate_pairs(sub_input_dir, 0, num_frames-1)
|
531 |
+
print(img_pairs)
|
532 |
+
|
533 |
+
with torch.no_grad():
|
534 |
+
|
535 |
+
# Calculate the optical flow and return a status to say whther this generated flow is usable
|
536 |
+
status = calculate_flow(viz_root_dir, sub_input_dir, img_pairs, optical_flow_model, sam_predictor, SAM_positive_sample_num, SAM_negative_sample_num,
|
537 |
+
mask_generator, traj_visualization, keep_size = True)
|
538 |
+
|
539 |
+
# file = open("log.txt", "a")
|
540 |
+
print("The status for folder " + sub_input_dir + " is " + str(status) + "\n")
|
541 |
+
|
542 |
+
if status == False:
|
543 |
+
# If the status is failed, we will remove it afterwords
|
544 |
+
print("The status is Failed, so we won't store this one as one promising data")
|
545 |
+
else:
|
546 |
+
print("We have successfully process one!")
|
547 |
+
|
548 |
+
|
549 |
+
if __name__ == '__main__':
|
550 |
+
|
551 |
+
# Manage the paramter
|
552 |
+
parser = argparse.ArgumentParser()
|
553 |
+
parser.add_argument('--input_dir', default = '../validation_flow14/')
|
554 |
+
parser.add_argument('--num_workers', type = int, default = 1) # starting index of the image sequence
|
555 |
+
parser.add_argument('--viz_root_dir', default = 'viz_results')
|
556 |
+
parser.add_argument('--traj_visualization', default = True) # If this is True,
|
557 |
+
|
558 |
+
# list_start = 0
|
559 |
+
# list_end = 25000
|
560 |
+
num_frames = 14
|
561 |
+
|
562 |
+
args = parser.parse_args()
|
563 |
+
input_dir = args.input_dir
|
564 |
+
num_workers = args.num_workers
|
565 |
+
viz_root_dir = args.viz_root_dir
|
566 |
+
traj_visualization = args.traj_visualization
|
567 |
+
|
568 |
+
|
569 |
+
|
570 |
+
store_idx = 0
|
571 |
+
dir_list = []
|
572 |
+
for sub_input_name in sorted(os.listdir(input_dir)):
|
573 |
+
sub_input_dir = os.path.join(input_dir, sub_input_name)
|
574 |
+
# sub_store_dir = os.path.join(store_dir, "0"*(7-len(str(store_idx)))+str(store_idx))
|
575 |
+
store_idx += 1
|
576 |
+
dir_list.append(sub_input_dir)
|
577 |
+
|
578 |
+
# Truncate the list to the target
|
579 |
+
# dir_list = dir_list[list_start:]
|
580 |
+
|
581 |
+
|
582 |
+
# Use multiprocessing to handle to speed up
|
583 |
+
num = math.ceil(len(dir_list) / num_workers)
|
584 |
+
for idx in range(num_workers):
|
585 |
+
# set_start_method('spawn', force=True)
|
586 |
+
|
587 |
+
request_list = dir_list[:num]
|
588 |
+
request_list.append(None)
|
589 |
+
dir_list = dir_list[num:]
|
590 |
+
|
591 |
+
|
592 |
+
process_partial_request(request_list, num_frames, traj_visualization, viz_root_dir) # This is for debug purpose
|
593 |
+
# p = mp.Process(target=process_partial_request, args=(request_list, num_frames, traj_visualization, viz_root_dir, ))
|
594 |
+
# p.start()
|
595 |
+
|
596 |
+
print("Submitted all jobs!")
|
597 |
+
# p.join() # 好像不加这个multiprocess就莫名自己结束了
|
598 |
+
print("All task finished!")
|
599 |
+
|
600 |
+
|
601 |
+
|
scripts/interpolate_by_repeat.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This file is trying to repeat the frames such the it reaches target frames needed
|
3 |
+
'''
|
4 |
+
import os, shutil, sys
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
input_path = "/nfs/turbo/coe-jjparkcv/boyangwa/AVDC/AVDC_results"
|
8 |
+
store_path = "/nfs/turbo/coe-jjparkcv/boyangwa/AVDC/AVDC_results_interpolated"
|
9 |
+
total_frames_needed = 14
|
10 |
+
|
11 |
+
# Handle the file folder management
|
12 |
+
if os.path.exists(store_path):
|
13 |
+
shutil.rmtree(store_path)
|
14 |
+
os.makedirs(store_path)
|
15 |
+
|
16 |
+
for video_name in sorted(os.listdir(input_path)):
|
17 |
+
sub_input_path = os.path.join(input_path, video_name)
|
18 |
+
sub_store_path = os.path.join(store_path, video_name)
|
19 |
+
|
20 |
+
# Create the store place
|
21 |
+
os.makedirs(sub_store_path)
|
22 |
+
|
23 |
+
# Find valid image lists
|
24 |
+
num_frames_input = 0
|
25 |
+
for file_name in os.listdir(sub_input_path):
|
26 |
+
if file_name.endswith("png"):
|
27 |
+
num_frames_input += 1
|
28 |
+
print("num_frames_input is ", num_frames_input)
|
29 |
+
|
30 |
+
# Calculate needed parameters
|
31 |
+
division_factor = total_frames_needed // num_frames_input
|
32 |
+
remain_frames = (total_frames_needed % num_frames_input) - 1 # -1 for adaptation
|
33 |
+
|
34 |
+
# Define the gap
|
35 |
+
gaps = [division_factor for _ in range(num_frames_input)]
|
36 |
+
for idx in range(remain_frames):
|
37 |
+
if idx % 2 == 0:
|
38 |
+
gaps[idx//2] += 1 # Start to end order
|
39 |
+
else:
|
40 |
+
gaps[-1*(1+(idx//2))] += 1 # End to start order
|
41 |
+
|
42 |
+
print("gaps is ", gaps)
|
43 |
+
|
44 |
+
|
45 |
+
# Write to the new folder
|
46 |
+
store_idx = 0
|
47 |
+
for frame_idx, gap in enumerate(gaps):
|
48 |
+
for tmp in range(gap): # Repeat copy gap num of times
|
49 |
+
img_path = os.path.join(sub_input_path, str(frame_idx)+".png")
|
50 |
+
shutil.copyfile(img_path, os.path.join(sub_store_path, str(store_idx)+".png"))
|
51 |
+
store_idx += 1
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
scripts/length_stats.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, shutil
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
input_folder_path = "../Bridge_v2"
|
8 |
+
|
9 |
+
average_length = []
|
10 |
+
|
11 |
+
# Iterate each file
|
12 |
+
for sub_folder_name in sorted(os.listdir(input_folder_path)):
|
13 |
+
sub_folder_path = os.path.join(input_folder_path, sub_folder_name)
|
14 |
+
|
15 |
+
average_length.append(len(os.listdir(sub_folder_path))) # Have more than one than expected, but we keep this
|
16 |
+
print("average length of {} is {}".format(sub_folder_name, average_length[-1]))
|
17 |
+
|
18 |
+
print("average_movement_list is ", average_length)
|
19 |
+
n, bins, patches = plt.hist(average_length, bins=100)
|
20 |
+
plt.savefig("dataset_length2.png")
|
21 |
+
|
scripts/motion_stats.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, shutil
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
from statistics import mean
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
input_folder_paths = ["../datasets_rob/Bridge_v1_raw", "../datasets_rob/Bridge_v2_raw"] # "../datasets_rob/Bridge_v1_raw", "../datasets_rob/Bridge_v2_raw"
|
10 |
+
num_frames = 14
|
11 |
+
store_name = "movement.png"
|
12 |
+
|
13 |
+
|
14 |
+
average_movement_list = []
|
15 |
+
not_valid_num = 0
|
16 |
+
not_exists_num = 0
|
17 |
+
# Iterate each file
|
18 |
+
for input_folder_path in input_folder_paths:
|
19 |
+
for sub_folder_name in sorted(os.listdir(input_folder_path)):
|
20 |
+
sub_folder_path = os.path.join(input_folder_path, sub_folder_name)
|
21 |
+
flow_path = os.path.join(sub_folder_path, 'flow.txt')
|
22 |
+
|
23 |
+
if not os.path.exists(flow_path):
|
24 |
+
not_exists_num += 1
|
25 |
+
continue
|
26 |
+
|
27 |
+
|
28 |
+
# Read the movement
|
29 |
+
file = open(flow_path, 'r')
|
30 |
+
info = file.readlines()
|
31 |
+
print(info)
|
32 |
+
if len(info) == 0:
|
33 |
+
not_valid_num += 1
|
34 |
+
continue
|
35 |
+
info = info[0][:-2]
|
36 |
+
per_video_movement = float(info)
|
37 |
+
|
38 |
+
|
39 |
+
# Calculate the number of frames in this video
|
40 |
+
num_frames_input = 0
|
41 |
+
valid = True
|
42 |
+
for file_name in os.listdir(sub_folder_path): # num_frames_input is the total number of files with name begin with im_
|
43 |
+
if file_name.startswith("im_"):
|
44 |
+
num_frames_input += 1
|
45 |
+
for idx in range(num_frames_input): # Ensure that this number is concurrent
|
46 |
+
img_path = os.path.join(sub_folder_path, 'im_' + str(idx) + '.jpg')
|
47 |
+
if not os.path.exists(img_path): # Should be sequential existing
|
48 |
+
valid = False
|
49 |
+
break
|
50 |
+
if num_frames_input < 2:
|
51 |
+
valid = False
|
52 |
+
if not valid:
|
53 |
+
not_valid_num += 1
|
54 |
+
print("This is not valid path")
|
55 |
+
continue
|
56 |
+
|
57 |
+
average_movement_list.append(per_video_movement * (num_frames_input/num_frames)) # Have more than one than expected, but we keep this
|
58 |
+
print("average movement of {} is {}".format(sub_folder_name, average_movement_list[-1]))
|
59 |
+
|
60 |
+
print("not_exists_num is ", not_exists_num)
|
61 |
+
print("not_valid_num is ", not_valid_num)
|
62 |
+
print("average_movement_list length is ", len(average_movement_list))
|
63 |
+
|
64 |
+
# Get mean and variance data
|
65 |
+
mean_value = mean(average_movement_list)
|
66 |
+
std_value = math.sqrt(np.var(average_movement_list))
|
67 |
+
print("Mean is ", mean_value)
|
68 |
+
print("std_value is ", std_value)
|
69 |
+
|
70 |
+
# Plot the figure
|
71 |
+
n, bins, patches = plt.hist(average_movement_list, bins=100)
|
72 |
+
plt.title("Mean" + str(mean_value) + "_STD"+str(std_value))
|
73 |
+
plt.savefig(store_name)
|
74 |
+
|
75 |
+
|
scripts/process_llama.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Process the llama file for the next step
|
3 |
+
'''
|
4 |
+
import os, shutil, sys
|
5 |
+
import json
|
6 |
+
import pandas as pd
|
7 |
+
import collections
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
|
12 |
+
# Define important path
|
13 |
+
json_path = "../SVD1/v1.jsonl"
|
14 |
+
folder_path = "/home/kiteret/Desktop/StableVideoDiffusion/full_text_tmp/"
|
15 |
+
|
16 |
+
|
17 |
+
# Read the json file
|
18 |
+
with open(json_path, 'r') as json_file:
|
19 |
+
json_list = list(json_file)
|
20 |
+
|
21 |
+
# Iterate all the json files
|
22 |
+
length_stats = collections.defaultdict(int)
|
23 |
+
for json_info in json_list:
|
24 |
+
json_info = json.loads(json_info)
|
25 |
+
|
26 |
+
|
27 |
+
# Define the path to write
|
28 |
+
key_start = len("/home/chfeng/llama3/full_text_tmp/")
|
29 |
+
key_end = len("lang.txt")
|
30 |
+
sub_path = json_info["file_path"][key_start:int(-1*key_end)]
|
31 |
+
new_text_path = os.path.join(folder_path, sub_path, "processed_text.txt")
|
32 |
+
if os.path.exists(new_text_path):
|
33 |
+
os.remove(new_text_path)
|
34 |
+
|
35 |
+
|
36 |
+
# Sanity check for the case where input is missed
|
37 |
+
if json_info["input"] == "":
|
38 |
+
print("It is weird for the input is empty in the LLM process for ", sub_path)
|
39 |
+
continue
|
40 |
+
|
41 |
+
|
42 |
+
# Re-Define the content
|
43 |
+
outputs = json_info["output"]
|
44 |
+
if outputs.find("action:") != 0:
|
45 |
+
print("It is weird for no actions: keyword in the outputs for ", sub_path, " with prompt ", outputs)
|
46 |
+
continue
|
47 |
+
|
48 |
+
# Prepare write file
|
49 |
+
contents = outputs.split('\n')
|
50 |
+
f = open(new_text_path, "a")
|
51 |
+
|
52 |
+
# Itearte
|
53 |
+
effective_length = 0
|
54 |
+
for idx, content in enumerate(contents):
|
55 |
+
key_word = content.split(":")[1][1:]
|
56 |
+
if key_word != "":
|
57 |
+
effective_length += 1
|
58 |
+
else:
|
59 |
+
if idx == 1:
|
60 |
+
print("It is abnormal for the this content to be empty ", sub_path, " with prompt ", outputs)
|
61 |
+
f.write(key_word + "\n")
|
62 |
+
# if effective_length == 2:
|
63 |
+
# print("short prompt case is ", sub_path, " with prompt ", outputs)
|
64 |
+
if effective_length < 2: # For those only 1 or zero, we won't consider them
|
65 |
+
print("The prompt is too short for ", sub_path, " with prompt ", outputs)
|
66 |
+
os.remove(new_text_path)
|
67 |
+
|
68 |
+
length_stats[effective_length] += 1
|
69 |
+
|
70 |
+
print("length_stats is ", length_stats)
|
71 |
+
|
72 |
+
|
73 |
+
|
74 |
+
|
scripts/process_sim.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This is a script to processs Mark's data.
|
3 |
+
'''
|
4 |
+
import os, sys, shutil
|
5 |
+
|
6 |
+
if __name__ == "__main__":
|
7 |
+
file_path = "/nfs/turbo/coe-jjparkcv/datasets/isaac-gym-pick-place/full/dataset_v3_proc"
|
8 |
+
store_path = "../datasets_rob/sim_raw"
|
9 |
+
most_descriptive_prompt_idx = 6 # Start from the 0
|
10 |
+
|
11 |
+
|
12 |
+
# Folder management
|
13 |
+
if os.path.exists(store_path):
|
14 |
+
shutil.rmtree(store_path)
|
15 |
+
os.makedirs(store_path)
|
16 |
+
|
17 |
+
# Check length
|
18 |
+
file_names = os.listdir(file_path)
|
19 |
+
target_length = len(file_names) // 10 # 10 files as a cycle
|
20 |
+
|
21 |
+
|
22 |
+
for idx in range(target_length):
|
23 |
+
sub_folder_path = os.path.join(file_path, "run_"+str(10*idx))
|
24 |
+
if not os.path.exists(sub_folder_path):
|
25 |
+
continue
|
26 |
+
|
27 |
+
# Prepare the target position
|
28 |
+
sub_store_path = os.path.join(store_path, str(idx))
|
29 |
+
os.makedirs(sub_store_path)
|
30 |
+
|
31 |
+
# Find the key prompt to read it
|
32 |
+
prompt_content = []
|
33 |
+
for tmp_idx in range(10):
|
34 |
+
tmp_text_path = os.path.join(file_path, "run_"+str(10*idx + tmp_idx), "lang.txt") # Usually, the 6th is the most concrete version
|
35 |
+
if not os.path.exists(tmp_text_path):
|
36 |
+
continue
|
37 |
+
file = open(tmp_text_path, 'r')
|
38 |
+
prompt_content.append(file.readlines()[0])
|
39 |
+
file.close()
|
40 |
+
print("prompt_content we have num ", len(prompt_content))
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
# Copy the image into the target position and copy the data.txt
|
45 |
+
for file_name in os.listdir(sub_folder_path):
|
46 |
+
if file_name == "lang.txt":
|
47 |
+
continue
|
48 |
+
shutil.copyfile(os.path.join(sub_folder_path, file_name), os.path.join(sub_store_path, file_name))
|
49 |
+
|
50 |
+
# Handle the lang.txt
|
51 |
+
target_lang_txt_path = os.path.join(sub_store_path, "lang.txt")
|
52 |
+
f = open(target_lang_txt_path, "a")
|
53 |
+
f.write(prompt_content[most_descriptive_prompt_idx]+"\n")
|
54 |
+
for tmp_idx in range(10):
|
55 |
+
if tmp_idx == most_descriptive_prompt_idx:
|
56 |
+
continue
|
57 |
+
f.write(prompt_content[tmp_idx]+"\n")
|
58 |
+
f.close()
|
59 |
+
|
scripts/resize_img.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, shutil
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
if __name__ == "__main__":
|
5 |
+
input_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/resize"
|
6 |
+
output_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/resize_resized"
|
7 |
+
|
8 |
+
if os.path.exists(output_path):
|
9 |
+
shutil.rmtree(output_path)
|
10 |
+
os.makedirs(output_path)
|
11 |
+
|
12 |
+
for img_name in os.listdir(input_path):
|
13 |
+
img_path = os.path.join(input_path, img_name)
|
14 |
+
img = cv2.imread(img_path)
|
15 |
+
img = cv2.resize(img, (384, 256))
|
16 |
+
store_path = os.path.join(output_path, img_name)
|
17 |
+
cv2.imwrite(store_path, img)
|
scripts/resize_video_seq.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This file is designed to resize the video sequence to the target resolution
|
3 |
+
'''
|
4 |
+
import os, sys, shutil
|
5 |
+
import cv2
|
6 |
+
|
7 |
+
if __name__ == "__main__":
|
8 |
+
input_folder = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/model_results/SVD_results"
|
9 |
+
store_path = "/nfs/turbo/jjparkcv-turbo-large/boyangwa/model_results/SVD_results_resized"
|
10 |
+
target_height, target_width = 256, 384
|
11 |
+
|
12 |
+
if os.path.exists(store_path):
|
13 |
+
shutil.rmtree(store_path)
|
14 |
+
os.makedirs(store_path)
|
15 |
+
|
16 |
+
for video_name in sorted(os.listdir(input_folder)):
|
17 |
+
print("We are processing ", video_name)
|
18 |
+
sub_video_folder = os.path.join(input_folder, video_name)
|
19 |
+
sub_store_folder = os.path.join(store_path, video_name)
|
20 |
+
os.makedirs(sub_store_folder)
|
21 |
+
|
22 |
+
for img_name in os.listdir(sub_video_folder):
|
23 |
+
if not img_name.endswith("jpg") and not img_name.endswith("png"):
|
24 |
+
continue
|
25 |
+
|
26 |
+
img_path = os.path.join(sub_video_folder, img_name)
|
27 |
+
store_img_path = os.path.join(sub_store_folder, img_name)
|
28 |
+
img = cv2.imread(img_path)
|
29 |
+
|
30 |
+
# Resize
|
31 |
+
img = cv2.resize(img, (target_width, target_height))
|
32 |
+
cv2.imwrite(store_img_path, img)
|
33 |
+
|
scripts/train_test_split.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, shutil
|
2 |
+
import random
|
3 |
+
|
4 |
+
|
5 |
+
if __name__ == "__main__":
|
6 |
+
base_dataset_path = "../datasets_rob/Bridge_v1_raw"
|
7 |
+
test_store_path = "../datasets_rob/Bridge_v1_test_raw"
|
8 |
+
split_ratio = 0.1 # [0, 1] range
|
9 |
+
|
10 |
+
# Prepare the folder
|
11 |
+
if os.path.exists(test_store_path):
|
12 |
+
shutil.rmtree(test_store_path)
|
13 |
+
os.makedirs(test_store_path)
|
14 |
+
|
15 |
+
full_img_lists = os.listdir(base_dataset_path)
|
16 |
+
random.shuffle(full_img_lists)
|
17 |
+
target_test_length = int(len(full_img_lists) * split_ratio)
|
18 |
+
test_img_lists = full_img_lists[-1 * target_test_length : ]
|
19 |
+
|
20 |
+
# Move the lists based on test_img_lists
|
21 |
+
for test_img_name in test_img_lists:
|
22 |
+
shutil.move(os.path.join(base_dataset_path, test_img_name), os.path.join(test_store_path, test_img_name))
|
23 |
+
|
scripts/visualize_thisthat_point.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This repo is provided to change the destination area.
|
3 |
+
'''
|
4 |
+
|
5 |
+
import os, cv2
|
6 |
+
|
7 |
+
|
8 |
+
def draw_dot(ref_img, new_h, new_w):
|
9 |
+
# Draw the dot
|
10 |
+
dot_range = 3
|
11 |
+
for i in range(-1*dot_range, dot_range+1):
|
12 |
+
for j in range(-1*dot_range, dot_range+1):
|
13 |
+
dil_vertical, dil_horizontal = new_h + i, new_w + j
|
14 |
+
if (0 <= dil_vertical and dil_vertical < ref_img.shape[0]) and (0 <= dil_horizontal and dil_horizontal < ref_img.shape[1]):
|
15 |
+
ref_img[dil_vertical, dil_horizontal, :] = [0, 128, 0]
|
16 |
+
|
17 |
+
return ref_img
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
instance_path = "datasets/validation_thisthat14/000049/"
|
22 |
+
new_w, new_h = 385, 310
|
23 |
+
# 256.1850280761719 241.71287155151367
|
24 |
+
|
25 |
+
# Read the items
|
26 |
+
data_path = os.path.join(instance_path, "data.txt")
|
27 |
+
ref_img_path = os.path.join(instance_path, "im_0.jpg")
|
28 |
+
ref_img = cv2.imread(ref_img_path)
|
29 |
+
|
30 |
+
|
31 |
+
# Read the first point
|
32 |
+
file1 = open(data_path, 'r')
|
33 |
+
Lines = file1.readlines()
|
34 |
+
frame_idx, horizontal, vertical = Lines[0].split(' ')
|
35 |
+
ref_img = draw_dot(ref_img, int(float(vertical)), int(float(horizontal)))
|
36 |
+
|
37 |
+
# Second dot
|
38 |
+
ref_img = draw_dot(ref_img, new_h, new_w)
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
# Store the image
|
43 |
+
cv2.imwrite("visual.png", ref_img)
|
svd/diffusion_arch/transformer_temporal.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
from dataclasses import dataclass
|
15 |
+
from typing import Any, Dict, Optional
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
20 |
+
from diffusers.utils import BaseOutput
|
21 |
+
from diffusers.models.attention import BasicTransformerBlock, TemporalBasicTransformerBlock
|
22 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
23 |
+
from diffusers.models.modeling_utils import ModelMixin
|
24 |
+
from diffusers.models.resnet import AlphaBlender
|
25 |
+
|
26 |
+
|
27 |
+
@dataclass
|
28 |
+
class TransformerTemporalModelOutput(BaseOutput):
|
29 |
+
"""
|
30 |
+
The output of [`TransformerTemporalModel`].
|
31 |
+
|
32 |
+
Args:
|
33 |
+
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
|
34 |
+
The hidden states output conditioned on `encoder_hidden_states` input.
|
35 |
+
"""
|
36 |
+
|
37 |
+
sample: torch.FloatTensor
|
38 |
+
|
39 |
+
|
40 |
+
class TransformerTemporalModel(ModelMixin, ConfigMixin):
|
41 |
+
"""
|
42 |
+
A Transformer model for video-like data.
|
43 |
+
|
44 |
+
Parameters:
|
45 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
46 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
47 |
+
in_channels (`int`, *optional*):
|
48 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
49 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
50 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
51 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
52 |
+
attention_bias (`bool`, *optional*):
|
53 |
+
Configure if the `TransformerBlock` attention should contain a bias parameter.
|
54 |
+
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
|
55 |
+
This is fixed during training since it is used to learn a number of position embeddings.
|
56 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`):
|
57 |
+
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
|
58 |
+
activation functions.
|
59 |
+
norm_elementwise_affine (`bool`, *optional*):
|
60 |
+
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
|
61 |
+
double_self_attention (`bool`, *optional*):
|
62 |
+
Configure if each `TransformerBlock` should contain two self-attention layers.
|
63 |
+
positional_embeddings: (`str`, *optional*):
|
64 |
+
The type of positional embeddings to apply to the sequence input before passing use.
|
65 |
+
num_positional_embeddings: (`int`, *optional*):
|
66 |
+
The maximum length of the sequence over which to apply positional embeddings.
|
67 |
+
"""
|
68 |
+
|
69 |
+
@register_to_config
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
num_attention_heads: int = 16,
|
73 |
+
attention_head_dim: int = 88,
|
74 |
+
in_channels: Optional[int] = None,
|
75 |
+
out_channels: Optional[int] = None,
|
76 |
+
num_layers: int = 1,
|
77 |
+
dropout: float = 0.0,
|
78 |
+
norm_num_groups: int = 32,
|
79 |
+
cross_attention_dim: Optional[int] = None,
|
80 |
+
attention_bias: bool = False,
|
81 |
+
sample_size: Optional[int] = None,
|
82 |
+
activation_fn: str = "geglu",
|
83 |
+
norm_elementwise_affine: bool = True,
|
84 |
+
double_self_attention: bool = True,
|
85 |
+
positional_embeddings: Optional[str] = None,
|
86 |
+
num_positional_embeddings: Optional[int] = None,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
self.num_attention_heads = num_attention_heads
|
90 |
+
self.attention_head_dim = attention_head_dim
|
91 |
+
inner_dim = num_attention_heads * attention_head_dim
|
92 |
+
|
93 |
+
self.in_channels = in_channels
|
94 |
+
|
95 |
+
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
96 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
97 |
+
|
98 |
+
# 3. Define transformers blocks
|
99 |
+
self.transformer_blocks = nn.ModuleList(
|
100 |
+
[
|
101 |
+
BasicTransformerBlock(
|
102 |
+
inner_dim,
|
103 |
+
num_attention_heads,
|
104 |
+
attention_head_dim,
|
105 |
+
dropout=dropout,
|
106 |
+
cross_attention_dim=cross_attention_dim,
|
107 |
+
activation_fn=activation_fn,
|
108 |
+
attention_bias=attention_bias,
|
109 |
+
double_self_attention=double_self_attention,
|
110 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
111 |
+
positional_embeddings=positional_embeddings,
|
112 |
+
num_positional_embeddings=num_positional_embeddings,
|
113 |
+
)
|
114 |
+
for d in range(num_layers)
|
115 |
+
]
|
116 |
+
)
|
117 |
+
|
118 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
119 |
+
|
120 |
+
def forward(
|
121 |
+
self,
|
122 |
+
hidden_states: torch.FloatTensor,
|
123 |
+
encoder_hidden_states: Optional[torch.LongTensor] = None,
|
124 |
+
timestep: Optional[torch.LongTensor] = None,
|
125 |
+
class_labels: torch.LongTensor = None,
|
126 |
+
num_frames: int = 1,
|
127 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
128 |
+
return_dict: bool = True,
|
129 |
+
) -> TransformerTemporalModelOutput:
|
130 |
+
"""
|
131 |
+
The [`TransformerTemporal`] forward method.
|
132 |
+
|
133 |
+
Args:
|
134 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
135 |
+
Input hidden_states.
|
136 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
137 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
138 |
+
self-attention.
|
139 |
+
timestep ( `torch.LongTensor`, *optional*):
|
140 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
141 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
142 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
143 |
+
`AdaLayerZeroNorm`.
|
144 |
+
num_frames (`int`, *optional*, defaults to 1):
|
145 |
+
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
146 |
+
cross_attention_kwargs (`dict`, *optional*):
|
147 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
148 |
+
`self.processor` in
|
149 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
150 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
151 |
+
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
152 |
+
tuple.
|
153 |
+
|
154 |
+
Returns:
|
155 |
+
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
156 |
+
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
157 |
+
returned, otherwise a `tuple` where the first element is the sample tensor.
|
158 |
+
"""
|
159 |
+
# 1. Input
|
160 |
+
batch_frames, channel, height, width = hidden_states.shape
|
161 |
+
batch_size = batch_frames // num_frames
|
162 |
+
|
163 |
+
residual = hidden_states
|
164 |
+
|
165 |
+
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
|
166 |
+
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
|
167 |
+
|
168 |
+
hidden_states = self.norm(hidden_states)
|
169 |
+
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
170 |
+
|
171 |
+
hidden_states = self.proj_in(hidden_states)
|
172 |
+
|
173 |
+
# 2. Blocks
|
174 |
+
for block in self.transformer_blocks:
|
175 |
+
hidden_states = block(
|
176 |
+
hidden_states,
|
177 |
+
encoder_hidden_states=encoder_hidden_states,
|
178 |
+
timestep=timestep,
|
179 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
180 |
+
class_labels=class_labels,
|
181 |
+
)
|
182 |
+
|
183 |
+
# 3. Output
|
184 |
+
hidden_states = self.proj_out(hidden_states)
|
185 |
+
hidden_states = (
|
186 |
+
hidden_states[None, None, :]
|
187 |
+
.reshape(batch_size, height, width, num_frames, channel)
|
188 |
+
.permute(0, 3, 4, 1, 2)
|
189 |
+
.contiguous()
|
190 |
+
)
|
191 |
+
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
|
192 |
+
|
193 |
+
output = hidden_states + residual
|
194 |
+
|
195 |
+
if not return_dict:
|
196 |
+
return (output,)
|
197 |
+
|
198 |
+
return TransformerTemporalModelOutput(sample=output)
|
199 |
+
|
200 |
+
|
201 |
+
class TransformerSpatioTemporalModel(nn.Module):
|
202 |
+
"""
|
203 |
+
A Transformer model for video-like data.
|
204 |
+
|
205 |
+
Parameters:
|
206 |
+
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
207 |
+
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
208 |
+
in_channels (`int`, *optional*):
|
209 |
+
The number of channels in the input and output (specify if the input is **continuous**).
|
210 |
+
out_channels (`int`, *optional*):
|
211 |
+
The number of channels in the output (specify if the input is **continuous**).
|
212 |
+
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
213 |
+
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
|
214 |
+
"""
|
215 |
+
|
216 |
+
def __init__(
|
217 |
+
self,
|
218 |
+
num_attention_heads: int = 16,
|
219 |
+
attention_head_dim: int = 88,
|
220 |
+
in_channels: int = 320,
|
221 |
+
out_channels: Optional[int] = None,
|
222 |
+
num_layers: int = 1,
|
223 |
+
cross_attention_dim: Optional[int] = None,
|
224 |
+
):
|
225 |
+
super().__init__()
|
226 |
+
self.num_attention_heads = num_attention_heads
|
227 |
+
self.attention_head_dim = attention_head_dim
|
228 |
+
|
229 |
+
inner_dim = num_attention_heads * attention_head_dim
|
230 |
+
self.inner_dim = inner_dim
|
231 |
+
|
232 |
+
# 2. Define input layers
|
233 |
+
self.in_channels = in_channels
|
234 |
+
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
|
235 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
236 |
+
|
237 |
+
# 3. Define transformers blocks
|
238 |
+
self.transformer_blocks = nn.ModuleList(
|
239 |
+
[
|
240 |
+
BasicTransformerBlock(
|
241 |
+
inner_dim,
|
242 |
+
num_attention_heads,
|
243 |
+
attention_head_dim,
|
244 |
+
cross_attention_dim=cross_attention_dim,
|
245 |
+
)
|
246 |
+
for d in range(num_layers)
|
247 |
+
]
|
248 |
+
)
|
249 |
+
|
250 |
+
time_mix_inner_dim = inner_dim
|
251 |
+
self.temporal_transformer_blocks = nn.ModuleList(
|
252 |
+
[
|
253 |
+
TemporalBasicTransformerBlock(
|
254 |
+
inner_dim,
|
255 |
+
time_mix_inner_dim,
|
256 |
+
num_attention_heads,
|
257 |
+
attention_head_dim,
|
258 |
+
cross_attention_dim=cross_attention_dim,
|
259 |
+
)
|
260 |
+
for _ in range(num_layers)
|
261 |
+
]
|
262 |
+
)
|
263 |
+
|
264 |
+
time_embed_dim = in_channels * 4
|
265 |
+
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
|
266 |
+
self.time_proj = Timesteps(in_channels, True, 0)
|
267 |
+
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
|
268 |
+
|
269 |
+
# 4. Define output layers
|
270 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
271 |
+
# TODO: should use out_channels for continuous projections
|
272 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
273 |
+
|
274 |
+
self.gradient_checkpointing = False
|
275 |
+
|
276 |
+
def forward(
|
277 |
+
self,
|
278 |
+
hidden_states: torch.Tensor,
|
279 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
280 |
+
image_only_indicator: Optional[torch.Tensor] = None,
|
281 |
+
return_dict: bool = True,
|
282 |
+
):
|
283 |
+
"""
|
284 |
+
Args:
|
285 |
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
286 |
+
Input hidden_states.
|
287 |
+
num_frames (`int`):
|
288 |
+
The number of frames to be processed per batch. This is used to reshape the hidden states.
|
289 |
+
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
290 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
291 |
+
self-attention.
|
292 |
+
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
|
293 |
+
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
|
294 |
+
images, 0 indicates that the input contains video frames.
|
295 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
296 |
+
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
|
297 |
+
tuple.
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
|
301 |
+
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
|
302 |
+
returned, otherwise a `tuple` where the first element is the sample tensor.
|
303 |
+
"""
|
304 |
+
# 1. Input
|
305 |
+
batch_frames, _, height, width = hidden_states.shape
|
306 |
+
num_frames = image_only_indicator.shape[-1]
|
307 |
+
batch_size = batch_frames // num_frames
|
308 |
+
|
309 |
+
time_context = encoder_hidden_states
|
310 |
+
time_context_first_timestep = time_context[None, :].reshape(
|
311 |
+
batch_size, num_frames, -1, time_context.shape[-1]
|
312 |
+
)[:, 0] # This part means that the cross attn section for the temporal blocks only consider ths first frames
|
313 |
+
|
314 |
+
|
315 |
+
encoder_hidden_states_dim = time_context_first_timestep.shape[1]
|
316 |
+
time_context = time_context_first_timestep[None, :].broadcast_to(
|
317 |
+
height * width, batch_size, encoder_hidden_states_dim, time_context.shape[-1]
|
318 |
+
)
|
319 |
+
time_context = time_context.reshape(height * width * batch_size, encoder_hidden_states_dim, time_context.shape[-1])
|
320 |
+
|
321 |
+
residual = hidden_states
|
322 |
+
|
323 |
+
hidden_states = self.norm(hidden_states)
|
324 |
+
inner_dim = hidden_states.shape[1]
|
325 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
|
326 |
+
hidden_states = self.proj_in(hidden_states)
|
327 |
+
|
328 |
+
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
|
329 |
+
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
|
330 |
+
num_frames_emb = num_frames_emb.reshape(-1)
|
331 |
+
t_emb = self.time_proj(num_frames_emb)
|
332 |
+
|
333 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
334 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
335 |
+
# there might be better ways to encapsulate this.
|
336 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
337 |
+
|
338 |
+
emb = self.time_pos_embed(t_emb)
|
339 |
+
emb = emb[:, None, :]
|
340 |
+
|
341 |
+
# 2. Blocks
|
342 |
+
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
|
343 |
+
if self.training and self.gradient_checkpointing:
|
344 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
345 |
+
block,
|
346 |
+
hidden_states,
|
347 |
+
None,
|
348 |
+
encoder_hidden_states,
|
349 |
+
None,
|
350 |
+
use_reentrant=False,
|
351 |
+
)
|
352 |
+
else:
|
353 |
+
hidden_states = block(
|
354 |
+
hidden_states,
|
355 |
+
encoder_hidden_states=encoder_hidden_states,
|
356 |
+
)
|
357 |
+
|
358 |
+
hidden_states_mix = hidden_states
|
359 |
+
hidden_states_mix = hidden_states_mix + emb
|
360 |
+
|
361 |
+
hidden_states_mix = temporal_block(
|
362 |
+
hidden_states_mix,
|
363 |
+
num_frames=num_frames,
|
364 |
+
encoder_hidden_states=time_context,
|
365 |
+
)
|
366 |
+
hidden_states = self.time_mixer(
|
367 |
+
x_spatial=hidden_states,
|
368 |
+
x_temporal=hidden_states_mix,
|
369 |
+
image_only_indicator=image_only_indicator,
|
370 |
+
)
|
371 |
+
|
372 |
+
# 3. Output
|
373 |
+
hidden_states = self.proj_out(hidden_states)
|
374 |
+
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
|
375 |
+
|
376 |
+
output = hidden_states + residual
|
377 |
+
|
378 |
+
if not return_dict:
|
379 |
+
return (output,)
|
380 |
+
|
381 |
+
return TransformerTemporalModelOutput(sample=output)
|