Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from transformers import AutoTokenizer, AutoModel,
|
4 |
|
5 |
# Use the base model's ID
|
6 |
base_model_id = "mistralai/Mistral-7B-v0.1"
|
7 |
-
|
|
|
|
|
8 |
|
9 |
# Load the fine-tuned model "Tonic/mistralmed"
|
10 |
model = AutoModel.from_pretrained("Tonic/mistralmed", config=config)
|
@@ -21,7 +23,7 @@ class ChatBot:
|
|
21 |
new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
|
22 |
flat_history = [item for sublist in self.history for item in sublist]
|
23 |
flat_history_tensor = torch.tensor(flat_history).unsqueeze(dim=0)
|
24 |
-
bot_input_ids = torch.cat([flat_history_tensor, new_user_input_ids], dim=-1) if self
|
25 |
chat_history_ids = model.generate(bot_input_ids, max_length=2000, pad_token_id=tokenizer.eos_token_id)
|
26 |
self.history.append(chat_history_ids[:, bot_input_ids.shape[-1]:].tolist()[0])
|
27 |
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from transformers import AutoTokenizer, AutoModel, BertConfig # Use BertConfig for your Mistral model
|
4 |
|
5 |
# Use the base model's ID
|
6 |
base_model_id = "mistralai/Mistral-7B-v0.1"
|
7 |
+
|
8 |
+
# Create a configuration object specific to the base model (you can replace with your model's actual configuration if available)
|
9 |
+
config = BertConfig()
|
10 |
|
11 |
# Load the fine-tuned model "Tonic/mistralmed"
|
12 |
model = AutoModel.from_pretrained("Tonic/mistralmed", config=config)
|
|
|
23 |
new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")
|
24 |
flat_history = [item for sublist in self.history for item in sublist]
|
25 |
flat_history_tensor = torch.tensor(flat_history).unsqueeze(dim=0)
|
26 |
+
bot_input_ids = torch.cat([flat_history_tensor, new_user_input_ids], dim=-1) if self history else new_user_input_ids
|
27 |
chat_history_ids = model.generate(bot_input_ids, max_length=2000, pad_token_id=tokenizer.eos_token_id)
|
28 |
self.history.append(chat_history_ids[:, bot_input_ids.shape[-1]:].tolist()[0])
|
29 |
response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
|