koukyo1994
commited on
add inference.py
Browse files- inference.py +185 -0
inference.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import imageio
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from transformers import AutoModel
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
|
14 |
+
# Constants
|
15 |
+
IMAGE_SIZE = (288, 512)
|
16 |
+
N_FRAMES_PER_ROUND = 25
|
17 |
+
MAX_NUM_FRAMES = 50
|
18 |
+
N_TOKENS_PER_FRAME = 576
|
19 |
+
TRAJ_TEMPLATE_PATH = Path("./assets/template_trajectory.json")
|
20 |
+
PATH_START_ID = 9
|
21 |
+
PATH_POINT_INTERVAL = 10
|
22 |
+
N_ACTION_TOKENS = 6
|
23 |
+
|
24 |
+
# change here if you want to use your own images
|
25 |
+
CONDITIONING_FRAMES_DIR = Path("./assets/conditioning_frames")
|
26 |
+
CONDITIONING_FRAMES_PATH_LIST = [
|
27 |
+
CONDITIONING_FRAMES_DIR / "001.png",
|
28 |
+
CONDITIONING_FRAMES_DIR / "002.png",
|
29 |
+
CONDITIONING_FRAMES_DIR / "003.png"
|
30 |
+
]
|
31 |
+
|
32 |
+
|
33 |
+
def set_random_seed(seed: int = 0):
|
34 |
+
random.seed(seed)
|
35 |
+
np.random.seed(seed)
|
36 |
+
torch.manual_seed(seed)
|
37 |
+
torch.cuda.manual_seed(seed)
|
38 |
+
torch.backends.cudnn.deterministic = True
|
39 |
+
|
40 |
+
|
41 |
+
def preprocess_image(image: Image.Image, size: tuple[int, int] = (288, 512)) -> torch.Tensor:
|
42 |
+
H, W = size
|
43 |
+
image = image.convert("RGB")
|
44 |
+
image = image.resize((W, H))
|
45 |
+
image_array = np.array(image)
|
46 |
+
image_array = (image_array / 127.5 - 1.0).astype(np.float32)
|
47 |
+
return torch.from_numpy(image_array).permute(2, 0, 1).unsqueeze(0).float()
|
48 |
+
|
49 |
+
|
50 |
+
def to_np_images(images: torch.Tensor) -> np.ndarray:
|
51 |
+
images = images.detach().cpu()
|
52 |
+
images = torch.clamp(images, -1., 1.)
|
53 |
+
images = (images + 1.) / 2.
|
54 |
+
images = images.permute(0, 2, 3, 1).numpy()
|
55 |
+
return (255 * images).astype(np.uint8)
|
56 |
+
|
57 |
+
|
58 |
+
def load_images(file_path_list: list[Path], size: tuple[int, int] = (288, 512)) -> torch.Tensor:
|
59 |
+
images = []
|
60 |
+
for file_path in file_path_list:
|
61 |
+
image = Image.open(file_path)
|
62 |
+
image = preprocess_image(image, size)
|
63 |
+
images.append(image)
|
64 |
+
return torch.cat(images, dim=0)
|
65 |
+
|
66 |
+
|
67 |
+
def save_images_to_mp4(images: np.ndarray, output_path: Path, fps: int = 10):
|
68 |
+
writer = imageio.get_writer(output_path, fps=fps)
|
69 |
+
for img in images:
|
70 |
+
writer.append_data(img)
|
71 |
+
writer.close()
|
72 |
+
|
73 |
+
|
74 |
+
def determine_num_rounds(num_frames: int, num_overlapping_frames: int, n_initial_frames: int) -> int:
|
75 |
+
n_rounds = (num_frames - n_initial_frames) // (N_FRAMES_PER_ROUND - num_overlapping_frames)
|
76 |
+
if (num_frames - n_initial_frames) % (N_FRAMES_PER_ROUND - num_overlapping_frames) > 0:
|
77 |
+
n_rounds += 1
|
78 |
+
return n_rounds
|
79 |
+
|
80 |
+
|
81 |
+
def prepare_action(
|
82 |
+
traj_template: dict,
|
83 |
+
cmd: str,
|
84 |
+
path_start_id: int,
|
85 |
+
path_point_interval: int,
|
86 |
+
n_action_tokens: int = 5,
|
87 |
+
start_index: int = 0,
|
88 |
+
n_frames: int = 25
|
89 |
+
) -> torch.Tensor:
|
90 |
+
trajs = traj_template[cmd]["instruction_trajs"]
|
91 |
+
actions = []
|
92 |
+
timesteps = np.arange(0.0, 3.0, 0.05)
|
93 |
+
for i in range(start_index, start_index + n_frames):
|
94 |
+
traj = trajs[i][path_start_id::path_point_interval][:n_action_tokens]
|
95 |
+
action = np.array(traj)
|
96 |
+
timestep = timesteps[path_start_id::path_point_interval][:n_action_tokens]
|
97 |
+
action = np.concatenate([
|
98 |
+
action[:, [1, 0]],
|
99 |
+
timestep.reshape(-1, 1)
|
100 |
+
], axis=1)
|
101 |
+
actions.append(torch.tensor(action))
|
102 |
+
return torch.cat(actions, dim=0)
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
parser = argparse.ArgumentParser()
|
107 |
+
parser.add_argument("--seed", type=int, default=0)
|
108 |
+
parser.add_argument("--output_dir", type=Path)
|
109 |
+
parser.add_argument("--cmd", type=str, default="curving_to_left/curving_to_left_moderate")
|
110 |
+
parser.add_argument("--num_frames", type=int, default=25)
|
111 |
+
parser.add_argument("--num_overlapping_frames", type=int, default=3)
|
112 |
+
args = parser.parse_args()
|
113 |
+
|
114 |
+
assert args.num_frames <= MAX_NUM_FRAMES, f"`num_frames` should be less than or equal to {MAX_NUM_FRAMES}"
|
115 |
+
assert args.num_overlapping_frames < N_FRAMES_PER_ROUND, f"`num_overlapping_frames` should be less than {N_FRAMES_PER_ROUND}"
|
116 |
+
|
117 |
+
set_random_seed(args.seed)
|
118 |
+
if args.output_dir is None:
|
119 |
+
output_dir = Path(f"./outputs/{args.cmd}")
|
120 |
+
else:
|
121 |
+
output_dir = args.output_dir
|
122 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
123 |
+
|
124 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
125 |
+
tokenizer = AutoModel.from_pretrained("turing-motors/Terra", subfolder="lfq_tokenizer_B_256", trust_remote_code=True).to(device).eval()
|
126 |
+
model = AutoModel.from_pretrained("turing-motors/Terra", subfolder="world_model", trust_remote_code=True).to(device).eval()
|
127 |
+
|
128 |
+
conditioning_frames = load_images(CONDITIONING_FRAMES_PATH_LIST, IMAGE_SIZE).to(device)
|
129 |
+
with torch.inference_mode(), torch.autocast(device_type="cuda"):
|
130 |
+
input_ids = tokenizer.tokenize(conditioning_frames).detach().unsqueeze(0)
|
131 |
+
|
132 |
+
num_rounds = determine_num_rounds(args.num_frames, args.num_overlapping_frames, len(CONDITIONING_FRAMES_PATH_LIST))
|
133 |
+
print(f"Number of generation rounds: {num_rounds}")
|
134 |
+
|
135 |
+
with open(TRAJ_TEMPLATE_PATH) as f:
|
136 |
+
traj_template = json.load(f)
|
137 |
+
|
138 |
+
all_outputs = []
|
139 |
+
for round in range(num_rounds):
|
140 |
+
start_index = round * (N_FRAMES_PER_ROUND - args.num_overlapping_frames)
|
141 |
+
num_frames_for_round = min(N_FRAMES_PER_ROUND, args.num_frames - start_index)
|
142 |
+
actions = prepare_action(
|
143 |
+
traj_template, args.cmd, PATH_START_ID, PATH_POINT_INTERVAL, N_ACTION_TOKENS, start_index, num_frames_for_round
|
144 |
+
).unsqueeze(0).to(device).float()
|
145 |
+
if round == 0:
|
146 |
+
num_generated_tokens = N_TOKENS_PER_FRAME * (num_frames_for_round - len(CONDITIONING_FRAMES_PATH_LIST))
|
147 |
+
else:
|
148 |
+
num_generated_tokens = N_TOKENS_PER_FRAME * (num_frames_for_round - args.num_overlapping_frames)
|
149 |
+
progress_bar = tqdm(total=num_generated_tokens, desc=f"Round {round + 1}")
|
150 |
+
with torch.inference_mode(), torch.autocast(device_type="cuda"):
|
151 |
+
output_tokens = model.generate(
|
152 |
+
input_ids=input_ids,
|
153 |
+
actions=actions,
|
154 |
+
do_sample=True,
|
155 |
+
max_length=N_TOKENS_PER_FRAME * num_frames_for_round,
|
156 |
+
temperature=1.0,
|
157 |
+
top_p=1.0,
|
158 |
+
use_cache=True,
|
159 |
+
pad_token_id=None,
|
160 |
+
eos_token_id=None,
|
161 |
+
progress_bar=progress_bar
|
162 |
+
)
|
163 |
+
if round == 0:
|
164 |
+
all_outputs.append(output_tokens[0])
|
165 |
+
else:
|
166 |
+
all_outputs.append(output_tokens[0, args.num_overlapping_frames * N_TOKENS_PER_FRAME:])
|
167 |
+
input_ids = output_tokens[:, -args.num_overlapping_frames * N_TOKENS_PER_FRAME:]
|
168 |
+
progress_bar.close()
|
169 |
+
|
170 |
+
output_ids = torch.cat(all_outputs)
|
171 |
+
|
172 |
+
# Calculate the shape of the latent tensor
|
173 |
+
downsample_ratio = 1
|
174 |
+
for coef in tokenizer.config.encoder_decoder_config["ch_mult"]:
|
175 |
+
downsample_ratio *= coef
|
176 |
+
h = IMAGE_SIZE[0] // downsample_ratio
|
177 |
+
w = IMAGE_SIZE[1] // downsample_ratio
|
178 |
+
c = tokenizer.config.encoder_decoder_config["z_channels"]
|
179 |
+
latent_shape = (len(output_ids) // 576, h, w, c)
|
180 |
+
|
181 |
+
# Decode the latent tensor to images
|
182 |
+
with torch.inference_mode(), torch.autocast(device_type="cuda"):
|
183 |
+
reconstructed = tokenizer.decode_tokens(output_ids, latent_shape)
|
184 |
+
reconstructed_images = to_np_images(reconstructed)
|
185 |
+
save_images_to_mp4(reconstructed_images, output_dir / "generated.mp4", fps=10)
|