Spaces:
Sleeping
Sleeping
import gc | |
import os | |
import gradio as gr | |
from llama_cpp import Llama | |
ALPACA_SYSTEM_PROMPT = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request' | |
ALPACA_SYSTEM_PROMPT_NO_INPUT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' | |
DEFAULT_MODEL = 'Med-Alpaca-2-7b-chat.Q4_K_M' | |
model_paths = { | |
'Med-Alpaca-2-7b-chat.Q2_K': { | |
'repo_id': 'minhnguyent546/Med-Alpaca-2-7b-chat-GGUF', | |
'filename': 'Med-Alpaca-2-7B-chat.Q2_K.gguf', | |
}, | |
'Med-Alpaca-2-7b-chat.Q4_K_M': { | |
'repo_id': 'minhnguyent546/Med-Alpaca-2-7b-chat-GGUF', | |
'filename': 'Med-Alpaca-2-7B-chat.Q4_K_M.gguf', | |
}, | |
'Med-Alpaca-2-7b-chat.Q6_K': { | |
'repo_id': 'minhnguyent546/Med-Alpaca-2-7b-chat-GGUF', | |
'filename': 'Med-Alpaca-2-7B-chat.Q6_K.gguf', | |
}, | |
'Med-Alpaca-2-7b-chat.Q8_0': { | |
'repo_id': 'minhnguyent546/Med-Alpaca-2-7b-chat-GGUF', | |
'filename': 'Med-Alpaca-2-7B-chat.Q8_0.gguf', | |
}, | |
'Med-Alpaca-2-7b-chat.F16': { | |
'repo_id': 'minhnguyent546/Med-Alpaca-2-7b-chat-GGUF', | |
'filename': 'Med-Alpaca-2-7B-chat.F16.gguf', | |
}, | |
} | |
model = None | |
def generate_alpaca_prompt( | |
instruction: str, | |
input: str | None = None, | |
response: str = '', | |
) -> str: | |
prompt = '' | |
if input is not None and input and input.strip() != '<noinput>': | |
prompt = ( | |
f'{ALPACA_SYSTEM_PROMPT}\n\n' | |
f'### Instruction:\n' | |
f'{instruction}\n\n' | |
f'### Input:\n' | |
f'{input}\n\n' | |
f'### Response: ' | |
f'{response}' | |
) | |
else: | |
prompt = ( | |
f'{ALPACA_SYSTEM_PROMPT_NO_INPUT}\n\n' | |
f'### Instruction:\n' | |
f'{instruction}\n\n' | |
f'### Response: ' | |
f'{response}' | |
) | |
return prompt.strip() | |
def chat_completion( | |
message, | |
history, | |
seed: int, | |
max_new_tokens: int, | |
temperature: float, | |
repeatition_penalty: float, | |
top_k: int, | |
top_p: float, | |
): | |
if model is None: | |
reload_model(DEFAULT_MODEL) | |
prompt = generate_alpaca_prompt(instruction=message) | |
response_iterator = model( | |
prompt, | |
stream=True, | |
seed=seed, | |
max_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
repeat_penalty=repeatition_penalty, | |
) | |
partial_response = '' | |
for token in response_iterator: | |
partial_response += token['choices'][0]['text'] | |
yield partial_response | |
def reload_model(model_name: str): | |
global model | |
if 'model' in globals(): | |
del model | |
gc.collect() | |
model = Llama.from_pretrained( | |
**model_paths[model_name], | |
n_ctx=4096, | |
n_threads=4, | |
cache_dir='./.hf_cache' | |
) | |
app_title_mark = gr.Markdown(f"""<center><font size=18>{model_name}</center>""") | |
chatbot = gr.Chatbot( | |
type='messages', | |
height=500, | |
placeholder='<strong>Hi doctor, I have a headache, what should I do?</strong>', | |
label=model_name, | |
avatar_images=[None, 'https://raw.githubusercontent.com/minhnguyent546/medical-llama2/demo/assets/med_alpaca.png'], # pyright: ignore[reportArgumentType] | |
) | |
return app_title_mark, chatbot | |
def main() -> None: | |
with gr.Blocks(theme=gr.themes.Ocean()) as demo: | |
app_title_mark = gr.Markdown(f"""<center><font size=18>{DEFAULT_MODEL}</center>""") | |
model_options = list(model_paths.keys()) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
with gr.Row(): | |
model_radio = gr.Radio(choices=model_options, label='Model', value=DEFAULT_MODEL) | |
with gr.Row(): | |
seed = gr.Number(value=998244353, label='Seed') | |
max_new_tokens = gr.Number(value=512, minimum=64, maximum=2048, label='Max new tokens') | |
with gr.Row(): | |
temperature = gr.Slider(0, 2, step=0.01, label='Temperature', value=0.6) | |
repeatition_penalty = gr.Slider(0.01, 5, step=0.05, label='Repetition penalty', value=1.1) | |
with gr.Row(): | |
top_k = gr.Slider(1, 100, step=1, label='Top k', value=40) | |
top_p = gr.Slider(0, 1, step=0.01, label='Top p', value=0.9) | |
with gr.Column(scale=5): | |
chatbot = gr.Chatbot( | |
type='messages', | |
height=500, | |
placeholder='<strong>Hi doctor, I have a headache, what should I do?</strong>', | |
label=DEFAULT_MODEL, | |
avatar_images=[None, 'https://raw.githubusercontent.com/minhnguyent546/medical-llama2/demo/assets/med_alpaca.png'], # pyright: ignore[reportArgumentType] | |
) | |
textbox = gr.Textbox( | |
placeholder='Hi doctor, I have a headache, what should I do?', | |
container=False, | |
submit_btn=True, | |
stop_btn=True, | |
) | |
chat_interface = gr.ChatInterface( | |
chat_completion, | |
type='messages', | |
chatbot=chatbot, | |
textbox=textbox, | |
additional_inputs=[ | |
seed, | |
max_new_tokens, | |
temperature, | |
repeatition_penalty, | |
top_k, | |
top_p, | |
], | |
) | |
model_radio.change(reload_model, inputs=[model_radio], outputs=[app_title_mark, chatbot]) | |
demo.queue(api_open=False, default_concurrency_limit=20) | |
demo.launch(max_threads=5, share=os.environ.get('GRADIO_SHARE', False)) | |
if __name__ == '__main__': | |
main() | |