Tonic commited on
Commit
81395fc
·
1 Parent(s): 6d33b71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -16
app.py CHANGED
@@ -1,35 +1,45 @@
1
- from transformers import AutoTokenizer, AutoModelForCausalLM
2
  from peft import PeftModel, PeftConfig
3
- from transformers import AutoModelForCausalLM
4
-
5
  import gradio as gr
6
 
7
  # Use the base model's ID
8
  base_model_id = "mistralai/Mistral-7B-v0.1"
9
  model_directory = "Tonic/mistralmed"
10
 
11
- #instantiate the Models
12
-
13
- config = PeftConfig.from_pretrained("Tonic/mistralmed", token="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")
14
- model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1")
15
- model = PeftModel.from_pretrained(model, "Tonic/mistralmed", token="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")
16
  tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
17
  tokenizer.pad_token = tokenizer.eos_token
18
  tokenizer.padding_side = 'left'
19
 
 
 
 
 
 
20
  class ChatBot:
21
  def __init__(self):
22
  self.history = []
23
 
24
  def predict(self, input):
25
- new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
26
- flat_history = [item for sublist in self.history for item in sublist]
27
- flat_history_tensor = torch.tensor(flat_history).unsqueeze(dim=0)
28
- bot_input_ids = torch.cat([flat_history_tensor, new_user_input_ids], dim=-1) if self.history else new_user_input_ids
29
- chat_history_ids = model.generate(bot_input_ids, max_length=512, pad_token_id=tokenizer.eos_token_id)
30
- self.history.append(chat_history_ids[:, bot_input_ids.shape[-1]:].tolist()[0])
31
- response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
32
- return response
 
 
 
 
 
 
 
 
 
 
33
 
34
  bot = ChatBot()
35
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  from peft import PeftModel, PeftConfig
3
+ import torch
 
4
  import gradio as gr
5
 
6
  # Use the base model's ID
7
  base_model_id = "mistralai/Mistral-7B-v0.1"
8
  model_directory = "Tonic/mistralmed"
9
 
10
+ # Instantiate the Models
 
 
 
 
11
  tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
12
  tokenizer.pad_token = tokenizer.eos_token
13
  tokenizer.padding_side = 'left'
14
 
15
+ # Load the PEFT model
16
+ peft_config = PeftConfig.from_pretrained("Tonic/mistralmed")
17
+ base_model = AutoModelForSeq2SeqLM.from_pretrained(model_directory)
18
+ peft_model = PeftModel.from_pretrained(base_model, "Tonic/mistralmed")
19
+
20
  class ChatBot:
21
  def __init__(self):
22
  self.history = []
23
 
24
  def predict(self, input):
25
+ # Encode user input
26
+ user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
27
+
28
+ # Concatenate the user input with chat history
29
+ if self.history:
30
+ chat_history_ids = torch.cat([self.history, user_input_ids], dim=-1)
31
+ else:
32
+ chat_history_ids = user_input_ids
33
+
34
+ # Generate a response using the PEFT model
35
+ response = peft_model.generate(chat_history_ids, max_length=512, pad_token_id=tokenizer.eos_token_id)
36
+
37
+ # Update chat history
38
+ self.history = response
39
+
40
+ # Decode and return the response
41
+ response_text = tokenizer.decode(response[0], skip_special_tokens=True)
42
+ return response_text
43
 
44
  bot = ChatBot()
45