Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,724 Bytes
2a00960 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import numpy as np
from einops import rearrange
import torch
import torch.cuda.amp as amp
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
def frame_pad(x, seq_len, shapes):
max_h, max_w = np.max(shapes, 0)
frames = []
cur_len = 0
for h, w in shapes:
frame_len = h * w
frames.append(
F.pad(
x[cur_len:cur_len + frame_len].view(h, w, -1),
(0, 0, 0, max_w - w, 0, max_h - h)) # .view(max_h * max_w, -1)
)
cur_len += frame_len
if cur_len >= seq_len:
break
return torch.stack(frames)
def frame_unpad(x, shapes):
max_h, max_w = np.max(shapes, 0)
x = rearrange(x, '(b h w) n c -> b h w n c', h=max_h, w=max_w)
frames = []
for i, (h, w) in enumerate(shapes):
if i >= len(x):
break
frames.append(rearrange(x[i, :h, :w], 'h w n c -> (h w) n c'))
return torch.concat(frames)
@amp.autocast(enabled=False)
def rope_apply_multires(x, x_lens, x_shapes, freqs, pad=True):
"""
x: [B*L, N, C].
x_lens: [B].
x_shapes: [B, F, 2].
freqs: [M, C // 2].
"""
n, c = x.size(1), x.size(2) // 2
# split freqs
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
# loop over samples
output = []
st = 0
for i, (seq_len,
shapes) in enumerate(zip(x_lens.tolist(), x_shapes.tolist())):
x_i = frame_pad(x[st:st + seq_len], seq_len, shapes) # f, h, w, c
f, h, w = x_i.shape[:3]
pad_seq_len = f * h * w
# precompute multipliers
x_i = torch.view_as_complex(
x_i.to(torch.float64).reshape(pad_seq_len, n, -1, 2))
freqs_i = torch.cat([
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
],
dim=-1).reshape(pad_seq_len, 1, -1)
# apply rotary embedding
x_i = torch.view_as_real(x_i * freqs_i).flatten(2).type_as(x)
x_i = frame_unpad(x_i, shapes)
# append to collection
output.append(x_i)
st += seq_len
return pad_sequence(output) if pad else torch.concat(output)
@amp.autocast(enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
"""
Precompute the frequency tensor for complex exponentials.
"""
assert dim % 2 == 0
freqs = torch.outer(
torch.arange(max_seq_len),
1.0 / torch.pow(theta,
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs |