Tonic commited on
Commit
5ab0bbc
·
1 Parent(s): 02f3e50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -25
app.py CHANGED
@@ -18,27 +18,16 @@ class OrcaChatBot:
18
  def __init__(self, model, tokenizer, system_message="You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."):
19
  self.model = model
20
  self.tokenizer = tokenizer
21
- self.system_message = system_message
22
- self.conversation_history = []
23
 
24
- def update_conversation_history(self, user_message, assistant_message):
25
- self.conversation_history.append(("user", user_message))
26
- self.conversation_history.append(("assistant", assistant_message))
27
-
28
-
29
- def format_prompt(self):
30
- prompt = f"<|im_start|>assistant\n{self.system_message}<|im_end|>\n"
31
- for role, message in self.conversation_history:
32
- if message.strip():
33
- prompt += f"<|im_start|>{role}\n{message}<|im_end|>\n"
34
- # if role == "assistant":
35
- # prompt += f"<|im_end|>\n"
36
- prompt += "<|im_start|> assistant\n"
37
  return prompt
38
 
39
- def predict(self, user_message, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
40
- self.update_conversation_history(user_message, "")
41
- prompt = self.format_prompt()
42
  inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
43
  input_ids = inputs["input_ids"].to(self.model.device)
44
 
@@ -48,19 +37,17 @@ class OrcaChatBot:
48
  temperature=temperature,
49
  top_p=top_p,
50
  repetition_penalty=repetition_penalty,
51
- # pad_token_id=self.tokenizer.eos_token_id,
52
  do_sample=True
53
- )
54
 
55
  response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
56
- self.update_conversation_history("", response)
57
  return response
58
-
59
- Orca_bot = OrcaChatBot(model, tokenizer)
60
 
61
  def gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty):
62
- full_message = f"{system_message}\n{user_message}" if system_message else user_message
63
- return Orca_bot.predict(full_message, temperature, max_new_tokens, top_p, repetition_penalty)
 
 
64
 
65
  iface = gr.Interface(
66
  fn=gradio_predict,
 
18
  def __init__(self, model, tokenizer, system_message="You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."):
19
  self.model = model
20
  self.tokenizer = tokenizer
21
+ self.default_system_message = system_message
 
22
 
23
+ def format_prompt(self, user_message, system_message):
24
+ if system_message is None:
25
+ system_message = self.default_system_message
26
+ prompt = f"<|im_start|>assistant\n{self.system_message}<|im_end|>\n<|im_start|>\nuser\n{user_message}<|im_end|>\nassistant\n"
 
 
 
 
 
 
 
 
 
27
  return prompt
28
 
29
+ def predict(self, user_message, system_message=None, temperature=0.4, max_new_tokens=70, top_p=0.99, repetition_penalty=1.9):
30
+ prompt = self.format_prompt(user_message, system_message)
 
31
  inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
32
  input_ids = inputs["input_ids"].to(self.model.device)
33
 
 
37
  temperature=temperature,
38
  top_p=top_p,
39
  repetition_penalty=repetition_penalty,
 
40
  do_sample=True
41
+ )
42
 
43
  response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
44
  return response
 
 
45
 
46
  def gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty):
47
+ response = Orca_bot.predict(user_message, system_message, temperature, max_new_tokens, top_p, repetition_penalty)
48
+ return response
49
+
50
+ Orca_bot = OrcaChatBot(model, tokenizer)
51
 
52
  iface = gr.Interface(
53
  fn=gradio_predict,