File size: 3,662 Bytes
a3c3064
 
927b5de
 
a3c3064
 
 
63a0917
 
 
 
 
 
 
 
 
a3c3064
 
 
 
 
 
 
63a0917
1d6dfdf
a3c3064
 
 
 
 
 
 
 
 
 
 
 
 
b999a69
a3c3064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154b81c
 
a3c3064
 
 
 
 
 
 
 
 
63a0917
a3c3064
 
63a0917
 
a3c3064
27d5e20
63a0917
8de5029
a3c3064
8de5029
1874bf4
a3c3064
1874bf4
 
a3c3064
63a0917
1874bf4
edc6972
 
927b5de
1874bf4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
from peft import PeftModel, PeftConfig
import torch
import gradio as gr
import random
from textwrap import wrap


title = "๐Ÿ‘‹๐Ÿปํ† ๋‹‰์˜ ๋ฏธ์ŠคํŠธ๋ž„๋ฉ”๋“œ ์ฑ„ํŒ…์— ์˜ค์‹  ๊ฒƒ์„ ํ™˜์˜ํ•ฉ๋‹ˆ๋‹ค๐Ÿš€๐Ÿ‘‹๐ŸปWelcome to Tonic's MistralMed Chat๐Ÿš€"
description = "์ด ๊ณต๊ฐ„์„ ์‚ฌ์šฉํ•˜์—ฌ ํ˜„์žฌ ๋ชจ๋ธ์„ ํ…Œ์ŠคํŠธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. [(Tonic/MistralMed)](https://huggingface.co/Tonic/MistralMed) ๋˜๋Š” ์ด ๊ณต๊ฐ„์„ ๋ณต์ œํ•˜๊ณ  ๋กœ์ปฌ ๋˜๋Š” ๐Ÿค—HuggingFace์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. [Discord์—์„œ ํ•จ๊ป˜ ๋งŒ๋“ค๊ธฐ ์œ„ํ•ด Discord์— ๊ฐ€์ž…ํ•˜์‹ญ์‹œ์˜ค](https://discord.gg/VqTxc76K3u). You can use this Space to test out the current model [(Tonic/MistralMed)](https://huggingface.co/Tonic/MistralMed) or duplicate this Space and use it locally or on ๐Ÿค—HuggingFace. [Join me on Discord to build together](https://discord.gg/VqTxc76K3u)."
examples = [["[Question:] What is the proper treatment for buccal herpes?", "You are a medicine and public health expert, you will receive a question, answer the question, and provide a complete answer"]]

base_model_id = "mistralai/Mistral-7B-v0.1"
model_directory = "Tonic/mistralmed"
device = "cuda" if torch.cuda.is_available() else "cpu"

def wrap_text(text, width=90):
    lines = text.split('\n')
    wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
    wrapped_text = '\n'.join(wrapped_lines)
    return wrapped_text

def multimodal_prompt(user_input, system_prompt="You are an expert medical analyst:"):

    formatted_input = f"<s>[INST]{system_prompt} {user_input}[/INST]</s>"

    encodeds = tokenizer(formatted_input, return_tensors="pt", add_special_tokens=False)
    model_inputs = encodeds.to(device)

    output = model.generate(
        **model_inputs,
        max_length=max_length,
        use_cache=True,
        early_stopping=True,
        bos_token_id=model.config.bos_token_id,
        eos_token_id=model.config.eos_token_id,
        pad_token_id=model.config.eos_token_id,
        temperature=0.1,
        do_sample=False
    )

    # Decode the response
    response_text = tokenizer.decode(output[0], skip_special_tokens=True)

    return response_text


tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", trust_remote_code=True, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

peft_config = PeftConfig.from_pretrained("Tonic/mistralmed")
peft_model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", trust_remote_code=True)
peft_model = PeftModel.from_pretrained(peft_model, "Tonic/mistralmed")
peft_model = peft_model.to(torch.bfloat16)
peft_model = peft_model.to(device)

class ChatBot:
    def __init__(self):
        self.history = []

class ChatBot:
    def __init__(self):
        self.history = []

def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
        formatted_input = f"<s>[INST]{system_prompt} {user_input}[/INST]"
        user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")
        user_input_ids = user_input_ids.to(device)
        response = peft_model.generate(input_ids=user_input_ids, max_length=256, pad_token_id=tokenizer.eos_token_id)
        response_text = tokenizer.decode(response[0], skip_special_tokens=True)
        
        return response_text  

bot = ChatBot()

iface = gr.Interface(
    fn=bot.predict,
    title=title,
    description=description,
    examples=examples,
    inputs=["text", "text"], 
    outputs="text",
    theme="ParityError/Anime"
)

iface.launch()