ZhongJingGPT / app.py
CMLL's picture
Update app.py
220ce3a verified
raw
history blame
4.59 kB
import os
from threading import Thread
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import pipeline, AutoTokenizer
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
DESCRIPTION = """\
# ZhongJing 2 1.8B Merge
This Space demonstrates model [CMLL/ZhongJing-2-1_8b-merge](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge) for text generation. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
"""
LICENSE = """
<p/>
---
As a derivative work of [CMLL/ZhongJing-2-1_8b-merge](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge),
this demo is governed by the original [license](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge/LICENSE).
"""
if not torch.cuda.is_available():
DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
if torch.cuda.is_available():
model_id = "CMLL/ZhongJing-2-1_8b-merge"
pipe = pipeline("text-generation", model=model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
system_prompt: str = "You are a helpful TCM medical assistant named 仲景中医大语言模型, created by 医哲未来.",
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9,
top_k: int = 50,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
conversation = [{"role": "system", "content": system_prompt}]
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_text = "\n".join([f"{entry['role']}: {entry['content']}" for entry in conversation])
generate_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": True,
"top_p": top_p,
"top_k": top_k,
"temperature": temperature,
"repetition_penalty": repetition_penalty,
}
# Function to run the generation
def run_generation():
try:
results = pipe(input_text, **generate_kwargs)
return results
except Exception as e:
return [f"Error in generation: {e}"]
# Run generation in a separate thread and wait for it to finish
outputs = []
generation_thread = Thread(target=lambda: outputs.extend(run_generation()))
generation_thread.start()
generation_thread.join()
for output in outputs:
yield output['generated_text'] if isinstance(output, dict) else output
chat_interface = gr.ChatInterface(
fn=generate,
additional_inputs=[
gr.Textbox(label="System prompt", lines=6, value="You are a helpful TCM medical assistant named 仲景中医大语言模型, created by 医哲未来."),
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.6,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
stop_btn=None,
examples=[
["Hello there! How are you doing?"],
["Can you explain briefly to me what is the Python programming language?"],
["Explain the plot of Cinderella in a sentence."],
["How many hours does it take a man to eat a Helicopter?"],
["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
],
)
with gr.Blocks(css="style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
chat_interface.render()
gr.Markdown(LICENSE)
if __name__ == "__main__":
demo.queue(max_size=20).launch()