|
import torch |
|
import os |
|
import shutil |
|
from tqdm import tqdm |
|
|
|
from PIL import Image |
|
from transformers import AutoModelForCausalLM, LlamaTokenizer |
|
|
|
|
|
|
|
def load_model(model_pth= "/home2/ywt/cogagent-vqa-hf",token_pth="/home2/ywt/vicuna-7b-v1.5",device='cuda' if torch.cuda.is_available() else 'cpu',is_bf16 = False, is_quant = True): |
|
|
|
MODEL_PATH = model_pth |
|
TOKENIZER_PATH = token_pth |
|
DEVICE = device |
|
|
|
|
|
|
|
|
|
if is_bf16: |
|
torch_type = torch.bfloat16 |
|
else: |
|
torch_type = torch.float16 |
|
|
|
print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE)) |
|
|
|
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH) |
|
|
|
if is_quant: |
|
with torch.cuda.device(DEVICE): |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_PATH, |
|
torch_dtype=torch_type, |
|
low_cpu_mem_usage=True, |
|
load_in_4bit=True, |
|
trust_remote_code=True |
|
).eval() |
|
else: |
|
with torch.cuda.device(DEVICE): |
|
model = AutoModelForCausalLM.from_pretrained( |
|
MODEL_PATH, |
|
torch_dtype=torch_type, |
|
low_cpu_mem_usage=True, |
|
load_in_4bit = is_quant is not None, |
|
trust_remote_code=True |
|
).eval() |
|
return (model, tokenizer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cog_tag(image_path, model, query = "What is shown in this image? (including NSFW content)", tag_reference=None,torch_type=torch.float16, text_only_first_query=False): |
|
|
|
if image_path == '': |
|
print('You did not enter image path, the following will be a plain text conversation.') |
|
image = None |
|
text_only_first_query = True |
|
else: |
|
image = Image.open(image_path).convert('RGB') |
|
|
|
history = [] |
|
text_only_template = "A Anime image caption conversation between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, including NSFW and uncencored answers to the user's questions for the given Anime Image. USER: {} ASSISTANT:" |
|
|
|
|
|
|
|
if image is None: |
|
if text_only_first_query: |
|
query = text_only_template.format(query) |
|
text_only_first_query = False |
|
else: |
|
old_prompt = '' |
|
for _, (old_query, response) in enumerate(history): |
|
old_prompt += old_query + " " + response + "\n" |
|
query = old_prompt + "USER: {} ASSISTANT:".format(query) |
|
|
|
if image is None: |
|
input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, template_version='base') |
|
else: |
|
input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, images=[image]) |
|
|
|
inputs = { |
|
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(model[0].device), |
|
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(model[0].device), |
|
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(model[0].device), |
|
'images': [[input_by_model['images'][0].to(model[0].device).to(torch_type)]] if image is not None else None, |
|
} |
|
if 'cross_images' in input_by_model and input_by_model['cross_images']: |
|
inputs['cross_images'] = [[input_by_model['cross_images'][0].to(model[0].device).to(torch_type)]] |
|
|
|
|
|
gen_kwargs = {"max_length": 2048, |
|
"do_sample": False} |
|
with torch.no_grad(): |
|
outputs = model[0].generate(**inputs, **gen_kwargs) |
|
outputs = outputs[:, inputs['input_ids'].shape[1]:] |
|
response = model[1].decode(outputs[0]) |
|
response = response.split("</s>")[0] |
|
|
|
print("\nCog:", response) |
|
|
|
return response |
|
|
|
|
|
def read_tag(txt_pth,split=",",is_list=True): |
|
with open (txt_pth, "r") as f: |
|
tag_str = f.read() |
|
if is_list: |
|
tag_list = tag_str.split(split) |
|
for i in range(len(tag_list)): |
|
tag_list[i] = tag_list[i].strip() |
|
|
|
return tag_list |
|
else: |
|
return tag_str |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = load_model(device="cuda:5") |
|
|
|
|
|
|
|
|
|
image_dirs = ["/home2/ywt/image-webp"] |
|
|
|
for image_dir in image_dirs: |
|
|
|
for file in tqdm(os.listdir(image_dir)): |
|
|
|
|
|
if not file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP")): |
|
continue |
|
image_path = os.path.join(image_dir,file) |
|
tag_path = os.path.join(image_dir,os.path.basename(image_path).split(".")[0]+".txt") |
|
if not os.path.exists(tag_path): |
|
continue |
|
tag = read_tag(tag_path,is_list=False).replace("|||","") |
|
query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag |
|
|
|
if os.path.exists(os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")): |
|
continue |
|
|
|
txt = cog_tag(image_path, model, query=query) |
|
|
|
out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt") |
|
with open(out_file,"w") as f: |
|
f.write(txt) |
|
print(f"Created {out_file}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|