vpcom commited on
Commit
8cf621f
1 Parent(s): 56d98ec

feat: initial interface inspired by Falcon spaces

Browse files
Files changed (1) hide show
  1. app.py +128 -2
app.py CHANGED
@@ -1,6 +1,132 @@
1
- import gradio as gr
2
  import os
 
 
 
 
 
3
 
4
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- gr.Interface.load("models/DataAnalyticsLab/PersianGPT-FT-Grover", api_key=HF_TOKEN).launch()
 
1
+ import json
2
  import os
3
+ import shutil
4
+ import requests
5
+
6
+ import gradio as gr
7
+ from huggingface_hub import Repository, InferenceClient
8
 
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
+ API_URL = "https://api-inference.huggingface.co/models/DataAnalyticsLab/PersianGPT-FT-Grover"
11
+ BOT_NAME = "PersianGPT-FT"
12
+
13
+ STOP_SEQUENCES = ["<|endoftext|>"]
14
+
15
+ EXAMPLES = [
16
+ ["<$غزل$@بر لبم هر ذره داغی می توان کردن"],
17
+ ["<$غزل$"],
18
+ ["<$قصیده$"],
19
+ ["<$مثنوی$"],
20
+ ["<$غزل$@دراین سرای بی کسی، کسی به در نمی زند"]
21
+ ]
22
+
23
+ client = InferenceClient(
24
+ API_URL,
25
+ headers={"Authorization": f"Bearer {HF_TOKEN}"},
26
+ )
27
+
28
+ def format_prompt(message, history, system_prompt):
29
+ prompt = ""
30
+ if system_prompt:
31
+ prompt += f"System: {system_prompt}\n"
32
+ for user_prompt, bot_response in history:
33
+ prompt += f"User: {user_prompt}\n"
34
+ prompt += f"Falcon: {bot_response}\n" # Response already contains "Falcon: "
35
+ prompt += f"""User: {message}
36
+ Falcon:"""
37
+ return prompt
38
+
39
+ seed = 42
40
+
41
+ def generate(
42
+ prompt, history, system_prompt="<|endoftext|>", temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
43
+ ):
44
+ temperature = float(temperature)
45
+ if temperature < 1e-2:
46
+ temperature = 1e-2
47
+ top_p = float(top_p)
48
+ global seed
49
+ generate_kwargs = dict(
50
+ temperature=temperature,
51
+ max_new_tokens=max_new_tokens,
52
+ top_p=top_p,
53
+ repetition_penalty=repetition_penalty,
54
+ stop_sequences=STOP_SEQUENCES,
55
+ do_sample=True,
56
+ seed=seed,
57
+ )
58
+ seed = seed + 1
59
+ formatted_prompt = format_prompt(prompt, history, system_prompt)
60
+
61
+ stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
62
+ output = ""
63
+
64
+ for response in stream:
65
+ output += response.token.text
66
+
67
+ for stop_str in STOP_SEQUENCES:
68
+ if output.endswith(stop_str):
69
+ output = output[:-len(stop_str)]
70
+ output = output.rstrip()
71
+ yield output
72
+ yield output
73
+ return output
74
+
75
+
76
+ additional_inputs=[
77
+ gr.Textbox("", label="Optional system prompt"),
78
+ gr.Slider(
79
+ label="Temperature",
80
+ value=0.9,
81
+ minimum=0.0,
82
+ maximum=1.0,
83
+ step=0.05,
84
+ interactive=True,
85
+ info="Higher values produce more diverse outputs",
86
+ ),
87
+ gr.Slider(
88
+ label="Max new tokens",
89
+ value=256,
90
+ minimum=0,
91
+ maximum=8192,
92
+ step=64,
93
+ interactive=True,
94
+ info="The maximum numbers of new tokens",
95
+ ),
96
+ gr.Slider(
97
+ label="Top-p (nucleus sampling)",
98
+ value=0.90,
99
+ minimum=0.0,
100
+ maximum=1,
101
+ step=0.05,
102
+ interactive=True,
103
+ info="Higher values sample more low-probability tokens",
104
+ ),
105
+ gr.Slider(
106
+ label="Repetition penalty",
107
+ value=1.2,
108
+ minimum=1.0,
109
+ maximum=2.0,
110
+ step=0.05,
111
+ interactive=True,
112
+ info="Penalize repeated tokens",
113
+ )
114
+ ]
115
+
116
+
117
+ with gr.Blocks() as demo:
118
+ with gr.Row():
119
+ with gr.Column():
120
+ gr.Markdown(
121
+ """
122
+ PERSIAN GPT Trained by Mojtaba Valipour @ Data Analytics Lab
123
+ """
124
+ )
125
+
126
+ gr.ChatInterface(
127
+ generate,
128
+ examples=EXAMPLES,
129
+ additional_inputs=additional_inputs,
130
+ )
131
 
132
+ demo.queue(concurrency_count=100, api_open=False).launch(show_api=False)