ikuinen99 commited on
Commit
19aa79a
1 Parent(s): 1f3b768
Files changed (1) hide show
  1. bubogpt/models/mm_gpt4.py +2 -3
bubogpt/models/mm_gpt4.py CHANGED
@@ -87,9 +87,8 @@ class MMGPT4(BaseModel):
87
  self.llama_tokenizer = LlamaTokenizer.from_pretrained('magicr/vicuna-7b', use_fast=False, use_auth_token=True)
88
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
89
 
90
- self.llama_model = LlamaForCausalLM.from_pretrained(
91
- 'magicr/vicuna-7b', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto", use_auth_token=True
92
- )
93
 
94
  if freeze_llm:
95
  for name, param in self.llama_model.named_parameters():
 
87
  self.llama_tokenizer = LlamaTokenizer.from_pretrained('magicr/vicuna-7b', use_fast=False, use_auth_token=True)
88
  self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
89
 
90
+ self.llama_model = LlamaForCausalLM.from_pretrained('magicr/vicuna-7b', load_in_8bit=True,
91
+ torch_dtype=torch.float16, device_map="auto", use_auth_token=True)
 
92
 
93
  if freeze_llm:
94
  for name, param in self.llama_model.named_parameters():