Feature Extraction
Transformers
Safetensors
diva
custom_code
Helw150 commited on
Commit
5b2106a
1 Parent(s): bf4916e

Less Magic Tokens

Browse files
Files changed (1) hide show
  1. modeling_diva.py +55 -22
modeling_diva.py CHANGED
@@ -88,17 +88,30 @@ class DiVAModel(PreTrainedModel):
88
  torch_dtype=torch.float16,
89
  )
90
  self.processor = AutoProcessor.from_pretrained(config_dict["reference_encoder"])
91
- self.tokenizer = AutoTokenizer.from_pretrained("WillHeld/via-llama")
92
- self.prefix = torch.tensor([128000, 128006, 882, 128007, 271]).to(
 
 
 
 
 
 
 
 
 
 
 
 
93
  self.llm_decoder.model.embed_tokens.weight.device
94
  )
95
 
96
- self.pre_user_suffix = torch.tensor(
97
- self.tokenizer.encode(
98
- "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n"
99
- )
 
100
  ).to(self.llm_decoder.model.embed_tokens.weight.device)
101
- self.final_header = torch.tensor([128009, 128006, 78191, 128007, 271]).to(
102
  self.llm_decoder.model.embed_tokens.weight.device
103
  )
104
  self.speech_encoder_device = speech_encoder_device
@@ -116,9 +129,7 @@ class DiVAModel(PreTrainedModel):
116
  **kwargs,
117
  ):
118
  if os.path.isdir(pretrained_model_name_or_path):
119
- via_path = (
120
- pretrained_model_name_or_path + "/model.safetensors"
121
- )
122
  config_path = pretrained_model_name_or_path + "/config.json"
123
  else:
124
  # Loading from huggingface repo
@@ -207,16 +218,16 @@ class DiVAModel(PreTrainedModel):
207
  padding=True,
208
  padding_side="right",
209
  )["input_ids"],
210
- device=self.pre_user_suffix.device,
211
  )
212
  prefix = torch.cat(
213
  [
214
- self.pre_user_suffix.expand(
215
  bsz,
216
  -1,
217
  ),
218
  user_prompt_text,
219
- self.prefix.expand(
220
  bsz,
221
  -1,
222
  ),
@@ -292,11 +303,27 @@ class DiVAModel(PreTrainedModel):
292
 
293
  if text_prompt != None and text_prompt != "":
294
  user_prompt_text = torch.tensor(
295
- self.tokenizer(text_prompt, add_special_tokens=False)["input_ids"],
296
- device=self.pre_user_suffix.device,
 
 
 
 
 
297
  )
298
  prefix = torch.cat(
299
- [self.pre_user_suffix, user_prompt_text, self.prefix], axis=0
 
 
 
 
 
 
 
 
 
 
 
300
  )
301
  else:
302
  prefix = self.prefix
@@ -344,14 +371,20 @@ class DiVAModel(PreTrainedModel):
344
  "<|eot_id|>", ""
345
  )
346
  else:
347
- yield (self.tokenizer.decode(outs, skip_special_tokens=True).replace(
348
- "<|eot_id|>", ""
349
- ), outputs)
 
 
 
350
  if not return_outputs:
351
  return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
352
  "<|eot_id|>", ""
353
  )
354
  else:
355
- return (self.tokenizer.decode(outs, skip_special_tokens=True).replace(
356
- "<|eot_id|>", ""
357
- ), outputs)
 
 
 
 
88
  torch_dtype=torch.float16,
89
  )
90
  self.processor = AutoProcessor.from_pretrained(config_dict["reference_encoder"])
91
+ self.tokenizer = AutoTokenizer.from_pretrained(config_dict["reference_decoder"])
92
+ if self.tokenizer.pad_token_id == None:
93
+ override_token = list(self.tokenizer.added_tokens_decoder.items())[-1]
94
+ self.tokenizer.pad_token_id = override_token[0]
95
+ self.tokenizer.pad_tokn = str(override_token[1])
96
+ prefix, suffix = self.tokenizer.apply_chat_template(
97
+ [{"role": "user", "content": "PLACEHOLDER"}],
98
+ tokenize=False,
99
+ add_generation_prompt=True,
100
+ ).split("PLACEHOLDER")
101
+ non_null = [line for line in prefix.split("\n") if line.strip()]
102
+ prefix_tok = self.tokenizer.encode(prefix, add_special_tokens=False)
103
+ suffix_tok = self.tokenizer.encode(suffix, add_special_tokens=False)
104
+ self.prefix = torch.tensor(prefix_tok).to(
105
  self.llm_decoder.model.embed_tokens.weight.device
106
  )
107
 
108
+ self.pre_system = torch.tensor(
109
+ self.tokenizer.encode(non_null[0] + "\n", add_special_tokens=False)
110
+ ).to(self.llm_decoder.model.embed_tokens.weight.device)
111
+ self.post_system = torch.tensor(
112
+ self.tokenizer.encode("\n" + non_null[-1] + "\n", add_special_tokens=False)
113
  ).to(self.llm_decoder.model.embed_tokens.weight.device)
114
+ self.final_header = torch.tensor(suffix_tok).to(
115
  self.llm_decoder.model.embed_tokens.weight.device
116
  )
117
  self.speech_encoder_device = speech_encoder_device
 
129
  **kwargs,
130
  ):
131
  if os.path.isdir(pretrained_model_name_or_path):
132
+ via_path = pretrained_model_name_or_path + "/model.safetensors"
 
 
133
  config_path = pretrained_model_name_or_path + "/config.json"
134
  else:
135
  # Loading from huggingface repo
 
218
  padding=True,
219
  padding_side="right",
220
  )["input_ids"],
221
+ device=self.pre_system.device,
222
  )
223
  prefix = torch.cat(
224
  [
225
+ self.pre_system.expand(
226
  bsz,
227
  -1,
228
  ),
229
  user_prompt_text,
230
+ self.post_system.expand(
231
  bsz,
232
  -1,
233
  ),
 
303
 
304
  if text_prompt != None and text_prompt != "":
305
  user_prompt_text = torch.tensor(
306
+ self.tokenizer(
307
+ text_prompt,
308
+ add_special_tokens=False,
309
+ padding=True,
310
+ padding_side="right",
311
+ )["input_ids"],
312
+ device=self.pre_system.device,
313
  )
314
  prefix = torch.cat(
315
+ [
316
+ self.pre_system.expand(
317
+ bsz,
318
+ -1,
319
+ ),
320
+ user_prompt_text,
321
+ self.post_system.expand(
322
+ bsz,
323
+ -1,
324
+ ),
325
+ ],
326
+ axis=1,
327
  )
328
  else:
329
  prefix = self.prefix
 
371
  "<|eot_id|>", ""
372
  )
373
  else:
374
+ yield (
375
+ self.tokenizer.decode(outs, skip_special_tokens=True).replace(
376
+ "<|eot_id|>", ""
377
+ ),
378
+ outputs,
379
+ )
380
  if not return_outputs:
381
  return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
382
  "<|eot_id|>", ""
383
  )
384
  else:
385
+ return (
386
+ self.tokenizer.decode(outs, skip_special_tokens=True).replace(
387
+ "<|eot_id|>", ""
388
+ ),
389
+ outputs,
390
+ )