tree3po commited on
Commit
d983af9
1 Parent(s): 62b915d

Delete open-oasis-master

Browse files
open-oasis-master/.gitattributes DELETED
@@ -1 +0,0 @@
1
- video.mp4 filter=lfs diff=lfs merge=lfs -text
 
 
open-oasis-master/LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2024 Etched & Decart
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-oasis-master/README.md DELETED
@@ -1,37 +0,0 @@
1
- # Oasis 500M
2
-
3
- ![](./media/arch.png)
4
-
5
- ![](./media/thumb.png)
6
-
7
- Oasis is an interactive world model developed by [Decart](https://www.decart.ai/) and [Etched](https://www.etched.com/). Based on diffusion transformers, Oasis takes in user keyboard input and generates gameplay in an autoregressive manner. We release the weights for Oasis 500M, a downscaled version of the model, along with inference code for action-conditional frame generation.
8
-
9
- For more details, see our [joint blog post](https://oasis-model.github.io/) to learn more.
10
-
11
- And to use the most powerful version of the model, be sure to check out the [live demo](https://oasis.us.decart.ai/) as well!
12
-
13
- ## Setup
14
- ```
15
- git clone https://github.com/etched-ai/open-oasis.git
16
- cd open-oasis
17
- pip install -r requirements.txt
18
- ```
19
-
20
- ## Download the model weights
21
- ```
22
- huggingface-cli login
23
- huggingface-cli download Etched/oasis-500m oasis500m.pt # DiT checkpoint
24
- huggingface-cli download Etched/oasis-500m vit-l-20.pt # ViT VAE checkpoint
25
- ```
26
-
27
- ## Basic Usage
28
- We include a basic inference script that loads a prompt frame from a video and generates additional frames conditioned on actions.
29
- ```
30
- python generate.py
31
- ```
32
- The resulting video will be saved to `video.mp4`. Here's are some examples of a generation from this 500M model!
33
-
34
- ![](media/sample_0.gif)
35
- ![](media/sample_1.gif)
36
-
37
- > Hint: try swapping out the `.mp4` input file in the script to try different environments!
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-oasis-master/attention.py DELETED
@@ -1,137 +0,0 @@
1
- """
2
- Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py
3
- """
4
- from typing import Optional
5
- from collections import namedtuple
6
- import torch
7
- from torch import nn
8
- from torch.nn import functional as F
9
- from einops import rearrange
10
- from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
11
- from embeddings import TimestepEmbedding, Timesteps, Positions2d
12
-
13
- class TemporalAxialAttention(nn.Module):
14
- def __init__(
15
- self,
16
- dim: int,
17
- heads: int = 4,
18
- dim_head: int = 32,
19
- is_causal: bool = True,
20
- rotary_emb: Optional[RotaryEmbedding] = None,
21
- ):
22
- super().__init__()
23
- self.inner_dim = dim_head * heads
24
- self.heads = heads
25
- self.head_dim = dim_head
26
- self.inner_dim = dim_head * heads
27
- self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
28
- self.to_out = nn.Linear(self.inner_dim, dim)
29
-
30
- self.rotary_emb = rotary_emb
31
- self.time_pos_embedding = (
32
- nn.Sequential(
33
- Timesteps(dim),
34
- TimestepEmbedding(in_channels=dim, time_embed_dim=dim * 4, out_dim=dim),
35
- )
36
- if rotary_emb is None
37
- else None
38
- )
39
- self.is_causal = is_causal
40
-
41
- def forward(self, x: torch.Tensor):
42
- B, T, H, W, D = x.shape
43
-
44
- if self.time_pos_embedding is not None:
45
- time_emb = self.time_pos_embedding(
46
- torch.arange(T, device=x.device)
47
- )
48
- x = x + rearrange(time_emb, "t d -> 1 t 1 1 d")
49
-
50
- q, k, v = self.to_qkv(x).chunk(3, dim=-1)
51
-
52
- q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads)
53
- k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads)
54
- v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads)
55
-
56
- if self.rotary_emb is not None:
57
- q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs)
58
- k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs)
59
-
60
- q, k, v = map(lambda t: t.contiguous(), (q, k, v))
61
-
62
- x = F.scaled_dot_product_attention(
63
- query=q, key=k, value=v, is_causal=self.is_causal
64
- )
65
-
66
- x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W)
67
- x = x.to(q.dtype)
68
-
69
- # linear proj
70
- x = self.to_out(x)
71
- return x
72
-
73
- class SpatialAxialAttention(nn.Module):
74
- def __init__(
75
- self,
76
- dim: int,
77
- heads: int = 4,
78
- dim_head: int = 32,
79
- rotary_emb: Optional[RotaryEmbedding] = None,
80
- ):
81
- super().__init__()
82
- self.inner_dim = dim_head * heads
83
- self.heads = heads
84
- self.head_dim = dim_head
85
- self.inner_dim = dim_head * heads
86
- self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False)
87
- self.to_out = nn.Linear(self.inner_dim, dim)
88
-
89
- self.rotary_emb = rotary_emb
90
- self.space_pos_embedding = (
91
- nn.Sequential(
92
- Positions2d(dim),
93
- TimestepEmbedding(in_channels=dim, time_embed_dim=dim * 4, out_dim=dim),
94
- )
95
- if rotary_emb is None
96
- else None
97
- )
98
-
99
- def forward(self, x: torch.Tensor):
100
- B, T, H, W, D = x.shape
101
-
102
- if self.space_pos_embedding is not None:
103
- h_steps = torch.arange(H, device=x.device)
104
- w_steps = torch.arange(W, device=x.device)
105
- grid = torch.meshgrid(h_steps, w_steps, indexing="ij")
106
- space_emb = self.space_pos_embedding(grid)
107
- x = x + rearrange(space_emb, "h w d -> 1 1 h w d")
108
-
109
- q, k, v = self.to_qkv(x).chunk(3, dim=-1)
110
-
111
- q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads)
112
- k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads)
113
- v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads)
114
-
115
- if self.rotary_emb is not None:
116
- freqs = self.rotary_emb.get_axial_freqs(H, W)
117
- q = apply_rotary_emb(freqs, q)
118
- k = apply_rotary_emb(freqs, k)
119
-
120
- # prepare for attn
121
- q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
122
- k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
123
- v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads)
124
-
125
- q, k, v = map(lambda t: t.contiguous(), (q, k, v))
126
-
127
- x = F.scaled_dot_product_attention(
128
- query=q, key=k, value=v, is_causal=False
129
- )
130
-
131
- x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W)
132
- x = x.to(q.dtype)
133
-
134
- # linear proj
135
- x = self.to_out(x)
136
- return x
137
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-oasis-master/dit.py DELETED
@@ -1,310 +0,0 @@
1
- """
2
- References:
3
- - DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4
- - Diffusion Forcing: https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/unet3d.py
5
- - Latte: https://github.com/Vchitect/Latte/blob/main/models/latte.py
6
- """
7
- from typing import Optional, Literal
8
- import torch
9
- from torch import nn
10
- from rotary_embedding_torch import RotaryEmbedding
11
- from einops import rearrange
12
- from embeddings import Timesteps, TimestepEmbedding
13
- from attention import SpatialAxialAttention, TemporalAxialAttention
14
- from timm.models.vision_transformer import Mlp
15
- from timm.layers.helpers import to_2tuple
16
- import math
17
-
18
- def modulate(x, shift, scale):
19
- fixed_dims = [1] * len(shift.shape[1:])
20
- shift = shift.repeat(x.shape[0] // shift.shape[0], *fixed_dims)
21
- scale = scale.repeat(x.shape[0] // scale.shape[0], *fixed_dims)
22
- while shift.dim() < x.dim():
23
- shift = shift.unsqueeze(-2)
24
- scale = scale.unsqueeze(-2)
25
- return x * (1 + scale) + shift
26
-
27
- def gate(x, g):
28
- fixed_dims = [1] * len(g.shape[1:])
29
- g = g.repeat(x.shape[0] // g.shape[0], *fixed_dims)
30
- while g.dim() < x.dim():
31
- g = g.unsqueeze(-2)
32
- return g * x
33
-
34
- class PatchEmbed(nn.Module):
35
- """2D Image to Patch Embedding"""
36
-
37
- def __init__(
38
- self,
39
- img_height=256,
40
- img_width=256,
41
- patch_size=16,
42
- in_chans=3,
43
- embed_dim=768,
44
- norm_layer=None,
45
- flatten=True,
46
- ):
47
- super().__init__()
48
- img_size = (img_height, img_width)
49
- patch_size = to_2tuple(patch_size)
50
- self.img_size = img_size
51
- self.patch_size = patch_size
52
- self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
53
- self.num_patches = self.grid_size[0] * self.grid_size[1]
54
- self.flatten = flatten
55
-
56
- self.proj = nn.Conv2d(
57
- in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
58
- )
59
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
60
-
61
- def forward(self, x, random_sample=False):
62
- B, C, H, W = x.shape
63
- assert random_sample or (
64
- H == self.img_size[0] and W == self.img_size[1]
65
- ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
66
- x = self.proj(x)
67
- if self.flatten:
68
- x = rearrange(x, "B C H W -> B (H W) C")
69
- else:
70
- x = rearrange(x, "B C H W -> B H W C")
71
- x = self.norm(x)
72
- return x
73
-
74
- class TimestepEmbedder(nn.Module):
75
- """
76
- Embeds scalar timesteps into vector representations.
77
- """
78
- def __init__(self, hidden_size, frequency_embedding_size=256):
79
- super().__init__()
80
- self.mlp = nn.Sequential(
81
- nn.Linear(frequency_embedding_size, hidden_size, bias=True), # hidden_size is diffusion model hidden size
82
- nn.SiLU(),
83
- nn.Linear(hidden_size, hidden_size, bias=True),
84
- )
85
- self.frequency_embedding_size = frequency_embedding_size
86
-
87
- @staticmethod
88
- def timestep_embedding(t, dim, max_period=10000):
89
- """
90
- Create sinusoidal timestep embeddings.
91
- :param t: a 1-D Tensor of N indices, one per batch element.
92
- These may be fractional.
93
- :param dim: the dimension of the output.
94
- :param max_period: controls the minimum frequency of the embeddings.
95
- :return: an (N, D) Tensor of positional embeddings.
96
- """
97
- # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
98
- half = dim // 2
99
- freqs = torch.exp(
100
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
101
- ).to(device=t.device)
102
- args = t[:, None].float() * freqs[None]
103
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
104
- if dim % 2:
105
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
106
- return embedding
107
-
108
- def forward(self, t):
109
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
110
- t_emb = self.mlp(t_freq)
111
- return t_emb
112
-
113
- class FinalLayer(nn.Module):
114
- """
115
- The final layer of DiT.
116
- """
117
- def __init__(self, hidden_size, patch_size, out_channels):
118
- super().__init__()
119
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
120
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
121
- self.adaLN_modulation = nn.Sequential(
122
- nn.SiLU(),
123
- nn.Linear(hidden_size, 2 * hidden_size, bias=True)
124
- )
125
-
126
- def forward(self, x, c):
127
- shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
128
- x = modulate(self.norm_final(x), shift, scale)
129
- x = self.linear(x)
130
- return x
131
-
132
- class SpatioTemporalDiTBlock(nn.Module):
133
- def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, is_causal=True, spatial_rotary_emb: Optional[RotaryEmbedding] = None, temporal_rotary_emb: Optional[RotaryEmbedding] = None):
134
- super().__init__()
135
- self.is_causal = is_causal
136
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
137
- approx_gelu = lambda: nn.GELU(approximate="tanh")
138
-
139
- self.s_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
140
- self.s_attn = SpatialAxialAttention(hidden_size, heads=num_heads, dim_head=hidden_size // num_heads, rotary_emb=spatial_rotary_emb)
141
- self.s_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
142
- self.s_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
143
- self.s_adaLN_modulation = nn.Sequential(
144
- nn.SiLU(),
145
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
146
- )
147
-
148
- self.t_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
149
- self.t_attn = TemporalAxialAttention(hidden_size, heads=num_heads, dim_head=hidden_size // num_heads, is_causal=is_causal, rotary_emb=temporal_rotary_emb)
150
- self.t_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
151
- self.t_mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
152
- self.t_adaLN_modulation = nn.Sequential(
153
- nn.SiLU(),
154
- nn.Linear(hidden_size, 6 * hidden_size, bias=True)
155
- )
156
-
157
- def forward(self, x, c):
158
- B, T, H, W, D = x.shape
159
-
160
- # spatial block
161
- s_shift_msa, s_scale_msa, s_gate_msa, s_shift_mlp, s_scale_mlp, s_gate_mlp = self.s_adaLN_modulation(c).chunk(6, dim=-1)
162
- x = x + gate(self.s_attn(modulate(self.s_norm1(x), s_shift_msa, s_scale_msa)), s_gate_msa)
163
- x = x + gate(self.s_mlp(modulate(self.s_norm2(x), s_shift_mlp, s_scale_mlp)), s_gate_mlp)
164
-
165
- # temporal block
166
- t_shift_msa, t_scale_msa, t_gate_msa, t_shift_mlp, t_scale_mlp, t_gate_mlp = self.t_adaLN_modulation(c).chunk(6, dim=-1)
167
- x = x + gate(self.t_attn(modulate(self.t_norm1(x), t_shift_msa, t_scale_msa)), t_gate_msa)
168
- x = x + gate(self.t_mlp(modulate(self.t_norm2(x), t_shift_mlp, t_scale_mlp)), t_gate_mlp)
169
-
170
- return x
171
-
172
- class DiT(nn.Module):
173
- """
174
- Diffusion model with a Transformer backbone.
175
- """
176
- def __init__(
177
- self,
178
- input_h=18,
179
- input_w=32,
180
- patch_size=2,
181
- in_channels=16,
182
- hidden_size=1024,
183
- depth=12,
184
- num_heads=16,
185
- mlp_ratio=4.0,
186
- external_cond_dim=25,
187
- max_frames=32,
188
- ):
189
- super().__init__()
190
- self.in_channels = in_channels
191
- self.out_channels = in_channels
192
- self.patch_size = patch_size
193
- self.num_heads = num_heads
194
- self.max_frames = max_frames
195
-
196
- self.x_embedder = PatchEmbed(input_h, input_w, patch_size, in_channels, hidden_size, flatten=False)
197
- self.t_embedder = TimestepEmbedder(hidden_size)
198
- frame_h, frame_w = self.x_embedder.grid_size
199
-
200
- self.spatial_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads // 2, freqs_for="pixel", max_freq=256)
201
- self.temporal_rotary_emb = RotaryEmbedding(dim=hidden_size // num_heads)
202
- self.external_cond = nn.Linear(external_cond_dim, hidden_size) if external_cond_dim > 0 else nn.Identity()
203
-
204
- self.blocks = nn.ModuleList(
205
- [
206
- SpatioTemporalDiTBlock(
207
- hidden_size,
208
- num_heads,
209
- mlp_ratio=mlp_ratio,
210
- is_causal=True,
211
- spatial_rotary_emb=self.spatial_rotary_emb,
212
- temporal_rotary_emb=self.temporal_rotary_emb,
213
- )
214
- for _ in range(depth)
215
- ]
216
- )
217
-
218
- self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
219
- self.initialize_weights()
220
-
221
- def initialize_weights(self):
222
- # Initialize transformer layers:
223
- def _basic_init(module):
224
- if isinstance(module, nn.Linear):
225
- torch.nn.init.xavier_uniform_(module.weight)
226
- if module.bias is not None:
227
- nn.init.constant_(module.bias, 0)
228
- self.apply(_basic_init)
229
-
230
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
231
- w = self.x_embedder.proj.weight.data
232
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
233
- nn.init.constant_(self.x_embedder.proj.bias, 0)
234
-
235
- # Initialize timestep embedding MLP:
236
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
237
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
238
-
239
- # Zero-out adaLN modulation layers in DiT blocks:
240
- for block in self.blocks:
241
- nn.init.constant_(block.s_adaLN_modulation[-1].weight, 0)
242
- nn.init.constant_(block.s_adaLN_modulation[-1].bias, 0)
243
- nn.init.constant_(block.t_adaLN_modulation[-1].weight, 0)
244
- nn.init.constant_(block.t_adaLN_modulation[-1].bias, 0)
245
-
246
- # Zero-out output layers:
247
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
248
- nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
249
- nn.init.constant_(self.final_layer.linear.weight, 0)
250
- nn.init.constant_(self.final_layer.linear.bias, 0)
251
-
252
- def unpatchify(self, x):
253
- """
254
- x: (N, H, W, patch_size**2 * C)
255
- imgs: (N, H, W, C)
256
- """
257
- c = self.out_channels
258
- p = self.x_embedder.patch_size[0]
259
- h = x.shape[1]
260
- w = x.shape[2]
261
-
262
- x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
263
- x = torch.einsum('nhwpqc->nchpwq', x)
264
- imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
265
- return imgs
266
-
267
- def forward(self, x, t, external_cond=None):
268
- """
269
- Forward pass of DiT.
270
- x: (B, T, C, H, W) tensor of spatial inputs (images or latent representations of images)
271
- t: (B, T,) tensor of diffusion timesteps
272
- """
273
-
274
- B, T, C, H, W = x.shape
275
-
276
- # add spatial embeddings
277
- x = rearrange(x, "b t c h w -> (b t) c h w")
278
- x = self.x_embedder(x) # (B*T, C, H, W) -> (B*T, H/2, W/2, D) , C = 16, D = d_model
279
- # restore shape
280
- x = rearrange(x, "(b t) h w d -> b t h w d", t = T)
281
- # embed noise steps
282
- t = rearrange(t, "b t -> (b t)")
283
- c = self.t_embedder(t) # (N, D)
284
- c = rearrange(c, "(b t) d -> b t d", t = T)
285
- if torch.is_tensor(external_cond):
286
- c += self.external_cond(external_cond)
287
- for block in self.blocks:
288
- x = block(x, c) # (N, T, H, W, D)
289
- x = self.final_layer(x, c) # (N, T, H, W, patch_size ** 2 * out_channels)
290
- # unpatchify
291
- x = rearrange(x, "b t h w d -> (b t) h w d")
292
- x = self.unpatchify(x) # (N, out_channels, H, W)
293
- x = rearrange(x, "(b t) c h w -> b t c h w", t = T)
294
-
295
- return x
296
-
297
- def DiT_S_2():
298
- return DiT(
299
- patch_size=2,
300
- hidden_size=1024,
301
- depth=16,
302
- num_heads=16,
303
- )
304
-
305
- DiT_models = {
306
- "DiT-S/2": DiT_S_2
307
- }
308
-
309
-
310
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-oasis-master/embeddings.py DELETED
@@ -1,103 +0,0 @@
1
- """
2
- Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
3
- """
4
-
5
- from typing import Optional
6
- import math
7
- import torch
8
- from torch import nn
9
-
10
- # pylint: disable=unused-import
11
- from diffusers.models.embeddings import TimestepEmbedding
12
-
13
-
14
- class Timesteps(nn.Module):
15
- def __init__(
16
- self,
17
- num_channels: int,
18
- flip_sin_to_cos: bool = True,
19
- downscale_freq_shift: float = 0,
20
- ):
21
- super().__init__()
22
- self.num_channels = num_channels
23
- self.flip_sin_to_cos = flip_sin_to_cos
24
- self.downscale_freq_shift = downscale_freq_shift
25
-
26
- def forward(self, timesteps):
27
- t_emb = get_timestep_embedding(
28
- timesteps,
29
- self.num_channels,
30
- flip_sin_to_cos=self.flip_sin_to_cos,
31
- downscale_freq_shift=self.downscale_freq_shift,
32
- )
33
- return t_emb
34
-
35
- class Positions2d(nn.Module):
36
- def __init__(
37
- self,
38
- num_channels: int,
39
- flip_sin_to_cos: bool = True,
40
- downscale_freq_shift: float = 0,
41
- ):
42
- super().__init__()
43
- self.num_channels = num_channels
44
- self.flip_sin_to_cos = flip_sin_to_cos
45
- self.downscale_freq_shift = downscale_freq_shift
46
-
47
- def forward(self, grid):
48
- h_emb = get_timestep_embedding(
49
- grid[0],
50
- self.num_channels // 2,
51
- flip_sin_to_cos=self.flip_sin_to_cos,
52
- downscale_freq_shift=self.downscale_freq_shift,
53
- )
54
- w_emb = get_timestep_embedding(
55
- grid[1],
56
- self.num_channels // 2,
57
- flip_sin_to_cos=self.flip_sin_to_cos,
58
- downscale_freq_shift=self.downscale_freq_shift,
59
- )
60
- emb = torch.cat((h_emb, w_emb), dim=-1)
61
- return emb
62
-
63
-
64
- def get_timestep_embedding(
65
- timesteps: torch.Tensor,
66
- embedding_dim: int,
67
- flip_sin_to_cos: bool = False,
68
- downscale_freq_shift: float = 1,
69
- scale: float = 1,
70
- max_period: int = 10000,
71
- ):
72
- """
73
- This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
74
-
75
- :param timesteps: a 1-D or 2-D Tensor of N indices, one per batch element.
76
- These may be fractional.
77
- :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
78
- embeddings. :return: an [N x dim] or [N x M x dim] Tensor of positional embeddings.
79
- """
80
- if len(timesteps.shape) not in [1, 2]:
81
- raise ValueError("Timesteps should be a 1D or 2D tensor")
82
-
83
- half_dim = embedding_dim // 2
84
- exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
85
- exponent = exponent / (half_dim - downscale_freq_shift)
86
-
87
- emb = torch.exp(exponent)
88
- emb = timesteps[..., None].float() * emb
89
-
90
- # scale embeddings
91
- emb = scale * emb
92
-
93
- # concat sine and cosine embeddings
94
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
95
-
96
- # flip sine and cosine embeddings
97
- if flip_sin_to_cos:
98
- emb = torch.cat([emb[..., half_dim:], emb[..., :half_dim]], dim=-1)
99
-
100
- # zero pad
101
- if embedding_dim % 2 == 1:
102
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
103
- return emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-oasis-master/generate.py DELETED
@@ -1,119 +0,0 @@
1
- """
2
- References:
3
- - Diffusion Forcing: https://github.com/buoyancy99/diffusion-forcing
4
- """
5
- import torch
6
- from dit import DiT_models
7
- from vae import VAE_models
8
- from torchvision.io import read_video, write_video
9
- from utils import one_hot_actions, sigmoid_beta_schedule
10
- from tqdm import tqdm
11
- from einops import rearrange
12
- from torch import autocast
13
- assert torch.cuda.is_available()
14
- device = "cuda:0"
15
-
16
- # load DiT checkpoint
17
- ckpt = torch.load("oasis500m.pt")
18
- model = DiT_models["DiT-S/2"]()
19
- model.load_state_dict(ckpt, strict=False)
20
- model = model.to(device).eval()
21
-
22
- # load VAE checkpoint
23
- vae_ckpt = torch.load("vit-l-20.pt")
24
- vae = VAE_models["vit-l-20-shallow-encoder"]()
25
- vae.load_state_dict(vae_ckpt)
26
- vae = vae.to(device).eval()
27
-
28
- # sampling params
29
- B = 1
30
- total_frames = 32
31
- max_noise_level = 1000
32
- ddim_noise_steps = 100
33
- noise_range = torch.linspace(-1, max_noise_level - 1, ddim_noise_steps + 1)
34
- noise_abs_max = 20
35
- ctx_max_noise_idx = ddim_noise_steps // 10 * 3
36
-
37
- # get input video
38
- video_id = "snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001"
39
- mp4_path = f"sample_data/{video_id}.mp4"
40
- actions_path = f"sample_data/{video_id}.actions.pt"
41
- video = read_video(mp4_path, pts_unit="sec")[0].float() / 255
42
- actions = one_hot_actions(torch.load(actions_path))
43
- offset = 100
44
- video = video[offset:offset+total_frames].unsqueeze(0)
45
- actions = actions[offset:offset+total_frames].unsqueeze(0)
46
-
47
- # sampling inputs
48
- n_prompt_frames = 1
49
- x = video[:, :n_prompt_frames]
50
- x = x.to(device)
51
- actions = actions.to(device)
52
-
53
- # vae encoding
54
- scaling_factor = 0.07843137255
55
- x = rearrange(x, "b t h w c -> (b t) c h w")
56
- H, W = x.shape[-2:]
57
- with torch.no_grad():
58
- x = vae.encode(x * 2 - 1).mean * scaling_factor
59
- x = rearrange(x, "(b t) (h w) c -> b t c h w", t=n_prompt_frames, h=H//vae.patch_size, w=W//vae.patch_size)
60
-
61
- # get alphas
62
- betas = sigmoid_beta_schedule(max_noise_level).to(device)
63
- alphas = 1.0 - betas
64
- alphas_cumprod = torch.cumprod(alphas, dim=0)
65
- alphas_cumprod = rearrange(alphas_cumprod, "T -> T 1 1 1")
66
-
67
- # sampling loop
68
- for i in tqdm(range(n_prompt_frames, total_frames)):
69
- chunk = torch.randn((B, 1, *x.shape[-3:]), device=device)
70
- chunk = torch.clamp(chunk, -noise_abs_max, +noise_abs_max)
71
- x = torch.cat([x, chunk], dim=1)
72
- start_frame = max(0, i + 1 - model.max_frames)
73
-
74
- for noise_idx in reversed(range(1, ddim_noise_steps + 1)):
75
- # set up noise values
76
- ctx_noise_idx = min(noise_idx, ctx_max_noise_idx)
77
- t_ctx = torch.full((B, i), noise_range[ctx_noise_idx], dtype=torch.long, device=device)
78
- t = torch.full((B, 1), noise_range[noise_idx], dtype=torch.long, device=device)
79
- t_next = torch.full((B, 1), noise_range[noise_idx - 1], dtype=torch.long, device=device)
80
- t_next = torch.where(t_next < 0, t, t_next)
81
- t = torch.cat([t_ctx, t], dim=1)
82
- t_next = torch.cat([t_ctx, t_next], dim=1)
83
-
84
- # sliding window
85
- x_curr = x.clone()
86
- x_curr = x_curr[:, start_frame:]
87
- t = t[:, start_frame:]
88
- t_next = t_next[:, start_frame:]
89
-
90
- # add some noise to the context
91
- ctx_noise = torch.randn_like(x_curr[:, :-1])
92
- ctx_noise = torch.clamp(ctx_noise, -noise_abs_max, +noise_abs_max)
93
- x_curr[:, :-1] = alphas_cumprod[t[:, :-1]].sqrt() * x_curr[:, :-1] + (1 - alphas_cumprod[t[:, :-1]]).sqrt() * ctx_noise
94
-
95
- # get model predictions
96
- with torch.no_grad():
97
- with autocast("cuda", dtype=torch.half):
98
- v = model(x_curr, t, actions[:, start_frame : i + 1])
99
-
100
- x_start = alphas_cumprod[t].sqrt() * x_curr - (1 - alphas_cumprod[t]).sqrt() * v
101
- x_noise = ((1 / alphas_cumprod[t]).sqrt() * x_curr - x_start) \
102
- / (1 / alphas_cumprod[t] - 1).sqrt()
103
-
104
- # get frame prediction
105
- x_pred = alphas_cumprod[t_next].sqrt() * x_start + x_noise * (1 - alphas_cumprod[t_next]).sqrt()
106
- x[:, -1:] = x_pred[:, -1:]
107
-
108
- # vae decoding
109
- x = rearrange(x, "b t c h w -> (b t) (h w) c")
110
- with torch.no_grad():
111
- x = (vae.decode(x / scaling_factor) + 1) / 2
112
- x = rearrange(x, "(b t) c h w -> b t h w c", t=total_frames)
113
-
114
- # save video
115
- x = torch.clamp(x, 0, 1)
116
- x = (x * 255).byte()
117
- write_video("video.mp4", x[0], fps=20)
118
- print("generation saved to video.mp4.")
119
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-oasis-master/media/arch.png DELETED
Binary file (89.2 kB)
 
open-oasis-master/media/sample_0.gif DELETED

Git LFS Details

  • SHA256: 684d0b42eed5f82d6285dbc46b0c69dbe4661c91fdb92043c3c298c300249574
  • Pointer size: 132 Bytes
  • Size of remote file: 3.15 MB
open-oasis-master/media/sample_1.gif DELETED

Git LFS Details

  • SHA256: d771ac40069b4e7a424d18d7c91c64904e560e5c61cc52f51f67eb6c667c39f9
  • Pointer size: 132 Bytes
  • Size of remote file: 2.95 MB
open-oasis-master/media/thumb.png DELETED
Binary file (768 kB)
 
open-oasis-master/requirements.txt DELETED
@@ -1,31 +0,0 @@
1
- av==13.1.0
2
- certifi==2024.8.30
3
- charset-normalizer==3.4.0
4
- diffusers==0.31.0
5
- einops==0.8.0
6
- filelock==3.13.1
7
- fsspec==2024.2.0
8
- huggingface-hub==0.26.2
9
- idna==3.10
10
- importlib_metadata==8.5.0
11
- Jinja2==3.1.3
12
- MarkupSafe==2.1.5
13
- mpmath==1.3.0
14
- networkx==3.2.1
15
- numpy==1.26.3
16
- packaging==24.1
17
- pillow==10.2.0
18
- PyYAML==6.0.2
19
- regex==2024.9.11
20
- requests==2.32.3
21
- safetensors==0.4.5
22
- sympy==1.13.1
23
- timm==1.0.11
24
- torch==2.5.1
25
- torchaudio==2.5.1
26
- torchvision==0.20.1
27
- tqdm==4.66.6
28
- triton==3.1.0
29
- typing_extensions==4.9.0
30
- urllib3==2.2.3
31
- zipp==3.20.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-oasis-master/rotary_embedding_torch.py DELETED
@@ -1,316 +0,0 @@
1
- """
2
- Adapted from https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
3
- """
4
-
5
- from __future__ import annotations
6
- from math import pi, log
7
-
8
- import torch
9
- from torch.nn import Module, ModuleList
10
- from torch.amp import autocast
11
- from torch import nn, einsum, broadcast_tensors, Tensor
12
-
13
- from einops import rearrange, repeat
14
-
15
- from typing import Literal
16
-
17
- # helper functions
18
-
19
- def exists(val):
20
- return val is not None
21
-
22
- def default(val, d):
23
- return val if exists(val) else d
24
-
25
- # broadcat, as tortoise-tts was using it
26
-
27
- def broadcat(tensors, dim = -1):
28
- broadcasted_tensors = broadcast_tensors(*tensors)
29
- return torch.cat(broadcasted_tensors, dim = dim)
30
-
31
- # rotary embedding helper functions
32
-
33
- def rotate_half(x):
34
- x = rearrange(x, '... (d r) -> ... d r', r = 2)
35
- x1, x2 = x.unbind(dim = -1)
36
- x = torch.stack((-x2, x1), dim = -1)
37
- return rearrange(x, '... d r -> ... (d r)')
38
-
39
- @autocast('cuda', enabled = False)
40
- def apply_rotary_emb(freqs, t, start_index = 0, scale = 1., seq_dim = -2):
41
- dtype = t.dtype
42
-
43
- if t.ndim == 3:
44
- seq_len = t.shape[seq_dim]
45
- freqs = freqs[-seq_len:]
46
-
47
- rot_dim = freqs.shape[-1]
48
- end_index = start_index + rot_dim
49
-
50
- assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
51
-
52
- # Split t into three parts: left, middle (to be transformed), and right
53
- t_left = t[..., :start_index]
54
- t_middle = t[..., start_index:end_index]
55
- t_right = t[..., end_index:]
56
-
57
- # Apply rotary embeddings without modifying t in place
58
- t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
59
-
60
- out = torch.cat((t_left, t_transformed, t_right), dim=-1)
61
-
62
- return out.type(dtype)
63
-
64
- # learned rotation helpers
65
-
66
- def apply_learned_rotations(rotations, t, start_index = 0, freq_ranges = None):
67
- if exists(freq_ranges):
68
- rotations = einsum('..., f -> ... f', rotations, freq_ranges)
69
- rotations = rearrange(rotations, '... r f -> ... (r f)')
70
-
71
- rotations = repeat(rotations, '... n -> ... (n r)', r = 2)
72
- return apply_rotary_emb(rotations, t, start_index = start_index)
73
-
74
- # classes
75
-
76
- class RotaryEmbedding(Module):
77
- def __init__(
78
- self,
79
- dim,
80
- custom_freqs: Tensor | None = None,
81
- freqs_for: Literal['lang', 'pixel', 'constant'] = 'lang',
82
- theta = 10000,
83
- max_freq = 10,
84
- num_freqs = 1,
85
- learned_freq = False,
86
- use_xpos = False,
87
- xpos_scale_base = 512,
88
- interpolate_factor = 1.,
89
- theta_rescale_factor = 1.,
90
- seq_before_head_dim = False,
91
- cache_if_possible = True,
92
- cache_max_seq_len = 8192
93
- ):
94
- super().__init__()
95
- # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
96
- # has some connection to NTK literature
97
- # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
98
-
99
- theta *= theta_rescale_factor ** (dim / (dim - 2))
100
-
101
- self.freqs_for = freqs_for
102
-
103
- if exists(custom_freqs):
104
- freqs = custom_freqs
105
- elif freqs_for == 'lang':
106
- freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
107
- elif freqs_for == 'pixel':
108
- freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
109
- elif freqs_for == 'spacetime':
110
- time_freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
111
- freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
112
- elif freqs_for == 'constant':
113
- freqs = torch.ones(num_freqs).float()
114
-
115
- if freqs_for == 'spacetime':
116
- self.time_freqs = nn.Parameter(time_freqs, requires_grad = learned_freq)
117
- self.freqs = nn.Parameter(freqs, requires_grad = learned_freq)
118
-
119
- self.cache_if_possible = cache_if_possible
120
- self.cache_max_seq_len = cache_max_seq_len
121
-
122
- self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False)
123
- self.register_buffer('cached_freqs_seq_len', torch.tensor(0), persistent = False)
124
-
125
- self.learned_freq = learned_freq
126
-
127
- # dummy for device
128
-
129
- self.register_buffer('dummy', torch.tensor(0), persistent = False)
130
-
131
- # default sequence dimension
132
-
133
- self.seq_before_head_dim = seq_before_head_dim
134
- self.default_seq_dim = -3 if seq_before_head_dim else -2
135
-
136
- # interpolation factors
137
-
138
- assert interpolate_factor >= 1.
139
- self.interpolate_factor = interpolate_factor
140
-
141
- # xpos
142
-
143
- self.use_xpos = use_xpos
144
-
145
- if not use_xpos:
146
- return
147
-
148
- scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
149
- self.scale_base = xpos_scale_base
150
-
151
- self.register_buffer('scale', scale, persistent = False)
152
- self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False)
153
- self.register_buffer('cached_scales_seq_len', torch.tensor(0), persistent = False)
154
-
155
- # add apply_rotary_emb as static method
156
-
157
- self.apply_rotary_emb = staticmethod(apply_rotary_emb)
158
-
159
- @property
160
- def device(self):
161
- return self.dummy.device
162
-
163
- def get_seq_pos(self, seq_len, device, dtype, offset = 0):
164
- return (torch.arange(seq_len, device = device, dtype = dtype) + offset) / self.interpolate_factor
165
-
166
- def rotate_queries_or_keys(self, t, freqs, seq_dim = None, offset = 0, scale = None):
167
- seq_dim = default(seq_dim, self.default_seq_dim)
168
-
169
- assert not self.use_xpos or exists(scale), 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings'
170
-
171
- device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim]
172
-
173
- seq = self.get_seq_pos(seq_len, device = device, dtype = dtype, offset = offset)
174
-
175
- seq_freqs = self.forward(seq, freqs, seq_len = seq_len, offset = offset)
176
-
177
- if seq_dim == -3:
178
- seq_freqs = rearrange(seq_freqs, 'n d -> n 1 d')
179
-
180
- return apply_rotary_emb(seq_freqs, t, scale = default(scale, 1.), seq_dim = seq_dim)
181
-
182
- def rotate_queries_with_cached_keys(self, q, k, seq_dim = None, offset = 0):
183
- dtype, device, seq_dim = q.dtype, q.device, default(seq_dim, self.default_seq_dim)
184
-
185
- q_len, k_len = q.shape[seq_dim], k.shape[seq_dim]
186
- assert q_len <= k_len
187
-
188
- q_scale = k_scale = 1.
189
-
190
- if self.use_xpos:
191
- seq = self.get_seq_pos(k_len, dtype = dtype, device = device)
192
-
193
- q_scale = self.get_scale(seq[-q_len:]).type(dtype)
194
- k_scale = self.get_scale(seq).type(dtype)
195
-
196
- rotated_q = self.rotate_queries_or_keys(q, seq_dim = seq_dim, scale = q_scale, offset = k_len - q_len + offset)
197
- rotated_k = self.rotate_queries_or_keys(k, seq_dim = seq_dim, scale = k_scale ** -1)
198
-
199
- rotated_q = rotated_q.type(q.dtype)
200
- rotated_k = rotated_k.type(k.dtype)
201
-
202
- return rotated_q, rotated_k
203
-
204
- def rotate_queries_and_keys(self, q, k, freqs, seq_dim = None):
205
- seq_dim = default(seq_dim, self.default_seq_dim)
206
-
207
- assert self.use_xpos
208
- device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
209
-
210
- seq = self.get_seq_pos(seq_len, dtype = dtype, device = device)
211
-
212
- seq_freqs = self.forward(seq, freqs, seq_len = seq_len)
213
- scale = self.get_scale(seq, seq_len = seq_len).to(dtype)
214
-
215
- if seq_dim == -3:
216
- seq_freqs = rearrange(seq_freqs, 'n d -> n 1 d')
217
- scale = rearrange(scale, 'n d -> n 1 d')
218
-
219
- rotated_q = apply_rotary_emb(seq_freqs, q, scale = scale, seq_dim = seq_dim)
220
- rotated_k = apply_rotary_emb(seq_freqs, k, scale = scale ** -1, seq_dim = seq_dim)
221
-
222
- rotated_q = rotated_q.type(q.dtype)
223
- rotated_k = rotated_k.type(k.dtype)
224
-
225
- return rotated_q, rotated_k
226
-
227
- def get_scale(
228
- self,
229
- t: Tensor,
230
- seq_len: int | None = None,
231
- offset = 0
232
- ):
233
- assert self.use_xpos
234
-
235
- should_cache = (
236
- self.cache_if_possible and
237
- exists(seq_len) and
238
- (offset + seq_len) <= self.cache_max_seq_len
239
- )
240
-
241
- if (
242
- should_cache and \
243
- exists(self.cached_scales) and \
244
- (seq_len + offset) <= self.cached_scales_seq_len.item()
245
- ):
246
- return self.cached_scales[offset:(offset + seq_len)]
247
-
248
- scale = 1.
249
- if self.use_xpos:
250
- power = (t - len(t) // 2) / self.scale_base
251
- scale = self.scale ** rearrange(power, 'n -> n 1')
252
- scale = repeat(scale, 'n d -> n (d r)', r = 2)
253
-
254
- if should_cache and offset == 0:
255
- self.cached_scales[:seq_len] = scale.detach()
256
- self.cached_scales_seq_len.copy_(seq_len)
257
-
258
- return scale
259
-
260
- def get_axial_freqs(self, *dims):
261
- Colon = slice(None)
262
- all_freqs = []
263
-
264
- for ind, dim in enumerate(dims):
265
- # only allow pixel freqs for last two dimensions
266
- use_pixel = (self.freqs_for == 'pixel' or self.freqs_for == 'spacetime') and ind >= len(dims) - 2
267
- if use_pixel:
268
- pos = torch.linspace(-1, 1, steps = dim, device = self.device)
269
- else:
270
- pos = torch.arange(dim, device = self.device)
271
-
272
- if self.freqs_for == 'spacetime' and not use_pixel:
273
- seq_freqs = self.forward(pos, self.time_freqs, seq_len = dim)
274
- else:
275
- seq_freqs = self.forward(pos, self.freqs, seq_len = dim)
276
-
277
- all_axis = [None] * len(dims)
278
- all_axis[ind] = Colon
279
-
280
- new_axis_slice = (Ellipsis, *all_axis, Colon)
281
- all_freqs.append(seq_freqs[new_axis_slice])
282
-
283
- all_freqs = broadcast_tensors(*all_freqs)
284
- return torch.cat(all_freqs, dim = -1)
285
-
286
- @autocast('cuda', enabled = False)
287
- def forward(
288
- self,
289
- t: Tensor,
290
- freqs: Tensor,
291
- seq_len = None,
292
- offset = 0
293
- ):
294
- should_cache = (
295
- self.cache_if_possible and
296
- not self.learned_freq and
297
- exists(seq_len) and
298
- self.freqs_for != 'pixel' and
299
- (offset + seq_len) <= self.cache_max_seq_len
300
- )
301
-
302
- if (
303
- should_cache and \
304
- exists(self.cached_freqs) and \
305
- (offset + seq_len) <= self.cached_freqs_seq_len.item()
306
- ):
307
- return self.cached_freqs[offset:(offset + seq_len)].detach()
308
-
309
- freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs)
310
- freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
311
-
312
- if should_cache and offset == 0:
313
- self.cached_freqs[:seq_len] = freqs.detach()
314
- self.cached_freqs_seq_len.copy_(seq_len)
315
-
316
- return freqs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-oasis-master/sample_data/Player729-f153ac423f61-20210806-224813.chunk_000.actions.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cc3ea8894f87e2c2c2387dd32b193f27a8a95009397c32b5fbaf8a6f23608b0c
3
- size 230180
 
 
 
 
open-oasis-master/sample_data/Player729-f153ac423f61-20210806-224813.chunk_000.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9fb1cf3a87be9deca2fec2e946427521a85026ee607cf9281aa87f6df447e4ea
3
- size 6818283
 
 
 
 
open-oasis-master/sample_data/snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001.actions.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:955929d771293156d3f27d295091a978dcd97fdaa78e3a17395ac90c0403004d
3
- size 230308
 
 
 
 
open-oasis-master/sample_data/snippy-chartreuse-mastiff-f79998db196d-20220401-224517.chunk_001.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:745b0348a014d943f70ccf6ccba17ad260540caba502b312d972235326003ab0
3
- size 7109171
 
 
 
 
open-oasis-master/sample_data/treechop-f153ac423f61-20210916-183423.chunk_000.actions.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:46ae60cc9d3a02df949923c707df4c5cd3f49d279aa6500c81f0ef00c14f7747
3
- size 230176
 
 
 
 
open-oasis-master/sample_data/treechop-f153ac423f61-20210916-183423.chunk_000.mp4 DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a0ad584df52d7b2636fae5d7a3116f596f25a09ba7d28ff5fc42193105605d92
3
- size 8716515
 
 
 
 
open-oasis-master/utils.py DELETED
@@ -1,82 +0,0 @@
1
- """
2
- Adapted from https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/utils.py
3
- Action format derived from VPT https://github.com/openai/Video-Pre-Training
4
- """
5
- import math
6
- import torch
7
- from torch import nn
8
- from einops import rearrange, parse_shape
9
- from typing import Mapping, Sequence
10
- import torch
11
- from einops import rearrange
12
-
13
-
14
- def sigmoid_beta_schedule(timesteps, start=-3, end=3, tau=1, clamp_min=1e-5):
15
- """
16
- sigmoid schedule
17
- proposed in https://arxiv.org/abs/2212.11972 - Figure 8
18
- better for images > 64x64, when used during training
19
- """
20
- steps = timesteps + 1
21
- t = torch.linspace(0, timesteps, steps, dtype=torch.float32) / timesteps
22
- v_start = torch.tensor(start / tau).sigmoid()
23
- v_end = torch.tensor(end / tau).sigmoid()
24
- alphas_cumprod = (-((t * (end - start) + start) / tau).sigmoid() + v_end) / (v_end - v_start)
25
- alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
26
- betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
27
- return torch.clip(betas, 0, 0.999)
28
-
29
-
30
- ACTION_KEYS = [
31
- "inventory",
32
- "ESC",
33
- "hotbar.1",
34
- "hotbar.2",
35
- "hotbar.3",
36
- "hotbar.4",
37
- "hotbar.5",
38
- "hotbar.6",
39
- "hotbar.7",
40
- "hotbar.8",
41
- "hotbar.9",
42
- "forward",
43
- "back",
44
- "left",
45
- "right",
46
- "cameraX",
47
- "cameraY",
48
- "jump",
49
- "sneak",
50
- "sprint",
51
- "swapHands",
52
- "attack",
53
- "use",
54
- "pickItem",
55
- "drop",
56
- ]
57
-
58
- def one_hot_actions(actions: Sequence[Mapping[str, int]]) -> torch.Tensor:
59
- actions_one_hot = torch.zeros(len(actions), len(ACTION_KEYS))
60
- for i, current_actions in enumerate(actions):
61
- for j, action_key in enumerate(ACTION_KEYS):
62
- if action_key.startswith("camera"):
63
- if action_key == "cameraX":
64
- value = current_actions["camera"][0]
65
- elif action_key == "cameraY":
66
- value = current_actions["camera"][1]
67
- else:
68
- raise ValueError(f"Unknown camera action key: {action_key}")
69
- # NOTE these numbers specific to the camera quantization used in
70
- # https://github.com/etched-ai/dreamcraft/blob/216e952f795bb3da598639a109bcdba4d2067b69/spark/preprocess_vpt_to_videos_actions.py#L312
71
- # see method `compress_mouse`
72
- max_val = 20
73
- bin_size = 0.5
74
- num_buckets = int(max_val / bin_size)
75
- value = (value - num_buckets) / num_buckets
76
- assert -1 - 1e-3 <= value <= 1 + 1e-3, f"Camera action value must be in [-1, 1], got {value}"
77
- else:
78
- value = current_actions[action_key]
79
- assert 0 <= value <= 1, f"Action value must be in [0, 1] got {value}"
80
- actions_one_hot[i, j] = value
81
-
82
- return actions_one_hot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
open-oasis-master/vae.py DELETED
@@ -1,381 +0,0 @@
1
- """
2
- References:
3
- - VQGAN: https://github.com/CompVis/taming-transformers
4
- - MAE: https://github.com/facebookresearch/mae
5
- """
6
- import numpy as np
7
- import math
8
- import functools
9
- from collections import namedtuple
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
- from einops import rearrange
14
- from timm.models.vision_transformer import Mlp
15
- from timm.layers.helpers import to_2tuple
16
- from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
17
- from dit import PatchEmbed
18
-
19
- class DiagonalGaussianDistribution(object):
20
- def __init__(self, parameters, deterministic=False, dim=1):
21
- self.parameters = parameters
22
- self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
23
- if dim == 1:
24
- self.dims = [1, 2, 3]
25
- elif dim == 2:
26
- self.dims = [1, 2]
27
- else:
28
- raise NotImplementedError
29
- self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
30
- self.deterministic = deterministic
31
- self.std = torch.exp(0.5 * self.logvar)
32
- self.var = torch.exp(self.logvar)
33
- if self.deterministic:
34
- self.var = self.std = torch.zeros_like(self.mean).to(
35
- device=self.parameters.device
36
- )
37
-
38
- def sample(self):
39
- x = self.mean + self.std * torch.randn(self.mean.shape).to(
40
- device=self.parameters.device
41
- )
42
- return x
43
-
44
- def mode(self):
45
- return self.mean
46
-
47
- class Attention(nn.Module):
48
- def __init__(
49
- self,
50
- dim,
51
- num_heads,
52
- frame_height,
53
- frame_width,
54
- qkv_bias=False,
55
- attn_drop=0.0,
56
- proj_drop=0.0,
57
- is_causal=False,
58
- ):
59
- super().__init__()
60
- self.num_heads = num_heads
61
- head_dim = dim // num_heads
62
- self.frame_height = frame_height
63
- self.frame_width = frame_width
64
-
65
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
66
- self.attn_drop = attn_drop
67
- self.proj = nn.Linear(dim, dim)
68
- self.proj_drop = nn.Dropout(proj_drop)
69
- self.is_causal = is_causal
70
-
71
- rotary_freqs = RotaryEmbedding(
72
- dim=head_dim // 4,
73
- freqs_for="pixel",
74
- max_freq=frame_height*frame_width,
75
- ).get_axial_freqs(frame_height, frame_width)
76
- self.register_buffer("rotary_freqs", rotary_freqs, persistent=False)
77
-
78
- def forward(self, x):
79
- B, N, C = x.shape
80
- assert N == self.frame_height * self.frame_width
81
-
82
- qkv = (
83
- self.qkv(x)
84
- .reshape(B, N, 3, self.num_heads, C // self.num_heads)
85
- .permute(2, 0, 3, 1, 4)
86
- )
87
- q, k, v = (
88
- qkv[0],
89
- qkv[1],
90
- qkv[2],
91
- ) # make torchscript happy (cannot use tensor as tuple)
92
-
93
- if self.rotary_freqs is not None:
94
- q = rearrange(q, "b h (H W) d -> b h H W d", H=self.frame_height, W=self.frame_width)
95
- k = rearrange(k, "b h (H W) d -> b h H W d", H=self.frame_height, W=self.frame_width)
96
- q = apply_rotary_emb(self.rotary_freqs, q)
97
- k = apply_rotary_emb(self.rotary_freqs, k)
98
- q = rearrange(q, "b h H W d -> b h (H W) d")
99
- k = rearrange(k, "b h H W d -> b h (H W) d")
100
-
101
- attn = F.scaled_dot_product_attention(
102
- q,
103
- k,
104
- v,
105
- dropout_p=self.attn_drop,
106
- is_causal=self.is_causal,
107
- )
108
- x = attn.transpose(1, 2).reshape(B, N, C)
109
-
110
- x = self.proj(x)
111
- x = self.proj_drop(x)
112
- return x
113
-
114
-
115
- class AttentionBlock(nn.Module):
116
- def __init__(
117
- self,
118
- dim,
119
- num_heads,
120
- frame_height,
121
- frame_width,
122
- mlp_ratio=4.0,
123
- qkv_bias=False,
124
- drop=0.0,
125
- attn_drop=0.0,
126
- attn_causal=False,
127
- drop_path=0.0,
128
- act_layer=nn.GELU,
129
- norm_layer=nn.LayerNorm,
130
- ):
131
- super().__init__()
132
- self.norm1 = norm_layer(dim)
133
- self.attn = Attention(
134
- dim,
135
- num_heads,
136
- frame_height,
137
- frame_width,
138
- qkv_bias=qkv_bias,
139
- attn_drop=attn_drop,
140
- proj_drop=drop,
141
- is_causal=attn_causal,
142
- )
143
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
144
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
145
- self.norm2 = norm_layer(dim)
146
- mlp_hidden_dim = int(dim * mlp_ratio)
147
- self.mlp = Mlp(
148
- in_features=dim,
149
- hidden_features=mlp_hidden_dim,
150
- act_layer=act_layer,
151
- drop=drop,
152
- )
153
-
154
- def forward(self, x):
155
- x = x + self.drop_path(self.attn(self.norm1(x)))
156
- x = x + self.drop_path(self.mlp(self.norm2(x)))
157
- return x
158
-
159
-
160
- class AutoencoderKL(nn.Module):
161
- def __init__(
162
- self,
163
- latent_dim,
164
- input_height=256,
165
- input_width=256,
166
- patch_size=16,
167
- enc_dim=768,
168
- enc_depth=6,
169
- enc_heads=12,
170
- dec_dim=768,
171
- dec_depth=6,
172
- dec_heads=12,
173
- mlp_ratio=4.0,
174
- norm_layer=functools.partial(nn.LayerNorm, eps=1e-6),
175
- use_variational=True,
176
- **kwargs,
177
- ):
178
- super().__init__()
179
- self.input_height = input_height
180
- self.input_width = input_width
181
- self.patch_size = patch_size
182
- self.seq_h = input_height // patch_size
183
- self.seq_w = input_width // patch_size
184
- self.seq_len = self.seq_h * self.seq_w
185
- self.patch_dim = 3 * patch_size**2
186
-
187
- self.latent_dim = latent_dim
188
- self.enc_dim = enc_dim
189
- self.dec_dim = dec_dim
190
-
191
- # patch
192
- self.patch_embed = PatchEmbed(input_height, input_width, patch_size, 3, enc_dim)
193
-
194
- # encoder
195
- self.encoder = nn.ModuleList(
196
- [
197
- AttentionBlock(
198
- enc_dim,
199
- enc_heads,
200
- self.seq_h,
201
- self.seq_w,
202
- mlp_ratio,
203
- qkv_bias=True,
204
- norm_layer=norm_layer,
205
- )
206
- for i in range(enc_depth)
207
- ]
208
- )
209
- self.enc_norm = norm_layer(enc_dim)
210
-
211
- # bottleneck
212
- self.use_variational = use_variational
213
- mult = 2 if self.use_variational else 1
214
- self.quant_conv = nn.Linear(enc_dim, mult * latent_dim)
215
- self.post_quant_conv = nn.Linear(latent_dim, dec_dim)
216
-
217
- # decoder
218
- self.decoder = nn.ModuleList(
219
- [
220
- AttentionBlock(
221
- dec_dim,
222
- dec_heads,
223
- self.seq_h,
224
- self.seq_w,
225
- mlp_ratio,
226
- qkv_bias=True,
227
- norm_layer=norm_layer,
228
- )
229
- for i in range(dec_depth)
230
- ]
231
- )
232
- self.dec_norm = norm_layer(dec_dim)
233
- self.predictor = nn.Linear(dec_dim, self.patch_dim) # decoder to patch
234
-
235
- # initialize this weight first
236
- self.initialize_weights()
237
-
238
-
239
- def initialize_weights(self):
240
- # initialization
241
- # initialize nn.Linear and nn.LayerNorm
242
- self.apply(self._init_weights)
243
-
244
- # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
245
- w = self.patch_embed.proj.weight.data
246
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
247
-
248
- def _init_weights(self, m):
249
- if isinstance(m, nn.Linear):
250
- # we use xavier_uniform following official JAX ViT:
251
- nn.init.xavier_uniform_(m.weight)
252
- if isinstance(m, nn.Linear) and m.bias is not None:
253
- nn.init.constant_(m.bias, 0.0)
254
- elif isinstance(m, nn.LayerNorm):
255
- nn.init.constant_(m.bias, 0.0)
256
- nn.init.constant_(m.weight, 1.0)
257
-
258
- def patchify(self, x):
259
- # patchify
260
- bsz, _, h, w = x.shape
261
- x = x.reshape(
262
- bsz,
263
- 3,
264
- self.seq_h,
265
- self.patch_size,
266
- self.seq_w,
267
- self.patch_size,
268
- ).permute(
269
- [0, 1, 3, 5, 2, 4]
270
- ) # [b, c, h, p, w, p] --> [b, c, p, p, h, w]
271
- x = x.reshape(
272
- bsz, self.patch_dim, self.seq_h, self.seq_w
273
- ) # --> [b, cxpxp, h, w]
274
- x = x.permute([0, 2, 3, 1]).reshape(
275
- bsz, self.seq_len, self.patch_dim
276
- ) # --> [b, hxw, cxpxp]
277
- return x
278
-
279
- def unpatchify(self, x):
280
- bsz = x.shape[0]
281
- # unpatchify
282
- x = x.reshape(bsz, self.seq_h, self.seq_w, self.patch_dim).permute(
283
- [0, 3, 1, 2]
284
- ) # [b, h, w, cxpxp] --> [b, cxpxp, h, w]
285
- x = x.reshape(
286
- bsz,
287
- 3,
288
- self.patch_size,
289
- self.patch_size,
290
- self.seq_h,
291
- self.seq_w,
292
- ).permute(
293
- [0, 1, 4, 2, 5, 3]
294
- ) # [b, c, p, p, h, w] --> [b, c, h, p, w, p]
295
- x = x.reshape(
296
- bsz,
297
- 3,
298
- self.input_height,
299
- self.input_width,
300
- ) # [b, c, hxp, wxp]
301
- return x
302
-
303
- def encode(self, x):
304
- # patchify
305
- x = self.patch_embed(x)
306
-
307
- # encoder
308
- for blk in self.encoder:
309
- x = blk(x)
310
- x = self.enc_norm(x)
311
-
312
- # bottleneck
313
- moments = self.quant_conv(x)
314
- if not self.use_variational:
315
- moments = torch.cat((moments, torch.zeros_like(moments)), 2)
316
- posterior = DiagonalGaussianDistribution(
317
- moments, deterministic=(not self.use_variational), dim=2
318
- )
319
- return posterior
320
-
321
- def decode(self, z):
322
- # bottleneck
323
- z = self.post_quant_conv(z)
324
-
325
- # decoder
326
- for blk in self.decoder:
327
- z = blk(z)
328
- z = self.dec_norm(z)
329
-
330
- # predictor
331
- z = self.predictor(z)
332
-
333
- # unpatchify
334
- dec = self.unpatchify(z)
335
- return dec
336
-
337
- def autoencode(self, input, sample_posterior=True):
338
- posterior = self.encode(input)
339
- if self.use_variational and sample_posterior:
340
- z = posterior.sample()
341
- else:
342
- z = posterior.mode()
343
- dec = self.decode(z)
344
- return dec, posterior, z
345
-
346
- def get_input(self, batch, k):
347
- x = batch[k]
348
- if len(x.shape) == 3:
349
- x = x[..., None]
350
- x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
351
- return x
352
-
353
- def forward(self, inputs, labels, split="train"):
354
- rec, post, latent = self.autoencode(inputs)
355
- return rec, post, latent
356
-
357
- def get_last_layer(self):
358
- return self.predictor.weight
359
-
360
- def ViT_L_20_Shallow_Encoder(**kwargs):
361
- if "latent_dim" in kwargs:
362
- latent_dim = kwargs.pop("latent_dim")
363
- else:
364
- latent_dim = 16
365
- return AutoencoderKL(
366
- latent_dim=latent_dim,
367
- patch_size=20,
368
- enc_dim=1024,
369
- enc_depth=6,
370
- enc_heads=16,
371
- dec_dim=1024,
372
- dec_depth=12,
373
- dec_heads=16,
374
- input_height=360,
375
- input_width=640,
376
- **kwargs,
377
- )
378
-
379
- VAE_models = {
380
- "vit-l-20-shallow-encoder": ViT_L_20_Shallow_Encoder,
381
- }