IAMJB commited on
Commit
fb97388
1 Parent(s): b11fb0b

Upload tokenization_chexagent.py

Browse files
Files changed (1) hide show
  1. tokenization_chexagent.py +646 -0
tokenization_chexagent.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from functools import lru_cache
3
+ from typing import TYPE_CHECKING
4
+
5
+ import regex as re
6
+ from transformers.tokenization_utils_base import TextInput
7
+ from transformers.utils import is_tf_available, is_torch_available, to_py_obj
8
+
9
+ if TYPE_CHECKING:
10
+ if is_torch_available():
11
+ import torch
12
+ if is_tf_available():
13
+ import tensorflow as tf
14
+
15
+ import os
16
+ import random
17
+ from typing import Dict, List, Tuple, Union, Any, Callable, Optional
18
+
19
+ import matplotlib as mpl
20
+ import matplotlib.colors as mcolors
21
+ import matplotlib.colors as mplc
22
+ import matplotlib.figure as mplfigure
23
+ import numpy as np
24
+ import requests
25
+ import torch
26
+ from PIL import Image
27
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
28
+ from transformers import PreTrainedTokenizer, AddedToken
29
+ from transformers.utils import logging
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+ VOCAB_FILES_NAMES = {
34
+ "vocab_file": "vocab.json",
35
+ "merges_file": "merges.txt",
36
+ }
37
+
38
+ PRETRAINED_VOCAB_FILES_MAP = {
39
+ "vocab_file": {
40
+ "Salesforce/codegen-350M-mono": "https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/vocab.json",
41
+ },
42
+ "merges_file": {
43
+ "Salesforce/codegen-350M-mono": "https://huggingface.co/Salesforce/codegen-350M-mono/resolve/main/merges.txt",
44
+ },
45
+ }
46
+
47
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
48
+ "Salesforce/codegen-350M-mono": 2048,
49
+ }
50
+
51
+ IMG_TOKEN_SPAN = 1024
52
+
53
+ DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['from'] == 'human' %}\n{{ '<|user|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'system' %}\n{{ '<|system|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'gpt' %}\n{{ '<|assistant|>\n' + message['value'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
54
+
55
+
56
+ @lru_cache()
57
+ def bytes_to_unicode():
58
+ """
59
+ Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
60
+ characters the bpe code barfs on.
61
+
62
+ The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
63
+ if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
64
+ decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
65
+ tables between utf-8 bytes and unicode strings.
66
+ """
67
+ bs = (
68
+ list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(
69
+ range(ord("®"), ord("ÿ") + 1))
70
+ )
71
+ cs = bs[:]
72
+ n = 0
73
+ for b in range(2 ** 8):
74
+ if b not in bs:
75
+ bs.append(b)
76
+ cs.append(2 ** 8 + n)
77
+ n += 1
78
+ cs = [chr(n) for n in cs]
79
+ return dict(zip(bs, cs))
80
+
81
+
82
+ def get_pairs(word):
83
+ """
84
+ Return set of symbol pairs in a word.
85
+
86
+ Word is represented as tuple of symbols (symbols being variable-length strings).
87
+ """
88
+ pairs = set()
89
+ prev_char = word[0]
90
+ for char in word[1:]:
91
+ pairs.add((prev_char, char))
92
+ prev_char = char
93
+ return pairs
94
+
95
+
96
+ def _list_find(
97
+ input_list: List[Any],
98
+ candidates: Tuple[Any],
99
+ start: int = 0,
100
+ ):
101
+ for i in range(start, len(input_list)):
102
+ if input_list[i] in candidates:
103
+ return i
104
+ return -1
105
+
106
+
107
+ def _replace_closed_tag(
108
+ input_tokens: List[Any],
109
+ start_tags: Union[Any, Tuple[Any]],
110
+ end_tags: Union[Any, Tuple[Any]],
111
+ inclusive_replace_func: Callable,
112
+ exclusive_replace_func: Callable = lambda x: x,
113
+ ):
114
+ if isinstance(start_tags, (str, int)):
115
+ start_tags = (start_tags,)
116
+ if isinstance(end_tags, (str, int)):
117
+ end_tags = (end_tags,)
118
+ assert len(start_tags) == len(end_tags)
119
+
120
+ output_tokens = []
121
+ end = 0
122
+ while True:
123
+ start = _list_find(input_tokens, start_tags, end)
124
+ if start == -1:
125
+ break
126
+ output_tokens.extend(exclusive_replace_func(input_tokens[end: start]))
127
+ tag_idx = start_tags.index(input_tokens[start])
128
+ end = _list_find(input_tokens, (end_tags[tag_idx],), start)
129
+ if end == -1:
130
+ raise ValueError("Unclosed image token")
131
+ output_tokens.extend(inclusive_replace_func(input_tokens[start: end + 1]))
132
+ end += 1
133
+ output_tokens.extend(exclusive_replace_func(input_tokens[end:]))
134
+ return output_tokens
135
+
136
+
137
+ class CheXagentTokenizer(PreTrainedTokenizer):
138
+ vocab_files_names = VOCAB_FILES_NAMES
139
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
140
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
141
+ model_input_names = ["input_ids", "attention_mask"]
142
+
143
+ def __init__(
144
+ self,
145
+ vocab_file,
146
+ merges_file,
147
+ errors="replace",
148
+ unk_token="<|endoftext|>",
149
+ bos_token="<|endoftext|>",
150
+ eos_token="<|endoftext|>",
151
+ pad_token=None,
152
+ add_prefix_space=False,
153
+ add_bos_token=False,
154
+ image_start_tag='<|img|>',
155
+ image_end_tag='<|/img|>',
156
+ image_pad_tag='<|imgpad|>',
157
+ ref_start_tag='<|ref|>',
158
+ ref_end_tag='<|/ref|>',
159
+ box_start_tag='<|box|>',
160
+ box_end_tag='<|/box|>',
161
+ quad_start_tag='<|quad|>',
162
+ quad_end_tag='<|/quad|>',
163
+ **kwargs,
164
+ ):
165
+ bos_token = AddedToken(bos_token, special=True) if isinstance(bos_token, str) else bos_token
166
+ eos_token = AddedToken(eos_token, special=True) if isinstance(eos_token, str) else eos_token
167
+ unk_token = AddedToken(unk_token, special=True) if isinstance(unk_token, str) else unk_token
168
+ pad_token = AddedToken(pad_token, special=True) if isinstance(pad_token, str) else pad_token
169
+ self.add_bos_token = add_bos_token
170
+
171
+ with open(vocab_file, encoding="utf-8") as vocab_handle:
172
+ self.encoder = json.load(vocab_handle)
173
+ self.decoder = {v: k for k, v in self.encoder.items()}
174
+ self.errors = errors # how to handle errors in decoding
175
+ self.byte_encoder = bytes_to_unicode()
176
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
177
+ with open(merges_file, encoding="utf-8") as merges_handle:
178
+ bpe_merges = merges_handle.read().split("\n")[1:-1]
179
+ bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
180
+ self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
181
+ self.cache = {}
182
+ self.add_prefix_space = add_prefix_space
183
+
184
+ # Should have added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
185
+ self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
186
+ super().__init__(
187
+ errors=errors,
188
+ unk_token=unk_token,
189
+ bos_token=bos_token,
190
+ eos_token=eos_token,
191
+ pad_token=pad_token,
192
+ add_prefix_space=add_prefix_space,
193
+ add_bos_token=add_bos_token,
194
+ **kwargs,
195
+ )
196
+
197
+ self.image_start_tag = image_start_tag
198
+ self.image_end_tag = image_end_tag
199
+ self.image_pad_tag = image_pad_tag
200
+ self.ref_start_tag = ref_start_tag
201
+ self.ref_end_tag = ref_end_tag
202
+ self.box_start_tag = box_start_tag
203
+ self.box_end_tag = box_end_tag
204
+ self.quad_start_tag = quad_start_tag
205
+ self.quad_end_tag = quad_end_tag
206
+ self.IMAGE_ST = (
207
+ image_start_tag, image_end_tag, image_pad_tag,
208
+ ref_start_tag, ref_end_tag, box_start_tag, box_end_tag,
209
+ quad_start_tag, quad_end_tag,
210
+ )
211
+ for special_token in self.IMAGE_ST:
212
+ if special_token not in self.get_vocab():
213
+ self.add_special_tokens({"additional_special_tokens": [special_token]})
214
+ for coordinate in range(10):
215
+ if f"<{coordinate}>" not in self.get_vocab():
216
+ self.add_special_tokens({"additional_special_tokens": [f"<|coord_{coordinate}|>"]})
217
+ if len(self) % 64 != 0:
218
+ for extra in range(((len(self) // 64) + 1) * 64 - len(self)):
219
+ if f"<extra_{extra}>" not in self.get_vocab():
220
+ self.add_special_tokens({"additional_special_tokens": [f"<|extra_{extra}|>"]})
221
+ self.img_start_id = self.convert_tokens_to_ids(self.image_start_tag)
222
+ self.img_end_id = self.convert_tokens_to_ids(self.image_end_tag)
223
+ self.img_pad_id = self.convert_tokens_to_ids(self.image_pad_tag)
224
+ self.ref_start_id = self.convert_tokens_to_ids(self.ref_start_tag)
225
+ self.ref_end_id = self.convert_tokens_to_ids(self.ref_end_tag)
226
+ self.box_start_id = self.convert_tokens_to_ids(self.box_start_tag)
227
+ self.box_end_id = self.convert_tokens_to_ids(self.box_end_tag)
228
+ self.quad_start_id = self.convert_tokens_to_ids(self.quad_start_tag)
229
+ self.quad_end_id = self.convert_tokens_to_ids(self.quad_end_tag)
230
+ self.chat_template = DEFAULT_CHAT_TEMPLATE
231
+
232
+ @property
233
+ def vocab_size(self):
234
+ return len(self.encoder)
235
+
236
+ def get_vocab(self):
237
+ return dict(self.encoder, **self.added_tokens_encoder)
238
+
239
+ def bpe(self, token):
240
+ if token in self.cache:
241
+ return self.cache[token]
242
+ word = tuple(token)
243
+ pairs = get_pairs(word)
244
+
245
+ if not pairs:
246
+ return token
247
+
248
+ while True:
249
+ bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
250
+ if bigram not in self.bpe_ranks:
251
+ break
252
+ first, second = bigram
253
+ new_word = []
254
+ i = 0
255
+ while i < len(word):
256
+ try:
257
+ j = word.index(first, i)
258
+ except ValueError:
259
+ new_word.extend(word[i:])
260
+ break
261
+ else:
262
+ new_word.extend(word[i:j])
263
+ i = j
264
+
265
+ if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
266
+ new_word.append(first + second)
267
+ i += 2
268
+ else:
269
+ new_word.append(word[i])
270
+ i += 1
271
+ new_word = tuple(new_word)
272
+ word = new_word
273
+ if len(word) == 1:
274
+ break
275
+ else:
276
+ pairs = get_pairs(word)
277
+ word = " ".join(word)
278
+ self.cache[token] = word
279
+ return word
280
+
281
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
282
+ if self.add_bos_token:
283
+ bos_token_ids = [self.bos_token_id]
284
+ else:
285
+ bos_token_ids = []
286
+
287
+ output = bos_token_ids + token_ids_0
288
+
289
+ if token_ids_1 is None:
290
+ return output
291
+
292
+ return output + bos_token_ids + token_ids_1
293
+
294
+ def tokenize(self, text: TextInput, **kwargs) -> List[str]:
295
+ def _encode_imgurl(img_tokens):
296
+ assert img_tokens[0] == self.image_start_tag and img_tokens[-1] == self.image_end_tag
297
+ img_tokens = img_tokens[1:-1]
298
+ img_url = ''.join(img_tokens)
299
+ out_img_tokens = list(img_url)
300
+ if len(out_img_tokens) > IMG_TOKEN_SPAN:
301
+ raise ValueError("The content in {}..{} is too long".format(self.image_start_tag, self.image_end_tag))
302
+ out_img_tokens.extend([self.image_pad_tag] * (IMG_TOKEN_SPAN - len(out_img_tokens)))
303
+ out_img_tokens = [self.image_start_tag] + out_img_tokens + [self.image_end_tag]
304
+ return out_img_tokens
305
+
306
+ tokens = super().tokenize(text, **kwargs)
307
+ tokens = _replace_closed_tag(tokens, self.image_start_tag, self.image_end_tag, _encode_imgurl)
308
+ return tokens
309
+
310
+ def _tokenize(self, text):
311
+ """Tokenize a string."""
312
+
313
+ bpe_tokens = []
314
+ for token in re.findall(self.pat, text):
315
+ token = "".join(
316
+ self.byte_encoder[b] for b in token.encode("utf-8")
317
+ ) # Maps all our bytes to unicode strings, avoiding control tokens of the BPE (spaces in our case)
318
+ bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(" "))
319
+ return bpe_tokens
320
+
321
+ def _convert_token_to_id(self, token):
322
+ """Converts a token (str) in an id using the vocab."""
323
+ return self.encoder.get(token, self.encoder.get(self.unk_token))
324
+
325
+ def _convert_id_to_token(self, index):
326
+ """Converts an index (integer) in a token (str) using the vocab."""
327
+ return self.decoder.get(index)
328
+
329
+ def convert_tokens_to_string(self, tokens):
330
+ """Converts a sequence of tokens (string) in a single string."""
331
+ text = "".join(tokens)
332
+ text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors)
333
+ return text
334
+
335
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
336
+ if not os.path.isdir(save_directory):
337
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
338
+ return
339
+ vocab_file = os.path.join(
340
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
341
+ )
342
+ merge_file = os.path.join(
343
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["merges_file"]
344
+ )
345
+
346
+ with open(vocab_file, "w", encoding="utf-8") as f:
347
+ f.write(json.dumps(self.encoder, indent=2, sort_keys=True, ensure_ascii=False) + "\n")
348
+
349
+ index = 0
350
+ with open(merge_file, "w", encoding="utf-8") as writer:
351
+ writer.write("#version: 0.2\n")
352
+ for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
353
+ if index != token_index:
354
+ logger.warning(
355
+ f"Saving vocabulary to {merge_file}: BPE merge indices are not consecutive."
356
+ " Please check that the tokenizer is not corrupted!"
357
+ )
358
+ index = token_index
359
+ writer.write(" ".join(bpe_tokens) + "\n")
360
+ index += 1
361
+
362
+ return vocab_file, merge_file
363
+
364
+ def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs):
365
+ add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space)
366
+ if is_split_into_words or add_prefix_space:
367
+ text = " " + text
368
+ return (text, kwargs)
369
+
370
+ def decode(
371
+ self,
372
+ token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor", "tf.Tensor"],
373
+ skip_special_tokens: bool = False,
374
+ clean_up_tokenization_spaces: bool = None,
375
+ truncate_before_pattern: Optional[List[str]] = None,
376
+ **kwargs,
377
+ ) -> str:
378
+ """
379
+ Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
380
+ tokens and clean up tokenization spaces.
381
+
382
+ Similar to doing `self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))`.
383
+
384
+ Args:
385
+ token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
386
+ List of tokenized input ids. Can be obtained using the `__call__` method.
387
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
388
+ Whether or not to remove special tokens in the decoding.
389
+ clean_up_tokenization_spaces (`bool`, *optional*):
390
+ Whether or not to clean up the tokenization spaces. If `None`, will default to
391
+ `self.clean_up_tokenization_spaces` (available in the `tokenizer_config`).
392
+ truncate_before_pattern (`List[str]`, *optional*, defaults to `None`):
393
+ A list of regular expression strings that will be used to truncate the returned string. This can be
394
+ used to remove extra pieces of code (e.g. truncate if observing a comment symbol "#" at the beginning
395
+ of a new line). An example pattern could be `["^#", re.escape("<|endoftext|>"), "^'''", "\n\n\n"]`.
396
+ kwargs (additional keyword arguments, *optional*):
397
+ Will be passed to the underlying model specific decode method.
398
+
399
+ Returns:
400
+ `str`: The decoded sentence.
401
+ """
402
+
403
+ token_ids = to_py_obj(token_ids)
404
+
405
+ decoded_text = self._decode(
406
+ token_ids=token_ids,
407
+ skip_special_tokens=skip_special_tokens,
408
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
409
+ **kwargs,
410
+ )
411
+
412
+ if truncate_before_pattern is not None and len(truncate_before_pattern) > 0:
413
+ decoded_text = self.truncate(decoded_text, truncate_before_pattern)
414
+
415
+ return decoded_text
416
+
417
+ def _decode(
418
+ self,
419
+ token_ids: List[int],
420
+ skip_special_tokens: bool = False,
421
+ clean_up_tokenization_spaces: bool = None,
422
+ spaces_between_special_tokens: bool = True,
423
+ **kwargs,
424
+ ) -> str:
425
+
426
+ def _decode_imgurl(img_token_ids):
427
+ assert img_token_ids[0] == self.img_start_id and img_token_ids[-1] == self.img_end_id
428
+ img_token_ids = img_token_ids[1:-1]
429
+ img_token_ids = img_token_ids[: img_token_ids.index(self.img_pad_id)]
430
+ return [self.img_start_id] + img_token_ids + [self.img_end_id]
431
+
432
+ token_ids = _replace_closed_tag(token_ids, self.img_start_id, self.img_end_id, _decode_imgurl)
433
+
434
+ return super()._decode(
435
+ token_ids, skip_special_tokens, clean_up_tokenization_spaces, spaces_between_special_tokens, **kwargs
436
+ )
437
+
438
+ def truncate(self, completion, truncate_before_pattern):
439
+ def find_re(string, pattern, start_pos):
440
+ m = pattern.search(string, start_pos)
441
+ return m.start() if m else -1
442
+
443
+ terminals = [re.compile(pattern, re.MULTILINE) for pattern in truncate_before_pattern]
444
+
445
+ prints = list(re.finditer("^print", completion, re.MULTILINE))
446
+
447
+ if len(prints) > 1:
448
+ completion = completion[: prints[1].start()]
449
+
450
+ defs = list(re.finditer("^def", completion, re.MULTILINE))
451
+
452
+ if len(defs) > 1:
453
+ completion = completion[: defs[1].start()]
454
+
455
+ start_pos = 0
456
+
457
+ terminals_pos = [
458
+ pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1
459
+ ]
460
+
461
+ if len(terminals_pos) > 0:
462
+ return completion[: min(terminals_pos)]
463
+ else:
464
+ return completion
465
+
466
+ def from_list_format(self, list_format: List[Dict]):
467
+ text = ''
468
+ num_images = 0
469
+ for ele in list_format:
470
+ if 'image' in ele:
471
+ num_images += 1
472
+ text += f'Picture {num_images}:'
473
+ text += self.image_start_tag + ele['image'] + self.image_end_tag
474
+ text += '\n'
475
+ elif 'text' in ele:
476
+ text += ele['text']
477
+ elif 'box' in ele:
478
+ if 'ref' in ele:
479
+ text += self.ref_start_tag + ele['ref'] + self.ref_end_tag
480
+ for box in ele['box']:
481
+ text += self.box_start_tag + '(%d,%d),(%d,%d)' % (box[0], box[1], box[2], box[3]) + self.box_end_tag
482
+ else:
483
+ raise ValueError("Unsupport element: " + str(ele))
484
+ return text
485
+
486
+ def _fetch_latest_picture(self, response, history):
487
+ if history is None:
488
+ history = []
489
+ _history = history + [(response, None)]
490
+ for q, r in _history[::-1]:
491
+ for ele in self.to_list_format(q)[::-1]:
492
+ if 'image' in ele:
493
+ return ele['image']
494
+ return None
495
+
496
+ def _fetch_all_box_with_ref(self, text):
497
+ list_format = self.to_list_format(text)
498
+ output = []
499
+ for i, ele in enumerate(list_format):
500
+ if 'box' in ele:
501
+ bbox = tuple(map(int, ele['box'].replace('(', '').replace(')', '').split(',')))
502
+ assert len(bbox) == 4
503
+ output.append({'box': bbox})
504
+ if i > 0 and 'ref' in list_format[i - 1]:
505
+ output[-1]['ref'] = list_format[i - 1]['ref'].strip()
506
+ return output
507
+
508
+ def draw_bbox_on_latest_picture(
509
+ self,
510
+ response,
511
+ history=None,
512
+ ) -> Optional[Image.Image]:
513
+ image = self._fetch_latest_picture(response, history)
514
+ if image is None:
515
+ return None
516
+ if image.startswith("http://") or image.startswith("https://"):
517
+ image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
518
+ h, w = image.height, image.width
519
+ else:
520
+ image = np.asarray(Image.open(image).convert("RGB"))
521
+ h, w = image.shape[0], image.shape[1]
522
+ visualizer = Visualizer(image)
523
+
524
+ boxes = self._fetch_all_box_with_ref(response)
525
+ if not boxes:
526
+ return None
527
+ color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()]) # init color
528
+ for box in boxes:
529
+ if 'ref' in box: # random new color for new refexps
530
+ color = random.choice([_ for _ in mcolors.TABLEAU_COLORS.keys()])
531
+ x1, y1, x2, y2 = box['box']
532
+ x1, y1, x2, y2 = (int(x1 / 1000 * w), int(y1 / 1000 * h), int(x2 / 1000 * w), int(y2 / 1000 * h))
533
+ visualizer.draw_box((x1, y1, x2, y2), alpha=1, edge_color=color)
534
+ if 'ref' in box:
535
+ visualizer.draw_text(box['ref'], (x1, y1), color=color, horizontal_alignment="left")
536
+ return visualizer.output
537
+
538
+
539
+ class VisImage:
540
+ def __init__(self, img, scale=1.0):
541
+ self.img = img
542
+ self.scale = scale
543
+ self.width, self.height = img.shape[1], img.shape[0]
544
+ self._setup_figure(img)
545
+
546
+ def _setup_figure(self, img):
547
+ fig = mplfigure.Figure(frameon=False)
548
+ self.dpi = fig.get_dpi()
549
+ # add a small 1e-2 to avoid precision lost due to matplotlib's truncation
550
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
551
+ fig.set_size_inches(
552
+ (self.width * self.scale + 1e-2) / self.dpi,
553
+ (self.height * self.scale + 1e-2) / self.dpi,
554
+ )
555
+ self.canvas = FigureCanvasAgg(fig)
556
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
557
+ ax = fig.add_axes([0.0, 0.0, 1.0, 1.0])
558
+ ax.axis("off")
559
+ self.fig = fig
560
+ self.ax = ax
561
+ self.reset_image(img)
562
+
563
+ def reset_image(self, img):
564
+ img = img.astype("uint8")
565
+ self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest")
566
+
567
+ def save(self, filepath):
568
+ self.fig.savefig(filepath)
569
+
570
+ def get_image(self):
571
+ canvas = self.canvas
572
+ s, (width, height) = canvas.print_to_buffer()
573
+
574
+ buffer = np.frombuffer(s, dtype="uint8")
575
+
576
+ img_rgba = buffer.reshape(height, width, 4)
577
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
578
+ return rgb.astype("uint8")
579
+
580
+
581
+ class Visualizer:
582
+ def __init__(self, img_rgb, metadata=None, scale=1.0):
583
+ self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8)
584
+ self.output = VisImage(self.img, scale=scale)
585
+ self.cpu_device = torch.device("cpu")
586
+
587
+ # too small texts are useless, therefore clamp to 14
588
+ self._default_font_size = max(
589
+ np.sqrt(self.output.height * self.output.width) // 30, 15 // scale
590
+ )
591
+
592
+ def draw_text(
593
+ self,
594
+ text,
595
+ position,
596
+ *,
597
+ font_size=None,
598
+ color="g",
599
+ horizontal_alignment="center",
600
+ rotation=0,
601
+ ):
602
+ if not font_size:
603
+ font_size = self._default_font_size
604
+
605
+ # since the text background is dark, we don't want the text to be dark
606
+ color = np.maximum(list(mplc.to_rgb(color)), 0.2)
607
+ color[np.argmax(color)] = max(0.8, np.max(color))
608
+
609
+ x, y = position
610
+ self.output.ax.text(
611
+ x,
612
+ y,
613
+ text,
614
+ size=font_size * self.output.scale,
615
+ bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"},
616
+ verticalalignment="top",
617
+ horizontalalignment=horizontal_alignment,
618
+ color=color,
619
+ zorder=10,
620
+ rotation=rotation,
621
+ )
622
+ return self.output
623
+
624
+ def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"):
625
+ x0, y0, x1, y1 = box_coord
626
+ width = x1 - x0
627
+ height = y1 - y0
628
+
629
+ linewidth = max(self._default_font_size / 4, 1)
630
+
631
+ self.output.ax.add_patch(
632
+ mpl.patches.Rectangle(
633
+ (x0, y0),
634
+ width,
635
+ height,
636
+ fill=False,
637
+ edgecolor=edge_color,
638
+ linewidth=linewidth * self.output.scale,
639
+ alpha=alpha,
640
+ linestyle=line_style,
641
+ )
642
+ )
643
+ return self.output
644
+
645
+ def get_output(self):
646
+ return self.output