Safetensors
jiangchengchengNLP commited on
Commit
d3fa7ee
1 Parent(s): 9f0f115

Upload 3 files

Browse files

qwenva.py、qwenva.pth and bird.jpeg

Files changed (3) hide show
  1. bird.jpeg +0 -0
  2. qwenva.pth +3 -0
  3. qwenva.py +431 -0
bird.jpeg ADDED
qwenva.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00105bceb629eff80893863e622e0e8861682b18f0b1f168d00bc960ab07bde2
3
+ size 1447761636
qwenva.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 视觉编码器
3
+
4
+ """
5
+
6
+ #视觉编码器
7
+ from transformers import CLIPModel
8
+ from transformers import CLIPConfig
9
+ vision_config=CLIPConfig.from_pretrained("openai/clip-vit-base-patch32")
10
+ clip_model = CLIPModel._from_config(vision_config)
11
+ vision_model=clip_model.vision_model
12
+ vision_projection=clip_model.visual_projection
13
+
14
+
15
+ #自实现qwen2.5-0.5B
16
+
17
+ """
18
+ 语言模型
19
+
20
+ """
21
+ import torch
22
+ import torch.nn as nn
23
+ #from torch.nn.attention import SDPBackend, sdpa_kernel
24
+ #所有decoder层共用一个Qwen2RotaryEmbedding,减少模型体积
25
+ #llama系的RoPE实现
26
+ def rotate_half(x):
27
+ """Rotates half the hidden dims of the input."""
28
+ x1 = x[..., : x.shape[-1] // 2]
29
+ x2 = x[..., x.shape[-1] // 2 :]
30
+ return torch.cat((-x2, x1), dim=-1)
31
+
32
+ class Qwen2RotaryEmbedding(nn.Module):
33
+ def __init__(self, head_dim, max_position_embeddings=2048, base=10000, device=None):
34
+ super().__init__()
35
+ self.dim = head_dim
36
+ self.max_position_embeddings = max_position_embeddings
37
+ self.base = base
38
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
39
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
40
+
41
+ # Build here to make `torch.jit.trace` work.
42
+ self._set_cos_sin_cache(
43
+ # seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
44
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
45
+ )
46
+
47
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
48
+ self.max_seq_len_cached = seq_len
49
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
50
+ freqs = torch.outer(t, self.inv_freq)
51
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
52
+ emb = torch.cat((freqs, freqs), dim=-1)
53
+
54
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
55
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
56
+
57
+ def forward(self, q,k,use_cache=False):
58
+ seq_len = k.size(2)
59
+ # x: [bs, num_attention_heads, seq_len, head_size]
60
+ if seq_len > self.max_seq_len_cached:
61
+ self._set_cos_sin_cache(seq_len=seq_len, device=q.device, dtype=q.dtype)
62
+ cos_pos=self.cos_cached[:seq_len].to(dtype=q.dtype).unsqueeze(0).unsqueeze(0)
63
+ sin_pos=self.sin_cached[:seq_len].to(dtype=q.dtype).unsqueeze(0).unsqueeze(0)
64
+ #print(cos_pos.size())
65
+ if use_cache:
66
+ q_embed=q*cos_pos[:,:,-1,:].unsqueeze(1)+rotate_half(q)*sin_pos[:,:,-1,:].unsqueeze(1)
67
+ else:
68
+ q_embed=q*cos_pos+rotate_half(q)*sin_pos
69
+ k_embed=k*cos_pos+rotate_half(k)*sin_pos
70
+ #print(q_embed.size())
71
+ #print(k_embed.size())
72
+ return q_embed,k_embed
73
+ """
74
+ 分组注意力层
75
+ """
76
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
77
+ """
78
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
79
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
80
+ """
81
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
82
+
83
+ if n_rep == 1:
84
+ return hidden_states # 如果 n_rep 为 1,则无需重复,直接返回
85
+
86
+ # 在 dim=2(即 seqlen 维度之间插入一个新维度),并扩展到 (batch, num_key_value_heads, n_rep, slen, head_dim)
87
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
88
+
89
+ # 将其形状调整为 (batch, num_key_value_heads * n_rep, slen, head_dim)
90
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
91
+
92
+ import math
93
+ class Qwen2SdpaAttention(nn.Module):
94
+ def __init__(self,hidden_size,num_attention_heads,num_kv_heads):
95
+ super(Qwen2SdpaAttention,self).__init__()
96
+ self.hidden_size=hidden_size
97
+ self.num_attention_heads=num_attention_heads
98
+ self.attention_head_size=hidden_size//num_attention_heads
99
+ self.num_kv_heads=num_kv_heads
100
+ self.id=id
101
+ self.q_proj=nn.Linear(hidden_size,hidden_size,bias=True)
102
+ self.k_proj=nn.Linear(hidden_size,hidden_size//(num_attention_heads//num_kv_heads),bias=True)
103
+ self.v_proj=nn.Linear(hidden_size,hidden_size//(num_attention_heads//num_kv_heads),bias=True)
104
+ self.o_proj=nn.Linear(hidden_size,hidden_size,bias=False)
105
+ self.rotary_emb=nn.Identity()
106
+ #self.rotary_emb=Qwen2RotaryEmbedding(head_dim=self.attention_head_size,max_position_embeddings=max_position_embeddings,dtype=dtype)
107
+ def forward(self,input_ids,attention_mask,position_embedding,use_cache=False,past_kv=None,id=None):
108
+ """
109
+ 如果启用kv缓存,输入的是一个单词的embedding,形状为[batch_size,1,hidden_size]
110
+ q的形状是[batch_size,1,hidden_size]
111
+ k的形状为[batch_size,seq_len,hidden_size//(num_attention_heads//num_kv_heads)]
112
+ v的形状为[batch_size,seq_len,hidden_size//(num_attention_heads//num_kv_heads)]
113
+ 考虑到预启动阶段。
114
+ """
115
+ batch_size,seq_len,_=input_ids.size()
116
+ q=self.q_proj(input_ids)
117
+ k=self.k_proj(input_ids)
118
+ v=self.v_proj(input_ids)
119
+ if use_cache:
120
+ if id not in past_kv.keys():
121
+ past_kv[id]=k,v
122
+ flag=True
123
+ else:
124
+ k_cache,v_cache=past_kv[id]
125
+ k=torch.cat((k_cache,k),dim=1)
126
+ v=torch.cat((v_cache,v),dim=1)
127
+ past_kv[id]=(k,v)
128
+ flag=False
129
+ #转化成多头 permute是根据当前填入位置选择索引
130
+ q=q.view(batch_size,-1,self.num_attention_heads,self.attention_head_size).permute(0,2,1,3)
131
+ #print(q.size())
132
+ k=k.view(batch_size,-1,self.num_kv_heads,self.attention_head_size).permute(0,2,1,3)
133
+ v=v.view(batch_size,-1,self.num_kv_heads, self.attention_head_size).permute(0, 2, 1, 3)
134
+ #旋转位置编码
135
+ if position_embedding is not None:
136
+ q,k=position_embedding(q,k,use_cache=use_cache)
137
+ else:
138
+ q,k=self.rotary_emb(q,k,use_cache=use_cache)
139
+ #计算分组注意力层
140
+ k=repeat_kv(k,self.num_attention_heads//self.num_kv_heads)
141
+ v=repeat_kv(v,self.num_attention_heads//self.num_kv_heads)
142
+ #print(k.size())
143
+ #print(v.size())
144
+ #casual_mask=torch.tril(torch.ones(1,1,seq_len,seq_len)).to(input_ids.device)
145
+ #attention_mask=attention_mask.unsqueeze(1).unsqueeze(-1)
146
+ #att_mask=attention_mask*casual_mask
147
+ #print(q.dtype)
148
+ #print(k.dtype)
149
+ #print(v.dtype)
150
+ #with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
151
+ attention_logits=F.scaled_dot_product_attention(q, k, v, is_causal=flag)
152
+
153
+ attention_logits=attention_logits.permute(0,2,1,3).contiguous().view(batch_size,seq_len,self.hidden_size)
154
+ attention_output=self.o_proj(attention_logits)
155
+ return attention_output
156
+
157
+ #激活函数
158
+ import torch.nn.functional as F
159
+ class SiLU(nn.Module):
160
+ def __init__(self):
161
+ super().__init__()
162
+
163
+ def forward(self, input):
164
+ return F.silu(input, inplace=False)
165
+
166
+ #前馈层
167
+ import torch
168
+ import torch.nn as nn
169
+ import torch.nn.functional as F
170
+ class Qwen2MLP(nn.Module):
171
+ def __init__(self,input_dim,expand_dim):
172
+ super(Qwen2MLP,self).__init__()
173
+ self.gate_proj=nn.Linear(input_dim,expand_dim,bias=False)
174
+ self.up_proj=nn.Linear(input_dim,expand_dim,bias=False)
175
+ self.down_proj=nn.Linear(expand_dim,input_dim,bias=False)
176
+ self.act_fn=SiLU()
177
+ def forward(self,x):
178
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
179
+
180
+ #qwenRMSNorm
181
+ class Qwen2RMSNorm(nn.Module):
182
+ def __init__(self,hidden_size,eps=1e-6):
183
+ super().__init__()
184
+ self.weight=nn.Parameter(torch.ones(hidden_size))
185
+ self.variance_epsilon=eps
186
+ def forward(self,hidden_states):
187
+ old_dtype=hidden_states.dtype
188
+ hidden_states = hidden_states.to(torch.float32)
189
+ variance=hidden_states.pow(2).mean(-1,keepdim=True)
190
+ hidden_states=hidden_states*torch.rsqrt(variance+self.variance_epsilon)
191
+
192
+ return self.weight*hidden_states.to(old_dtype)
193
+
194
+ #decoder层
195
+ class Qwen2DecoderLayer(nn.Module):
196
+ def __init__(self,hidden_size,num_attention_heads,num_kv_heads,expand_dim):
197
+ super(Qwen2DecoderLayer, self).__init__()
198
+ self.self_attn =Qwen2SdpaAttention(hidden_size=hidden_size,num_attention_heads=num_attention_heads,num_kv_heads=num_kv_heads)
199
+ self.mlp=Qwen2MLP(input_dim=hidden_size,expand_dim=expand_dim)
200
+ self.input_layernorm=Qwen2RMSNorm(hidden_size)
201
+ self.post_attention_layernorm=Qwen2RMSNorm(hidden_size)
202
+ def forward(self,hidden_states,attention_mask,position_embedding,use_cache=False,past_kv=None,id=None):
203
+ residual=hidden_states
204
+ hidden_states=self.input_layernorm(hidden_states)
205
+ output=self.self_attn(hidden_states,attention_mask,position_embedding,use_cache=use_cache,past_kv=past_kv,id=id)
206
+ output_=residual+output
207
+ residual=output_
208
+ output_=self.post_attention_layernorm(output_)
209
+ output_=self.mlp(output_)
210
+ output_=residual+output_
211
+ return output_
212
+ #模型主体
213
+ class Qwen2Model(nn.Module):
214
+ def __init__(self,vocab_size,hidden_size,num_layers,num_attention_heads,num_kv_heads,max_position_embeddings,expand_dim):
215
+ super().__init__()
216
+ self.embed_tokens=nn.Embedding(vocab_size,hidden_size)
217
+ self.layers=nn.ModuleList(
218
+ [Qwen2DecoderLayer(hidden_size=hidden_size,num_attention_heads=num_attention_heads,num_kv_heads=num_kv_heads,expand_dim=expand_dim)
219
+ for _ in range(num_layers)]
220
+
221
+ )
222
+ self.norm=Qwen2RMSNorm(hidden_size)
223
+ self.rotary_emb=Qwen2RotaryEmbedding(head_dim=hidden_size//num_attention_heads,max_position_embeddings=max_position_embeddings)
224
+ def forward(self,input_ids,attention_mask,use_cache=False,past_kv=None):
225
+ token_embed=self.embed_tokens(input_ids)
226
+ #with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
227
+ for index,layer in enumerate(self.layers):
228
+ token_embed=layer(token_embed,attention_mask,self.rotary_emb,use_cache=use_cache,past_kv=past_kv,id=index)
229
+ token_embed=self.norm(token_embed)
230
+ return token_embed
231
+
232
+ #文本预测生成模型
233
+ class Qwen2ForCausalLM(nn.Module):
234
+ def __init__(self, config):
235
+ super().__init__()
236
+ self.config = config
237
+ self.model=Qwen2Model(vocab_size=config.vocab_size, hidden_size=config.hidden_size, num_layers=config.num_layers, num_attention_heads=config.num_attention_heads,num_kv_heads=config.num_kv_heads,expand_dim=config.expand_dim,max_position_embeddings=config.max_position_embeddings)
238
+ self.lm_head=nn.Linear(config.hidden_size,config.vocab_size,bias=False)
239
+ self.dtype=config.dtype
240
+ def forward(self,input_ids,attention_mask,use_cache=False,past_kv=None):
241
+ if use_cache:
242
+ if past_kv is None:
243
+ past_kv={}
244
+ output=self.model(input_ids=input_ids,attention_mask=attention_mask,use_cache=use_cache,past_kv=past_kv)
245
+ logits=self.lm_head(output)
246
+ return logits,past_kv
247
+ else:
248
+ output=self.model(input_ids=input_ids,attention_mask=attention_mask)
249
+ logits=self.lm_head(output)
250
+ return logits
251
+
252
+ class Qwen2config:
253
+ def __init__(self):
254
+ self.name = "Qwen2.5-0.5B"
255
+ self.vocab_size=151936
256
+ self.hidden_size=896
257
+ self.num_layers=24
258
+ self.num_kv_heads=2
259
+ self.num_attention_heads=14
260
+ self.max_position_embeddings= 32768
261
+ self.expand_dim=4864
262
+ self.dtype=torch.float16
263
+
264
+
265
+ config=Qwen2config()
266
+
267
+ qwen_model=Qwen2ForCausalLM(config)
268
+
269
+
270
+ #qwenva模型主体实现
271
+ #对齐层
272
+ class AlignLayer(torch.nn.Module):
273
+ def __init__(self,text1_dim,text2_dim,expand_dim):
274
+ super(AlignLayer, self).__init__()
275
+ self.vision_proj=vision_projection.to(dtype=config.dtype)
276
+ self.expand_proj=torch.nn.Linear(text1_dim,expand_dim)
277
+ self.text_proj=torch.nn.Linear(expand_dim,text2_dim)
278
+ self.activate=torch.nn.SiLU()
279
+ def forward(self,vision_embedding):
280
+ embed=self.vision_proj(vision_embedding)
281
+ embed=self.expand_proj(embed)
282
+ embed=self.activate(embed)
283
+ embed=self.text_proj(embed)
284
+ return embed
285
+ text_model=qwen_model
286
+ rotary_emb=text_model.model.rotary_emb
287
+ text_embedding=text_model.model.embed_tokens
288
+ transformer=text_model.model.layers
289
+ lm_head=text_model.lm_head
290
+ from transformers import AutoTokenizer
291
+ model_name="Qwen/Qwen2.5-0.5B"
292
+ tokenizer=AutoTokenizer.from_pretrained(model_name)
293
+ tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
294
+ from huggingface_hub import PyTorchModelHubMixin
295
+ class Qwenva(torch.nn.Module,PyTorchModelHubMixin):
296
+ def __init__(self,text1_dim,text2_dim,expand_dim,dtype=config.dtype):
297
+ super(Qwenva, self).__init__()
298
+ self.vision_encoder=vision_model.to(dtype=config.dtype)
299
+ self.text_embedding=text_embedding
300
+ self.align_layer=AlignLayer(text1_dim,text2_dim,expand_dim).to(dtype)
301
+ # 确保 align_layer 的参数梯度可用
302
+ self.transformer=transformer
303
+ self.rotary_emb=rotary_emb
304
+ #for param in self.rotary_emb.parameters():
305
+ #param.requires_grad = False
306
+ self.lm_head=lm_head
307
+ self.tokenizer=tokenizer
308
+ def forward(self,input_ids,attention_mask,pixel_values=None,image_idx=None,use_cache=True,past_kv=None):
309
+ #print(align_embedding.shape)
310
+ if past_kv is None and pixel_values is not None:
311
+ token_embedding=self.text_embedding(input_ids)
312
+ batch_size=input_ids.shape[0]
313
+ vision_embedding=self.vision_encoder(pixel_values)[1]
314
+ #print(vision_embedding.shape,attention_mask.shape)
315
+ align_embedding=self.align_layer(vision_embedding)
316
+ #print(align_embedding.shape)
317
+ #print(vision_embedding.shape,attention_mask.shape)
318
+ align_embedding=self.align_layer(vision_embedding)
319
+ mix_embedding=token_embedding.clone()
320
+ #print(mix_embedding.shape)
321
+ #print(align_embedding.shape)
322
+ #print(image_idx.shape)
323
+ #生成有效的嵌入位置坐标,image_idx的形状为[batch_size,1]
324
+ valid_indices = image_idx.ne(-100)
325
+ #print(valid_indices.squeeze())
326
+ valid_positions = torch.arange(batch_size).to(input_ids.device)
327
+ #print(valid_positions)
328
+ valid_positions = valid_positions[valid_indices.squeeze()].squeeze()
329
+ #print(valid_positions)
330
+ valid_image_idx =image_idx[valid_positions]
331
+ #print(valid_image_idx)
332
+ mix_embedding[valid_positions,valid_image_idx] = align_embedding[valid_positions]
333
+ past_kv={}
334
+ else:
335
+ mix_embedding=self.text_embedding(input_ids)
336
+ for index,layer in enumerate(self.transformer):
337
+ mix_embedding=layer(mix_embedding,attention_mask,position_embedding=self.rotary_emb,use_cache=use_cache,past_kv=past_kv,id=index)
338
+ #print(mix_embedding.shape)
339
+ logits=self.lm_head(mix_embedding)
340
+ if use_cache:
341
+ return logits,past_kv
342
+ else:
343
+ return logits
344
+ def generate(self,input_ids,attention_mask,pixel_values=None,image_idx=None,temperature=1,top_k=2,repetition_penalty=1.0,max_length=300):
345
+ import math
346
+ device=input_ids.device
347
+ #system_user_len=input_ids.shape[1]
348
+ token_eos = torch.tensor(tokenizer.encode('<|im_end|>')).to(device) # 终止符,遇到该字符就结束推理
349
+ out_token = None
350
+ #start_token=input_ids
351
+ temperature=temperature
352
+ top_k=top_k
353
+ repetition_penalty =repetition_penalty # 重复惩罚
354
+ import torch.nn.functional as F
355
+ past_kv=None
356
+ with torch.no_grad():
357
+ while out_token != token_eos and len(input_ids[0,:])<max_length:
358
+ #print(input_ids.shape)
359
+ # #print(attention_mask.shape)
360
+ if past_kv is None:
361
+ logits,past_kv=self.forward(input_ids,attention_mask,pixel_values,image_idx,use_cache=True,past_kv=past_kv)
362
+ else:
363
+ logits,past_kv=self.forward(input_ids[:,-1].unsqueeze(0),attention_mask[:,-1].unsqueeze(0),pixel_values,image_idx,use_cache=True,past_kv=past_kv)
364
+ # 应用重复惩罚
365
+ if len(input_ids[0,:]) > 1:
366
+ for i in input_ids[0]:
367
+ logits[0,-1,i] /= repetition_penalty
368
+ #top_k采样
369
+ top_k_logits,top_k_indices=torch.topk(logits[0,-1,:],k=top_k)
370
+ out_token=top_k_indices[torch.multinomial(F.softmax(top_k_logits/temperature,dim=-1),num_samples=1)].unsqueeze(0)
371
+ #最大采样
372
+ #out_token=torch.argmax(logits[0,-1,:]).unsqueeze(0).unsqueeze(0)
373
+ #start_token=out_token
374
+ input_ids =torch.cat([input_ids ,out_token], dim=1) # 每次都把之前的所有token与推理得到的新token拼接起来作为下次的输入
375
+ attention_mask = torch.cat([attention_mask,torch.ones(1,1).to(device)], dim=1) # 注意力掩码也要跟着变化
376
+ #text = self.tokenizer.decode(input_ids[0,:])
377
+ return input_ids
378
+
379
+
380
+
381
+ #processor实现,负责与处理数据
382
+ from transformers import CLIPProcessor, AutoTokenizer
383
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
384
+ model_name="Qwen/Qwen2.5-0.5B"
385
+ tokenizer=AutoTokenizer.from_pretrained(model_name)
386
+ tokenizer.add_special_tokens({"additional_special_tokens": ["<image>"]})
387
+ import torch
388
+ from huggingface_hub import TextGenerationOutputToken
389
+ from transformers import ProcessorMixin
390
+ class Proccessor(ProcessorMixin):
391
+ feature_extractor_class: str = "CLIPProcessor"
392
+ tokenizer_class: str = "Qwen2TokenizerFast"
393
+ def __init__(self,feature_extractor,tokenizer):
394
+ super().__init__(feature_extractor=feature_extractor,tokenizer=tokenizer)
395
+ self.tokenizer=tokenizer
396
+ self.feature_extractor=feature_extractor
397
+ self.image_token=self.tokenizer.encode('<image>')[0]
398
+ def __call__(self,input_data,input_image=None,device="cuda"):
399
+ if isinstance(input_data,str):
400
+ input_=self.tokenizer.apply_chat_template(
401
+ [{'role':'user','content':'<image>\n{}'.format(input_data)}
402
+ ],
403
+ add_generation_prompt=True,)
404
+ elif isinstance(input_data,list):
405
+ input_=self.tokenizer.apply_chat_template(
406
+ input_data,
407
+ add_generation_prompt=True,
408
+ )
409
+ input_ids=torch.tensor(input_).unsqueeze(0).to(device)
410
+ attention_mask=torch.ones(1,len(input_ids[0])).to(device)
411
+ img_idx=input_.index(self.image_token)
412
+ img_idx=torch.tensor(img_idx).unsqueeze(0).to(device)
413
+ if input_image is not None:
414
+ inputs = self.feature_extractor(images=input_image, return_tensors="pt")
415
+ pixel_values=inputs['pixel_values'].to('cuda')
416
+ return {
417
+ "input_ids":input_ids,
418
+ "attention_mask":attention_mask,
419
+ "pixel_values":pixel_values,
420
+ "image_idx":img_idx
421
+ }
422
+ else:
423
+ return {
424
+ "input_ids":input_ids,
425
+ "attention_mask":attention_mask}
426
+ processor=Proccessor(processor,tokenizer)
427
+ model=Qwenva(512,896,4096,dtype=config.dtype)
428
+ model.load_state_dict(torch.load("./qwenva.pth",weights_only=True))
429
+ model.eval()
430
+
431
+