alexkueck commited on
Commit
c84e019
1 Parent(s): 3613d4f

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +14 -0
utils.py CHANGED
@@ -22,6 +22,7 @@ import transformers
22
  from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2Tokenizer, GPT2LMHeadModel
23
  #import auto_gptq
24
  #from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
 
25
 
26
 
27
  def reset_state():
@@ -99,6 +100,19 @@ def load_tokenizer_and_model(base_model,load_8bit=False):
99
  return tokenizer,model,device
100
 
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def load_tokenizer_and_model_gpt2(base_model,load_8bit=False):
103
  if torch.cuda.is_available():
104
  device = "cuda"
 
22
  from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2Tokenizer, GPT2LMHeadModel
23
  #import auto_gptq
24
  #from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
25
+ from transformers import LlamaForCausalLM, LlamaTokenizer
26
 
27
 
28
  def reset_state():
 
100
  return tokenizer,model,device
101
 
102
 
103
+ def load_tokenizer_and_model_Baize(base_model, load_8bit=True):
104
+ if torch.cuda.is_available():
105
+ device = "cuda"
106
+ else:
107
+ device = "cpu"
108
+
109
+
110
+ tokenizer = LlamaTokenizer.from_pretrained(base_model, add_eos_token=True, use_auth_token=True)
111
+ model = LlamaForCausalLM.from_pretrained(base_model, load_in_8bit=True, device_map="auto")
112
+
113
+ return tokenizer,model, device
114
+
115
+
116
  def load_tokenizer_and_model_gpt2(base_model,load_8bit=False):
117
  if torch.cuda.is_available():
118
  device = "cuda"