Fix incorrect image embedding when running with a single GPU and 24GB VRAM

#3
by xdedss - opened
Files changed (1) hide show
  1. modeling_internvl.py +16 -0
modeling_internvl.py CHANGED
@@ -114,13 +114,29 @@ class CrossAttention(nn.Module):
114
  k_bias = self.k_bias
115
  v_bias = self.v_bias
116
 
 
 
 
 
 
 
117
  q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
 
 
118
  q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
119
 
 
 
120
  k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
 
 
121
  k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
122
 
 
 
123
  v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
 
 
124
  v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
125
 
126
  q = q * self.scale
 
114
  k_bias = self.k_bias
115
  v_bias = self.v_bias
116
 
117
+ # simulate module forward hooks to let accelerate load the actual weight
118
+ # see https://github.com/huggingface/accelerate/blob/1f7a79b428749f45187ec69485f2c966fe21926e/src/accelerate/hooks.py#L163
119
+ simulate_hooks = hasattr(self.q, '_hf_hook')
120
+
121
+ if simulate_hooks:
122
+ self.q._hf_hook.pre_forward(self.q, x)
123
  q = F.linear(input=x, weight=self.q.weight, bias=q_bias)
124
+ if simulate_hooks:
125
+ self.q._hf_hook.post_forward(self.q, x)
126
  q = q.reshape(B, N, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0) # (B, N_head, N_q, dim)
127
 
128
+ if simulate_hooks:
129
+ self.k._hf_hook.pre_forward(self.k, k)
130
  k = F.linear(input=k, weight=self.k.weight, bias=k_bias)
131
+ if simulate_hooks:
132
+ self.k._hf_hook.post_forward(self.k, k)
133
  k = k.reshape(B, N_k, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
134
 
135
+ if simulate_hooks:
136
+ self.v._hf_hook.pre_forward(self.v, v)
137
  v = F.linear(input=v, weight=self.v.weight, bias=v_bias)
138
+ if simulate_hooks:
139
+ self.v._hf_hook.post_forward(self.v, v)
140
  v = v.reshape(B, N_v, 1, self.num_heads, -1).permute(2, 0, 3, 1, 4).squeeze(0)
141
 
142
  q = q * self.scale