File size: 1,118 Bytes
8a4a948
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import argparse
import torch
from torchvision.transforms import ToPILImage
from PIL import Image

def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of MoMA.")
    parser.add_argument("--load_attn_adapters",type=str,default="checkpoints/attn_adapters_projectors.th",help="self_cross attentions and LLM projectors.")
    parser.add_argument("--output_path",type=str,default="output",help="output directory.")
    parser.add_argument("--model_path",type=str,default="KunpengSong/MoMA_llava_7b",help="fine tuned llava (Multi-modal LLM decoder)")
    args = parser.parse_known_args()[0]
    args.device = torch.device("cuda", 0)
    args.load_8bit, args.load_4bit = False, True
    return args

def show_PIL_image(tensor):
    # tensor of shape [3, 3, 512, 512]
    to_pil = ToPILImage()
    images = [to_pil(tensor[i]) for i in range(tensor.shape[0])]

    concatenated_image = Image.new('RGB', (images[0].width * 3, images[0].height))
    x_offset = 0
    for img in images:
        concatenated_image.paste(img, (x_offset, 0))
        x_offset += img.width

    return concatenated_image