lewtun HF staff commited on
Commit
b12b521
1 Parent(s): 423d96b
Files changed (2) hide show
  1. app.ipynb +36 -17
  2. app.py +21 -9
app.ipynb CHANGED
@@ -31,7 +31,23 @@
31
  },
32
  {
33
  "cell_type": "code",
34
- "execution_count": 3,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  "metadata": {},
36
  "outputs": [],
37
  "source": [
@@ -42,8 +58,7 @@
42
  " temperature,\n",
43
  " top_p\n",
44
  "):\n",
45
- " API_URL = f\"https://api-inference.huggingface.co/models/{model_id}\"\n",
46
- " headers = {\"Authorization\": f\"Bearer {HF_TOKEN}\", \"x-wait-for-model\": \"1\"}\n",
47
  "\n",
48
  " payload = {\n",
49
  " \"inputs\": inputs,\n",
@@ -55,7 +70,7 @@
55
  " },\n",
56
  " }\n",
57
  "\n",
58
- " response = requests.post(API_URL, json=payload, headers=headers)\n",
59
  "\n",
60
  " if response.status_code == 200:\n",
61
  " return response.json()\n",
@@ -65,23 +80,24 @@
65
  },
66
  {
67
  "cell_type": "code",
68
- "execution_count": 4,
69
  "metadata": {},
70
  "outputs": [
71
  {
72
  "data": {
73
  "text/plain": [
74
- "[{'generated_text': 'love'}]"
75
  ]
76
  },
77
- "execution_count": 4,
78
  "metadata": {},
79
  "output_type": "execute_result"
80
  }
81
  ],
82
  "source": [
83
- "model_id = \"google/flan-t5-xl\"\n",
84
- "query = \"what is the answer to the universe?\"\n",
 
85
  "query_chat_api(model_id, query, 1, 0.95)"
86
  ]
87
  },
@@ -101,7 +117,7 @@
101
  },
102
  {
103
  "cell_type": "code",
104
- "execution_count": 12,
105
  "metadata": {},
106
  "outputs": [],
107
  "source": [
@@ -121,7 +137,10 @@
121
  " inputs = prompt_template[\"prompt\"].format(human_input=text_input)\n",
122
  "\n",
123
  " output = query_chat_api(model_id, inputs, temperature, top_p)\n",
124
- " history.append(\" \" + output[0][\"generated_text\"])\n",
 
 
 
125
  "\n",
126
  " chat = [\n",
127
  " (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)\n",
@@ -695,14 +714,14 @@
695
  },
696
  {
697
  "cell_type": "code",
698
- "execution_count": 6,
699
  "metadata": {},
700
  "outputs": [
701
  {
702
  "name": "stdout",
703
  "output_type": "stream",
704
  "text": [
705
- "Running on local URL: http://127.0.0.1:7860\n",
706
  "\n",
707
  "To create a public link, set `share=True` in `launch()`.\n"
708
  ]
@@ -710,7 +729,7 @@
710
  {
711
  "data": {
712
  "text/html": [
713
- "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
714
  ],
715
  "text/plain": [
716
  "<IPython.core.display.HTML object>"
@@ -723,7 +742,7 @@
723
  "data": {
724
  "text/plain": []
725
  },
726
- "execution_count": 6,
727
  "metadata": {},
728
  "output_type": "execute_result"
729
  }
@@ -744,7 +763,7 @@
744
  " with gr.Row():\n",
745
  " with gr.Column(scale=1):\n",
746
  " model_id = gr.Dropdown(\n",
747
- " choices=[\"google/flan-t5-xl\"],\n",
748
  " value=\"google/flan-t5-xl\",\n",
749
  " label=\"Model\",\n",
750
  " interactive=True,\n",
@@ -846,7 +865,7 @@
846
  },
847
  {
848
  "cell_type": "code",
849
- "execution_count": 13,
850
  "metadata": {},
851
  "outputs": [],
852
  "source": [
 
31
  },
32
  {
33
  "cell_type": "code",
34
+ "execution_count": 32,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "# |export\n",
39
+ "def get_model_endpoint(model_id):\n",
40
+ " if \"joi\" in model_id:\n",
41
+ " headers = None\n",
42
+ " return \"https://joi-20b.ngrok.io/generate\", headers\n",
43
+ " else:\n",
44
+ " headers = {\"Authorization\": f\"Bearer {HF_TOKEN}\", \"x-wait-for-model\": \"1\"}\n",
45
+ " return f\"https://api-inference.huggingface.co/models/{model_id}\", headers\n"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "code",
50
+ "execution_count": 33,
51
  "metadata": {},
52
  "outputs": [],
53
  "source": [
 
58
  " temperature,\n",
59
  " top_p\n",
60
  "):\n",
61
+ " endpoint, headers = get_model_endpoint(model_id)\n",
 
62
  "\n",
63
  " payload = {\n",
64
  " \"inputs\": inputs,\n",
 
70
  " },\n",
71
  " }\n",
72
  "\n",
73
+ " response = requests.post(endpoint, json=payload, headers=headers)\n",
74
  "\n",
75
  " if response.status_code == 200:\n",
76
  " return response.json()\n",
 
80
  },
81
  {
82
  "cell_type": "code",
83
+ "execution_count": 36,
84
  "metadata": {},
85
  "outputs": [
86
  {
87
  "data": {
88
  "text/plain": [
89
+ "{'generated_text': '\\n\\nJoi: Black holes are regions of space-time where gravity is so strong that nothing'}"
90
  ]
91
  },
92
+ "execution_count": 36,
93
  "metadata": {},
94
  "output_type": "execute_result"
95
  }
96
  ],
97
  "source": [
98
+ "# model_id = \"google/flan-t5-xl\"\n",
99
+ "model_id = \"Rallio67/joi_20B_instruct_alpha\"\n",
100
+ "query = \"What can you tell me about black holes?\"\n",
101
  "query_chat_api(model_id, query, 1, 0.95)"
102
  ]
103
  },
 
117
  },
118
  {
119
  "cell_type": "code",
120
+ "execution_count": 37,
121
  "metadata": {},
122
  "outputs": [],
123
  "source": [
 
137
  " inputs = prompt_template[\"prompt\"].format(human_input=text_input)\n",
138
  "\n",
139
  " output = query_chat_api(model_id, inputs, temperature, top_p)\n",
140
+ " # TODO: remove this hack when inference backend schema is updated\n",
141
+ " if isinstance(output, list):\n",
142
+ " output = output[0]\n",
143
+ " history.append(\" \" + output[\"generated_text\"])\n",
144
  "\n",
145
  " chat = [\n",
146
  " (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)\n",
 
714
  },
715
  {
716
  "cell_type": "code",
717
+ "execution_count": 38,
718
  "metadata": {},
719
  "outputs": [
720
  {
721
  "name": "stdout",
722
  "output_type": "stream",
723
  "text": [
724
+ "Running on local URL: http://127.0.0.1:7861\n",
725
  "\n",
726
  "To create a public link, set `share=True` in `launch()`.\n"
727
  ]
 
729
  {
730
  "data": {
731
  "text/html": [
732
+ "<div><iframe src=\"http://127.0.0.1:7861/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
733
  ],
734
  "text/plain": [
735
  "<IPython.core.display.HTML object>"
 
742
  "data": {
743
  "text/plain": []
744
  },
745
+ "execution_count": 38,
746
  "metadata": {},
747
  "output_type": "execute_result"
748
  }
 
763
  " with gr.Row():\n",
764
  " with gr.Column(scale=1):\n",
765
  " model_id = gr.Dropdown(\n",
766
+ " choices=[\"google/flan-t5-xl\" ,\"Rallio67/joi_20B_instruct_alpha\"],\n",
767
  " value=\"google/flan-t5-xl\",\n",
768
  " label=\"Model\",\n",
769
  " interactive=True,\n",
 
865
  },
866
  {
867
  "cell_type": "code",
868
+ "execution_count": 15,
869
  "metadata": {},
870
  "outputs": [],
871
  "source": [
app.py CHANGED
@@ -1,7 +1,7 @@
1
  # AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb.
2
 
3
  # %% auto 0
4
- __all__ = ['HF_TOKEN', 'title', 'description', 'query_chat_api', 'inference_chat']
5
 
6
  # %% app.ipynb 0
7
  import gradio as gr
@@ -21,14 +21,23 @@ HF_TOKEN = os.getenv("HF_TOKEN")
21
 
22
 
23
  # %% app.ipynb 2
 
 
 
 
 
 
 
 
 
 
24
  def query_chat_api(
25
  model_id,
26
  inputs,
27
  temperature,
28
  top_p
29
  ):
30
- API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
31
- headers = {"Authorization": f"Bearer {HF_TOKEN}", "x-wait-for-model": "1"}
32
 
33
  payload = {
34
  "inputs": inputs,
@@ -40,7 +49,7 @@ def query_chat_api(
40
  },
41
  }
42
 
43
- response = requests.post(API_URL, json=payload, headers=headers)
44
 
45
  if response.status_code == 200:
46
  return response.json()
@@ -48,7 +57,7 @@ def query_chat_api(
48
  return "Error: " + response.text
49
 
50
 
51
- # %% app.ipynb 5
52
  def inference_chat(
53
  model_id,
54
  prompt_template,
@@ -64,7 +73,10 @@ def inference_chat(
64
  inputs = prompt_template["prompt"].format(human_input=text_input)
65
 
66
  output = query_chat_api(model_id, inputs, temperature, top_p)
67
- history.append(" " + output[0]["generated_text"])
 
 
 
68
 
69
  chat = [
70
  (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
@@ -73,7 +85,7 @@ def inference_chat(
73
  return {chatbot: chat, state: history}
74
 
75
 
76
- # %% app.ipynb 15
77
  title = """<h1 align="center">Chatty Language Models</h1>"""
78
  description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
79
 
@@ -98,7 +110,7 @@ So far, the following prompts are available:
98
  As you can see, most of these prompts exceed the maximum context size of models like Flan-T5, so an error usually means the Inference API has timed out.
99
  """
100
 
101
- # %% app.ipynb 16
102
  with gr.Blocks(
103
  css="""
104
  .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
@@ -113,7 +125,7 @@ with gr.Blocks(
113
  with gr.Row():
114
  with gr.Column(scale=1):
115
  model_id = gr.Dropdown(
116
- choices=["google/flan-t5-xl"],
117
  value="google/flan-t5-xl",
118
  label="Model",
119
  interactive=True,
 
1
  # AUTOGENERATED! DO NOT EDIT! File to edit: app.ipynb.
2
 
3
  # %% auto 0
4
+ __all__ = ['HF_TOKEN', 'title', 'description', 'get_model_endpoint', 'query_chat_api', 'inference_chat']
5
 
6
  # %% app.ipynb 0
7
  import gradio as gr
 
21
 
22
 
23
  # %% app.ipynb 2
24
+ def get_model_endpoint(model_id):
25
+ if "joi" in model_id:
26
+ headers = None
27
+ return "https://joi-20b.ngrok.io/generate", headers
28
+ else:
29
+ headers = {"Authorization": f"Bearer {HF_TOKEN}", "x-wait-for-model": "1"}
30
+ return f"https://api-inference.huggingface.co/models/{model_id}", headers
31
+
32
+
33
+ # %% app.ipynb 3
34
  def query_chat_api(
35
  model_id,
36
  inputs,
37
  temperature,
38
  top_p
39
  ):
40
+ endpoint, headers = get_model_endpoint(model_id)
 
41
 
42
  payload = {
43
  "inputs": inputs,
 
49
  },
50
  }
51
 
52
+ response = requests.post(endpoint, json=payload, headers=headers)
53
 
54
  if response.status_code == 200:
55
  return response.json()
 
57
  return "Error: " + response.text
58
 
59
 
60
+ # %% app.ipynb 6
61
  def inference_chat(
62
  model_id,
63
  prompt_template,
 
73
  inputs = prompt_template["prompt"].format(human_input=text_input)
74
 
75
  output = query_chat_api(model_id, inputs, temperature, top_p)
76
+ # TODO: remove this hack when inference backend schema is updated
77
+ if isinstance(output, list):
78
+ output = output[0]
79
+ history.append(" " + output["generated_text"])
80
 
81
  chat = [
82
  (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
 
85
  return {chatbot: chat, state: history}
86
 
87
 
88
+ # %% app.ipynb 16
89
  title = """<h1 align="center">Chatty Language Models</h1>"""
90
  description = """Language models can be conditioned to act like dialogue agents through a conversational prompt that typically takes the form:
91
 
 
110
  As you can see, most of these prompts exceed the maximum context size of models like Flan-T5, so an error usually means the Inference API has timed out.
111
  """
112
 
113
+ # %% app.ipynb 17
114
  with gr.Blocks(
115
  css="""
116
  .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
 
125
  with gr.Row():
126
  with gr.Column(scale=1):
127
  model_id = gr.Dropdown(
128
+ choices=["google/flan-t5-xl" ,"Rallio67/joi_20B_instruct_alpha"],
129
  value="google/flan-t5-xl",
130
  label="Model",
131
  interactive=True,