Happzy-WHU commited on
Commit
84a6c36
1 Parent(s): 6b39623

first commit.

Browse files
Files changed (4) hide show
  1. README.md +23 -5
  2. V3.py +37 -0
  3. app.py +91 -0
  4. requirements.txt +138 -0
README.md CHANGED
@@ -1,14 +1,32 @@
1
  ---
2
  title: Open O1
3
- emoji: 👁
4
- colorFrom: gray
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  short_description: This is an official demo website for open-o1.
12
  ---
13
 
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Open O1
3
+ emoji: 💬
4
+ colorFrom: yellow
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.36.1
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  short_description: This is an official demo website for open-o1.
12
  ---
13
 
14
+ # open o1 deployment
15
+
16
+ 1. Git clone this repository.
17
+
18
+
19
+ ```shell
20
+ git clone https://huggingface.co/spaces/happzy2633/open-o1
21
+ ```
22
+
23
+ 2. Install the dependencies listed in requirements.txt.
24
+
25
+ 3. Replace the value of the use\_auth\_token variable with your own.
26
+
27
+ 4. Execute the script below.
28
+
29
+
30
+ ```shell
31
+ python app.py
32
+ ```
V3.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from transformers import AutoTokenizer
4
+ from vllm import LLM, SamplingParams
5
+ from huggingface_hub import snapshot_download
6
+
7
+ use_auth_token = os.getenv("YOUR_AUTH_TOKEN")
8
+
9
+ repo_id = "m-a-p/qwen2.5-7b-ins-v3"
10
+ local_dir = repo_id.rsplit("/")[-1]
11
+ snapshot_download(repo_id=repo_id, local_dir=local_dir, use_auth_token=use_auth_token, resume_download=True)
12
+
13
+ model_path = "qwen2.5-7b-ins-v3/checkpoint-1000"
14
+
15
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
16
+ sampling_params = SamplingParams(temperature=0.7, top_p=0.8, repetition_penalty=1.05, max_tokens=8192)
17
+ llm = LLM(model=model_path)
18
+
19
+ def api_call_batch(batch_messages):
20
+ text_list = [
21
+ tokenizer.apply_chat_template(conversation=messages, tokenize=False, add_generation_prompt=True, return_tensors='pt')
22
+ for messages in batch_messages
23
+ ]
24
+ outputs = llm.generate(text_list, sampling_params)
25
+ result = [output.outputs[0].text for output in outputs]
26
+ return result
27
+
28
+ def api_call(messages):
29
+ return api_call_batch([messages])[0]
30
+
31
+ def call_gpt(history, prompt):
32
+ return api_call(history+[{"role":"user", "content":prompt}])
33
+
34
+ if __name__ == "__main__":
35
+ messages = [{"role":"user", "content":"你是谁?"}]
36
+ breakpoint()
37
+ print(api_call_batch([messages]*4))
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from loguru import logger
3
+ from V3 import call_gpt
4
+
5
+ class Conversation:
6
+ def __init__(self, max_history_len=10):
7
+ self.max_history_len = max_history_len
8
+
9
+ def pop_history(self, history):
10
+ while len(history) > self.max_history_len:
11
+ for item in history:
12
+ if item["role"] == "user":
13
+ history.remove(item)
14
+ break
15
+ for item in history:
16
+ if item["role"] == "assistant":
17
+ history.remove(item)
18
+ break
19
+ return history
20
+
21
+ def ask(self, history, prompt):
22
+ history = self.pop_history(history)
23
+ logger.info(history)
24
+ return call_gpt(history, prompt)
25
+
26
+ conv = Conversation()
27
+
28
+ def make_history(system_prompt, qa_list):
29
+ history = [{"role": "system", "content": system_prompt}]
30
+ for q, a in qa_list:
31
+ history.append({"role": "user", "content": q})
32
+ history.append({"role": "assistant", "content": a})
33
+ return history
34
+
35
+ def answer(system_prompt, prompt, history=[]):
36
+ history.append(prompt)
37
+ qa_list = [(u, b) for u, b in zip(history[::2], history[1::2])]
38
+ message = conv.ask(make_history(system_prompt, qa_list), prompt)
39
+
40
+ # 对反引号进行转义
41
+ message = message.replace("`", "\\`")
42
+
43
+ # 包裹为代码块
44
+ message = f"```\n{message}\n```"
45
+
46
+ history.append(message)
47
+
48
+ chatbot_messages = []
49
+ for q, a in qa_list:
50
+ chatbot_messages.append((q, a))
51
+
52
+ chatbot_messages.append((prompt, message))
53
+
54
+ return "", chatbot_messages, history
55
+
56
+ def clear_history(state):
57
+ state.clear()
58
+ return state, []
59
+
60
+ with gr.Blocks(css="#chatbot{height:500px} .overflow-y-auto{height:500px}") as rxbot:
61
+ with gr.Row():
62
+ sys = gr.Textbox(show_label=False, value="You are open-o1, a helpful assistant.")
63
+ chatbot = gr.Chatbot()
64
+ state = gr.State([])
65
+
66
+ with gr.Row():
67
+ txt = gr.Textbox(show_label=False, placeholder="请输入你的问题", max_lines=8)
68
+
69
+ with gr.Row():
70
+ clear_button = gr.Button("🧹Clear History")
71
+ send_button = gr.Button("🚀Send")
72
+
73
+ send_button.click(
74
+ fn=answer,
75
+ inputs=[sys, txt, state],
76
+ outputs=[txt, chatbot, state]
77
+ )
78
+
79
+ txt.submit(
80
+ fn=answer,
81
+ inputs=[sys, txt, state],
82
+ outputs=[txt, chatbot, state]
83
+ )
84
+
85
+ clear_button.click(
86
+ fn=clear_history,
87
+ inputs=[state],
88
+ outputs=[state, chatbot]
89
+ )
90
+
91
+ rxbot.launch()
requirements.txt ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ aiohappyeyeballs==2.3.6
3
+ aiohttp==3.10.3
4
+ aiosignal==1.3.1
5
+ annotated-types==0.7.0
6
+ anyio==4.4.0
7
+ async-timeout==4.0.3
8
+ attrs==24.2.0
9
+ blinker==1.8.2
10
+ certifi==2024.7.4
11
+ charset-normalizer==3.3.2
12
+ click==8.1.7
13
+ cloudpickle==3.0.0
14
+ cmake==3.30.2
15
+ contourpy==1.3.0
16
+ cycler==0.12.1
17
+ datasets==2.21.0
18
+ dill==0.3.8
19
+ diskcache==5.6.3
20
+ distro==1.9.0
21
+ exceptiongroup==1.2.2
22
+ fastapi==0.112.1
23
+ ffmpy==0.4.0
24
+ filelock==3.15.4
25
+ Flask==3.0.3
26
+ Flask-Cors==4.0.1
27
+ fonttools==4.54.1
28
+ frozenlist==1.4.1
29
+ fsspec==2024.6.1
30
+ gradio==4.44.1
31
+ gradio_client==1.3.0
32
+ h11==0.14.0
33
+ httpcore==1.0.5
34
+ httptools==0.6.1
35
+ httpx==0.27.0
36
+ huggingface-hub==0.24.5
37
+ idna==3.7
38
+ importlib_resources==6.4.5
39
+ interegular==0.3.3
40
+ itsdangerous==2.2.0
41
+ Jinja2==3.1.4
42
+ jiter==0.5.0
43
+ jsonschema==4.23.0
44
+ jsonschema-specifications==2023.12.1
45
+ kiwisolver==1.4.7
46
+ lark==1.2.2
47
+ llvmlite==0.43.0
48
+ lm-format-enforcer==0.10.3
49
+ loguru==0.7.2
50
+ markdown-it-py==3.0.0
51
+ MarkupSafe==2.1.5
52
+ matplotlib==3.9.2
53
+ mdurl==0.1.2
54
+ mpmath==1.3.0
55
+ msgpack==1.0.8
56
+ multidict==6.0.5
57
+ multiprocess==0.70.16
58
+ nest-asyncio==1.6.0
59
+ networkx==3.3
60
+ ninja==1.11.1.1
61
+ numba==0.60.0
62
+ numpy==1.26.4
63
+ nvidia-cublas-cu12==12.1.3.1
64
+ nvidia-cuda-cupti-cu12==12.1.105
65
+ nvidia-cuda-nvrtc-cu12==12.1.105
66
+ nvidia-cuda-runtime-cu12==12.1.105
67
+ nvidia-cudnn-cu12==9.1.0.70
68
+ nvidia-cufft-cu12==11.0.2.54
69
+ nvidia-curand-cu12==10.3.2.106
70
+ nvidia-cusolver-cu12==11.4.5.107
71
+ nvidia-cusparse-cu12==12.1.0.106
72
+ nvidia-ml-py==12.560.30
73
+ nvidia-nccl-cu12==2.20.5
74
+ nvidia-nvjitlink-cu12==12.6.20
75
+ nvidia-nvtx-cu12==12.1.105
76
+ openai==1.40.8
77
+ orjson==3.10.7
78
+ outlines==0.0.46
79
+ packaging==24.1
80
+ pandas==2.2.2
81
+ pillow==10.4.0
82
+ prometheus-fastapi-instrumentator==7.0.0
83
+ prometheus_client==0.20.0
84
+ protobuf==5.27.3
85
+ psutil==6.0.0
86
+ py-cpuinfo==9.0.0
87
+ pyairports==2.1.1
88
+ pyarrow==17.0.0
89
+ pycountry==24.6.1
90
+ pydantic==2.8.2
91
+ pydantic_core==2.20.1
92
+ pydub==0.25.1
93
+ pyext==0.7
94
+ Pygments==2.18.0
95
+ pyparsing==3.1.4
96
+ python-dateutil==2.9.0.post0
97
+ python-dotenv==1.0.1
98
+ python-multipart==0.0.12
99
+ pytz==2024.1
100
+ PyYAML==6.0.2
101
+ pyzmq==26.1.0
102
+ ray==2.34.0
103
+ referencing==0.35.1
104
+ regex==2024.7.24
105
+ requests==2.32.3
106
+ rich==13.9.1
107
+ rpds-py==0.20.0
108
+ ruff==0.6.8
109
+ safetensors==0.4.4
110
+ semantic-version==2.10.0
111
+ sentencepiece==0.2.0
112
+ shellingham==1.5.4
113
+ six==1.16.0
114
+ sniffio==1.3.1
115
+ starlette==0.38.2
116
+ sympy==1.13.2
117
+ tiktoken==0.7.0
118
+ tokenizers==0.19.1
119
+ tomlkit==0.12.0
120
+ torch==2.4.0
121
+ torchvision==0.19.0
122
+ tqdm==4.66.5
123
+ transformers==4.44.0
124
+ triton==3.0.0
125
+ typer==0.12.5
126
+ typing_extensions==4.12.2
127
+ tzdata==2024.1
128
+ urllib3==2.2.2
129
+ uvicorn==0.30.6
130
+ uvloop==0.20.0
131
+ vllm==0.5.4
132
+ vllm-flash-attn==2.6.1
133
+ watchfiles==0.23.0
134
+ websockets==12.0
135
+ Werkzeug==3.0.3
136
+ xformers==0.0.27.post2
137
+ xxhash==3.4.1
138
+ yarl==1.9.4