Tonic commited on
Commit
cbd9440
·
unverified ·
1 Parent(s): 93a33fb

add cis_2D preprocessing

Browse files
Files changed (1) hide show
  1. app.py +21 -23
app.py CHANGED
@@ -53,30 +53,29 @@ class GELU(nn.Module):
53
  else:
54
  return F.gelu(self.linear(x))
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  class Rope2D(nn.Module):
57
  def __init__(self, dim, max_position_embeddings=1024, base=10000):
58
  super().__init__()
59
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
60
- self.register_buffer("inv_freq", inv_freq)
61
- self.max_seq_len_cached = max_position_embeddings
62
- t = torch.arange(self.max_seq_len_cached, dtype=self.inv_freq.dtype)
63
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
64
- emb = torch.cat((freqs, freqs), dim=-1)
65
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
66
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
67
-
68
- def forward(self, x, seq_len=None):
69
- if seq_len > self.max_seq_len_cached:
70
- self.max_seq_len_cached = seq_len
71
- t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
72
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
73
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
74
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
75
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
76
- return (
77
- self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
78
- self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
79
- )
80
 
81
  class VisionEncoder(nn.Module):
82
  def __init__(self, config):
@@ -92,14 +91,13 @@ class VisionEncoder(nn.Module):
92
  x = self.embed(pixel_values)
93
  b, c, h, w = x.shape
94
  x = x.flatten(2).transpose(1, 2)
95
- cos, sin = self.rope(x, seq_len=h*w)
96
  for layer in self.layers:
97
  x = layer(x)
98
  x = self.norm(x)
99
  x = self.gelu(x)
100
  return x
101
 
102
-
103
  class PixtralModel(nn.Module):
104
  def __init__(self, params):
105
  super().__init__()
 
53
  else:
54
  return F.gelu(self.linear(x))
55
 
56
+ def precompute_freqs_cis_2d(dim: int, height: int, width: int, theta: float) -> torch.Tensor:
57
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
58
+ h = torch.arange(height, device=freqs.device)
59
+ w = torch.arange(width, device=freqs.device)
60
+
61
+ freqs_h = torch.outer(h, freqs[::2]).float()
62
+ freqs_w = torch.outer(w, freqs[1::2]).float()
63
+ freqs_2d = torch.cat([
64
+ freqs_h[:, None, :].repeat(1, width, 1),
65
+ freqs_w[None, :, :].repeat(height, 1, 1),
66
+ ], dim=-1)
67
+ return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
68
+
69
  class Rope2D(nn.Module):
70
  def __init__(self, dim, max_position_embeddings=1024, base=10000):
71
  super().__init__()
72
+ self.dim = dim
73
+ self.max_position_embeddings = max_position_embeddings
74
+ self.base = base
75
+
76
+ def forward(self, x, height, width):
77
+ freqs_cis = precompute_freqs_cis_2d(self.dim, height, width, self.base)
78
+ return freqs_cis.to(x.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  class VisionEncoder(nn.Module):
81
  def __init__(self, config):
 
91
  x = self.embed(pixel_values)
92
  b, c, h, w = x.shape
93
  x = x.flatten(2).transpose(1, 2)
94
+ freqs_cis = self.rope(x, h, w)
95
  for layer in self.layers:
96
  x = layer(x)
97
  x = self.norm(x)
98
  x = self.gelu(x)
99
  return x
100
 
 
101
  class PixtralModel(nn.Module):
102
  def __init__(self, params):
103
  super().__init__()