Helw150
commited on
Commit
•
5b2106a
1
Parent(s):
bf4916e
Less Magic Tokens
Browse files- modeling_diva.py +55 -22
modeling_diva.py
CHANGED
@@ -88,17 +88,30 @@ class DiVAModel(PreTrainedModel):
|
|
88 |
torch_dtype=torch.float16,
|
89 |
)
|
90 |
self.processor = AutoProcessor.from_pretrained(config_dict["reference_encoder"])
|
91 |
-
self.tokenizer = AutoTokenizer.from_pretrained("
|
92 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
self.llm_decoder.model.embed_tokens.weight.device
|
94 |
)
|
95 |
|
96 |
-
self.
|
97 |
-
self.tokenizer.encode(
|
98 |
-
|
99 |
-
|
|
|
100 |
).to(self.llm_decoder.model.embed_tokens.weight.device)
|
101 |
-
self.final_header = torch.tensor(
|
102 |
self.llm_decoder.model.embed_tokens.weight.device
|
103 |
)
|
104 |
self.speech_encoder_device = speech_encoder_device
|
@@ -116,9 +129,7 @@ class DiVAModel(PreTrainedModel):
|
|
116 |
**kwargs,
|
117 |
):
|
118 |
if os.path.isdir(pretrained_model_name_or_path):
|
119 |
-
via_path =
|
120 |
-
pretrained_model_name_or_path + "/model.safetensors"
|
121 |
-
)
|
122 |
config_path = pretrained_model_name_or_path + "/config.json"
|
123 |
else:
|
124 |
# Loading from huggingface repo
|
@@ -207,16 +218,16 @@ class DiVAModel(PreTrainedModel):
|
|
207 |
padding=True,
|
208 |
padding_side="right",
|
209 |
)["input_ids"],
|
210 |
-
device=self.
|
211 |
)
|
212 |
prefix = torch.cat(
|
213 |
[
|
214 |
-
self.
|
215 |
bsz,
|
216 |
-1,
|
217 |
),
|
218 |
user_prompt_text,
|
219 |
-
self.
|
220 |
bsz,
|
221 |
-1,
|
222 |
),
|
@@ -292,11 +303,27 @@ class DiVAModel(PreTrainedModel):
|
|
292 |
|
293 |
if text_prompt != None and text_prompt != "":
|
294 |
user_prompt_text = torch.tensor(
|
295 |
-
self.tokenizer(
|
296 |
-
|
|
|
|
|
|
|
|
|
|
|
297 |
)
|
298 |
prefix = torch.cat(
|
299 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
)
|
301 |
else:
|
302 |
prefix = self.prefix
|
@@ -344,14 +371,20 @@ class DiVAModel(PreTrainedModel):
|
|
344 |
"<|eot_id|>", ""
|
345 |
)
|
346 |
else:
|
347 |
-
yield (
|
348 |
-
|
349 |
-
|
|
|
|
|
|
|
350 |
if not return_outputs:
|
351 |
return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
|
352 |
"<|eot_id|>", ""
|
353 |
)
|
354 |
else:
|
355 |
-
return (
|
356 |
-
|
357 |
-
|
|
|
|
|
|
|
|
88 |
torch_dtype=torch.float16,
|
89 |
)
|
90 |
self.processor = AutoProcessor.from_pretrained(config_dict["reference_encoder"])
|
91 |
+
self.tokenizer = AutoTokenizer.from_pretrained(config_dict["reference_decoder"])
|
92 |
+
if self.tokenizer.pad_token_id == None:
|
93 |
+
override_token = list(self.tokenizer.added_tokens_decoder.items())[-1]
|
94 |
+
self.tokenizer.pad_token_id = override_token[0]
|
95 |
+
self.tokenizer.pad_tokn = str(override_token[1])
|
96 |
+
prefix, suffix = self.tokenizer.apply_chat_template(
|
97 |
+
[{"role": "user", "content": "PLACEHOLDER"}],
|
98 |
+
tokenize=False,
|
99 |
+
add_generation_prompt=True,
|
100 |
+
).split("PLACEHOLDER")
|
101 |
+
non_null = [line for line in prefix.split("\n") if line.strip()]
|
102 |
+
prefix_tok = self.tokenizer.encode(prefix, add_special_tokens=False)
|
103 |
+
suffix_tok = self.tokenizer.encode(suffix, add_special_tokens=False)
|
104 |
+
self.prefix = torch.tensor(prefix_tok).to(
|
105 |
self.llm_decoder.model.embed_tokens.weight.device
|
106 |
)
|
107 |
|
108 |
+
self.pre_system = torch.tensor(
|
109 |
+
self.tokenizer.encode(non_null[0] + "\n", add_special_tokens=False)
|
110 |
+
).to(self.llm_decoder.model.embed_tokens.weight.device)
|
111 |
+
self.post_system = torch.tensor(
|
112 |
+
self.tokenizer.encode("\n" + non_null[-1] + "\n", add_special_tokens=False)
|
113 |
).to(self.llm_decoder.model.embed_tokens.weight.device)
|
114 |
+
self.final_header = torch.tensor(suffix_tok).to(
|
115 |
self.llm_decoder.model.embed_tokens.weight.device
|
116 |
)
|
117 |
self.speech_encoder_device = speech_encoder_device
|
|
|
129 |
**kwargs,
|
130 |
):
|
131 |
if os.path.isdir(pretrained_model_name_or_path):
|
132 |
+
via_path = pretrained_model_name_or_path + "/model.safetensors"
|
|
|
|
|
133 |
config_path = pretrained_model_name_or_path + "/config.json"
|
134 |
else:
|
135 |
# Loading from huggingface repo
|
|
|
218 |
padding=True,
|
219 |
padding_side="right",
|
220 |
)["input_ids"],
|
221 |
+
device=self.pre_system.device,
|
222 |
)
|
223 |
prefix = torch.cat(
|
224 |
[
|
225 |
+
self.pre_system.expand(
|
226 |
bsz,
|
227 |
-1,
|
228 |
),
|
229 |
user_prompt_text,
|
230 |
+
self.post_system.expand(
|
231 |
bsz,
|
232 |
-1,
|
233 |
),
|
|
|
303 |
|
304 |
if text_prompt != None and text_prompt != "":
|
305 |
user_prompt_text = torch.tensor(
|
306 |
+
self.tokenizer(
|
307 |
+
text_prompt,
|
308 |
+
add_special_tokens=False,
|
309 |
+
padding=True,
|
310 |
+
padding_side="right",
|
311 |
+
)["input_ids"],
|
312 |
+
device=self.pre_system.device,
|
313 |
)
|
314 |
prefix = torch.cat(
|
315 |
+
[
|
316 |
+
self.pre_system.expand(
|
317 |
+
bsz,
|
318 |
+
-1,
|
319 |
+
),
|
320 |
+
user_prompt_text,
|
321 |
+
self.post_system.expand(
|
322 |
+
bsz,
|
323 |
+
-1,
|
324 |
+
),
|
325 |
+
],
|
326 |
+
axis=1,
|
327 |
)
|
328 |
else:
|
329 |
prefix = self.prefix
|
|
|
371 |
"<|eot_id|>", ""
|
372 |
)
|
373 |
else:
|
374 |
+
yield (
|
375 |
+
self.tokenizer.decode(outs, skip_special_tokens=True).replace(
|
376 |
+
"<|eot_id|>", ""
|
377 |
+
),
|
378 |
+
outputs,
|
379 |
+
)
|
380 |
if not return_outputs:
|
381 |
return self.tokenizer.decode(outs, skip_special_tokens=True).replace(
|
382 |
"<|eot_id|>", ""
|
383 |
)
|
384 |
else:
|
385 |
+
return (
|
386 |
+
self.tokenizer.decode(outs, skip_special_tokens=True).replace(
|
387 |
+
"<|eot_id|>", ""
|
388 |
+
),
|
389 |
+
outputs,
|
390 |
+
)
|