YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

LLaVA-LoRA Adapter

This is a LoRA adapter for the LLaVA model, fine-tuned for spatial description tasks.

Base Model

This adapter is trained on top of llava-hf/llava-1.5-7b-hf.

Training

The model was fine-tuned using LoRA with the following configuration:

  • Rank: 8
  • Alpha: 32
  • Target modules: q_proj, v_proj, k_proj
  • Dataset: PersReFex validation set

Usage

from peft import PeftModel
from transformers import AutoProcessor, LlavaForConditionalGeneration
import torch

# Load base model
base_model = LlavaForConditionalGeneration.from_pretrained(
    "llava-hf/llava-1.5-7b-hf",
    torch_dtype=torch.bfloat16
).to('cuda')
processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")

# Load LoRA adapter
model = PeftModel.from_pretrained(
    base_model,
    "ZinengTang/llava-lora-spatial"
)

from PIL import Image
init_prompt_instruct = "Describe the location of the blue sphere relative to the environment features."
conversation = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": init_prompt_instruct},
            {"type": "image"},  # This will be replaced with the actual image
        ],
    },
]
speaker_image = Image.open('your_image_path')
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
# print(prompt)
# Process the input image and prompt
inputs = processor(
    images=speaker_image,
    text=prompt,
    return_tensors="pt",
    max_length=256,
).to('cuda')

with torch.no_grad():
    generated = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        pixel_values=inputs["pixel_values"],
        max_length=512,
        num_beams=1,
        do_sample=True,
        temperature=0.7
    )
    generated_message = processor.batch_decode(
        generated, 
        skip_special_tokens=True
    )
    print(generated_message)
    generated_message = generated_message[0].split('ASSISTANT: ')[-1][:100]
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.