K00B404's picture
Update app2.py
ac2217d verified
import gradio as gr
import transformers
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image
import warnings
# Disable warnings and progress bars
transformers.logging.set_verbosity_error()
transformers.logging.disable_progress_bar()
warnings.filterwarnings('ignore')
# Initialize model and tokenizer
def load_model(device='cpu'):
model = AutoModelForCausalLM.from_pretrained(
'qnguyen3/nanoLLaVA',
torch_dtype=torch.float16,
device_map='auto',
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
'qnguyen3/nanoLLaVA',
trust_remote_code=True
)
return model, tokenizer
def generate_caption(image, model, tokenizer):
# Prepare the prompt
prompt = "Describe this image in detail"
messages = [
{"role": "system", "content": "Answer the question"},
{"role": "user", "content": f'<image>\n{prompt}'}
]
# Apply chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Process text and image
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
# Generate caption
output_ids = model.generate(
input_ids,
images=image_tensor,
max_new_tokens=2048,
use_cache=True
)[0]
# Decode the output
caption = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
return caption
def create_persona(caption):
persona_prompt = f"""<|im_start|>system
Role : An entity exactly as described in your image
Background : Your appearance and characteristics match the image description
Personality : Reflect the mood, style, and elements captured in the image
Goal : Interact authentically based on your visual characteristics
You are a character with the following stats:
{caption}
Please stay in character and respond as this entity would,
incorporating visual elements from your description into your responses.<|im_end|>"""
return persona_prompt
def process_image_to_persona(image, model, tokenizer):
if image is None:
return "Please upload an image.", ""
# Convert to PIL Image if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Generate caption from image
caption = generate_caption(image, model, tokenizer)
# Transform caption into persona
persona = create_persona(caption)
return caption, persona
default_system_prompt='''
Your image shows us that you are:
A small, fluffy white pig with a pink nose and small ears,
standing upright. The pig has a long pink tongue, which is also pink in color.
The pig's eyes are open and appear to be looking at the camera.
The pig's fur is fluffy and white, and there are pink and white spots on the fur. The pig's paws are also pink and white,
and they have pink nails.The pig's legs are long and pink. The pig's body is positioned in front of a black background.
'''
def chat(prompt,
system_prompt=default_system_prompt,
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct",
max_tokens=512
):
pipeline = transformers.pipeline(
"text-generation",
model=model_id,
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
outputs = pipeline(
messages,
max_new_tokens=max_tokens,
)
output = outputs[0]["generated_text"][-1]
print(output)
return output