ZhongJingGPT / app.py
CMLL's picture
Update app.py
dd1d2c2 verified
raw
history blame
6.27 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# FulPhil-仲景中医大语言模型-ZhongJingGPT-V2-1_8b-First LLM in TCM
致敬仲景先师,融汇古典与智能。本模型由医哲未来(FulPhil)研发,是专注于中医药领域的强大语言模型。它能够运用传统中医理论,结合现代人工智能技术,为中医研究和应用提供卓越助力。欢迎关注我们的 [GitHub主页](https://github.com/pariskang/CMLM-ZhongJing) 及模型 [ZhongJing-2-1_8b](https://huggingface.co/CMLM/ZhongJing-2-1_8b) 下载体验!
Paying tribute to the ancient master Zhang Zhongjing, this model integrates classical knowledge with modern intelligence. Developed by FulPhil (Future Medicine Philosophy), it is a powerful language model focused on the field of Traditional Chinese Medicine (TCM). It employs traditional TCM theories combined with contemporary artificial intelligence technology to provide excellent support for TCM research and applications. Welcome to visit our [GitHub homepage](https://github.com/pariskang/CMLM-ZhongJing) and download the model [ZhongJing-2-1_8b-merge](https://huggingface.co/CMLM/ZhongJing-2-1_8b) for a trial experience!
请注意!!!本模型不得用于任何医疗或潜在具有医疗或康养建议的任何场景,目前仍为科研测试阶段,敬请帮我们提出宝贵意见,谢谢。
Please note!!! This model should not be used for any medical purposes or scenarios potentially involving medical or health advice. It is currently still in the research and testing stage. We sincerely request your valuable feedback. Thank you.
"""
LICENSE = """
<p/>
---
As a derivate work of [ZhongJing-2-1_8b](https://huggingface.co/CMLM/ZhongJing-2-1_8b) by FulPhil,
this demo is governed by the original [license](https://huggingface.co/CMLM/ZhongJing-2-1_8b/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/CMLM/ZhongJing-2-1_8b/blob/main/USE_POLICY.md).
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "CMLL/ZhongJing-2-1_8b-merge"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str = "You are a helpful TCM assistant named 仲景中医大语言模型, created by 医哲未来. You can switch between Chinese and English based on user preference.",
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = [{"role": "system", "content": system_prompt}]
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6, value="You are a helpful TCM assistant named 仲景中医大语言模型, created by 医哲未来. You can switch between Chinese and English based on user preference."),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["你能简要解释一下什么是中医吗?"],
["简述《黄帝内经》的主要内容。"],
["中医如何治疗失眠?"],
["我发热,咳嗽,咽痛,舌苔黄腻,脉滑数,请给出中医诊断及处方?"],
["写一篇关于‘AI在中医研究中的应用’的100字文章。"],
["写一篇从中医角度关于‘秋季女性健康调养方案‘的1000字科普文章,从季节变化、饮食调理、活动养生等方面进行阐述"],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch()