AlexandrosChariton commited on
Commit
e198071
1 Parent(s): 8078aec

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +2 -2
README.md CHANGED
@@ -59,7 +59,7 @@ import torch
59
 
60
  # Load the model and processor
61
  model_id = "mistral-community/pixtral-12b"
62
- model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map="mps")
63
  processor = AutoProcessor.from_pretrained(model_id)
64
 
65
  # Load the LoRA configuration
@@ -76,7 +76,7 @@ image = Image.open(image_path)
76
  PROMPT = "<s>[INST]Describe the chess position in the image, piece by piece.[IMG][/INST]"
77
 
78
  # Pass single image instead of list of URLs
79
- inputs = processor(text=PROMPT, images=image, return_tensors="pt").to("mps")
80
  generate_ids = lora_model.generate(**inputs, max_new_tokens=650)
81
  output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
82
  print(output)
 
59
 
60
  # Load the model and processor
61
  model_id = "mistral-community/pixtral-12b"
62
+ model = LlavaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cuda")
63
  processor = AutoProcessor.from_pretrained(model_id)
64
 
65
  # Load the LoRA configuration
 
76
  PROMPT = "<s>[INST]Describe the chess position in the image, piece by piece.[IMG][/INST]"
77
 
78
  # Pass single image instead of list of URLs
79
+ inputs = processor(text=PROMPT, images=image, return_tensors="pt").to("cuda")
80
  generate_ids = lora_model.generate(**inputs, max_new_tokens=650)
81
  output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
82
  print(output)