Aria-torchao-int8wo / inference.py
aria-dev's picture
first version
e83fa52
raw
history blame
1.4 kB
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM
import requests
model_id_or_path = "./"
tokenizer_id_or_path = "./"
model = AutoModelForCausalLM.from_pretrained(
model_id_or_path,
device_map="cuda",
torch_dtype=torch.bfloat16,
trust_remote_code=True,
attn_implementation="flash_attention_2",
)
model = torch.compile(model, mode="max-autotune", fullgraph=True)
messages = [
{
"role": "user",
"content": [
{"text": None, "type": "image"},
{"text": "what's in the image?", "type": "text"},
],
}
]
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
image = Image.open(requests.get(image_path, stream=True).raw)
processor = AutoProcessor.from_pretrained(tokenizer_id_or_path, trust_remote_code=True)
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
out = model.generate(**inputs, max_new_tokens=100, tokenizer=processor.tokenizer, stop_strings=["<|im_end|>"])
output_ids = out[0][inputs["input_ids"].shape[1] :]
result = processor.decode(output_ids, skip_special_tokens=True)
print(result)