xiaozhou0822 commited on
Commit
9a0bf3e
1 Parent(s): aec6838

Create aa.py

Browse files
Files changed (1) hide show
  1. aa.py +203 -0
aa.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import shutil
4
+ import requests
5
+
6
+ import gradio as gr
7
+ from huggingface_hub import Repository
8
+ from text_generation import Client
9
+
10
+ from share_btn import community_icon_html, loading_icon_html, share_js, share_btn_css
11
+
12
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
13
+
14
+ API_URL = "https://api-inference.huggingface.co/models/codellama/CodeLlama-13b-hf"
15
+
16
+ FIM_PREFIX = "<PRE> "
17
+ FIM_MIDDLE = " <MID>"
18
+ FIM_SUFFIX = " <SUF>"
19
+
20
+ FIM_INDICATOR = "<FILL_ME>"
21
+
22
+ EOS_STRING = "</s>"
23
+ EOT_STRING = "<EOT>"
24
+
25
+ theme = gr.themes.Monochrome(
26
+ primary_hue="indigo",
27
+ secondary_hue="blue",
28
+ neutral_hue="slate",
29
+ radius_size=gr.themes.sizes.radius_sm,
30
+ font=[
31
+ gr.themes.GoogleFont("Open Sans"),
32
+ "ui-sans-serif",
33
+ "system-ui",
34
+ "sans-serif",
35
+ ],
36
+ )
37
+
38
+ client = Client(
39
+ API_URL,
40
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
41
+ )
42
+
43
+
44
+ def generate(
45
+ prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
46
+ ):
47
+
48
+ temperature = float(temperature)
49
+ if temperature < 1e-2:
50
+ temperature = 1e-2
51
+ top_p = float(top_p)
52
+ fim_mode = False
53
+
54
+ generate_kwargs = dict(
55
+ temperature=temperature,
56
+ max_new_tokens=max_new_tokens,
57
+ top_p=top_p,
58
+ repetition_penalty=repetition_penalty,
59
+ do_sample=True,
60
+ seed=42,
61
+ )
62
+
63
+ if FIM_INDICATOR in prompt:
64
+ fim_mode = True
65
+ try:
66
+ prefix, suffix = prompt.split(FIM_INDICATOR)
67
+ except:
68
+ raise ValueError(f"Only one {FIM_INDICATOR} allowed in prompt!")
69
+ prompt = f"{FIM_PREFIX}{prefix}{FIM_SUFFIX}{suffix}{FIM_MIDDLE}"
70
+
71
+
72
+ stream = client.generate_stream(prompt, **generate_kwargs)
73
+
74
+
75
+ if fim_mode:
76
+ output = prefix
77
+ else:
78
+ output = prompt
79
+
80
+ previous_token = ""
81
+ for response in stream:
82
+ if any([end_token in response.token.text for end_token in [EOS_STRING, EOT_STRING]]):
83
+ if fim_mode:
84
+ output += suffix
85
+ yield output
86
+ return output
87
+ print("output", output)
88
+ else:
89
+ return output
90
+ else:
91
+ output += response.token.text
92
+ previous_token = response.token.text
93
+ yield output
94
+ return output
95
+
96
+
97
+ examples = [
98
+ "X_train, y_train, X_test, y_test = train_test_split(X, y, test_size=0.1)\n\n# Train a logistic regression model, predict the labels on the test set and compute the accuracy score",
99
+ "// Returns every other value in the array as a new array.\nfunction everyOther(arr) {",
100
+ "Poor English: She no went to the market. Corrected English:",
101
+ "def alternating(list1, list2):\n results = []\n for i in range(min(len(list1), len(list2))):\n results.append(list1[i])\n results.append(list2[i])\n if len(list1) > len(list2):\n <FILL_ME>\n else:\n results.extend(list2[i+1:])\n return results",
102
+ "def remove_non_ascii(s: str) -> str:\n \"\"\" <FILL_ME>\nprint(remove_non_ascii('afkdj$$('))",
103
+ ]
104
+
105
+
106
+ def process_example(args):
107
+ for x in generate(args):
108
+ pass
109
+ return x
110
+
111
+
112
+ css = ".generating {visibility: hidden}"
113
+
114
+ monospace_css = """
115
+ #q-input textarea {
116
+ font-family: monospace, 'Consolas', Courier, monospace;
117
+ }
118
+ """
119
+
120
+
121
+ css += share_btn_css + monospace_css + ".gradio-container {color: black}"
122
+
123
+ description = """
124
+ <div style="text-align: center;">
125
+ <h1> 🦙 Code Llama Playground</h1>
126
+ </div>
127
+ <div style="text-align: left;">
128
+ <p>This is a demo to generate text and code with the following <a href="https://huggingface.co/codellama/CodeLlama-13b-hf">Code Llama model (13B)</a>. Please note that this model is not designed for instruction purposes but for code completion. If you're looking for instruction or want to chat with a fine-tuned model, you can use <a href="https://huggingface.co/spaces/codellama/codellama-13b-chat">this demo instead</a>. You can learn more about the model in the <a href="https://huggingface.co/blog/codellama/">blog post</a> or <a href="https://huggingface.co/papers/2308.12950">paper</a></p>
129
+ <p>For a chat demo of the largest Code Llama model (34B parameters), you can now <a href="https://huggingface.co/chat/">select Code Llama in Hugging Chat!</a></p>
130
+ </div>
131
+ """
132
+
133
+ with gr.Blocks(theme=theme, analytics_enabled=False, css=css) as demo:
134
+ with gr.Column():
135
+ gr.Markdown(description)
136
+ with gr.Row():
137
+ with gr.Column():
138
+ instruction = gr.Textbox(
139
+ placeholder="Enter your code here",
140
+ lines=5,
141
+ label="Input",
142
+ elem_id="q-input",
143
+ )
144
+ submit = gr.Button("Generate", variant="primary")
145
+ output = gr.Code(elem_id="q-output", lines=30, label="Output")
146
+ with gr.Row():
147
+ with gr.Column():
148
+ with gr.Accordion("Advanced settings", open=False):
149
+ with gr.Row():
150
+ column_1, column_2 = gr.Column(), gr.Column()
151
+ with column_1:
152
+ temperature = gr.Slider(
153
+ label="Temperature",
154
+ value=0.1,
155
+ minimum=0.0,
156
+ maximum=1.0,
157
+ step=0.05,
158
+ interactive=True,
159
+ info="Higher values produce more diverse outputs",
160
+ )
161
+ max_new_tokens = gr.Slider(
162
+ label="Max new tokens",
163
+ value=256,
164
+ minimum=0,
165
+ maximum=8192,
166
+ step=64,
167
+ interactive=True,
168
+ info="The maximum numbers of new tokens",
169
+ )
170
+ with column_2:
171
+ top_p = gr.Slider(
172
+ label="Top-p (nucleus sampling)",
173
+ value=0.90,
174
+ minimum=0.0,
175
+ maximum=1,
176
+ step=0.05,
177
+ interactive=True,
178
+ info="Higher values sample more low-probability tokens",
179
+ )
180
+ repetition_penalty = gr.Slider(
181
+ label="Repetition penalty",
182
+ value=1.05,
183
+ minimum=1.0,
184
+ maximum=2.0,
185
+ step=0.05,
186
+ interactive=True,
187
+ info="Penalize repeated tokens",
188
+ )
189
+
190
+ gr.Examples(
191
+ examples=examples,
192
+ inputs=[instruction],
193
+ cache_examples=False,
194
+ fn=process_example,
195
+ outputs=[output],
196
+ )
197
+
198
+ submit.click(
199
+ generate,
200
+ inputs=[instruction, temperature, max_new_tokens, top_p, repetition_penalty],
201
+ outputs=[output],
202
+ )
203
+ demo.queue(concurrency_count=16).launch(debug=True)