CMLL commited on
Commit
d71ad7e
1 Parent(s): 220ce3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -38
app.py CHANGED
@@ -5,22 +5,22 @@ from typing import Iterator
5
  import gradio as gr
6
  import spaces
7
  import torch
8
- from transformers import pipeline, AutoTokenizer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
- # ZhongJing 2 1.8B Merge
16
- This Space demonstrates model [CMLL/ZhongJing-2-1_8b-merge](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge) for text generation. Feel free to play with it, or duplicate to run generations without a queue! If you want to run your own service, you can also [deploy the model on Inference Endpoints](https://huggingface.co/inference-endpoints).
17
  """
18
 
19
  LICENSE = """
20
  <p/>
21
  ---
22
- As a derivative work of [CMLL/ZhongJing-2-1_8b-merge](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge),
23
- this demo is governed by the original [license](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge/LICENSE).
24
  """
25
 
26
  if not torch.cuda.is_available():
@@ -28,7 +28,7 @@ if not torch.cuda.is_available():
28
 
29
  if torch.cuda.is_available():
30
  model_id = "CMLL/ZhongJing-2-1_8b-merge"
31
- pipe = pipeline("text-generation", model=model_id)
32
  tokenizer = AutoTokenizer.from_pretrained(model_id)
33
  tokenizer.use_default_system_prompt = False
34
 
@@ -36,50 +36,50 @@ if torch.cuda.is_available():
36
  def generate(
37
  message: str,
38
  chat_history: list[tuple[str, str]],
39
- system_prompt: str = "You are a helpful TCM medical assistant named 仲景中医大语言模型, created by 医哲未来.",
40
  max_new_tokens: int = 1024,
41
  temperature: float = 0.6,
42
  top_p: float = 0.9,
43
  top_k: int = 50,
44
  repetition_penalty: float = 1.2,
45
  ) -> Iterator[str]:
46
- conversation = [{"role": "system", "content": system_prompt}]
 
 
47
  for user, assistant in chat_history:
48
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
49
  conversation.append({"role": "user", "content": message})
50
 
51
- input_text = "\n".join([f"{entry['role']}: {entry['content']}" for entry in conversation])
 
 
 
 
52
 
53
- generate_kwargs = {
54
- "max_new_tokens": max_new_tokens,
55
- "do_sample": True,
56
- "top_p": top_p,
57
- "top_k": top_k,
58
- "temperature": temperature,
59
- "repetition_penalty": repetition_penalty,
60
- }
 
 
 
 
 
 
61
 
62
- # Function to run the generation
63
- def run_generation():
64
- try:
65
- results = pipe(input_text, **generate_kwargs)
66
- return results
67
- except Exception as e:
68
- return [f"Error in generation: {e}"]
69
-
70
- # Run generation in a separate thread and wait for it to finish
71
  outputs = []
72
- generation_thread = Thread(target=lambda: outputs.extend(run_generation()))
73
- generation_thread.start()
74
- generation_thread.join()
75
-
76
- for output in outputs:
77
- yield output['generated_text'] if isinstance(output, dict) else output
78
 
79
  chat_interface = gr.ChatInterface(
80
  fn=generate,
81
  additional_inputs=[
82
- gr.Textbox(label="System prompt", lines=6, value="You are a helpful TCM medical assistant named 仲景中医大语言模型, created by 医哲未来."),
83
  gr.Slider(
84
  label="Max new tokens",
85
  minimum=1,
@@ -118,11 +118,11 @@ chat_interface = gr.ChatInterface(
118
  ],
119
  stop_btn=None,
120
  examples=[
121
- ["Hello there! How are you doing?"],
122
- ["Can you explain briefly to me what is the Python programming language?"],
123
- ["Explain the plot of Cinderella in a sentence."],
124
- ["How many hours does it take a man to eat a Helicopter?"],
125
- ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
126
  ],
127
  )
128
 
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  MAX_MAX_NEW_TOKENS = 2048
11
  DEFAULT_MAX_NEW_TOKENS = 1024
12
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
+ # ZhongJing-2-1_8b-merge
16
+ This Space demonstrates model [ZhongJing-2-1_8b-merge](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge) by CMLL, a powerful model for TCM-related applications. Feel free to play with it, or duplicate to run generations without a queue!
17
  """
18
 
19
  LICENSE = """
20
  <p/>
21
  ---
22
+ As a derivate work of [ZhongJing-2-1_8b-merge](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge) by CMLL,
23
+ this demo is governed by the original [license](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/CMLL/ZhongJing-2-1_8b-merge/blob/main/USE_POLICY.md).
24
  """
25
 
26
  if not torch.cuda.is_available():
 
28
 
29
  if torch.cuda.is_available():
30
  model_id = "CMLL/ZhongJing-2-1_8b-merge"
31
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
32
  tokenizer = AutoTokenizer.from_pretrained(model_id)
33
  tokenizer.use_default_system_prompt = False
34
 
 
36
  def generate(
37
  message: str,
38
  chat_history: list[tuple[str, str]],
39
+ system_prompt: str,
40
  max_new_tokens: int = 1024,
41
  temperature: float = 0.6,
42
  top_p: float = 0.9,
43
  top_k: int = 50,
44
  repetition_penalty: float = 1.2,
45
  ) -> Iterator[str]:
46
+ conversation = []
47
+ if system_prompt:
48
+ conversation.append({"role": "system", "content": system_prompt})
49
  for user, assistant in chat_history:
50
  conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
51
  conversation.append({"role": "user", "content": message})
52
 
53
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
54
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
55
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
56
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
57
+ input_ids = input_ids.to(model.device)
58
 
59
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
60
+ generate_kwargs = dict(
61
+ {"input_ids": input_ids},
62
+ streamer=streamer,
63
+ max_new_tokens=max_new_tokens,
64
+ do_sample=True,
65
+ top_p=top_p,
66
+ top_k=top_k,
67
+ temperature=temperature,
68
+ num_beams=1,
69
+ repetition_penalty=repetition_penalty,
70
+ )
71
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
72
+ t.start()
73
 
 
 
 
 
 
 
 
 
 
74
  outputs = []
75
+ for text in streamer:
76
+ outputs.append(text)
77
+ yield "".join(outputs)
 
 
 
78
 
79
  chat_interface = gr.ChatInterface(
80
  fn=generate,
81
  additional_inputs=[
82
+ gr.Textbox(label="System prompt", lines=6),
83
  gr.Slider(
84
  label="Max new tokens",
85
  minimum=1,
 
118
  ],
119
  stop_btn=None,
120
  examples=[
121
+ ["你是谁?"],
122
+ ["你能简要解释一下什么是中医吗?"],
123
+ ["简述《黄帝内经》的主要内容。"],
124
+ ["中医如何治疗失眠?"],
125
+ ["写一篇关于‘AI在中医研究中的应用’的100字文章。"],
126
  ],
127
  )
128