tinyllava commited on
Commit
c1585b5
1 Parent(s): a916181

Delete generate_model.py

Browse files
Files changed (1) hide show
  1. generate_model.py +0 -730
generate_model.py DELETED
@@ -1,730 +0,0 @@
1
- import argparse
2
- import time
3
- import logging
4
- import requests
5
- import os
6
- from PIL import Image
7
- from io import BytesIO
8
-
9
- from PIL import Image
10
- import torch
11
- from transformers import AutoTokenizer
12
-
13
- from transformers import AutoTokenizer, AutoModelForCausalLM
14
-
15
- from PIL import Image
16
- from io import BytesIO
17
- import base64
18
-
19
- import torch
20
- from transformers import StoppingCriteria
21
-
22
- import math
23
- import ast
24
-
25
- # Model Constants
26
- IGNORE_INDEX = -100
27
- IMAGE_TOKEN_INDEX = -200
28
- DEFAULT_IMAGE_TOKEN = "<image>"
29
- DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
30
- DEFAULT_IM_START_TOKEN = "<im_start>"
31
- DEFAULT_IM_END_TOKEN = "<im_end>"
32
- IMAGE_PLACEHOLDER = "<image-placeholder>"
33
- import dataclasses
34
- from enum import auto, Enum
35
- from typing import List, Tuple
36
-
37
-
38
- class SeparatorStyle(Enum):
39
- """Different separator style."""
40
- SINGLE = auto()
41
- TWO = auto()
42
- MPT = auto()
43
- PLAIN = auto()
44
- LLAMA_2 = auto()
45
- TINY_LLAMA = auto()
46
- QWEN_2 = auto()
47
-
48
-
49
- @dataclasses.dataclass
50
- class Conversation:
51
- """A class that keeps all conversation history."""
52
- system: str
53
- roles: List[str]
54
- messages: List[List[str]]
55
- offset: int
56
- sep_style: SeparatorStyle = SeparatorStyle.SINGLE
57
- sep: str = "###"
58
- sep2: str = None
59
- version: str = "Unknown"
60
-
61
- skip_next: bool = False
62
-
63
- def get_prompt(self):
64
- messages = self.messages
65
- if len(messages) > 0 and type(messages[0][1]) is tuple:
66
- messages = self.messages.copy()
67
- init_role, init_msg = messages[0].copy()
68
- init_msg = init_msg[0].replace("<image>", "").strip()
69
- if 'mmtag' in self.version:
70
- messages[0] = (init_role, init_msg)
71
- messages.insert(0, (self.roles[0], "<Image><image></Image>"))
72
- messages.insert(1, (self.roles[1], "Received."))
73
- else:
74
- messages[0] = (init_role, "<image>\n" + init_msg)
75
-
76
- if self.sep_style == SeparatorStyle.SINGLE:
77
- ret = self.system + self.sep
78
- for role, message in messages:
79
- if message:
80
- if type(message) is tuple:
81
- message, _, _ = message
82
- ret += role + ": " + message + self.sep
83
- else:
84
- ret += role + ":"
85
- elif self.sep_style == SeparatorStyle.TWO:
86
- seps = [self.sep, self.sep2]
87
- ret = self.system + seps[0]
88
- for i, (role, message) in enumerate(messages):
89
- if message:
90
- if type(message) is tuple:
91
- message, _, _ = message
92
- ret += role + ": " + message + seps[i % 2]
93
- else:
94
- ret += role + ":"
95
- elif self.sep_style == SeparatorStyle.MPT:
96
- ret = self.system + self.sep
97
- for role, message in messages:
98
- if message:
99
- if type(message) is tuple:
100
- message, _, _ = message
101
- ret += role + message + self.sep
102
- else:
103
- ret += role
104
- elif self.sep_style == SeparatorStyle.LLAMA_2:
105
- wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
106
- wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
107
- ret = ""
108
-
109
- for i, (role, message) in enumerate(messages):
110
- if i == 0:
111
- assert message, "first message should not be none"
112
- assert role == self.roles[0], "first message should come from user"
113
- if message:
114
- if type(message) is tuple:
115
- message, _, _ = message
116
- if i == 0: message = wrap_sys(self.system) + message
117
- if i % 2 == 0:
118
- message = wrap_inst(message)
119
- ret += self.sep + message
120
- else:
121
- ret += " " + message + " " + self.sep2
122
- else:
123
- ret += ""
124
- ret = ret.lstrip(self.sep)
125
- elif self.sep_style == SeparatorStyle.TINY_LLAMA:
126
- sep = "</s>"
127
- wrap_sys = lambda msg: f"<|system|>\n{msg}\n"
128
- wrap_user = lambda msg: f"<|user|>\n{msg}\n"
129
- wrap_assistant = lambda msg: f"<|assistant|>\n{msg}"
130
- ret = ""
131
-
132
- for i, (role, message) in enumerate(messages):
133
- if i == 0:
134
- assert message, "first message should not be none"
135
- assert role == self.roles[0], "first message should come from user"
136
- if message:
137
- if type(message) is tuple:
138
- message, _, _ = message
139
- if i % 2 == 0:
140
- message = wrap_user(message)
141
- if i == 0:
142
- message = wrap_sys(self.system) + message
143
- ret += self.sep + message
144
- else:
145
- message = wrap_assistant(message) + self.sep2
146
- ret += message
147
- else:
148
- ret += "<|assistant|>\n"
149
- ret = ret.lstrip(self.sep)
150
- elif self.sep_style == SeparatorStyle.QWEN_2:
151
- ret = self.system + self.sep
152
- for role, message in messages:
153
- if message:
154
- if type(message) is tuple:
155
- message, _, _ = message
156
- ret += role + message + self.sep
157
- else:
158
- ret += role
159
- elif self.sep_style == SeparatorStyle.PLAIN:
160
- seps = [self.sep, self.sep2]
161
- ret = self.system
162
- for i, (role, message) in enumerate(messages):
163
- if message:
164
- if type(message) is tuple:
165
- message, _, _ = message
166
- ret += message + seps[i % 2]
167
- else:
168
- ret += ""
169
- else:
170
- raise ValueError(f"Invalid style: {self.sep_style}")
171
-
172
- return ret
173
-
174
- def append_message(self, role, message):
175
- self.messages.append([role, message])
176
-
177
- def get_images(self, return_pil=False):
178
- images = []
179
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
180
- if i % 2 == 0:
181
- if type(msg) is tuple:
182
- import base64
183
- from io import BytesIO
184
- from PIL import Image
185
- msg, image, image_process_mode = msg
186
- if image_process_mode == "Pad":
187
- def expand2square(pil_img, background_color=(122, 116, 104)):
188
- width, height = pil_img.size
189
- if width == height:
190
- return pil_img
191
- elif width > height:
192
- result = Image.new(pil_img.mode, (width, width), background_color)
193
- result.paste(pil_img, (0, (width - height) // 2))
194
- return result
195
- else:
196
- result = Image.new(pil_img.mode, (height, height), background_color)
197
- result.paste(pil_img, ((height - width) // 2, 0))
198
- return result
199
- image = expand2square(image)
200
- elif image_process_mode in ["Default", "Crop"]:
201
- pass
202
- elif image_process_mode == "Resize":
203
- image = image.resize((336, 336))
204
- else:
205
- raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
206
- max_hw, min_hw = max(image.size), min(image.size)
207
- aspect_ratio = max_hw / min_hw
208
- max_len, min_len = 800, 400
209
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
210
- longest_edge = int(shortest_edge * aspect_ratio)
211
- W, H = image.size
212
- if longest_edge != max(image.size):
213
- if H > W:
214
- H, W = longest_edge, shortest_edge
215
- else:
216
- H, W = shortest_edge, longest_edge
217
- image = image.resize((W, H))
218
- if return_pil:
219
- images.append(image)
220
- else:
221
- buffered = BytesIO()
222
- image.save(buffered, format="PNG")
223
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
224
- images.append(img_b64_str)
225
- return images
226
-
227
- def to_gradio_chatbot(self):
228
- ret = []
229
- for i, (role, msg) in enumerate(self.messages[self.offset:]):
230
- if i % 2 == 0:
231
- if type(msg) is tuple:
232
- import base64
233
- from io import BytesIO
234
- msg, image, image_process_mode = msg
235
- max_hw, min_hw = max(image.size), min(image.size)
236
- aspect_ratio = max_hw / min_hw
237
- max_len, min_len = 800, 400
238
- shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
239
- longest_edge = int(shortest_edge * aspect_ratio)
240
- W, H = image.size
241
- if H > W:
242
- H, W = longest_edge, shortest_edge
243
- else:
244
- H, W = shortest_edge, longest_edge
245
- image = image.resize((W, H))
246
- buffered = BytesIO()
247
- image.save(buffered, format="JPEG")
248
- img_b64_str = base64.b64encode(buffered.getvalue()).decode()
249
- img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
250
- msg = img_str + msg.replace('<image>', '').strip()
251
- ret.append([msg, None])
252
- else:
253
- ret.append([msg, None])
254
- else:
255
- ret[-1][-1] = msg
256
- return ret
257
-
258
- def copy(self):
259
- return Conversation(
260
- system=self.system,
261
- roles=self.roles,
262
- messages=[[x, y] for x, y in self.messages],
263
- offset=self.offset,
264
- sep_style=self.sep_style,
265
- sep=self.sep,
266
- sep2=self.sep2,
267
- version=self.version)
268
-
269
- def dict(self):
270
- if len(self.get_images()) > 0:
271
- return {
272
- "system": self.system,
273
- "roles": self.roles,
274
- "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
275
- "offset": self.offset,
276
- "sep": self.sep,
277
- "sep2": self.sep2,
278
- }
279
- return {
280
- "system": self.system,
281
- "roles": self.roles,
282
- "messages": self.messages,
283
- "offset": self.offset,
284
- "sep": self.sep,
285
- "sep2": self.sep2,
286
- }
287
-
288
-
289
-
290
-
291
- conv_phi_v0 = Conversation(
292
- system="A chat between a curious user and an artificial intelligence assistant. "
293
- "The assistant gives helpful, detailed, and polite answers to the user's questions.",
294
- roles=("USER", "ASSISTANT"),
295
- version="phi",
296
- messages=(),
297
- offset=0,
298
- sep_style=SeparatorStyle.TWO,
299
- sep=" ",
300
- sep2="<|endoftext|>",
301
- )
302
-
303
-
304
-
305
- def select_best_resolution(original_size, possible_resolutions):
306
- """
307
- Selects the best resolution from a list of possible resolutions based on the original size.
308
-
309
- Args:
310
- original_size (tuple): The original size of the image in the format (width, height).
311
- possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
312
-
313
- Returns:
314
- tuple: The best fit resolution in the format (width, height).
315
- """
316
- original_width, original_height = original_size
317
- best_fit = None
318
- max_effective_resolution = 0
319
- min_wasted_resolution = float('inf')
320
-
321
- for width, height in possible_resolutions:
322
- scale = min(width / original_width, height / original_height)
323
- downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
324
- effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
325
- wasted_resolution = (width * height) - effective_resolution
326
-
327
- if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
328
- max_effective_resolution = effective_resolution
329
- min_wasted_resolution = wasted_resolution
330
- best_fit = (width, height)
331
-
332
- return best_fit
333
-
334
-
335
- ## added by llava-1.6
336
- def resize_and_pad_image(image, target_resolution):
337
- """
338
- Resize and pad an image to a target resolution while maintaining aspect ratio.
339
-
340
- Args:
341
- image (PIL.Image.Image): The input image.
342
- target_resolution (tuple): The target resolution (width, height) of the image.
343
-
344
- Returns:
345
- PIL.Image.Image: The resized and padded image.
346
- """
347
- original_width, original_height = image.size
348
- target_width, target_height = target_resolution
349
-
350
- scale_w = target_width / original_width
351
- scale_h = target_height / original_height
352
-
353
- if scale_w < scale_h:
354
- new_width = target_width
355
- new_height = min(math.ceil(original_height * scale_w), target_height)
356
- else:
357
- new_height = target_height
358
- new_width = min(math.ceil(original_width * scale_h), target_width)
359
-
360
- # Resize the image
361
- resized_image = image.resize((new_width, new_height))
362
-
363
- new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
364
- paste_x = (target_width - new_width) // 2
365
- paste_y = (target_height - new_height) // 2
366
- new_image.paste(resized_image, (paste_x, paste_y))
367
-
368
- return new_image
369
-
370
-
371
- ## added by llava-1.6
372
- def divide_to_patches(image, patch_size):
373
- """
374
- Divides an image into patches of a specified size.
375
-
376
- Args:
377
- image (PIL.Image.Image): The input image.
378
- patch_size (int): The size of each patch.
379
-
380
- Returns:
381
- list: A list of PIL.Image.Image objects representing the patches.
382
- """
383
- patches = []
384
- width, height = image.size
385
- for i in range(0, height, patch_size):
386
- for j in range(0, width, patch_size):
387
- box = (j, i, j + patch_size, i + patch_size)
388
- patch = image.crop(box)
389
- patches.append(patch)
390
-
391
- return patches
392
-
393
-
394
- ## added by llava-1.6
395
- def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
396
- """
397
- Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
398
-
399
- Args:
400
- image_size (tuple): The size of the input image in the format (width, height).
401
- grid_pinpoints (str): A string representation of a list of possible resolutions.
402
- patch_size (int): The size of each image patch.
403
-
404
- Returns:
405
- tuple: The shape of the image patch grid in the format (width, height).
406
- """
407
- if type(grid_pinpoints) is list:
408
- possible_resolutions = grid_pinpoints
409
- else:
410
- possible_resolutions = ast.literal_eval(grid_pinpoints)
411
- width, height = select_best_resolution(image_size, possible_resolutions)
412
- return width // patch_size, height // patch_size
413
-
414
-
415
- ## added by llava-1.6
416
- def process_anyres_image(image, processor, grid_pinpoints):
417
- """
418
- Process an image with variable resolutions.
419
-
420
- Args:
421
- image (PIL.Image.Image): The input image to be processed.
422
- processor: The image processor object.
423
- grid_pinpoints (str): A string representation of a list of possible resolutions.
424
-
425
- Returns:
426
- torch.Tensor: A tensor containing the processed image patches.
427
- """
428
- if type(grid_pinpoints) is list:
429
- possible_resolutions = grid_pinpoints
430
- else:
431
- possible_resolutions = ast.literal_eval(grid_pinpoints)
432
- best_resolution = select_best_resolution(image.size, possible_resolutions)
433
- image_padded = resize_and_pad_image(image, best_resolution)
434
-
435
- patches = divide_to_patches(image_padded, processor.crop_size['height'])
436
-
437
- image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
438
-
439
- image_patches = [image_original_resize] + patches
440
- image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
441
- for image_patch in image_patches]
442
- return torch.stack(image_patches, dim=0)
443
-
444
-
445
- def load_image_from_base64(image):
446
- return Image.open(BytesIO(base64.b64decode(image)))
447
-
448
-
449
- def expand2square(pil_img, background_color):
450
- width, height = pil_img.size
451
- if width == height:
452
- return pil_img
453
- elif width > height:
454
- result = Image.new(pil_img.mode, (width, width), background_color)
455
- result.paste(pil_img, (0, (width - height) // 2))
456
- return result
457
- else:
458
- result = Image.new(pil_img.mode, (height, height), background_color)
459
- result.paste(pil_img, ((height - width) // 2, 0))
460
- return result
461
-
462
-
463
- def process_images(images, image_processor, model_cfg):
464
- image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
465
- new_images = []
466
- if image_aspect_ratio == 'pad':
467
- for image in images:
468
- image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
469
- image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
470
- new_images.append(image)
471
- elif image_aspect_ratio == "anyres":
472
- for image in images:
473
- image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
474
- new_images.append(image)
475
- else:
476
- return image_processor(images, return_tensors='pt')['pixel_values']
477
- if all(x.shape == new_images[0].shape for x in new_images):
478
- new_images = torch.stack(new_images, dim=0)
479
- return new_images
480
-
481
-
482
- def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
483
- prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('<image>')]
484
-
485
- def insert_separator(X, sep):
486
- return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
487
-
488
- input_ids = []
489
- offset = 0
490
- if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
491
- offset = 1
492
- input_ids.append(prompt_chunks[0][0])
493
-
494
- for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
495
- input_ids.extend(x[offset:])
496
-
497
- if return_tensors is not None:
498
- if return_tensors == 'pt':
499
- return torch.tensor(input_ids, dtype=torch.long)
500
- raise ValueError(f'Unsupported tensor type: {return_tensors}')
501
- return input_ids
502
-
503
-
504
- def get_model_name_from_path(model_path):
505
- model_path = model_path.strip("/")
506
- model_paths = model_path.split("/")
507
- if model_paths[-1].startswith('checkpoint-'):
508
- return model_paths[-2] + "_" + model_paths[-1]
509
- else:
510
- return model_paths[-1]
511
-
512
-
513
- class KeywordsStoppingCriteria(StoppingCriteria):
514
- def __init__(self, keywords, tokenizer, input_ids):
515
- self.keywords = keywords
516
- self.keyword_ids = []
517
- self.max_keyword_len = 0
518
- for keyword in keywords:
519
- cur_keyword_ids = tokenizer(keyword).input_ids
520
- if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
521
- cur_keyword_ids = cur_keyword_ids[1:]
522
- if len(cur_keyword_ids) > self.max_keyword_len:
523
- self.max_keyword_len = len(cur_keyword_ids)
524
- self.keyword_ids.append(torch.tensor(cur_keyword_ids))
525
- self.tokenizer = tokenizer
526
- self.start_len = input_ids.shape[1]
527
-
528
- def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
529
- offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
530
- self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
531
- for keyword_id in self.keyword_ids:
532
- if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
533
- return True
534
- outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
535
- for keyword in self.keywords:
536
- if keyword in outputs:
537
- return True
538
- return False
539
-
540
- def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
541
- outputs = []
542
- for i in range(output_ids.shape[0]):
543
- outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
544
- return all(outputs)
545
-
546
-
547
-
548
- def load_image(image_file):
549
- if image_file.startswith("http") or image_file.startswith("https"):
550
- response = requests.get(image_file)
551
- image = Image.open(BytesIO(response.content)).convert("RGB")
552
- else:
553
- image = Image.open(image_file).convert("RGB")
554
- return image
555
-
556
-
557
- def generate(
558
- prompt: str,
559
- model: str,
560
- tokenizer = None,
561
- image: str = None,
562
- device: str = None,
563
- max_new_tokens: int = 1024,
564
- num_beams = 1,
565
- top_p=None,
566
- temperature=0.2
567
- ):
568
- if not device:
569
- if torch.cuda.is_available() and torch.cuda.device_count():
570
- device = "cuda:0"
571
- logging.warning(
572
- 'inference device is not set, using cuda:0, %s',
573
- torch.cuda.get_device_name(0)
574
- )
575
- else:
576
- device = 'cpu'
577
- logging.warning(
578
- (
579
- 'No CUDA device detected, using cpu, '
580
- 'expect slower speeds.'
581
- )
582
- )
583
-
584
- if 'cuda' in device and not torch.cuda.is_available():
585
- raise ValueError('CUDA device requested but no CUDA device detected.')
586
-
587
- if isinstance(model, str):
588
- checkpoint_path = model
589
- # print(f'loading model from {checkpoint_path}...')
590
- model = AutoModelForCausalLM.from_pretrained(
591
- checkpoint_path,
592
- trust_remote_code=True
593
- )
594
- # print('model load over')
595
- config = model.config
596
- if tokenizer is None:
597
- tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, model_max_length = config.tokenizer_model_max_length,
598
- padding_side = config.tokenizer_padding_side)
599
- image_processor = model.vision_tower._image_processor
600
- context_len = getattr(config, 'max_sequence_length', 2048)
601
- model.to(device).eval()
602
-
603
-
604
- if image is not None:
605
- prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt
606
- conv = conv_phi_v0.copy()
607
- conv.append_message(conv.roles[0], prompt)
608
- conv.append_message(conv.roles[1], None)
609
- prompt = conv.get_prompt()
610
- if image is not None:
611
- # print('loading image...')
612
- image = load_image(image)
613
- # print('load image over')
614
- image_tensor = process_images(image, image_processor, config).to(model.device, dtype=torch.float16)
615
-
616
- input_ids = (
617
- tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
618
- .unsqueeze(0)
619
- .to(model.device, dtype=torch.float16)
620
- )
621
- # Generate
622
- stime = time.time()
623
- # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
624
- # keywords = [stop_str]
625
- # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
626
- # print('start inference...')
627
- with torch.inference_mode():
628
- output_ids = model.generate(
629
- input_ids,
630
- images=image_tensor,
631
- do_sample=True if temperature > 0 else False,
632
- temperature=temperature,
633
- top_p=top_p,
634
- num_beams=num_beams,
635
- pad_token_id=tokenizer.pad_token_id,
636
- max_new_tokens=max_new_tokens,
637
- use_cache=True,
638
- # stopping_criteria=[stopping_criteria],
639
- )
640
-
641
- # print('inference over')
642
- generation_time = time.time() - stime
643
- outputs = tokenizer.batch_decode(
644
- output_ids, skip_special_tokens=True
645
- )[0]
646
- # outputs = outputs.strip()
647
- # if outputs.endswith(stop_str):
648
- # outputs = outputs[: -len(stop_str)]
649
- outputs = outputs.strip()
650
-
651
- return outputs, generation_time
652
- def tinyllava_elm_generate_parser():
653
- """Argument Parser"""
654
-
655
- class KwargsParser(argparse.Action):
656
- """Parser action class to parse kwargs of form key=value"""
657
- def __call__(self, parser, namespace, values, option_string=None):
658
- setattr(namespace, self.dest, dict())
659
- for val in values:
660
- if '=' not in val:
661
- raise ValueError(
662
- (
663
- 'Argument parsing error, kwargs are expected in'
664
- ' the form of key=value.'
665
- )
666
- )
667
- kwarg_k, kwarg_v = val.split('=')
668
- try:
669
- converted_v = int(kwarg_v)
670
- except ValueError:
671
- try:
672
- converted_v = float(kwarg_v)
673
- except ValueError:
674
- converted_v = kwarg_v
675
- getattr(namespace, self.dest)[kwarg_k] = converted_v
676
-
677
- parser = argparse.ArgumentParser('TinyLLaVA-OpenELM Generate Module')
678
- parser.add_argument(
679
- '--model',
680
- dest='model',
681
- help='Path to the hf converted model.',
682
- required=True,
683
- type=str,
684
- )
685
- parser.add_argument(
686
- '--prompt',
687
- dest='prompt',
688
- help='Prompt for LLM call.',
689
- default='',
690
- type=str,
691
- )
692
- parser.add_argument(
693
- '--device',
694
- dest='device',
695
- help='Device used for inference.',
696
- type=str,
697
- )
698
- parser.add_argument("--image", type=str, default=None)
699
- parser.add_argument("--temperature", type=float, default=0)
700
- parser.add_argument("--top_p", type=float, default=None)
701
- parser.add_argument("--num_beams", type=int, default=1)
702
- parser.add_argument("--max_new_tokens", type=int, default=512)
703
- return parser.parse_args()
704
-
705
-
706
- if __name__ == '__main__':
707
- args = tinyllava_elm_generate_parser()
708
-
709
- output_text, genertaion_time = generate(
710
- prompt=args.prompt,
711
- image=args.image,
712
- model=args.model,
713
- device=args.device,
714
- max_new_tokens = args.max_new_tokens,
715
- num_beams = args.num_beams,
716
- top_p=args.top_p,
717
- temperature=args.temperature
718
- )
719
-
720
- print_txt = (
721
- f'\r\n{"=" * os.get_terminal_size().columns}\r\n'
722
- '\033[1m Prompt + Generated Output\033[0m\r\n'
723
- f'{"-" * os.get_terminal_size().columns}\r\n'
724
- f'{output_text}\r\n'
725
- f'{"-" * os.get_terminal_size().columns}\r\n'
726
- '\r\nGeneration took'
727
- f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m'
728
- 'seconds.\r\n'
729
- )
730
- print(print_txt)