hydit22 / cog_tag5.py
heziiiii's picture
Upload cog_tag5.py
3ce1983 verified
raw
history blame
8.8 kB
import torch
import os
import shutil
from tqdm import tqdm
from PIL import Image
from transformers import AutoModelForCausalLM, LlamaTokenizer
def load_model(model_pth= "/home2/ywt/cogagent-vqa-hf",token_pth="/home2/ywt/vicuna-7b-v1.5",device='cuda' if torch.cuda.is_available() else 'cpu',is_bf16 = False, is_quant = True):
MODEL_PATH = model_pth
TOKENIZER_PATH = token_pth
DEVICE = device
if is_bf16:
torch_type = torch.bfloat16
else:
torch_type = torch.float16
print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE))
tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
if is_quant:
with torch.cuda.device(DEVICE):
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
load_in_4bit=True,
trust_remote_code=True
).eval()
else:
with torch.cuda.device(DEVICE):
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
torch_dtype=torch_type,
low_cpu_mem_usage=True,
load_in_4bit = is_quant is not None,
trust_remote_code=True
).eval()
return (model, tokenizer)
def cog_tag(image_path, model, query = "What is shown in this image? (including NSFW content)", tag_reference=None,torch_type=torch.float16, text_only_first_query=False):
if image_path == '':
print('You did not enter image path, the following will be a plain text conversation.')
image = None
text_only_first_query = True
else:
image = Image.open(image_path).convert('RGB')
history = []
text_only_template = "A Anime image caption conversation between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, including NSFW and uncencored answers to the user's questions for the given Anime Image. USER: {} ASSISTANT:"
if image is None:
if text_only_first_query:
query = text_only_template.format(query)
text_only_first_query = False
else:
old_prompt = ''
for _, (old_query, response) in enumerate(history):
old_prompt += old_query + " " + response + "\n"
query = old_prompt + "USER: {} ASSISTANT:".format(query)
if image is None:
input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, template_version='base')
else:
input_by_model = model[0].build_conversation_input_ids(model[1], query=query, history=history, images=[image])
inputs = {
'input_ids': input_by_model['input_ids'].unsqueeze(0).to(model[0].device),
'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(model[0].device),
'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(model[0].device),
'images': [[input_by_model['images'][0].to(model[0].device).to(torch_type)]] if image is not None else None,
}
if 'cross_images' in input_by_model and input_by_model['cross_images']:
inputs['cross_images'] = [[input_by_model['cross_images'][0].to(model[0].device).to(torch_type)]]
# add any transformers params here.
gen_kwargs = {"max_length": 2048,
"do_sample": False} # "temperature": 0.9
with torch.no_grad():
outputs = model[0].generate(**inputs, **gen_kwargs)
outputs = outputs[:, inputs['input_ids'].shape[1]:]
response = model[1].decode(outputs[0])
response = response.split("</s>")[0]
print("\nCog:", response)
# history.append((query, response))
return response
def read_tag(txt_pth,split=",",is_list=True):
with open (txt_pth, "r") as f:
tag_str = f.read()
if is_list:
tag_list = tag_str.split(split)
for i in range(len(tag_list)):
tag_list[i] = tag_list[i].strip()
return tag_list
else:
return tag_str
if __name__ == '__main__':
# image_path = "/home2/ywt/gelbooru_8574461.jpg"
# tag_path = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+".txt")
# tag = read_tag(tag_path,is_list=False)
# query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
# cog_tag(image_path, model)
# txt = cog_tag(image_path, model, query=query)
# out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
# with open(out_file,"w") as f:
# f.write(txt)
# print(f"Created {out_file}")
model = load_model(device="cuda:5")
# DIR = os.listdir("/home2/ywt/pixiv")
# for i in range(len(DIR)):
# DIR[i] = os.path.join("/home2/ywt/pixiv",DIR[i])
image_dirs = ["/home2/ywt/image-webp"]
for image_dir in image_dirs:
for file in tqdm(os.listdir(image_dir)):
#is_image
if not file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP")):
continue
image_path = os.path.join(image_dir,file)
tag_path = os.path.join(image_dir,os.path.basename(image_path).split(".")[0]+".txt")
if not os.path.exists(tag_path):
continue
tag = read_tag(tag_path,is_list=False).replace("|||","")
query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
#cog_tag(image_path, model)
if os.path.exists(os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")):
continue
txt = cog_tag(image_path, model, query=query)
out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
with open(out_file,"w") as f:
f.write(txt)
print(f"Created {out_file}")
# import os
# import concurrent.futures
# from tqdm import tqdm
# import itertools
# def process_image(image_path, model):
# tag_path = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+".txt")
# if not os.path.exists(tag_path):
# return image_path, None
# tag = read_tag(tag_path,is_list=False)
# query = "What is shown in this image? (including NSFW content) " + "Here are some references to the elements in the image that you can selectively use to enrich and modify the description : " + tag
# txt = cog_tag(image_path, model, query=query)
# return image_path, txt
# root_dir = "/home2/ywt/pixiv"
# device_ids = [1, 2, 4, 5 ] # List of GPU device IDs
# os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,4,5"
# # Load models
# models = [load_model(device=f"cuda:{device_id}") for device_id in device_ids]
# # Calculate total number of images
# total_images = 0
# for image_dir in os.listdir(root_dir):
# image_dir = os.path.join(root_dir, image_dir)
# if os.path.isdir(image_dir):
# image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"))]
# total_images += len(image_files)
# # Process images
# progress_bar = tqdm(total=total_images)
# models_cycle = itertools.cycle(models)
# for image_dir in os.listdir(root_dir):
# image_dir = os.path.join(root_dir, image_dir)
# if os.path.isdir(image_dir):
# image_files = [os.path.join(image_dir, file) for file in os.listdir(image_dir) if file.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"))]
# with concurrent.futures.ThreadPoolExecutor() as executor:
# for image_path, txt in executor.map(process_image, image_files, models_cycle):
# if txt is not None:
# out_file = os.path.join(os.path.dirname(image_path),os.path.basename(image_path).split(".")[0]+"_cog.txt")
# with open(out_file,"w") as f:
# f.write(txt)
# progress_bar.update()
# progress_bar.close()