ikuinen99 commited on
Commit
d58d6b8
1 Parent(s): 192e5fb
bubogpt/configs/models/mmgpt4.yaml CHANGED
@@ -10,7 +10,7 @@ model:
10
  num_query_token: 32
11
 
12
  # Vicuna
13
- llama_model: "vicuna"
14
 
15
  # generation configs
16
  prompt: ""
 
10
  num_query_token: 32
11
 
12
  # Vicuna
13
+ llama_model: "magicr/vicuna-7b"
14
 
15
  # generation configs
16
  prompt: ""
bubogpt/models/mm_gpt4.py CHANGED
@@ -86,10 +86,10 @@ class MMGPT4(BaseModel):
86
  gc.collect()
87
 
88
  print(f'Loading LLAMA from {llama_model}')
89
- self.llama_tokenizer = LlamaTokenizer.from_pretrained('magicr/vicuna-7b', use_fast=False, use_auth_token=True)
90
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
91
 
92
- self.llama_model = LlamaForCausalLM.from_pretrained('magicr/vicuna-7b', load_in_8bit=True,
93
  torch_dtype=torch.float16, device_map="auto", use_auth_token=True)
94
 
95
  if freeze_llm:
 
86
  gc.collect()
87
 
88
  print(f'Loading LLAMA from {llama_model}')
89
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False, use_auth_token=True)
90
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
91
 
92
+ self.llama_model = LlamaForCausalLM.from_pretrained(llama_model, load_in_8bit=True,
93
  torch_dtype=torch.float16, device_map="auto", use_auth_token=True)
94
 
95
  if freeze_llm: