Tonic commited on
Commit
bf07a1e
1 Parent(s): 1b29238

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -70,23 +70,31 @@ peft_config = PeftConfig.from_pretrained("Tonic/mistralmed", token="hf_dQUWWpJJy
70
  peft_model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", trust_remote_code=True)
71
  peft_model = PeftModel.from_pretrained(peft_model, "Tonic/mistralmed", token="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")
72
  # Remove the memory function
 
 
 
 
 
 
73
  class ChatBot:
74
  def __init__(self):
 
75
  self.history = []
76
 
77
  def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
78
- # Combine user input and system prompt
79
  formatted_input = f"<s>[INST]{system_prompt} {user_input}[/INST]"
80
 
81
- # Encode user input
82
  user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
83
 
84
  # Generate a response using the PEFT model
85
- response = peft_model.generate(input_ids, max_length=512, pad_token_id=tokenizer.eos_token_id)
86
 
87
- # Decode and return the response
88
  response_text = tokenizer.decode(response[0], skip_special_tokens=True)
89
- return response_text
 
90
 
91
  bot = ChatBot()
92
 
 
70
  peft_model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", trust_remote_code=True)
71
  peft_model = PeftModel.from_pretrained(peft_model, "Tonic/mistralmed", token="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")
72
  # Remove the memory function
73
+ # ... (previous code)
74
+
75
+ class ChatBot:
76
+ def __init__(self):
77
+ self.history = []
78
+
79
  class ChatBot:
80
  def __init__(self):
81
+ # Initialize the ChatBot class with an empty history
82
  self.history = []
83
 
84
  def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
85
+ # Combine the user's input with the system prompt
86
  formatted_input = f"<s>[INST]{system_prompt} {user_input}[/INST]"
87
 
88
+ # Encode the formatted input using the tokenizer
89
  user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
90
 
91
  # Generate a response using the PEFT model
92
+ response = peft_model.generate(user_input_ids, max_length=512, pad_token_id=tokenizer.eos_token_id)
93
 
94
+ # Decode the generated response to text
95
  response_text = tokenizer.decode(response[0], skip_special_tokens=True)
96
+
97
+ return response_text # Return the generated response
98
 
99
  bot = ChatBot()
100