Staticaliza commited on
Commit
c02118d
1 Parent(s): 120bd05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -12
app.py CHANGED
@@ -2,12 +2,15 @@ import gradio as gr
2
  from huggingface_hub import Repository, InferenceClient
3
  import os
4
  import json
 
5
 
6
  API_TOKEN = os.environ.get("API_TOKEN")
7
  API_ENDPOINT = os.environ.get("API_ENDPOINT")
8
 
9
  KEY = os.environ.get("KEY")
10
 
 
 
11
  API_ENDPOINTS = {
12
  "Falcon": "tiiuae/falcon-180B-chat",
13
  "Llama": "meta-llama/Llama-2-70b-chat-hf"
@@ -20,25 +23,24 @@ for model_name, model_endpoint in API_ENDPOINTS.items():
20
  CHOICES.append(model_name)
21
  CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })
22
 
23
- def format(input, chat_history, : str) -> str:
24
- instructions = instructions.strip(" ").strip("\n")
25
- prompt = instructions
26
- for turn in chat_history:
27
- user_message, bot_message = turn
28
- prompt = f"{prompt}\n{USER_NAME}: {user_message}\n{BOT_NAME}: {bot_message}"
29
- prompt = f"{prompt}\n{USER_NAME}: {message}\n{BOT_NAME}:"
30
  return prompt
31
 
32
- def predict(instruction, history, input, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):
33
 
34
  if (access_key != KEY):
35
  print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
36
  return ("[UNAUTHORIZED ACCESS]", input);
37
 
38
  stops = json.loads(stop_seqs)
 
 
39
 
40
  response = CLIENTS[model].text_generation(
41
- input,
42
  temperature = temperature,
43
  max_new_tokens = max_tokens,
44
  top_p = top_p,
@@ -52,9 +54,15 @@ def predict(instruction, history, input, access_key, model, temperature, top_p,
52
  return_full_text = False
53
  )
54
 
55
- print(f"---\nUSER: {input}\nBOT: {response}\n---")
 
 
 
 
 
 
56
 
57
- return (response, input)
58
 
59
  def maintain_cloud():
60
  print(">>> SPACE MAINTAINED!")
@@ -68,6 +76,7 @@ with gr.Blocks() as demo:
68
  with gr.Column():
69
  history = gr.Chatbot(elem_id = "chatbot")
70
  input = gr.Textbox(label = "Input", lines = 2)
 
71
  instruction = gr.Textbox(label = "Instruction", lines = 4)
72
  access_key = gr.Textbox(label = "Access Key", lines = 1)
73
  run = gr.Button("▶")
@@ -87,7 +96,7 @@ with gr.Blocks() as demo:
87
  with gr.Column():
88
  output = gr.Textbox(label = "Output", value = "", lines = 50)
89
 
90
- run.click(predict, inputs = [instruction, history, input, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input])
91
  cloud.click(maintain_cloud, inputs = [], outputs = [input, output])
92
 
93
  demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True)
 
2
  from huggingface_hub import Repository, InferenceClient
3
  import os
4
  import json
5
+ import re
6
 
7
  API_TOKEN = os.environ.get("API_TOKEN")
8
  API_ENDPOINT = os.environ.get("API_ENDPOINT")
9
 
10
  KEY = os.environ.get("KEY")
11
 
12
+ SPECIAL_SYMBOLS = ["‹", "›"]
13
+
14
  API_ENDPOINTS = {
15
  "Falcon": "tiiuae/falcon-180B-chat",
16
  "Llama": "meta-llama/Llama-2-70b-chat-hf"
 
23
  CHOICES.append(model_name)
24
  CLIENTS[model_name] = InferenceClient(model_endpoint, headers = { "Authorization": f"Bearer {API_TOKEN}" })
25
 
26
+ def format(instruction = "", history = "", input = "", preinput = ""):
27
+ sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
28
+ formatted_history = '\n'.join(f"{{sy_l}{message}{sy_r}" for message in history)
29
+ task_message = f"{instruction}\n{formatted_history}\n{sy_l}{input}{sy_r}\n{preinput}"
 
 
 
30
  return prompt
31
 
32
+ def predict(instruction, history, input, preinput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed):
33
 
34
  if (access_key != KEY):
35
  print(">>> MODEL FAILED: Input: " + input + ", Attempted Key: " + access_key)
36
  return ("[UNAUTHORIZED ACCESS]", input);
37
 
38
  stops = json.loads(stop_seqs)
39
+
40
+ formatted_input = format(instruction, history, input, preinput)
41
 
42
  response = CLIENTS[model].text_generation(
43
+ formatted_input,
44
  temperature = temperature,
45
  max_new_tokens = max_tokens,
46
  top_p = top_p,
 
54
  return_full_text = False
55
  )
56
 
57
+ sy_l, sy_r = SPECIAL_SYMBOLS[0], SPECIAL_SYMBOLS[1]
58
+ pre_result = f"{sy_l}{response}{sy_r}{''.join(SPECIAL_SYMBOLS)}"
59
+ pattern = re.compile(f"{sy_l}(.*?){sy_r}", re.DOTALL)
60
+ match = pattern.search(pre_result)
61
+ get_result = match.group(1).strip() if match else ""
62
+
63
+ print(f"---\nUSER: {input}\nBOT: {get_result}\n---")
64
 
65
+ return (get_result, input)
66
 
67
  def maintain_cloud():
68
  print(">>> SPACE MAINTAINED!")
 
76
  with gr.Column():
77
  history = gr.Chatbot(elem_id = "chatbot")
78
  input = gr.Textbox(label = "Input", lines = 2)
79
+ preinput = gr.Textbox(label = "Pre-Input", lines = 1)
80
  instruction = gr.Textbox(label = "Instruction", lines = 4)
81
  access_key = gr.Textbox(label = "Access Key", lines = 1)
82
  run = gr.Button("▶")
 
96
  with gr.Column():
97
  output = gr.Textbox(label = "Output", value = "", lines = 50)
98
 
99
+ run.click(predict, inputs = [instruction, history, input, preinput, access_key, model, temperature, top_p, top_k, rep_p, max_tokens, stop_seqs, seed], outputs = [output, input])
100
  cloud.click(maintain_cloud, inputs = [], outputs = [input, output])
101
 
102
  demo.queue(concurrency_count = 500, api_open = True).launch(show_api = True)