wangrongsheng commited on
Commit
16a199e
1 Parent(s): fe7fe2a

del two models

Browse files
Files changed (1) hide show
  1. app.py +35 -79
app.py CHANGED
@@ -13,17 +13,13 @@ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
 
14
  DESCRIPTION = """\
15
  # Machine Mindset
16
-
17
  MM (Machine_Mindset) series models are developed through a collaboration between FarReel AI Lab(formerly known as the ChatLaw project) and Peking University's Deep Research Institute. These models are large-scale language models for various MBTI types in both Chinese and English, built on the Baichuan and LLaMA2 platforms.
18
  """
19
 
20
  LICENSE = """
21
-
22
  ---
23
  * Our code adheres to the Apache 2.0 open-source license. Please refer to the [LICENSE](https://github.com/PKU-YuanGroup/Machine-Mindset/blob/main/LICENSE) for specific details of the open-source agreement.
24
-
25
  * Our model weights are subject to an open-source agreement based on the original weights, with specific details provided in the Chinese version under the baichuan open-source license. For commercial use, please refer to [model_LICENSE](https://huggingface.co/JessyTsu1/Machine_Mindset_zh_INTP/resolve/main/Machine_Mindset%E5%9F%BA%E4%BA%8Ebaichuan%E7%9A%84%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) for further information.
26
-
27
  * The English version follows the open-source agreement under the [llama2 license](https://ai.meta.com/resources/models-and-libraries/llama-downloads/).
28
  """
29
 
@@ -36,16 +32,10 @@ if torch.cuda.is_available():
36
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
37
  tokenizer = AutoTokenizer.from_pretrained(model_id)
38
  tokenizer.use_default_system_prompt = False
39
-
40
- model_id_zh = "FarReelAILab/Machine_Mindset_zh_INTJ"
41
- model_zh = AutoModelForCausalLM.from_pretrained(model_id_zh, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True)
42
- tokenizer_zh = AutoTokenizer.from_pretrained(model_id_zh, trust_remote_code=True)
43
- tokenizer_zh.use_default_system_prompt = False
44
 
45
 
46
  @spaces.GPU
47
  def generate(
48
- select_model: str,
49
  message: str,
50
  chat_history: list[tuple[str, str]],
51
  system_prompt: str,
@@ -55,78 +45,43 @@ def generate(
55
  top_k: int = 50,
56
  repetition_penalty: float = 1.2,
57
  ) -> Iterator[str]:
58
- if select_model=="INTJ-en":
59
- conversation = []
60
- if system_prompt:
61
- conversation.append({"role": "system", "content": system_prompt})
62
- for user, assistant in chat_history:
63
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
64
- conversation.append({"role": "user", "content": message})
65
-
66
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
67
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
68
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
69
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
70
- input_ids = input_ids.to(model.device)
71
-
72
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
73
- generate_kwargs = dict(
74
- {"input_ids": input_ids},
75
- streamer=streamer,
76
- max_new_tokens=max_new_tokens,
77
- do_sample=True,
78
- top_p=top_p,
79
- top_k=top_k,
80
- temperature=temperature,
81
- num_beams=1,
82
- repetition_penalty=repetition_penalty,
83
- )
84
- t = Thread(target=model.generate, kwargs=generate_kwargs)
85
- t.start()
86
-
87
- outputs = []
88
- for text in streamer:
89
- outputs.append(text)
90
- yield "".join(outputs)
91
-
92
- if select_model=="INTJ-zh":
93
- conversation = []
94
- if system_prompt:
95
- conversation.append({"role": "system", "content": system_prompt})
96
- for user, assistant in chat_history:
97
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
98
- conversation.append({"role": "user", "content": message})
99
-
100
- input_ids = tokenizer_zh.apply_chat_template(conversation, return_tensors="pt")
101
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
102
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
103
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
104
- input_ids = input_ids.to(model_zh.device)
105
-
106
- streamer = TextIteratorStreamer(tokenizer_zh, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
107
- generate_kwargs = dict(
108
- {"input_ids": input_ids},
109
- streamer=streamer,
110
- max_new_tokens=max_new_tokens,
111
- do_sample=True,
112
- top_p=top_p,
113
- top_k=top_k,
114
- temperature=temperature,
115
- num_beams=1,
116
- repetition_penalty=repetition_penalty,
117
- )
118
- t = Thread(target=model_zh.generate, kwargs=generate_kwargs)
119
- t.start()
120
-
121
- outputs = []
122
- for text in streamer:
123
- outputs.append(text)
124
- yield "".join(outputs)
125
 
126
  chat_interface = gr.ChatInterface(
127
  fn=generate,
128
  additional_inputs=[
129
- gr.Dropdown(choices=["INTJ-en", "INTJ-zh"], value="INTJ-en", label="Select Model"),
130
  gr.Textbox(label="System prompt", lines=6),
131
  gr.Slider(
132
  label="Max new tokens",
@@ -170,6 +125,7 @@ chat_interface = gr.ChatInterface(
170
  ["Can you explain briefly to me what is the Python programming language?"],
171
  ["Explain the plot of Cinderella in a sentence."],
172
  ["How many hours does it take a man to eat a Helicopter?"],
 
173
  ],
174
  )
175
 
@@ -180,4 +136,4 @@ with gr.Blocks(css="style.css") as demo:
180
  gr.Markdown(LICENSE)
181
 
182
  if __name__ == "__main__":
183
- demo.queue(max_size=20).launch()
 
13
 
14
  DESCRIPTION = """\
15
  # Machine Mindset
 
16
  MM (Machine_Mindset) series models are developed through a collaboration between FarReel AI Lab(formerly known as the ChatLaw project) and Peking University's Deep Research Institute. These models are large-scale language models for various MBTI types in both Chinese and English, built on the Baichuan and LLaMA2 platforms.
17
  """
18
 
19
  LICENSE = """
 
20
  ---
21
  * Our code adheres to the Apache 2.0 open-source license. Please refer to the [LICENSE](https://github.com/PKU-YuanGroup/Machine-Mindset/blob/main/LICENSE) for specific details of the open-source agreement.
 
22
  * Our model weights are subject to an open-source agreement based on the original weights, with specific details provided in the Chinese version under the baichuan open-source license. For commercial use, please refer to [model_LICENSE](https://huggingface.co/JessyTsu1/Machine_Mindset_zh_INTP/resolve/main/Machine_Mindset%E5%9F%BA%E4%BA%8Ebaichuan%E7%9A%84%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) for further information.
 
23
  * The English version follows the open-source agreement under the [llama2 license](https://ai.meta.com/resources/models-and-libraries/llama-downloads/).
24
  """
25
 
 
32
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
33
  tokenizer = AutoTokenizer.from_pretrained(model_id)
34
  tokenizer.use_default_system_prompt = False
 
 
 
 
 
35
 
36
 
37
  @spaces.GPU
38
  def generate(
 
39
  message: str,
40
  chat_history: list[tuple[str, str]],
41
  system_prompt: str,
 
45
  top_k: int = 50,
46
  repetition_penalty: float = 1.2,
47
  ) -> Iterator[str]:
48
+ conversation = []
49
+ if system_prompt:
50
+ conversation.append({"role": "system", "content": system_prompt})
51
+ for user, assistant in chat_history:
52
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
53
+ conversation.append({"role": "user", "content": message})
54
+
55
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
56
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
57
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
58
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
59
+ input_ids = input_ids.to(model.device)
60
+
61
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
62
+ generate_kwargs = dict(
63
+ {"input_ids": input_ids},
64
+ streamer=streamer,
65
+ max_new_tokens=max_new_tokens,
66
+ do_sample=True,
67
+ top_p=top_p,
68
+ top_k=top_k,
69
+ temperature=temperature,
70
+ num_beams=1,
71
+ repetition_penalty=repetition_penalty,
72
+ )
73
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
74
+ t.start()
75
+
76
+ outputs = []
77
+ for text in streamer:
78
+ outputs.append(text)
79
+ yield "".join(outputs)
80
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
  chat_interface = gr.ChatInterface(
83
  fn=generate,
84
  additional_inputs=[
 
85
  gr.Textbox(label="System prompt", lines=6),
86
  gr.Slider(
87
  label="Max new tokens",
 
125
  ["Can you explain briefly to me what is the Python programming language?"],
126
  ["Explain the plot of Cinderella in a sentence."],
127
  ["How many hours does it take a man to eat a Helicopter?"],
128
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
129
  ],
130
  )
131
 
 
136
  gr.Markdown(LICENSE)
137
 
138
  if __name__ == "__main__":
139
+ demo.queue(max_size=20).launch()