wangrongsheng commited on
Commit
bb2872c
·
1 Parent(s): 2ea34ca

init deploy code

Browse files
Files changed (3) hide show
  1. app.py +143 -0
  2. requirements.txt +9 -0
  3. style.css +16 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from threading import Thread
3
+ from typing import Iterator
4
+
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
+
10
+ MAX_MAX_NEW_TOKENS = 2048
11
+ DEFAULT_MAX_NEW_TOKENS = 1024
12
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
13
+
14
+ DESCRIPTION = """\
15
+ # Machine Mindset
16
+
17
+ MM (Machine_Mindset) series models are developed through a collaboration between FarReel AI Lab(formerly known as the ChatLaw project) and Peking University's Deep Research Institute. These models are large-scale language models for various MBTI types in both Chinese and English, built on the Baichuan and LLaMA2 platforms.
18
+ """
19
+
20
+ LICENSE = """
21
+
22
+ ---
23
+ * Our code adheres to the Apache 2.0 open-source license. Please refer to the [LICENSE](https://github.com/PKU-YuanGroup/Machine-Mindset/blob/main/LICENSE) for specific details of the open-source agreement.
24
+
25
+ * Our model weights are subject to an open-source agreement based on the original weights, with specific details provided in the Chinese version under the baichuan open-source license. For commercial use, please refer to [model_LICENSE](https://huggingface.co/JessyTsu1/Machine_Mindset_zh_INTP/resolve/main/Machine_Mindset%E5%9F%BA%E4%BA%8Ebaichuan%E7%9A%84%E6%A8%A1%E5%9E%8B%E7%A4%BE%E5%8C%BA%E8%AE%B8%E5%8F%AF%E5%8D%8F%E8%AE%AE.pdf) for further information.
26
+
27
+ * The English version follows the open-source agreement under the [llama2 license](https://ai.meta.com/resources/models-and-libraries/llama-downloads/).
28
+ """
29
+
30
+ if not torch.cuda.is_available():
31
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
32
+
33
+
34
+ if torch.cuda.is_available():
35
+ model_id = "FarReelAILab/Machine_Mindset_en_INTJ"
36
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
37
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
38
+ tokenizer.use_default_system_prompt = False
39
+
40
+
41
+ @spaces.GPU
42
+ def generate(
43
+ message: str,
44
+ chat_history: list[tuple[str, str]],
45
+ system_prompt: str,
46
+ max_new_tokens: int = 1024,
47
+ temperature: float = 0.6,
48
+ top_p: float = 0.9,
49
+ top_k: int = 50,
50
+ repetition_penalty: float = 1.2,
51
+ ) -> Iterator[str]:
52
+ conversation = []
53
+ if system_prompt:
54
+ conversation.append({"role": "system", "content": system_prompt})
55
+ for user, assistant in chat_history:
56
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
57
+ conversation.append({"role": "user", "content": message})
58
+
59
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
60
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
61
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
62
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
63
+ input_ids = input_ids.to(model.device)
64
+
65
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
66
+ generate_kwargs = dict(
67
+ {"input_ids": input_ids},
68
+ streamer=streamer,
69
+ max_new_tokens=max_new_tokens,
70
+ do_sample=True,
71
+ top_p=top_p,
72
+ top_k=top_k,
73
+ temperature=temperature,
74
+ num_beams=1,
75
+ repetition_penalty=repetition_penalty,
76
+ )
77
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
78
+ t.start()
79
+
80
+ outputs = []
81
+ for text in streamer:
82
+ outputs.append(text)
83
+ yield "".join(outputs)
84
+
85
+
86
+ chat_interface = gr.ChatInterface(
87
+ fn=generate,
88
+ additional_inputs=[
89
+ gr.Textbox(label="System prompt", lines=6),
90
+ gr.Slider(
91
+ label="Max new tokens",
92
+ minimum=1,
93
+ maximum=MAX_MAX_NEW_TOKENS,
94
+ step=1,
95
+ value=DEFAULT_MAX_NEW_TOKENS,
96
+ ),
97
+ gr.Slider(
98
+ label="Temperature",
99
+ minimum=0.1,
100
+ maximum=4.0,
101
+ step=0.1,
102
+ value=0.6,
103
+ ),
104
+ gr.Slider(
105
+ label="Top-p (nucleus sampling)",
106
+ minimum=0.05,
107
+ maximum=1.0,
108
+ step=0.05,
109
+ value=0.9,
110
+ ),
111
+ gr.Slider(
112
+ label="Top-k",
113
+ minimum=1,
114
+ maximum=1000,
115
+ step=1,
116
+ value=50,
117
+ ),
118
+ gr.Slider(
119
+ label="Repetition penalty",
120
+ minimum=1.0,
121
+ maximum=2.0,
122
+ step=0.05,
123
+ value=1.2,
124
+ ),
125
+ ],
126
+ stop_btn=None,
127
+ examples=[
128
+ ["Hello there! How are you doing?"],
129
+ ["Can you explain briefly to me what is the Python programming language?"],
130
+ ["Explain the plot of Cinderella in a sentence."],
131
+ ["How many hours does it take a man to eat a Helicopter?"],
132
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
133
+ ],
134
+ )
135
+
136
+ with gr.Blocks(css="style.css") as demo:
137
+ gr.Markdown(DESCRIPTION)
138
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
139
+ chat_interface.render()
140
+ gr.Markdown(LICENSE)
141
+
142
+ if __name__ == "__main__":
143
+ demo.queue(max_size=20).launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.23.0
2
+ bitsandbytes==0.41.1
3
+ gradio==3.48.0
4
+ protobuf==3.20.3
5
+ scipy==1.11.2
6
+ sentencepiece==0.1.99
7
+ spaces==0.16.1
8
+ torch==2.0.0
9
+ transformers==4.34.0
style.css ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
4
+
5
+ #duplicate-button {
6
+ margin: auto;
7
+ color: white;
8
+ background: #1565c0;
9
+ border-radius: 100vh;
10
+ }
11
+
12
+ .contain {
13
+ max-width: 900px;
14
+ margin: auto;
15
+ padding-top: 1.5rem;
16
+ }