vojay commited on
Commit
9a14ed3
·
verified ·
1 Parent(s): 41ce050

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -17,7 +17,7 @@ adapter_model_id = "vojay/Llama-2-7b-chat-hf-mental-health"
17
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
18
  model.load_adapter(adapter_model_id)
19
 
20
- tokenizer = AutoTokenizer.from_pretrained(model_id, padding_side="left")
21
 
22
 
23
  def get_base_prompt():
@@ -34,7 +34,7 @@ def format_prompt(base, user_message):
34
 
35
  def predict(input, history=[]):
36
  input = format_prompt(get_base_prompt(), input)
37
- new_user_input_ids = tokenizer.encode(f"{input}{tokenizer.eos_token}", return_tensors="pt")
38
  bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
39
 
40
  history = model.generate(
 
17
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
18
  model.load_adapter(adapter_model_id)
19
 
20
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
21
 
22
 
23
  def get_base_prompt():
 
34
 
35
  def predict(input, history=[]):
36
  input = format_prompt(get_base_prompt(), input)
37
+ new_user_input_ids = tokenizer.encode(f"{tokenizer.eos_token}{input}", return_tensors="pt")
38
  bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)
39
 
40
  history = model.generate(