AlexanderKazakov
commited on
Commit
•
360f505
1
Parent(s):
eba1a12
make it work in zero draft
Browse files- .idea/rag-gradio-sample-project.iml +1 -0
- gradio_app/app.py +31 -34
- gradio_app/backend/query_llm.py +10 -13
- gradio_app/backend/semantic_search.py +5 -9
- gradio_app/templates/{template.j2 → prompt_template.j2} +0 -0
- prep_scripts/lancedb_setup.py +22 -13
- prep_scripts/markdown_to_text.py +9 -8
- settings.py +8 -0
.idea/rag-gradio-sample-project.iml
CHANGED
@@ -2,6 +2,7 @@
|
|
2 |
<module type="PYTHON_MODULE" version="4">
|
3 |
<component name="NewModuleRootManager">
|
4 |
<content url="file://$MODULE_DIR$">
|
|
|
5 |
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
6 |
</content>
|
7 |
<orderEntry type="jdk" jdkName="Python 3.11 (rag-gradio-sample-project) (2)" jdkType="Python SDK" />
|
|
|
2 |
<module type="PYTHON_MODULE" version="4">
|
3 |
<component name="NewModuleRootManager">
|
4 |
<content url="file://$MODULE_DIR$">
|
5 |
+
<excludeFolder url="file://$MODULE_DIR$/data" />
|
6 |
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
7 |
</content>
|
8 |
<orderEntry type="jdk" jdkName="Python 3.11 (rag-gradio-sample-project) (2)" jdkType="Python SDK" />
|
gradio_app/app.py
CHANGED
@@ -6,28 +6,25 @@ Credit to Derek Thomas, derek@huggingface.co
|
|
6 |
# subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])
|
7 |
|
8 |
import logging
|
9 |
-
from pathlib import Path
|
10 |
from time import perf_counter
|
11 |
|
12 |
import gradio as gr
|
13 |
from jinja2 import Environment, FileSystemLoader
|
14 |
|
15 |
from backend.query_llm import generate_hf, generate_openai
|
16 |
-
|
17 |
|
18 |
-
|
19 |
-
TEXT_COLUMN_NAME = ""
|
20 |
|
21 |
-
proj_dir = Path(__file__).parent
|
22 |
# Setting up the logging
|
23 |
logging.basicConfig(level=logging.INFO)
|
24 |
logger = logging.getLogger(__name__)
|
25 |
|
26 |
# Set up the template environment with the templates directory
|
27 |
-
env = Environment(loader=FileSystemLoader(
|
28 |
|
29 |
# Load the templates directly from the environment
|
30 |
-
|
31 |
template_html = env.get_template('template_html.j2')
|
32 |
|
33 |
# Examples
|
@@ -47,34 +44,34 @@ def bot(history, api_kind):
|
|
47 |
query = history[-1][0]
|
48 |
|
49 |
if not query:
|
50 |
-
|
51 |
-
|
52 |
|
53 |
-
logger.
|
54 |
# Retrieve documents relevant to query
|
55 |
document_start = perf_counter()
|
56 |
|
57 |
-
query_vec =
|
58 |
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank).to_list()
|
59 |
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
|
60 |
|
61 |
document_time = perf_counter() - document_start
|
62 |
-
logger.
|
63 |
|
64 |
# Create Prompt
|
65 |
-
prompt =
|
66 |
prompt_html = template_html.render(documents=documents, query=query)
|
67 |
|
68 |
if api_kind == "HuggingFace":
|
69 |
-
|
70 |
elif api_kind == "OpenAI":
|
71 |
-
|
72 |
elif api_kind is None:
|
73 |
-
|
74 |
-
|
75 |
else:
|
76 |
-
|
77 |
-
|
78 |
|
79 |
history[-1][1] = ""
|
80 |
for character in generate_fn(prompt, history[:-1]):
|
@@ -84,22 +81,22 @@ def bot(history, api_kind):
|
|
84 |
|
85 |
with gr.Blocks() as demo:
|
86 |
chatbot = gr.Chatbot(
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
|
96 |
with gr.Row():
|
97 |
txt = gr.Textbox(
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
104 |
|
105 |
api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
|
@@ -107,14 +104,14 @@ with gr.Blocks() as demo:
|
|
107 |
prompt_html = gr.HTML()
|
108 |
# Turn off interactivity while generating if you click
|
109 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
110 |
-
|
111 |
|
112 |
# Turn it back on
|
113 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
114 |
|
115 |
# Turn off interactivity while generating if you hit enter
|
116 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
117 |
-
|
118 |
|
119 |
# Turn it back on
|
120 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
|
|
6 |
# subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])
|
7 |
|
8 |
import logging
|
|
|
9 |
from time import perf_counter
|
10 |
|
11 |
import gradio as gr
|
12 |
from jinja2 import Environment, FileSystemLoader
|
13 |
|
14 |
from backend.query_llm import generate_hf, generate_openai
|
15 |
+
from backend.semantic_search import table, embedder
|
16 |
|
17 |
+
from settings import *
|
|
|
18 |
|
|
|
19 |
# Setting up the logging
|
20 |
logging.basicConfig(level=logging.INFO)
|
21 |
logger = logging.getLogger(__name__)
|
22 |
|
23 |
# Set up the template environment with the templates directory
|
24 |
+
env = Environment(loader=FileSystemLoader('gradio_app/templates'))
|
25 |
|
26 |
# Load the templates directly from the environment
|
27 |
+
prompt_template = env.get_template('prompt_template.j2')
|
28 |
template_html = env.get_template('template_html.j2')
|
29 |
|
30 |
# Examples
|
|
|
44 |
query = history[-1][0]
|
45 |
|
46 |
if not query:
|
47 |
+
gr.Warning("Please submit a non-empty string as a prompt")
|
48 |
+
raise ValueError("Empty string was submitted")
|
49 |
|
50 |
+
logger.info('Retrieving documents...')
|
51 |
# Retrieve documents relevant to query
|
52 |
document_start = perf_counter()
|
53 |
|
54 |
+
query_vec = embedder.encode(query)
|
55 |
documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank).to_list()
|
56 |
documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
|
57 |
|
58 |
document_time = perf_counter() - document_start
|
59 |
+
logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
|
60 |
|
61 |
# Create Prompt
|
62 |
+
prompt = prompt_template.render(documents=documents, query=query)
|
63 |
prompt_html = template_html.render(documents=documents, query=query)
|
64 |
|
65 |
if api_kind == "HuggingFace":
|
66 |
+
generate_fn = generate_hf
|
67 |
elif api_kind == "OpenAI":
|
68 |
+
generate_fn = generate_openai
|
69 |
elif api_kind is None:
|
70 |
+
gr.Warning("API name was not provided")
|
71 |
+
raise ValueError("API name was not provided")
|
72 |
else:
|
73 |
+
gr.Warning(f"API {api_kind} is not supported")
|
74 |
+
raise ValueError(f"API {api_kind} is not supported")
|
75 |
|
76 |
history[-1][1] = ""
|
77 |
for character in generate_fn(prompt, history[:-1]):
|
|
|
81 |
|
82 |
with gr.Blocks() as demo:
|
83 |
chatbot = gr.Chatbot(
|
84 |
+
[],
|
85 |
+
elem_id="chatbot",
|
86 |
+
avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
|
87 |
+
'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
|
88 |
+
bubble_full_width=False,
|
89 |
+
show_copy_button=True,
|
90 |
+
show_share_button=True,
|
91 |
+
)
|
92 |
|
93 |
with gr.Row():
|
94 |
txt = gr.Textbox(
|
95 |
+
scale=3,
|
96 |
+
show_label=False,
|
97 |
+
placeholder="Enter text and press enter",
|
98 |
+
container=False,
|
99 |
+
)
|
100 |
txt_btn = gr.Button(value="Submit text", scale=1)
|
101 |
|
102 |
api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
|
|
|
104 |
prompt_html = gr.HTML()
|
105 |
# Turn off interactivity while generating if you click
|
106 |
txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
107 |
+
bot, [chatbot, api_kind], [chatbot, prompt_html])
|
108 |
|
109 |
# Turn it back on
|
110 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
111 |
|
112 |
# Turn off interactivity while generating if you hit enter
|
113 |
txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
|
114 |
+
bot, [chatbot, api_kind], [chatbot, prompt_html])
|
115 |
|
116 |
# Turn it back on
|
117 |
txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
|
gradio_app/backend/query_llm.py
CHANGED
@@ -7,19 +7,16 @@ from typing import Any, Dict, Generator, List
|
|
7 |
from huggingface_hub import InferenceClient
|
8 |
from transformers import AutoTokenizer
|
9 |
|
10 |
-
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
repetition_penalty = 1.2
|
15 |
|
16 |
OPENAI_KEY = getenv("OPENAI_API_KEY")
|
17 |
HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
|
18 |
|
19 |
-
|
20 |
-
|
21 |
-
token=HF_TOKEN
|
22 |
-
)
|
23 |
|
24 |
|
25 |
def format_prompt(message: str, api_kind: str):
|
@@ -40,12 +37,12 @@ def format_prompt(message: str, api_kind: str):
|
|
40 |
return messages
|
41 |
elif api_kind == "hf":
|
42 |
return tokenizer.apply_chat_template(messages, tokenize=False)
|
43 |
-
|
44 |
raise ValueError("API is not supported")
|
45 |
|
46 |
|
47 |
-
def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int =
|
48 |
-
top_p: float = 0.
|
49 |
"""
|
50 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
51 |
|
@@ -99,8 +96,8 @@ def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tok
|
|
99 |
return "I do not know what happened, but I couldn't understand you."
|
100 |
|
101 |
|
102 |
-
def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int =
|
103 |
-
top_p: float = 0.
|
104 |
"""
|
105 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
106 |
|
|
|
7 |
from huggingface_hub import InferenceClient
|
8 |
from transformers import AutoTokenizer
|
9 |
|
10 |
+
from settings import *
|
11 |
|
12 |
+
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
|
|
|
14 |
|
15 |
OPENAI_KEY = getenv("OPENAI_API_KEY")
|
16 |
HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
|
17 |
|
18 |
+
|
19 |
+
hf_client = InferenceClient(LLM_NAME, token=HF_TOKEN)
|
|
|
|
|
20 |
|
21 |
|
22 |
def format_prompt(message: str, api_kind: str):
|
|
|
37 |
return messages
|
38 |
elif api_kind == "hf":
|
39 |
return tokenizer.apply_chat_template(messages, tokenize=False)
|
40 |
+
else:
|
41 |
raise ValueError("API is not supported")
|
42 |
|
43 |
|
44 |
+
def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
|
45 |
+
top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
|
46 |
"""
|
47 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
48 |
|
|
|
96 |
return "I do not know what happened, but I couldn't understand you."
|
97 |
|
98 |
|
99 |
+
def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 512,
|
100 |
+
top_p: float = 0.6, repetition_penalty: float = 1.2) -> Generator[str, None, str]:
|
101 |
"""
|
102 |
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
103 |
|
gradio_app/backend/semantic_search.py
CHANGED
@@ -1,18 +1,14 @@
|
|
1 |
import logging
|
2 |
import lancedb
|
3 |
-
import os
|
4 |
-
from pathlib import Path
|
5 |
from sentence_transformers import SentenceTransformer
|
6 |
|
7 |
-
|
8 |
-
|
9 |
|
10 |
# Setting up the logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
logger = logging.getLogger(__name__)
|
13 |
-
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
db = lancedb.connect(db_uri)
|
18 |
-
table = db.open_table(DB_TABLE_NAME)
|
|
|
1 |
import logging
|
2 |
import lancedb
|
|
|
|
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
|
5 |
+
from settings import *
|
6 |
+
|
7 |
|
8 |
# Setting up the logging
|
9 |
logging.basicConfig(level=logging.INFO)
|
10 |
logger = logging.getLogger(__name__)
|
11 |
+
embedder = SentenceTransformer(EMB_MODEL_NAME)
|
12 |
|
13 |
+
db = lancedb.connect(LANCEDB_DIRECTORY)
|
14 |
+
table = db.open_table(LANCEDB_TABLE_NAME)
|
|
|
|
gradio_app/templates/{template.j2 → prompt_template.j2}
RENAMED
File without changes
|
prep_scripts/lancedb_setup.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
import lancedb
|
2 |
import torch
|
3 |
import pyarrow as pa
|
@@ -8,13 +11,16 @@ import numpy as np
|
|
8 |
|
9 |
from sentence_transformers import SentenceTransformer
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
VECTOR_COLUMN_NAME = ""
|
15 |
-
TEXT_COLUMN_NAME = ""
|
16 |
-
INPUT_DIR = "<chunked docs directory>"
|
17 |
-
db = lancedb.connect(".lancedb") # db location
|
18 |
batch_size = 32
|
19 |
|
20 |
model = SentenceTransformer(EMB_MODEL_NAME)
|
@@ -29,17 +35,17 @@ else:
|
|
29 |
|
30 |
schema = pa.schema(
|
31 |
[
|
32 |
-
pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(),
|
33 |
pa.field(TEXT_COLUMN_NAME, pa.string())
|
34 |
])
|
35 |
-
tbl = db.create_table(
|
36 |
|
37 |
-
input_dir = Path(
|
38 |
files = list(input_dir.rglob("*"))
|
39 |
|
40 |
sentences = []
|
41 |
for file in files:
|
42 |
-
with open(file) as f:
|
43 |
sentences.append(f.read())
|
44 |
|
45 |
for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))):
|
@@ -54,12 +60,15 @@ for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))):
|
|
54 |
})
|
55 |
|
56 |
tbl.add(df)
|
|
|
57 |
except:
|
58 |
-
print(f"batch {i} was skipped")
|
|
|
59 |
|
60 |
'''
|
61 |
create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/
|
62 |
with the size of the transformer docs, index is not really needed
|
63 |
-
but we'll do it for
|
64 |
'''
|
65 |
-
tbl.create_index(num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME)
|
|
|
|
1 |
+
import shutil
|
2 |
+
import traceback
|
3 |
+
|
4 |
import lancedb
|
5 |
import torch
|
6 |
import pyarrow as pa
|
|
|
11 |
|
12 |
from sentence_transformers import SentenceTransformer
|
13 |
|
14 |
+
from settings import *
|
15 |
+
|
16 |
+
|
17 |
+
emb_sizes = {
|
18 |
+
"sentence-transformers/all-MiniLM-L6-v2": 384,
|
19 |
+
"thenlper/gte-large": 0
|
20 |
+
}
|
21 |
|
22 |
+
shutil.rmtree(LANCEDB_DIRECTORY, ignore_errors=True)
|
23 |
+
db = lancedb.connect(LANCEDB_DIRECTORY)
|
|
|
|
|
|
|
|
|
24 |
batch_size = 32
|
25 |
|
26 |
model = SentenceTransformer(EMB_MODEL_NAME)
|
|
|
35 |
|
36 |
schema = pa.schema(
|
37 |
[
|
38 |
+
pa.field(VECTOR_COLUMN_NAME, pa.list_(pa.float32(), emb_sizes[EMB_MODEL_NAME])),
|
39 |
pa.field(TEXT_COLUMN_NAME, pa.string())
|
40 |
])
|
41 |
+
tbl = db.create_table(LANCEDB_TABLE_NAME, schema=schema, mode="overwrite")
|
42 |
|
43 |
+
input_dir = Path(TEXT_CHUNKS_DIR)
|
44 |
files = list(input_dir.rglob("*"))
|
45 |
|
46 |
sentences = []
|
47 |
for file in files:
|
48 |
+
with open(file, encoding='utf-8') as f:
|
49 |
sentences.append(f.read())
|
50 |
|
51 |
for i in tqdm.tqdm(range(0, int(np.ceil(len(sentences) / batch_size)))):
|
|
|
60 |
})
|
61 |
|
62 |
tbl.add(df)
|
63 |
+
|
64 |
except:
|
65 |
+
print(f"batch {i} was skipped: {traceback.format_exc()}")
|
66 |
+
|
67 |
|
68 |
'''
|
69 |
create ivf-pd index https://lancedb.github.io/lancedb/ann_indexes/
|
70 |
with the size of the transformer docs, index is not really needed
|
71 |
+
but we'll do it for demonstration purposes
|
72 |
'''
|
73 |
+
# tbl.create_index(num_partitions=256, num_sub_vectors=96, vector_column_name=VECTOR_COLUMN_NAME)
|
74 |
+
|
prep_scripts/markdown_to_text.py
CHANGED
@@ -1,12 +1,12 @@
|
|
|
|
|
|
1 |
from bs4 import BeautifulSoup
|
2 |
from markdown import markdown
|
3 |
import os
|
4 |
import re
|
5 |
from pathlib import Path
|
6 |
|
7 |
-
|
8 |
-
DIR_TO_SCRAPE = "data/transformers/docs/source/en/"
|
9 |
-
OUTPUT_DIR = str(Path().resolve() / "docs_dump")
|
10 |
|
11 |
|
12 |
def markdown_to_text(markdown_string):
|
@@ -20,7 +20,7 @@ def markdown_to_text(markdown_string):
|
|
20 |
|
21 |
# extract text
|
22 |
soup = BeautifulSoup(html, "html.parser")
|
23 |
-
text = ''.join(soup.findAll(
|
24 |
|
25 |
text = re.sub('```(py|diff|python)', '', text)
|
26 |
text = re.sub('```\n', '\n', text)
|
@@ -31,19 +31,20 @@ def markdown_to_text(markdown_string):
|
|
31 |
return text
|
32 |
|
33 |
|
34 |
-
dir_to_scrape = Path(
|
35 |
files = list(dir_to_scrape.rglob("*"))
|
36 |
|
37 |
-
|
|
|
38 |
|
39 |
for file in files:
|
40 |
parent = file.parent.stem if file.parent.stem != dir_to_scrape.stem else ""
|
41 |
if file.is_file():
|
42 |
-
with open(file) as f:
|
43 |
md = f.read()
|
44 |
|
45 |
text = markdown_to_text(md)
|
46 |
|
47 |
-
with open(os.path.join(
|
48 |
f.write(text)
|
49 |
|
|
|
1 |
+
import shutil
|
2 |
+
|
3 |
from bs4 import BeautifulSoup
|
4 |
from markdown import markdown
|
5 |
import os
|
6 |
import re
|
7 |
from pathlib import Path
|
8 |
|
9 |
+
from settings import *
|
|
|
|
|
10 |
|
11 |
|
12 |
def markdown_to_text(markdown_string):
|
|
|
20 |
|
21 |
# extract text
|
22 |
soup = BeautifulSoup(html, "html.parser")
|
23 |
+
text = ''.join(soup.findAll(string=True))
|
24 |
|
25 |
text = re.sub('```(py|diff|python)', '', text)
|
26 |
text = re.sub('```\n', '\n', text)
|
|
|
31 |
return text
|
32 |
|
33 |
|
34 |
+
dir_to_scrape = Path(MARKDOWN_DIR_TO_SCRAPE)
|
35 |
files = list(dir_to_scrape.rglob("*"))
|
36 |
|
37 |
+
shutil.rmtree(TEXT_CHUNKS_DIR, ignore_errors=True)
|
38 |
+
os.makedirs(TEXT_CHUNKS_DIR)
|
39 |
|
40 |
for file in files:
|
41 |
parent = file.parent.stem if file.parent.stem != dir_to_scrape.stem else ""
|
42 |
if file.is_file():
|
43 |
+
with open(file, encoding='utf-8') as f:
|
44 |
md = f.read()
|
45 |
|
46 |
text = markdown_to_text(md)
|
47 |
|
48 |
+
with open(os.path.join(TEXT_CHUNKS_DIR, f"{parent}_{file.stem}.txt"), "w", encoding='utf-8') as f:
|
49 |
f.write(text)
|
50 |
|
settings.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MARKDOWN_DIR_TO_SCRAPE = "data/transformers/docs/source/en/"
|
2 |
+
TEXT_CHUNKS_DIR = "data/docs_dump"
|
3 |
+
EMB_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
4 |
+
LANCEDB_DIRECTORY = "data/lancedb"
|
5 |
+
LANCEDB_TABLE_NAME = "table"
|
6 |
+
VECTOR_COLUMN_NAME = "embedding"
|
7 |
+
TEXT_COLUMN_NAME = "text"
|
8 |
+
LLM_NAME = "mistralai/Mistral-7B-Instruct-v0.1"
|