Tonic commited on
Commit
237d9d2
·
1 Parent(s): 4ffc0ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -44
app.py CHANGED
@@ -1,4 +1,4 @@
1
- import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
 
@@ -10,48 +10,41 @@ bnb_config = BitsAndBytesConfig(
10
  bnb_4bit_compute_dtype=torch.bfloat16
11
  )
12
 
13
- model = AutoModelForCausalLM.from_pretrained(base_model_id, quantization_config=bnb_config)
14
- base_model = AutoModelForCausalLM.from_pretrained(
15
- base_model_id, # Mistral, same as before
16
- quantization_config=bnb_config, # Same quantization config as before
17
- device_map="auto",
18
- trust_remote_code=True,
19
- use_auth_token=api_token
20
- )
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True)
23
- tokenizer.pad_token = tokenizer.eos_tokentokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
24
- tokenizer.padding_side = 'left'
25
- model = PeftModel.from_pretrained(base_model, "Tonic/mistralmed")
26
-
27
- class ChatBot:
28
- def __init__(self):
29
- self.history = []
30
-
31
- def predict(self, input):
32
- new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
33
- flat_history = [item for sublist in self.history for item in sublist]
34
- flat_history_tensor = torch.tensor(flat_history).unsqueeze(dim=0) # convert list to 2-D tensor
35
- bot_input_ids = torch.cat([flat_history_tensor, new_user_input_ids], dim=-1) if self.history else new_user_input_ids
36
- chat_history_ids = model.generate(bot_input_ids, max_length=2000, pad_token_id=tokenizer.eos_token_id)
37
- self.history.append(chat_history_ids[:, bot_input_ids.shape[-1]:].tolist()[0])
38
- response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
39
- return response
40
-
41
- bot = ChatBot()
42
-
43
- title = "👋🏻Welcome to Tonic's EZ Chat🚀"
44
- description = "You can use this Space to test out the current model (MistralMed) or duplicate this Space and use it for any other model on 🤗HuggingFace. Join me on [Discord](https://discord.gg/fpEPNZGsbt) to build together."
45
- examples = [["What is the boiling point of nitrogen?"]]
46
-
47
- iface = gr.Interface(
48
- fn=bot.predict,
49
- title=title,
50
- description=description,
51
- examples=examples,
52
- inputs="text",
53
- outputs="text",
54
  theme="ParityError/Anime"
55
- )
56
-
57
- iface.launch()
 
1
+ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
4
 
 
10
  bnb_4bit_compute_dtype=torch.bfloat16
11
  )
12
 
13
+ # Load the fine-tuned model "Tonic/mistralmed"
14
+ model = AutoModelForCausalLM.from_pretrained("Tonic/mistralmed", quantization_config=bnb_config)
15
+
16
+ tokenizer = AutoTokenizer.from_pretrained("Tonic/mistralmed", trust_remote_code=True)
17
+ tokenizer.pad_token = tokenizer.eos_token
18
+ tokenizer.padding_side = 'left'
19
+
20
+ class ChatBot:
21
+ def __init__(self):
22
+ self.history = []
23
+
24
+ def predict(self, input):
25
+ new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
26
+ flat_history = [item for sublist in self.history for item in sublist]
27
+ flat_history_tensor = torch.tensor(flat_history).unsqueeze(dim=0)
28
+ bot_input_ids = torch.cat([flat_history_tensor, new_user_input_ids], dim=-1) if self.history else new_user_input_ids
29
+ chat_history_ids = model.generate(bot_input_ids, max_length=2000, pad_token_id=tokenizer.eos_token_id)
30
+ self.history.append(chat_history_ids[:, bot_input_ids.shape[-1]:].tolist()[0])
31
+ response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
32
+ return response
33
 
34
+ bot = ChatBot()
35
+
36
+ title = "👋🏻Welcome to Tonic's EZ Chat🚀"
37
+ description = "You can use this Space to test out the current model (MistralMed) or duplicate this Space and use it for any other model on 🤗HuggingFace. Join me on [Discord](https://discord.gg/fpEPNZGsbt) to build together."
38
+ examples = [["What is the boiling point of nitrogen"]]
39
+
40
+ iface = gr.Interface(
41
+ fn=bot.predict,
42
+ title=title,
43
+ description=description,
44
+ examples=examples,
45
+ inputs="text",
46
+ outputs="text",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  theme="ParityError/Anime"
48
+ )
49
+
50
+ iface.launch()