AlexanderKazakov commited on
Commit
360f505
1 Parent(s): eba1a12

make it work in zero draft

Browse files
.idea/rag-gradio-sample-project.iml CHANGED
@@ -2,6 +2,7 @@
2
  <module type="PYTHON_MODULE" version="4">
3
  <component name="NewModuleRootManager">
4
  <content url="file://$MODULE_DIR$">
 
5
  <excludeFolder url="file://$MODULE_DIR$/venv" />
6
  </content>
7
  <orderEntry type="jdk" jdkName="Python 3.11 (rag-gradio-sample-project) (2)" jdkType="Python SDK" />
 
2
  <module type="PYTHON_MODULE" version="4">
3
  <component name="NewModuleRootManager">
4
  <content url="file://$MODULE_DIR$">
5
+ <excludeFolder url="file://$MODULE_DIR$/data" />
6
  <excludeFolder url="file://$MODULE_DIR$/venv" />
7
  </content>
8
  <orderEntry type="jdk" jdkName="Python 3.11 (rag-gradio-sample-project) (2)" jdkType="Python SDK" />
gradio_app/app.py CHANGED
@@ -6,28 +6,25 @@ Credit to Derek Thomas, derek@huggingface.co
6
  # subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])
7
 
8
  import logging
9
- from pathlib import Path
10
  from time import perf_counter
11
 
12
  import gradio as gr
13
  from jinja2 import Environment, FileSystemLoader
14
 
15
  from backend.query_llm import generate_hf, generate_openai
16
- # from backend.semantic_search import table, retriever
17
 
18
- VECTOR_COLUMN_NAME = ""
19
- TEXT_COLUMN_NAME = ""
20
 
21
- proj_dir = Path(__file__).parent
22
  # Setting up the logging
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
25
 
26
  # Set up the template environment with the templates directory
27
- env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
28
 
29
  # Load the templates directly from the environment
30
- template = env.get_template('template.j2')
31
  template_html = env.get_template('template_html.j2')
32
 
33
  # Examples
@@ -47,34 +44,34 @@ def bot(history, api_kind):
47
  query = history[-1][0]
48
 
49
  if not query:
50
- gr.Warning("Please submit a non-empty string as a prompt")
51
- raise ValueError("Empty string was submitted")
52
 
53
- logger.warning('Retrieving documents...')
54
  # Retrieve documents relevant to query
55
  document_start = perf_counter()
56
 
57
- query_vec = retriever.encode(query)
58
  documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank).to_list()
59
  documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
60
 
61
  document_time = perf_counter() - document_start
62
- logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
63
 
64
  # Create Prompt
65
- prompt = template.render(documents=documents, query=query)
66
  prompt_html = template_html.render(documents=documents, query=query)
67
 
68
  if api_kind == "HuggingFace":
69
- generate_fn = generate_hf
70
  elif api_kind == "OpenAI":
71
- generate_fn = generate_openai
72
  elif api_kind is None:
73
- gr.Warning("API name was not provided")
74
- raise ValueError("API name was not provided")
75
  else:
76
- gr.Warning(f"API {api_kind} is not supported")
77
- raise ValueError(f"API {api_kind} is not supported")
78
 
79
  history[-1][1] = ""
80
  for character in generate_fn(prompt, history[:-1]):
@@ -84,22 +81,22 @@ def bot(history, api_kind):
84
 
85
  with gr.Blocks() as demo:
86
  chatbot = gr.Chatbot(
87
- [],
88
- elem_id="chatbot",
89
- avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
90
- 'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
91
- bubble_full_width=False,
92
- show_copy_button=True,
93
- show_share_button=True,
94
- )
95
 
96
  with gr.Row():
97
  txt = gr.Textbox(
98
- scale=3,
99
- show_label=False,
100
- placeholder="Enter text and press enter",
101
- container=False,
102
- )
103
  txt_btn = gr.Button(value="Submit text", scale=1)
104
 
105
  api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
@@ -107,14 +104,14 @@ with gr.Blocks() as demo:
107
  prompt_html = gr.HTML()
108
  # Turn off interactivity while generating if you click
109
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
110
- bot, [chatbot, api_kind], [chatbot, prompt_html])
111
 
112
  # Turn it back on
113
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
114
 
115
  # Turn off interactivity while generating if you hit enter
116
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
117
- bot, [chatbot, api_kind], [chatbot, prompt_html])
118
 
119
  # Turn it back on
120
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
 
6
  # subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])
7
 
8
  import logging
 
9
  from time import perf_counter
10
 
11
  import gradio as gr
12
  from jinja2 import Environment, FileSystemLoader
13
 
14
  from backend.query_llm import generate_hf, generate_openai
15
+ from backend.semantic_search import table, embedder
16
 
17
+ from settings import *
 
18
 
 
19
  # Setting up the logging
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
23
  # Set up the template environment with the templates directory
24
+ env = Environment(loader=FileSystemLoader('gradio_app/templates'))
25
 
26
  # Load the templates directly from the environment
27
+ prompt_template = env.get_template('prompt_template.j2')
28
  template_html = env.get_template('template_html.j2')
29
 
30
  # Examples
 
44
  query = history[-1][0]
45
 
46
  if not query:
47
+ gr.Warning("Please submit a non-empty string as a prompt")
48
+ raise ValueError("Empty string was submitted")
49
 
50
+ logger.info('Retrieving documents...')
51
  # Retrieve documents relevant to query
52
  document_start = perf_counter()
53
 
54
+ query_vec = embedder.encode(query)
55
  documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank).to_list()
56
  documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
57
 
58
  document_time = perf_counter() - document_start
59
+ logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
60
 
61
  # Create Prompt
62
+ prompt = prompt_template.render(documents=documents, query=query)
63
  prompt_html = template_html.render(documents=documents, query=query)
64
 
65
  if api_kind == "HuggingFace":
66
+ generate_fn = generate_hf
67
  elif api_kind == "OpenAI":
68
+ generate_fn = generate_openai
69
  elif api_kind is None:
70
+ gr.Warning("API name was not provided")
71
+ raise ValueError("API name was not provided")
72
  else:
73
+ gr.Warning(f"API {api_kind} is not supported")
74
+ raise ValueError(f"API {api_kind} is not supported")
75
 
76
  history[-1][1] = ""
77
  for character in generate_fn(prompt, history[:-1]):
 
81
 
82
  with gr.Blocks() as demo:
83
  chatbot = gr.Chatbot(
84
+ [],
85
+ elem_id="chatbot",
86
+ avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
87
+ 'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
88
+ bubble_full_width=False,
89
+ show_copy_button=True,
90
+ show_share_button=True,
91
+ )
92
 
93
  with gr.Row():
94
  txt = gr.Textbox(
95
+ scale=3,
96
+ show_label=False,
97
+ placeholder="Enter text and press enter",
98
+ container=False,
99
+ )
100
  txt_btn = gr.Button(value="Submit text", scale=1)
101
 
102
  api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
 
104
  prompt_html = gr.HTML()
105
  # Turn off interactivity while generating if you click
106
  txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
107
+ bot, [chatbot, api_kind], [chatbot, prompt_html])
108
 
109
  # Turn it back on
110
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
111
 
112
  # Turn off interactivity while generating if you hit enter
113
  txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
114
+ bot, [chatbot, api_kind], [chatbot, prompt_html])
115
 
116
  # Turn it back on
117
  txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
gradio_app/backend/query_llm.py CHANGED
@@ -7,19 +7,16 @@ from typing import Any, Dict, Generator, List
7
  from huggingface_hub import InferenceClient
8
  from transformers import AutoTokenizer
9
 
10
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
11
 
12
- temperature = 0.9
13
- top_p = 0.6
14
- repetition_penalty = 1.2
15
 
16
  OPENAI_KEY = getenv("OPENAI_API_KEY")
17
  HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
18
 
19
- hf_client = InferenceClient(
20
- "mistralai/Mistral-7B-Instruct-v0.1",
21
- token=HF_TOKEN
22
- )
23
 
24
 
25
  def format_prompt(message: str, api_kind: str):
@@ -40,12 +37,12 @@ def format_prompt(message: str, api_kind: str):
40
  return messages
41
  elif api_kind == "hf":
42
  return tokenizer.apply_chat_template(messages, tokenize=False)
43
- elif api_kind:
44
  raise ValueError("API is not supported")
45
 
46
 
47
- def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
48
- top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
49
  """
50
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
51
 
@@ -99,8 +96,8 @@ def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tok
99
  return "I do not know what happened, but I couldn't understand you."
100
 
101
 
102
- def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
103
- top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
104
  """
105
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
106
 
 
7
  from huggingface_hub import InferenceClient
8
  from transformers import AutoTokenizer
9
 
10
+ from settings import *
11
 
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
 
14
 
15
  OPENAI_KEY = getenv("OPENAI_API_KEY")
16
  HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
17
 
18
+
19
+ hf_client = InferenceClient(LLM_NAME, token=HF_TOKEN)
 
 
20
 
21
 
22
  def format_prompt(message: str, api_kind: str):
 
37
  return messages
38
  elif api_kind == "hf":
39
  return tokenizer.apply_chat_template(messages, tokenize=False)
40
+ else:
41
  raise ValueError("API is not supported")
42
 
43
 
44
+ def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
45
+ top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
46
  """
47
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
48
 
 
96
  return "I do not know what happened, but I couldn't understand you."
97
 
98
 
99
+ def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
100
+ top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
101
  """
102
  Generate a sequence of tokens based on a given prompt and history using Mistral client.
103
 
gradio_app/backend/semantic_search.py CHANGED
@@ -1,18 +1,14 @@
1
  import logging
2
  import lancedb
3
- import os
4
- from pathlib import Path
5
  from sentence_transformers import SentenceTransformer
6
 
7
- EMB_MODEL_NAME = ""
8
- DB_TABLE_NAME = ""
9
 
10
  # Setting up the logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
- retriever = SentenceTransformer(EMB_MODEL_NAME)
14
 
15
- # db
16
- db_uri = os.path.join(Path(__file__).parents[1], ".lancedb")
17
- db = lancedb.connect(db_uri)
18
- table = db.open_table(DB_TABLE_NAME)
 
1
  import logging
2
  import lancedb
 
 
3
  from sentence_transformers import SentenceTransformer
4
 
5
+ from settings import *
6
+
7
 
8
  # Setting up the logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
+ embedder = SentenceTransformer(EMB_MODEL_NAME)
12
 
13
+ db = lancedb.connect(LANCEDB_DIRECTORY)
14
+ table = db.open_table(LANCEDB_TABLE_NAME)
 
 
gradio_app/templates/{template.j2 → prompt_template.j2} RENAMED
File without changes
prep_scripts/lancedb_setup.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import lancedb
2
  import torch
3
  import pyarrow as pa
@@ -8,13 +11,16 @@ import numpy as np
8
 
9
  from sentence_transformers import SentenceTransformer
10
 
 
 
 
 
 
 
 
11
 
12
- EMB_MODEL_NAME = ""
13
- DB_TABLE_NAME = ""
14
- VECTOR_COLUMN_NAME = ""
15
- TEXT_COLUMN_NAME = ""
16
- INPUT_DIR = "<chunked docs directory>"
17
- db = lancedb.connect(".lancedb") # db location
18
  batch_size = 32
19
 
20
  model = SentenceTransformer(EMB_MODEL_NAME)
@@ -29,17 +35,17 @@ else:
29
 
30
  schema = pa.schema(
31
  [
32
- pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(), 768)),
33
  pa.field(TEXT_COLUMN_NAME, pa.string())
34
  ])
35
- tbl = db.create_table(DB_TABLE_NAME, schema=schema, mode="overwrite")
36
 
37
- input_dir = Path(INPUT_DIR)
38
  files = list(input_dir.rglob("*"))
39
 
40
  sentences = []
41
  for file in files:
42
- with open(file) as f:
43
  sentences.append(f.read())
44
 
45
  for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))):
@@ -54,12 +60,15 @@ for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))):
54
  })
55
 
56
  tbl.add(df)
 
57
  except:
58
- print(f"batch {i} was skipped")
 
59
 
60
  '''
61
  create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/
62
  with the size of the transformer docs, index is not really needed
63
- but we'll do it for demonstrational purposes
64
  '''
65
- tbl.create_index(num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME)
 
 
1
+ import shutil
2
+ import traceback
3
+
4
  import lancedb
5
  import torch
6
  import pyarrow as pa
 
11
 
12
  from sentence_transformers import SentenceTransformer
13
 
14
+ from settings import *
15
+
16
+
17
+ emb_sizes = {
18
+ "sentence-transformers/all-MiniLM-L6-v2": 384,
19
+ "thenlper/gte-large": 0
20
+ }
21
 
22
+ shutil.rmtree(LANCEDB_DIRECTORY, ignore_errors=True)
23
+ db = lancedb.connect(LANCEDB_DIRECTORY)
 
 
 
 
24
  batch_size = 32
25
 
26
  model = SentenceTransformer(EMB_MODEL_NAME)
 
35
 
36
  schema = pa.schema(
37
  [
38
+ pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(), emb_sizes[EMB_MODEL_NAME])),
39
  pa.field(TEXT_COLUMN_NAME, pa.string())
40
  ])
41
+ tbl = db.create_table(LANCEDB_TABLE_NAME, schema=schema, mode="overwrite")
42
 
43
+ input_dir = Path(TEXT_CHUNKS_DIR)
44
  files = list(input_dir.rglob("*"))
45
 
46
  sentences = []
47
  for file in files:
48
+ with open(file, encoding='utf-8') as f:
49
  sentences.append(f.read())
50
 
51
  for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))):
 
60
  })
61
 
62
  tbl.add(df)
63
+
64
  except:
65
+ print(f"batch {i} was skipped: {traceback.format_exc()}")
66
+
67
 
68
  '''
69
  create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/
70
  with the size of the transformer docs, index is not really needed
71
+ but we'll do it for demonstration purposes
72
  '''
73
+ # tbl.create_index(num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME)
74
+
prep_scripts/markdown_to_text.py CHANGED
@@ -1,12 +1,12 @@
 
 
1
  from bs4 import BeautifulSoup
2
  from markdown import markdown
3
  import os
4
  import re
5
  from pathlib import Path
6
 
7
-
8
- DIR_TO_SCRAPE = "data/transformers/docs/source/en/"
9
- OUTPUT_DIR = str(Path().resolve() / "docs_dump")
10
 
11
 
12
  def markdown_to_text(markdown_string):
@@ -20,7 +20,7 @@ def markdown_to_text(markdown_string):
20
 
21
  # extract text
22
  soup = BeautifulSoup(html, "html.parser")
23
- text = ''.join(soup.findAll(text=True))
24
 
25
  text = re.sub('```(py|diff|python)', '', text)
26
  text = re.sub('```\n', '\n', text)
@@ -31,19 +31,20 @@ def markdown_to_text(markdown_string):
31
  return text
32
 
33
 
34
- dir_to_scrape = Path(DIR_TO_SCRAPE)
35
  files = list(dir_to_scrape.rglob("*"))
36
 
37
- os.makedirs(OUTPUT_DIR, exist_ok=True)
 
38
 
39
  for file in files:
40
  parent = file.parent.stem if file.parent.stem != dir_to_scrape.stem else ""
41
  if file.is_file():
42
- with open(file) as f:
43
  md = f.read()
44
 
45
  text = markdown_to_text(md)
46
 
47
- with open(os.path.join(OUTPUT_DIR, f"{parent}_{file.stem}.txt"), "w") as f:
48
  f.write(text)
49
 
 
1
+ import shutil
2
+
3
  from bs4 import BeautifulSoup
4
  from markdown import markdown
5
  import os
6
  import re
7
  from pathlib import Path
8
 
9
+ from settings import *
 
 
10
 
11
 
12
  def markdown_to_text(markdown_string):
 
20
 
21
  # extract text
22
  soup = BeautifulSoup(html, "html.parser")
23
+ text = ''.join(soup.findAll(string=True))
24
 
25
  text = re.sub('```(py|diff|python)', '', text)
26
  text = re.sub('```\n', '\n', text)
 
31
  return text
32
 
33
 
34
+ dir_to_scrape = Path(MARKDOWN_DIR_TO_SCRAPE)
35
  files = list(dir_to_scrape.rglob("*"))
36
 
37
+ shutil.rmtree(TEXT_CHUNKS_DIR, ignore_errors=True)
38
+ os.makedirs(TEXT_CHUNKS_DIR)
39
 
40
  for file in files:
41
  parent = file.parent.stem if file.parent.stem != dir_to_scrape.stem else ""
42
  if file.is_file():
43
+ with open(file, encoding='utf-8') as f:
44
  md = f.read()
45
 
46
  text = markdown_to_text(md)
47
 
48
+ with open(os.path.join(TEXT_CHUNKS_DIR, f"{parent}_{file.stem}.txt"), "w", encoding='utf-8') as f:
49
  f.write(text)
50
 
settings.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ MARKDOWN_DIR_TO_SCRAPE = "data/transformers/docs/source/en/"
2
+ TEXT_CHUNKS_DIR = "data/docs_dump"
3
+ EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
4
+ LANCEDB_DIRECTORY = "data/lancedb"
5
+ LANCEDB_TABLE_NAME = "table"
6
+ VECTOR_COLUMN_NAME = "embedding"
7
+ TEXT_COLUMN_NAME = "text"
8
+ LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"