ikuinen99 commited on
Commit
1f3b768
1 Parent(s): 6d68957
Files changed (1) hide show
  1. bubogpt/models/mm_gpt4.py +2 -3
bubogpt/models/mm_gpt4.py CHANGED
@@ -84,12 +84,11 @@ class MMGPT4(BaseModel):
84
  print('Loading ImageBind Done')
85
 
86
  print(f'Loading LLAMA from {llama_model}')
87
- self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
88
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
89
 
90
  self.llama_model = LlamaForCausalLM.from_pretrained(
91
- llama_model,
92
- torch_dtype=torch.float16,
93
  )
94
 
95
  if freeze_llm:
 
84
  print('Loading ImageBind Done')
85
 
86
  print(f'Loading LLAMA from {llama_model}')
87
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained('magicr/vicuna-7b', use_fast=False, use_auth_token=True)
88
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
89
 
90
  self.llama_model = LlamaForCausalLM.from_pretrained(
91
+ 'magicr/vicuna-7b', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=True
 
92
  )
93
 
94
  if freeze_llm: