Aragoner commited on
Commit
039472c
1 Parent(s): 8a41d49

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +108 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Credit to Derek Thomas, derek@huggingface.co
3
+ """
4
+ import os
5
+ import logging
6
+ from pathlib import Path
7
+ from time import perf_counter
8
+
9
+ import gradio as gr
10
+ from jinja2 import Environment, FileSystemLoader
11
+ from dotenv import load_dotenv
12
+ load_dotenv()
13
+ from backend.query_llm import generate_hf, generate_openai
14
+ from backend.semantic_search import retrieve
15
+
16
+
17
+ TOP_K = int(os.getenv("TOP_K", 4))
18
+
19
+ proj_dir = Path(__file__).parent
20
+ # Setting up the logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Set up the template environment with the templates directory
25
+ env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
26
+
27
+ # Load the templates directly from the environment
28
+ template = env.get_template('template.j2')
29
+ template_html = env.get_template('template_html.j2')
30
+
31
+
32
+ def add_text(history, text):
33
+ history = [] if history is None else history
34
+ history = history + [(text, None)]
35
+ return history, gr.Textbox(value="", interactive=False)
36
+
37
+
38
+ def bot(history, api_kind):
39
+ query = history[-1][0]
40
+
41
+ if not query:
42
+ raise gr.Warning("Please submit a non-empty string as a prompt")
43
+
44
+ logger.info('Retrieving documents...')
45
+ # Retrieve documents relevant to query
46
+ document_start = perf_counter()
47
+
48
+ documents = retrieve(query, TOP_K)
49
+
50
+ document_time = perf_counter() - document_start
51
+ logger.info(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
52
+
53
+ # Create Prompt
54
+ prompt = template.render(documents=documents, query=query)
55
+ prompt_html = template_html.render(documents=documents, query=query)
56
+
57
+ if api_kind == "HuggingFace":
58
+ generate_fn = generate_hf
59
+ elif api_kind == "OpenAI":
60
+ generate_fn = generate_openai
61
+ else:
62
+ raise gr.Error(f"API {api_kind} is not supported")
63
+
64
+ history[-1][1] = ""
65
+ for character in generate_fn(prompt, history[:-1]):
66
+ history[-1][1] = character
67
+ yield history, prompt_html
68
+
69
+
70
+ with gr.Blocks() as demo:
71
+ chatbot = gr.Chatbot(
72
+ [],
73
+ elem_id="chatbot",
74
+ avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
75
+ 'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
76
+ bubble_full_width=False,
77
+ show_copy_button=True,
78
+ show_share_button=True,
79
+ )
80
+
81
+ with gr.Row():
82
+ txt = gr.Textbox(
83
+ scale=3,
84
+ show_label=False,
85
+ placeholder="Enter text and press enter",
86
+ container=False,
87
+ )
88
+ txt_btn = gr.Button(value="Submit text", scale=1)
89
+
90
+ api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
91
+
92
+ prompt_html = gr.HTML()
93
+ # Turn off interactivity while generating if you click
94
+ txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
95
+ bot, [chatbot, api_kind], [chatbot, prompt_html])
96
+
97
+ # Turn it back on
98
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
99
+
100
+ # Turn off interactivity while generating if you hit enter
101
+ txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
102
+ bot, [chatbot, api_kind], [chatbot, prompt_html])
103
+
104
+ # Turn it back on
105
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
106
+
107
+ demo.queue()
108
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ lancedb==0.5.3
2
+ openai==1.11.1
3
+ sentence-transformers==2.3.1
4
+ tqdm==4.66.1
5
+ torch==2.1.1
6
+ transformers==4.37.2
7
+ python-dotenv==1.0.1
8
+ jinja2==3.0.1