plaguss HF staff commited on
Commit
68ffbe0
1 Parent(s): 4f3757c

Update app to register interactions in an argilla dataset

Browse files
Files changed (2) hide show
  1. app.py +139 -47
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional, Any, Generator
2
  import os
3
  from pathlib import Path
4
  import tarfile
@@ -11,12 +11,14 @@ from huggingface_hub.file_download import hf_hub_download
11
  from huggingface_hub import InferenceClient, login
12
  from transformers import AutoTokenizer
13
  import gradio as gr
 
 
14
 
15
 
16
  @dataclass
17
  class Settings:
18
- """Settings class to store useful variables for the App.
19
- """
20
  LANCEDB: str = "lancedb"
21
  LANCEDB_FILE_TAR: str = "lancedb.tar.gz"
22
  TOKEN: str = os.getenv("HF_API_TOKEN")
@@ -24,13 +26,29 @@ class Settings:
24
  REPO_ID: str = "plaguss/argilla_sdk_docs_queries"
25
  TABLE_NAME: str = "docs"
26
  MODEL_NAME: str = "plaguss/bge-base-argilla-sdk-matryoshka"
27
- DEVICE: str = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
28
  MODEL_ID: str = "meta-llama/Meta-Llama-3-70B-Instruct"
 
 
 
 
29
 
30
  settings = Settings()
31
 
32
  login(token=settings.TOKEN)
33
 
 
 
 
 
 
 
34
 
35
  def untar_file(source: Path) -> Path:
36
  """Untar and decompress files which have passed by `make_tarfile`.
@@ -51,7 +69,7 @@ def download_database(
51
  repo_id: str,
52
  lancedb_file: str = "lancedb.tar.gz",
53
  local_dir: Path = Path.home() / ".cache/argilla_sdk_docs_db",
54
- token: str = os.getenv("HF_API_TOKEN")
55
  ) -> Path:
56
  """Helper function to download the database. Will download a compressed lancedb stored
57
  in a Hugging Face repository.
@@ -69,18 +87,18 @@ def download_database(
69
  """
70
  lancedb_download = Path(
71
  hf_hub_download(
72
- repo_id,
73
- lancedb_file,
74
- repo_type="dataset",
75
- token=token,
76
- local_dir=local_dir
77
  )
78
  )
79
  return untar_file(lancedb_download)
80
 
81
 
82
  # Get the model to create the embeddings
83
- model = get_registry().get("sentence-transformers").create(name=settings.MODEL_NAME, device=settings.DEVICE)
 
 
 
 
84
 
85
 
86
  class Database:
@@ -90,7 +108,12 @@ class Database:
90
  the expected location. Once ready, the only functionality available is
91
  to retrieve the doc chunks to be used as examples for the LLM.
92
  """
 
93
  def __init__(self, settings: Settings) -> None:
 
 
 
 
94
  self.settings = settings
95
  self._table: lancedb.table.LanceTable = self.get_table_from_db()
96
 
@@ -110,39 +133,56 @@ class Database:
110
  self.settings.REPO_ID,
111
  lancedb_file=self.settings.LANCEDB_FILE_TAR,
112
  local_dir=self.settings.LOCAL_DIR,
113
- token=self.settings.TOKEN
114
  )
115
 
116
  db = lancedb.connect(str(lancedb_db_path))
117
  table = db.open_table(self.settings.TABLE_NAME)
118
  return table
119
 
120
- def retrieve_doc_chunks(self, query: str, limit: int = 12, hard_limit: int = 4) -> str:
121
- """Search for similar queries in the database, and return a list with
122
-
123
- TODO: SPLIT IN TWO SEPARATE FUNCTIONS TO PREPARE THE CONTEXT.
 
124
 
125
  Args:
126
- query (str): _description_
127
- limit (int, optional): _description_. Defaults to 12.
128
- hard_limit (int, optional): _description_. Defaults to 4.
 
 
 
 
129
 
130
  Returns:
131
- str: _description_
132
  """
133
- # Embed the query to use our custom model instead of the default one.
134
  embedded_query = model.generate_embeddings([query])
135
  field_to_retrieve = "text"
136
  retrieved = (
137
- self._table
138
- .search(embedded_query[0])
139
- .metric("cosine")
140
- .limit(limit)
141
- .select([field_to_retrieve]) # Just grab the chunk to use for context
142
- .to_list()
143
  )
144
- # We have repeated questions (up to 4) for a given chunk, so we may get repeated chunks.
145
- # Request more than necessary and filter them afterwards
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  responses = []
147
  unique_responses = set()
148
 
@@ -164,8 +204,7 @@ database = Database(settings=settings)
164
 
165
 
166
  def get_client_and_tokenizer(
167
- model_id: str = settings.MODEL_ID,
168
- tokenizer_id: Optional[str] = None
169
  ) -> tuple[InferenceClient, AutoTokenizer]:
170
  """Obtains the inference client and the tokenizer corresponding to the model.
171
 
@@ -182,14 +221,9 @@ def get_client_and_tokenizer(
182
  tokenizer_id = model_id
183
 
184
  client = InferenceClient()
185
- base_url = client._resolve_url(
186
- model=model_id, task="text-generation"
187
- )
188
  # Note: We could move to the AsyncClient
189
- client = InferenceClient(
190
- model=base_url,
191
- token=os.getenv("HF_API_TOKEN")
192
- )
193
 
194
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
195
  return client, tokenizer
@@ -204,7 +238,9 @@ client_kwargs = {
204
  "temperature": 0.3,
205
  "top_p": None,
206
  "top_k": None,
207
- "stop_sequences": ["<|eot_id|>", "<|end_of_text|>"] if settings.MODEL_ID.startswith("meta-llama/Meta-Llama-3") else None,
 
 
208
  "seed": None,
209
  }
210
 
@@ -313,6 +349,42 @@ def prepare_input(message: str, history: list[tuple[str, str]]) -> str:
313
  )[0]
314
 
315
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  def chatty(message: str, history: list[tuple[str, str]]) -> Generator[str, None, None]:
317
  """Main function of the app, contains the interaction with the LLM.
318
 
@@ -326,28 +398,48 @@ def chatty(message: str, history: list[tuple[str, str]]) -> Generator[str, None,
326
  """
327
  prompt = prepare_input(message, history)
328
 
329
- partial_message = ""
330
- for token_stream in client.text_generation(prompt=prompt, **client_kwargs):
331
- partial_message += token_stream
332
- yield partial_message
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
 
336
  if __name__ == "__main__":
337
-
338
  import gradio as gr
339
 
340
  gr.ChatInterface(
341
  chatty,
342
- chatbot=gr.Chatbot(height=600),
343
- textbox=gr.Textbox(placeholder="Ask me about the new argilla SDK", container=False, scale=7),
 
 
344
  title="Argilla SDK Chatbot",
345
  description="Ask a question about Argilla SDK",
346
  theme="soft",
347
  examples=[
348
  "How can I connect to an argilla server?",
349
  "How can I access a dataset?",
350
- "How can I get the current user?"
351
  ],
352
  cache_examples=True,
353
  retry_btn=None,
 
1
+ from typing import Optional, Generator
2
  import os
3
  from pathlib import Path
4
  import tarfile
 
11
  from huggingface_hub import InferenceClient, login
12
  from transformers import AutoTokenizer
13
  import gradio as gr
14
+ import argilla as rg
15
+ import uuid
16
 
17
 
18
  @dataclass
19
  class Settings:
20
+ """Settings class to store useful variables for the App."""
21
+
22
  LANCEDB: str = "lancedb"
23
  LANCEDB_FILE_TAR: str = "lancedb.tar.gz"
24
  TOKEN: str = os.getenv("HF_API_TOKEN")
 
26
  REPO_ID: str = "plaguss/argilla_sdk_docs_queries"
27
  TABLE_NAME: str = "docs"
28
  MODEL_NAME: str = "plaguss/bge-base-argilla-sdk-matryoshka"
29
+ DEVICE: str = (
30
+ "mps"
31
+ if torch.backends.mps.is_available()
32
+ else "cuda"
33
+ if torch.cuda.is_available()
34
+ else "cpu"
35
+ )
36
  MODEL_ID: str = "meta-llama/Meta-Llama-3-70B-Instruct"
37
+ ARGILLA_URL = r"https://plaguss-argilla-sdk-chatbot.hf.space"
38
+ ARGILLA_API_KEY = os.getenv("ARGILLA_CHATBOT_API_KEY")
39
+ ARGILLA_DATASET = "chatbot_interactions"
40
+
41
 
42
  settings = Settings()
43
 
44
  login(token=settings.TOKEN)
45
 
46
+ client_rg = rg.Argilla(
47
+ api_url=settings.ARGILLA_URL,
48
+ api_key=settings.ARGILLA_API_KEY
49
+ )
50
+ argilla_dataset = client_rg.datasets(settings.ARGILLA_DATASET)
51
+
52
 
53
  def untar_file(source: Path) -> Path:
54
  """Untar and decompress files which have passed by `make_tarfile`.
 
69
  repo_id: str,
70
  lancedb_file: str = "lancedb.tar.gz",
71
  local_dir: Path = Path.home() / ".cache/argilla_sdk_docs_db",
72
+ token: str = os.getenv("HF_API_TOKEN"),
73
  ) -> Path:
74
  """Helper function to download the database. Will download a compressed lancedb stored
75
  in a Hugging Face repository.
 
87
  """
88
  lancedb_download = Path(
89
  hf_hub_download(
90
+ repo_id, lancedb_file, repo_type="dataset", token=token, local_dir=local_dir
 
 
 
 
91
  )
92
  )
93
  return untar_file(lancedb_download)
94
 
95
 
96
  # Get the model to create the embeddings
97
+ model = (
98
+ get_registry()
99
+ .get("sentence-transformers")
100
+ .create(name=settings.MODEL_NAME, device=settings.DEVICE)
101
+ )
102
 
103
 
104
  class Database:
 
108
  the expected location. Once ready, the only functionality available is
109
  to retrieve the doc chunks to be used as examples for the LLM.
110
  """
111
+
112
  def __init__(self, settings: Settings) -> None:
113
+ """
114
+ Args:
115
+ settings: Instance of the settings.
116
+ """
117
  self.settings = settings
118
  self._table: lancedb.table.LanceTable = self.get_table_from_db()
119
 
 
133
  self.settings.REPO_ID,
134
  lancedb_file=self.settings.LANCEDB_FILE_TAR,
135
  local_dir=self.settings.LOCAL_DIR,
136
+ token=self.settings.TOKEN,
137
  )
138
 
139
  db = lancedb.connect(str(lancedb_db_path))
140
  table = db.open_table(self.settings.TABLE_NAME)
141
  return table
142
 
143
+ def retrieve_doc_chunks(
144
+ self, query: str, limit: int = 12, hard_limit: int = 4
145
+ ) -> str:
146
+ """Search for similar queries in the database, and return the context to be passed
147
+ to the LLM.
148
 
149
  Args:
150
+ query: Query from the user.
151
+ limit: Number of similar items to retrieve. Defaults to 12.
152
+ hard_limit: Limit of responses to take into account.
153
+ As we generated repeated questions initially, the database may contain
154
+ repeated chunks of documents, in the initial `limit` selection, using
155
+ `hard_limit` we limit to this number the total of unique retrieved chunks.
156
+ Defaults to 4.
157
 
158
  Returns:
159
+ The context to be used by the model to generate the response.
160
  """
161
+ # Embed the query to use our custom model instead of the default one.
162
  embedded_query = model.generate_embeddings([query])
163
  field_to_retrieve = "text"
164
  retrieved = (
165
+ self._table.search(embedded_query[0])
166
+ .metric("cosine")
167
+ .limit(limit)
168
+ .select([field_to_retrieve]) # Just grab the chunk to use for context
169
+ .to_list()
 
170
  )
171
+ return self._prepare_context(retrieved, hard_limit)
172
+
173
+ @staticmethod
174
+ def _prepare_context(retrieved: list[dict[str, str]], hard_limit: int) -> str:
175
+ """Prepares the examples to be used in the LLM prompt.
176
+
177
+ Args:
178
+ retrieved: The list of retrieved chunks.
179
+ hard_limit: Max number of doc pieces to return.
180
+
181
+ Returns:
182
+ Context to be used by the LLM.
183
+ """
184
+ # We have repeated questions (up to 4) for a given chunk, so we may get repeated chunks.
185
+ # Request more than necessary and filter them afterwards
186
  responses = []
187
  unique_responses = set()
188
 
 
204
 
205
 
206
  def get_client_and_tokenizer(
207
+ model_id: str = settings.MODEL_ID, tokenizer_id: Optional[str] = None
 
208
  ) -> tuple[InferenceClient, AutoTokenizer]:
209
  """Obtains the inference client and the tokenizer corresponding to the model.
210
 
 
221
  tokenizer_id = model_id
222
 
223
  client = InferenceClient()
224
+ base_url = client._resolve_url(model=model_id, task="text-generation")
 
 
225
  # Note: We could move to the AsyncClient
226
+ client = InferenceClient(model=base_url, token=os.getenv("HF_API_TOKEN"))
 
 
 
227
 
228
  tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
229
  return client, tokenizer
 
238
  "temperature": 0.3,
239
  "top_p": None,
240
  "top_k": None,
241
+ "stop_sequences": ["<|eot_id|>", "<|end_of_text|>"]
242
+ if settings.MODEL_ID.startswith("meta-llama/Meta-Llama-3")
243
+ else None,
244
  "seed": None,
245
  }
246
 
 
349
  )[0]
350
 
351
 
352
+ def create_chat_html(history: list[tuple[str, str]]) -> str:
353
+ """Helper function to create a conversation in HTML in argilla.
354
+
355
+ Args:
356
+ history: History of messages with the chatbot.
357
+
358
+ Returns:
359
+ HTML formatted conversation.
360
+ """
361
+ chat_html = ""
362
+ alignments = ["right", "left"]
363
+ colors = ["#c2e3f7", "#f5f5f5"]
364
+
365
+ for turn in history:
366
+ # Create the HTML message div with inline styles
367
+ message_html = ""
368
+
369
+ # To include message still not answered
370
+ (user, assistant) = turn
371
+ if assistant is None:
372
+ turn = (user, )
373
+
374
+ for i, content in enumerate(turn):
375
+ message_html += f'<div style="display: flex; justify-content: {alignments[i]}; margin: 10px;">'
376
+ message_html += f'<div style="background-color: {colors[i]}; padding: 10px; border-radius: 10px; max-width: 70%; word-wrap: break-word;">{content}</div>'
377
+ message_html += "</div>"
378
+
379
+ # Add the message to the chat HTML
380
+ chat_html += message_html
381
+
382
+ return chat_html
383
+
384
+
385
+ conv_id = str(uuid.uuid4())
386
+
387
+
388
  def chatty(message: str, history: list[tuple[str, str]]) -> Generator[str, None, None]:
389
  """Main function of the app, contains the interaction with the LLM.
390
 
 
398
  """
399
  prompt = prepare_input(message, history)
400
 
401
+ partial_response = ""
 
 
 
402
 
403
+ for token_stream in client.text_generation(prompt=prompt, **client_kwargs):
404
+ partial_response += token_stream
405
+ yield partial_response
406
+
407
+ global conv_id
408
+ new_conversation = len(history) == 0
409
+ if new_conversation:
410
+ conv_id = str(uuid.uuid4())
411
+ else:
412
+ history.append((message, None))
413
+
414
+ # Register to argilla dataset
415
+ argilla_dataset.records.log(
416
+ [
417
+ {
418
+ "instruction": create_chat_html(history) if history else message,
419
+ "response": partial_response,
420
+ "conv_id": conv_id,
421
+ "turn": len(history)
422
+ },
423
+ ]
424
+ )
425
 
426
 
427
  if __name__ == "__main__":
 
428
  import gradio as gr
429
 
430
  gr.ChatInterface(
431
  chatty,
432
+ chatbot=gr.Chatbot(height=700),
433
+ textbox=gr.Textbox(
434
+ placeholder="Ask me about the new argilla SDK", container=False, scale=7
435
+ ),
436
  title="Argilla SDK Chatbot",
437
  description="Ask a question about Argilla SDK",
438
  theme="soft",
439
  examples=[
440
  "How can I connect to an argilla server?",
441
  "How can I access a dataset?",
442
+ "How can I get the current user?",
443
  ],
444
  cache_examples=True,
445
  retry_btn=None,
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  torch==2.3.1
2
  lancedb==0.8.2
3
- sentence-transformers==3.0.1
 
 
1
  torch==2.3.1
2
  lancedb==0.8.2
3
+ sentence-transformers==3.0.1
4
+ argilla==2.0.0rc1