Spaces:
Running
Running
import gradio as gr | |
import transformers | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from PIL import Image | |
import warnings | |
# Disable warnings and progress bars | |
transformers.logging.set_verbosity_error() | |
transformers.logging.disable_progress_bar() | |
warnings.filterwarnings('ignore') | |
# Initialize model and tokenizer | |
def load_model(device='cpu'): | |
model = AutoModelForCausalLM.from_pretrained( | |
'qnguyen3/nanoLLaVA', | |
torch_dtype=torch.float16, | |
device_map='auto', | |
trust_remote_code=True | |
) | |
tokenizer = AutoTokenizer.from_pretrained( | |
'qnguyen3/nanoLLaVA', | |
trust_remote_code=True | |
) | |
return model, tokenizer | |
def generate_caption(image, model, tokenizer): | |
# Prepare the prompt | |
prompt = "Describe this image in detail" | |
messages = [ | |
{"role": "system", "content": "Answer the question"}, | |
{"role": "user", "content": f'<image>\n{prompt}'} | |
] | |
# Apply chat template | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Process text and image | |
text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')] | |
input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0) | |
image_tensor = model.process_images([image], model.config).to(dtype=model.dtype) | |
# Generate caption | |
output_ids = model.generate( | |
input_ids, | |
images=image_tensor, | |
max_new_tokens=2048, | |
use_cache=True | |
)[0] | |
# Decode the output | |
caption = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip() | |
return caption | |
def create_persona(caption): | |
persona_prompt = f"""<|im_start|>system | |
Role : An entity exactly as described in your image | |
Background : Your appearance and characteristics match the image description | |
Personality : Reflect the mood, style, and elements captured in the image | |
Goal : Interact authentically based on your visual characteristics | |
You are a character with the following stats: | |
{caption} | |
Please stay in character and respond as this entity would, | |
incorporating visual elements from your description into your responses.<|im_end|>""" | |
return persona_prompt | |
def process_image_to_persona(image, model, tokenizer): | |
if image is None: | |
return "Please upload an image.", "" | |
# Convert to PIL Image if needed | |
if not isinstance(image, Image.Image): | |
image = Image.fromarray(image) | |
# Generate caption from image | |
caption = generate_caption(image, model, tokenizer) | |
# Transform caption into persona | |
persona = create_persona(caption) | |
return caption, persona | |
default_system_prompt=''' | |
Your image shows us that you are: | |
A small, fluffy white pig with a pink nose and small ears, | |
standing upright. The pig has a long pink tongue, which is also pink in color. | |
The pig's eyes are open and appear to be looking at the camera. | |
The pig's fur is fluffy and white, and there are pink and white spots on the fur. The pig's paws are also pink and white, | |
and they have pink nails.The pig's legs are long and pink. The pig's body is positioned in front of a black background. | |
''' | |
def chat(prompt, | |
system_prompt=default_system_prompt, | |
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct", | |
max_tokens=512 | |
): | |
pipeline = transformers.pipeline( | |
"text-generation", | |
model=model_id, | |
model_kwargs={"torch_dtype": torch.bfloat16}, | |
device_map="auto", | |
) | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": prompt}, | |
] | |
outputs = pipeline( | |
messages, | |
max_new_tokens=max_tokens, | |
) | |
output = outputs[0]["generated_text"][-1] | |
print(output) | |
return output | |