English
yeswondwerr commited on
Commit
43f1a0b
1 Parent(s): 6cd45a2

Update __main__.py

Browse files
Files changed (1) hide show
  1. __main__.py +15 -23
__main__.py CHANGED
@@ -1,4 +1,5 @@
1
  import argparse
 
2
  import torch
3
  import torch.nn as nn
4
  from PIL import Image
@@ -12,26 +13,14 @@ from transformers import (
12
  from transformers import TextStreamer
13
 
14
 
15
- def tokenizer_image_token(
16
- prompt, tokenizer, image_token_index=-200, return_tensors=None
17
- ):
18
- prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
19
-
20
- def insert_separator(X, sep):
21
- return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
22
 
23
- input_ids = []
24
- offset = 0
25
- if (
26
- len(prompt_chunks) > 0
27
- and len(prompt_chunks[0]) > 0
28
- and prompt_chunks[0][0] == tokenizer.bos_token_id
29
- ):
30
- offset = 1
31
- input_ids.append(prompt_chunks[0][0])
32
-
33
- for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
34
- input_ids.extend(x[offset:])
35
 
36
  return torch.tensor(input_ids, dtype=torch.long)
37
 
@@ -126,18 +115,20 @@ def answer_question(
126
  tokenizer.eos_token = "<|eot_id|>"
127
 
128
  try:
129
- q = input("user: ")
130
  except EOFError:
131
  q = ""
132
  if not q:
133
  print("no input detected. exiting.")
 
 
134
 
135
  question = "<image>" + q
136
 
137
- prompt = f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
138
 
139
  input_ids = (
140
- tokenizer_image_token(prompt, tokenizer, -200, return_tensors="pt")
141
  .unsqueeze(0)
142
  .to(model.device)
143
  )
@@ -183,13 +174,14 @@ def answer_question(
183
  }
184
 
185
  while True:
 
186
  generated_ids = model.generate(
187
  inputs_embeds=new_embeds, attention_mask=attn_mask, **model_kwargs
188
  )[0]
189
 
190
  generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
191
  try:
192
- q = input("user: ")
193
  except EOFError:
194
  q = ""
195
  if not q:
 
1
  import argparse
2
+ import sys
3
  import torch
4
  import torch.nn as nn
5
  from PIL import Image
 
13
  from transformers import TextStreamer
14
 
15
 
16
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=-200):
17
+ prompt_chunks = prompt.split("<image>")
18
+ tokenized_chunks = [tokenizer(chunk).input_ids for chunk in prompt_chunks]
19
+ input_ids = tokenized_chunks[0]
 
 
 
20
 
21
+ for chunk in tokenized_chunks[1:]:
22
+ input_ids.append(image_token_index)
23
+ input_ids.extend(chunk[1:]) # Exclude BOS token on nonzero index
 
 
 
 
 
 
 
 
 
24
 
25
  return torch.tensor(input_ids, dtype=torch.long)
26
 
 
115
  tokenizer.eos_token = "<|eot_id|>"
116
 
117
  try:
118
+ q = input("\nuser: ")
119
  except EOFError:
120
  q = ""
121
  if not q:
122
  print("no input detected. exiting.")
123
+ sys.exit()
124
+
125
 
126
  question = "<image>" + q
127
 
128
+ prompt = f"<|start_header_id|>user<|end_header_id|>\n\n{question}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
129
 
130
  input_ids = (
131
+ tokenizer_image_token(prompt, tokenizer)
132
  .unsqueeze(0)
133
  .to(model.device)
134
  )
 
174
  }
175
 
176
  while True:
177
+ print('assistant: ')
178
  generated_ids = model.generate(
179
  inputs_embeds=new_embeds, attention_mask=attn_mask, **model_kwargs
180
  )[0]
181
 
182
  generated_text = tokenizer.decode(generated_ids, skip_special_tokens=False)
183
  try:
184
+ q = input("\nuser: ")
185
  except EOFError:
186
  q = ""
187
  if not q: