Spaces:
Paused
Paused
import random | |
import time | |
import torch | |
import gradio as gr | |
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM | |
from peft import PeftModel, PeftConfig | |
from textwrap import wrap, fill | |
# Functions to Wrap the Prompt Correctly | |
def wrap_text(text, width=90): | |
lines = text.split('\n') | |
wrapped_lines = [fill(line, width=width) for line in lines] | |
wrapped_text = '\n'.join(wrapped_lines) | |
return wrapped_text | |
def multimodal_prompt(user_input, system_prompt): | |
""" | |
Generates text using a large language model, given a user input and a system prompt. | |
Args: | |
user_input: The user's input text to generate a response for. | |
system_prompt: Optional system prompt. | |
Returns: | |
A string containing the generated text in the Falcon-like format. | |
""" | |
# Combine user input and system prompt | |
formatted_input = f"{{{{ {system_prompt} }}}}\nUser: {user_input}\nFalcon:" | |
# Encode the input text | |
encodeds = tokenizer(formatted_input, return_tensors="pt", add_special_tokens=False) | |
model_inputs = encodeds.to(device) | |
# Generate a response using the model | |
output = peft_model.generate( | |
**model_inputs, | |
max_length=500, | |
use_cache=True, | |
early_stopping=False, | |
bos_token_id=peft_model.config.bos_token_id, | |
eos_token_id=peft_model.config.eos_token_id, | |
pad_token_id=peft_model.config.eos_token_id, | |
temperature=0.4, | |
do_sample=True | |
) | |
# Decode the response | |
response_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return response_text | |
class ChatbotInterface(): | |
def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."): | |
self.name = name | |
self.system_prompt = system_prompt | |
self.chatbot = gr.Chatbot() | |
self.chat_history = [] | |
with gr.Row() as row: | |
row.justify = "end" | |
self.msg = gr.Textbox(scale=7) | |
#self.msg.change(fn=, inputs=, outputs=) | |
self.submit = gr.Button("Submit", scale=1) | |
clear = gr.ClearButton([self.msg, self.chatbot]) | |
chat_history = [] | |
self.submit.click(self.respond, [self.msg, self.chatbot], [self.msg, self.chatbot]) | |
def respond(self, msg, history): | |
#bot_message = random.choice(["Hello, I'm MedChat! How can I help you?", "Hello there! I'm Medchat, a medical assistant! How can I help you?"]) | |
formatted_input = f"{{{{ {self.system_prompt} }}}}\nUser: {msg}\n{self.name}:" | |
input_ids = tokenizer.encode( | |
formatted_input, | |
return_tensors="pt", | |
add_special_tokens=False | |
) | |
response = peft_model.generate( | |
input_ids=input_ids, | |
max_length=900, | |
use_cache=False, | |
early_stopping=False, | |
bos_token_id=peft_model.config.bos_token_id, | |
eos_token_id=peft_model.config.eos_token_id, | |
pad_token_id=peft_model.config.eos_token_id, | |
temperature=0.4, | |
do_sample=True | |
) | |
response_text = tokenizer.decode(response[0], skip_special_tokens=True) | |
self.chat_history.append([formatted_input, response_text]) | |
return "", self.chat_history | |
if __name__ == "__main__": | |
# Define the device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Use the base model's ID | |
base_model_id = "tiiuae/falcon-7b-instruct" | |
model_directory = "Tonic/GaiaMiniMed" | |
# Instantiate the Tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, padding_side="left") | |
# Specify the configuration class for the model | |
model_config = AutoConfig.from_pretrained(base_model_id) | |
# Load the PEFT model with the specified configuration | |
peft_model = AutoModelForCausalLM.from_pretrained(model_directory, config=model_config) | |
peft_model = PeftModel.from_pretrained(peft_model, model_directory) | |
with gr.Blocks() as demo: | |
with gr.Row() as intro: | |
gr.Markdown( | |
""" | |
# MedChat: Your Medical Assistant Chatbot | |
Welcome to MedChat, your friendly medical assistant chatbot! ๐ฉบ | |
Dive into a world of medical expertise where you can interact with three specialized chatbots, all trained on the latest and most comprehensive medical dataset. Whether you have health-related questions, need medical advice, or just want to learn more about your well-being, MedChat is here to help! | |
## How it Works | |
Simply type your medical query or concern, and let MedChat's advanced algorithms provide you with accurate and reliable responses. | |
## Explore and Compare | |
Feel like experimenting? Click the **Submit to All** button and witness the magic as all three chatbots compete to provide you with the best possible answer! It's a unique opportunity to compare the insights from different models and choose the one that suits your needs the best. | |
_Ready to get started? Type your question and let's begin!_ | |
""" | |
) | |
with gr.Row() as row: | |
with gr.Column() as col1: | |
with gr.Tab("GaiaMinimed") as gaia: | |
gaia_bot = ChatbotInterface("GaiaMinimed") | |
with gr.Column() as col2: | |
with gr.Tab("MistralMed") as mistral: | |
mistral_bot = ChatbotInterface("MistralMed") | |
with gr.Tab("Falcon-7B") as falcon7b: | |
falcon_bot = ChatbotInterface("Falcon-7B") | |
gaia_bot.msg.change(fn=lambda s: (s[::1], s[::1]), inputs=gaia_bot.msg, outputs=[mistral_bot.msg, falcon_bot.msg]) | |
mistral_bot.msg.change(fn=lambda s: (s[::1], s[::1]), inputs=mistral_bot.msg, outputs=[gaia_bot.msg, falcon_bot.msg]) | |
falcon_bot.msg.change(fn=lambda s: (s[::1], s[::1]), inputs=falcon_bot.msg, outputs=[gaia_bot.msg, mistral_bot.msg]) | |
demo.launch() |