NCTCMumbai commited on
Commit
e34a93e
1 Parent(s): 8fe98c8

Upload 7 files

Browse files
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ Credit to Derek Thomas, derek@huggingface.co
4
+ """
5
+
6
+ import subprocess
7
+
8
+ # subprocess.run(["pip", "install", "--upgrade", "transformers[torch,sentencepiece]==4.34.1"])
9
+
10
+ import logging
11
+ from pathlib import Path
12
+ from time import perf_counter
13
+
14
+ import gradio as gr
15
+ from jinja2 import Environment, FileSystemLoader
16
+ import numpy as np
17
+ from sentence_transformers import CrossEncoder
18
+
19
+ from backend.query_llm import generate_hf, generate_openai
20
+ from backend.semantic_search import table, retriever
21
+
22
+ VECTOR_COLUMN_NAME = "embeddings"
23
+ TEXT_COLUMN_NAME = "text"
24
+
25
+ proj_dir = Path(__file__).parent
26
+ # Setting up the logging
27
+ logging.basicConfig(level=logging.INFO)
28
+ logger = logging.getLogger(__name__)
29
+
30
+ # Set up the template environment with the templates directory
31
+ env = Environment(loader=FileSystemLoader(proj_dir / 'templates'))
32
+
33
+ # Load the templates directly from the environment
34
+ template = env.get_template('template.j2')
35
+ template_html = env.get_template('template_html.j2')
36
+
37
+ # crossEncoder
38
+ cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')
39
+
40
+ # Examples
41
+ examples = ['What is Let Export Order?',
42
+ 'What are benefits of the AEO Scheme ?',
43
+ 'Which circular talks about EOU?', ]
44
+
45
+
46
+ def add_text(history, text):
47
+ history = [] if history is None else history
48
+ history = history + [(text, None)]
49
+ return history, gr.Textbox(value="", interactive=False)
50
+
51
+
52
+ def bot(history, api_kind):
53
+ top_rerank = 15
54
+ top_k_rank = 8
55
+ query = history[-1][0]
56
+
57
+ if not query:
58
+ gr.Warning("Please submit a non-empty string as a prompt")
59
+ raise ValueError("Empty string was submitted")
60
+
61
+ logger.warning('Retrieving documents...')
62
+ # Retrieve documents relevant to query
63
+ document_start = perf_counter()
64
+
65
+ query_vec = retriever.encode(query)
66
+ logger.warning(f'Finished query vec')
67
+ doc1 = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_k_rank)
68
+
69
+
70
+
71
+ logger.warning(f'Finished search')
72
+ documents = table.search(query_vec, vector_column_name=VECTOR_COLUMN_NAME).limit(top_rerank).to_list()
73
+ documents = [doc[TEXT_COLUMN_NAME] for doc in documents]
74
+ logger.warning(f'start cross encoder {len(documents)}')
75
+ # Retrieve documents relevant to query
76
+ query_doc_pair = [[query, doc] for doc in documents]
77
+ cross_scores = cross_encoder.predict(query_doc_pair)
78
+ sim_scores_argsort = list(reversed(np.argsort(cross_scores)))
79
+ logger.warning(f'Finished cross encoder {len(documents)}')
80
+
81
+ documents = [documents[idx] for idx in sim_scores_argsort[:top_k_rank]]
82
+ logger.warning(f'num documents {len(documents)}')
83
+
84
+ document_time = perf_counter() - document_start
85
+ logger.warning(f'Finished Retrieving documents in {round(document_time, 2)} seconds...')
86
+
87
+ # Create Prompt
88
+ prompt = template.render(documents=documents, query=query)
89
+ prompt_html = template_html.render(documents=documents, query=query)
90
+
91
+ if api_kind == "HuggingFace":
92
+ generate_fn = generate_hf
93
+ elif api_kind == "OpenAI":
94
+ generate_fn = generate_openai
95
+ elif api_kind is None:
96
+ gr.Warning("API name was not provided")
97
+ raise ValueError("API name was not provided")
98
+ else:
99
+ gr.Warning(f"API {api_kind} is not supported")
100
+ raise ValueError(f"API {api_kind} is not supported")
101
+
102
+ history[-1][1] = ""
103
+ for character in generate_fn(prompt, history[:-1]):
104
+ history[-1][1] = character
105
+ yield history, prompt_html
106
+
107
+
108
+ with gr.Blocks() as demo:
109
+ # Beautiful heading with logo
110
+ gr.HTML(value="""
111
+ <div style="display: flex; align-items: center; justify-content: space-between;">
112
+ <h1 style="color: #2ECC71">Customs Manual Chatbot</h1>
113
+ <img src='logo.png' alt="Chatbot" width="50" height="50" />
114
+ </div>
115
+ """, elem_id="heading")
116
+
117
+ # Formatted description
118
+ gr.HTML(value="""<p style="font-family: sans-serif; font-size: 16px;">A free chat bot developed by National Customs Targeting Center using Open source LLMs.</p>""", elem_id="description")
119
+
120
+ chatbot = gr.Chatbot(
121
+ [],
122
+ elem_id="chatbot",
123
+ avatar_images=('https://aui.atlassian.com/aui/8.8/docs/images/avatar-person.svg',
124
+ 'https://huggingface.co/datasets/huggingface/brand-assets/resolve/main/hf-logo.svg'),
125
+ bubble_full_width=False,
126
+ show_copy_button=True,
127
+ show_share_button=True,
128
+ )
129
+
130
+ with gr.Row():
131
+ txt = gr.Textbox(
132
+ scale=3,
133
+ show_label=False,
134
+ placeholder="Enter text and press enter",
135
+ container=False,
136
+ )
137
+ txt_btn = gr.Button(value="Submit text", scale=1)
138
+
139
+ api_kind = gr.Radio(choices=["HuggingFace", "OpenAI"], value="HuggingFace")
140
+
141
+ prompt_html = gr.HTML()
142
+ # Turn off interactivity while generating if you click
143
+ txt_msg = txt_btn.click(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
144
+ bot, [chatbot, api_kind], [chatbot, prompt_html])
145
+
146
+ # Turn it back on
147
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
148
+
149
+ # Turn off interactivity while generating if you hit enter
150
+ txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(
151
+ bot, [chatbot, api_kind], [chatbot, prompt_html])
152
+
153
+ # Turn it back on
154
+ txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)
155
+
156
+ # Examples
157
+ gr.Examples(examples, txt)
158
+
159
+ demo.queue()
160
+ demo.launch(debug=True)
backend/__pycache__/query_llm.cpython-310.pyc ADDED
Binary file (4.36 kB). View file
 
backend/__pycache__/semantic_search.cpython-310.pyc ADDED
Binary file (700 Bytes). View file
 
backend/query_llm.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import openai
4
+ import gradio as gr
5
+
6
+ from os import getenv
7
+ from typing import Any, Dict, Generator, List
8
+
9
+ from huggingface_hub import InferenceClient
10
+ from transformers import AutoTokenizer
11
+
12
+ tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
13
+
14
+ temperature = 0.9
15
+ top_p = 0.6
16
+ repetition_penalty = 1.2
17
+
18
+ OPENAI_KEY = getenv("OPENAI_API_KEY")
19
+ HF_TOKEN = getenv("HUGGING_FACE_HUB_TOKEN")
20
+
21
+ hf_client = InferenceClient(
22
+ "mistralai/Mistral-7B-Instruct-v0.1",
23
+ token=HF_TOKEN
24
+ )
25
+
26
+
27
+ def format_prompt(message: str, api_kind: str):
28
+ """
29
+ Formats the given message using a chat template.
30
+
31
+ Args:
32
+ message (str): The user message to be formatted.
33
+
34
+ Returns:
35
+ str: Formatted message after applying the chat template.
36
+ """
37
+
38
+ # Create a list of message dictionaries with role and content
39
+ messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
40
+
41
+ if api_kind == "openai":
42
+ return messages
43
+ elif api_kind == "hf":
44
+ return tokenizer.apply_chat_template(messages, tokenize=False)
45
+ elif api_kind:
46
+ raise ValueError("API is not supported")
47
+
48
+
49
+ def generate_hf(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
50
+ top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
51
+ """
52
+ Generate a sequence of tokens based on a given prompt and history using Mistral client.
53
+
54
+ Args:
55
+ prompt (str): The initial prompt for the text generation.
56
+ history (str): Context or history for the text generation.
57
+ temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
58
+ max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
59
+ top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
60
+ repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
61
+
62
+ Returns:
63
+ Generator[str, None, str]: A generator yielding chunks of generated text.
64
+ Returns a final string if an error occurs.
65
+ """
66
+
67
+ temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
68
+ top_p = float(top_p)
69
+
70
+ generate_kwargs = {
71
+ 'temperature': temperature,
72
+ 'max_new_tokens': max_new_tokens,
73
+ 'top_p': top_p,
74
+ 'repetition_penalty': repetition_penalty,
75
+ 'do_sample': True,
76
+ 'seed': 42,
77
+ }
78
+
79
+ formatted_prompt = format_prompt(prompt, "hf")
80
+
81
+ try:
82
+ stream = hf_client.text_generation(formatted_prompt, **generate_kwargs,
83
+ stream=True, details=True, return_full_text=False)
84
+ output = ""
85
+ for response in stream:
86
+ output += response.token.text
87
+ yield output
88
+
89
+ except Exception as e:
90
+ if "Too Many Requests" in str(e):
91
+ print("ERROR: Too many requests on Mistral client")
92
+ gr.Warning("Unfortunately Mistral is unable to process")
93
+ return "Unfortunately, I am not able to process your request now."
94
+ elif "Authorization header is invalid" in str(e):
95
+ print("Authetification error:", str(e))
96
+ gr.Warning("Authentication error: HF token was either not provided or incorrect")
97
+ return "Authentication error"
98
+ else:
99
+ print("Unhandled Exception:", str(e))
100
+ gr.Warning("Unfortunately Mistral is unable to process")
101
+ return "I do not know what happened, but I couldn't understand you."
102
+
103
+
104
+ def generate_openai(prompt: str, history: str, temperature: float = 0.9, max_new_tokens: int = 256,
105
+ top_p: float = 0.95, repetition_penalty: float = 1.0) -> Generator[str, None, str]:
106
+ """
107
+ Generate a sequence of tokens based on a given prompt and history using Mistral client.
108
+
109
+ Args:
110
+ prompt (str): The initial prompt for the text generation.
111
+ history (str): Context or history for the text generation.
112
+ temperature (float, optional): The softmax temperature for sampling. Defaults to 0.9.
113
+ max_new_tokens (int, optional): Maximum number of tokens to be generated. Defaults to 256.
114
+ top_p (float, optional): Nucleus sampling probability. Defaults to 0.95.
115
+ repetition_penalty (float, optional): Penalty for repeated tokens. Defaults to 1.0.
116
+
117
+ Returns:
118
+ Generator[str, None, str]: A generator yielding chunks of generated text.
119
+ Returns a final string if an error occurs.
120
+ """
121
+
122
+ temperature = max(float(temperature), 1e-2) # Ensure temperature isn't too low
123
+ top_p = float(top_p)
124
+
125
+ generate_kwargs = {
126
+ 'temperature': temperature,
127
+ 'max_tokens': max_new_tokens,
128
+ 'top_p': top_p,
129
+ 'frequency_penalty': max(-2., min(repetition_penalty, 2.)),
130
+ }
131
+
132
+ formatted_prompt = format_prompt(prompt, "openai")
133
+
134
+ try:
135
+ stream = openai.ChatCompletion.create(model="gpt-3.5-turbo-0301",
136
+ messages=formatted_prompt,
137
+ **generate_kwargs,
138
+ stream=True)
139
+ output = ""
140
+ for chunk in stream:
141
+ output += chunk.choices[0].delta.get("content", "")
142
+ yield output
143
+
144
+ except Exception as e:
145
+ if "Too Many Requests" in str(e):
146
+ print("ERROR: Too many requests on OpenAI client")
147
+ gr.Warning("Unfortunately OpenAI is unable to process")
148
+ return "Unfortunately, I am not able to process your request now."
149
+ elif "You didn't provide an API key" in str(e):
150
+ print("Authetification error:", str(e))
151
+ gr.Warning("Authentication error: OpenAI key was either not provided or incorrect")
152
+ return "Authentication error"
153
+ else:
154
+ print("Unhandled Exception:", str(e))
155
+ gr.Warning("Unfortunately OpenAI is unable to process")
156
+ return "I do not know what happened, but I couldn't understand you."
backend/semantic_search.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import logging
3
+ import lancedb
4
+ import os
5
+ from pathlib import Path
6
+ from sentence_transformers import SentenceTransformer
7
+
8
+ EMB_MODEL_NAME = "thenlper/gte-base"
9
+ DB_TABLE_NAME = "Huggingface_docs"
10
+
11
+ # Setting up the logging
12
+ logging.basicConfig(level=logging.INFO)
13
+ logger = logging.getLogger(__name__)
14
+ retriever = SentenceTransformer(EMB_MODEL_NAME)
15
+
16
+ # db
17
+ db_uri = os.path.join(Path(__file__).parents[1], ".lancedb")
18
+ db = lancedb.connect(db_uri)
19
+ table = db.open_table(DB_TABLE_NAME)
logo.png ADDED
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ transformers[torch,sentencepiece]
2
+ wikiextractor==3.0.6
3
+ sentence-transformers>2.2.0
4
+ ipywidgets==8.1.1
5
+ tqdm==4.66.1
6
+ aiohttp==3.8.6
7
+ huggingface-hub==0.17.3
8
+ lancedb
9
+ openai==0.28