Aragoner commited on
Commit
7e92170
·
verified ·
1 Parent(s): 039472c

Upload 8 files

Browse files
backend/.DS_Store ADDED
Binary file (6.15 kB). View file
 
backend/__pycache__/query_llm.cpython-310.pyc ADDED
Binary file (4.03 kB). View file
 
backend/__pycache__/query_llm.cpython-37.pyc ADDED
Binary file (4.01 kB). View file
 
backend/__pycache__/query_llm.cpython-39.pyc ADDED
Binary file (4.04 kB). View file
 
backend/__pycache__/semantic_search.cpython-310.pyc ADDED
Binary file (1.18 kB). View file
 
backend/__pycache__/semantic_search.cpython-39.pyc ADDED
Binary file (2.14 kB). View file
 
backend/query_llm.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import gradio as gr
3
+ import os
4
+
5
+ from typing import Any, Dict, Generator, List
6
+
7
+ from huggingface_hub import InferenceClient
8
+ from transformers import AutoTokenizer
9
+
10
+ from dotenv import load_dotenv
11
+ load_dotenv()
12
+ OPENAI_KEY = os.getenv("OPENAI_API_KEY")
13
+ HF_TOKEN = os.getenv("HF_TOKEN")
14
+ TOKENIZER = AutoTokenizer.from_pretrained(os.getenv("HF_MODEL"))
15
+
16
+ HF_CLIENT = InferenceClient(
17
+ os.getenv("HF_MODEL"),
18
+ token=HF_TOKEN
19
+ )
20
+ OAI_CLIENT = openai.Client(api_key=OPENAI_KEY)
21
+
22
+ HF_GENERATE_KWARGS = {
23
+ 'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2),
24
+ 'max_new_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)),
25
+ 'top_p': float(os.getenv("TOP_P", 0.6)),
26
+ 'repetition_penalty': float(os.getenv("REP_PENALTY", 1.2)),
27
+ 'do_sample': bool(os.getenv("DO_SAMPLE", True))
28
+ }
29
+
30
+ OAI_GENERATE_KWARGS = {
31
+ 'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2),
32
+ 'max_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)),
33
+ 'top_p': float(os.getenv("TOP_P", 0.6)),
34
+ 'frequency_penalty': max(-2, min(float(os.getenv("FREQ_PENALTY", 0)), 2))
35
+ }
36
+
37
+
38
+ def format_prompt(message: str, api_kind: str):
39
+ """
40
+ Formats the given message using a chat template.
41
+
42
+ Args:
43
+ message (str): The user message to be formatted.
44
+ api_kind (str): LLM API provider.
45
+ Returns:
46
+ str: Formatted message after applying the chat template.
47
+ """
48
+
49
+ # Create a list of message dictionaries with role and content
50
+ messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
51
+
52
+ if api_kind == "openai":
53
+ return messages
54
+ elif api_kind == "hf":
55
+ return TOKENIZER.apply_chat_template(messages, tokenize=False)
56
+ elif api_kind:
57
+ raise ValueError("API is not supported")
58
+
59
+
60
+ def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
61
+ """
62
+ Generate a sequence of tokens based on a given prompt and history using Mistral client.
63
+
64
+ Args:
65
+ prompt (str): The prompt for the text generation.
66
+ history (str): Context or history for the text generation.
67
+ Returns:
68
+ Generator[str, None, str]: A generator yielding chunks of generated text.
69
+ Returns a final string if an error occurs.
70
+ """
71
+
72
+ formatted_prompt = format_prompt(prompt, "hf")
73
+ formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
74
+
75
+ try:
76
+ stream = HF_CLIENT.text_generation(
77
+ formatted_prompt,
78
+ **HF_GENERATE_KWARGS,
79
+ stream=True,
80
+ details=True,
81
+ return_full_text=False
82
+ )
83
+ output = ""
84
+ for response in stream:
85
+ output += response.token.text
86
+ yield output
87
+
88
+ except Exception as e:
89
+ if "Too Many Requests" in str(e):
90
+ raise gr.Error(f"Too many requests: {str(e)}")
91
+ elif "Authorization header is invalid" in str(e):
92
+ raise gr.Error("Authentication error: HF token was either not provided or incorrect")
93
+ else:
94
+ raise gr.Error(f"Unhandled Exception: {str(e)}")
95
+
96
+
97
+ def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
98
+ """
99
+ Generate a sequence of tokens based on a given prompt and history using Mistral client.
100
+
101
+ Args:
102
+ prompt (str): The initial prompt for the text generation.
103
+ history (str): Context or history for the text generation.
104
+ Returns:
105
+ Generator[str, None, str]: A generator yielding chunks of generated text.
106
+ Returns a final string if an error occurs.
107
+ """
108
+ formatted_prompt = format_prompt(prompt, "openai")
109
+
110
+ try:
111
+ stream = OAI_CLIENT.chat.completions.create(
112
+ model=os.getenv("OPENAI_MODEL"),
113
+ messages=formatted_prompt,
114
+ **OAI_GENERATE_KWARGS,
115
+ stream=True
116
+ )
117
+ output = ""
118
+ for chunk in stream:
119
+ if chunk.choices[0].delta.content:
120
+ output += chunk.choices[0].delta.content
121
+ yield output
122
+
123
+ except Exception as e:
124
+ if "Too Many Requests" in str(e):
125
+ raise gr.Error("ERROR: Too many requests on OpenAI client")
126
+ elif "You didn't provide an API key" in str(e):
127
+ raise gr.Error("Authentication error: OpenAI key was either not provided or incorrect")
128
+ else:
129
+ raise gr.Error(f"Unhandled Exception: {str(e)}")
backend/semantic_search.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lancedb
2
+ import os
3
+ import gradio as gr
4
+ from sentence_transformers import SentenceTransformer
5
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
6
+ import torch
7
+ import time
8
+ import os
9
+ from pathlib import Path
10
+
11
+ db = lancedb.connect(".lancedb")
12
+
13
+ TABLE = db.open_table(os.getenv("TABLE_NAME"))
14
+ VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
15
+ TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
16
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
17
+ CROSS_ENCODER = os.getenv("CROSS_ENCODER")
18
+
19
+ retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
20
+ cross_encoder = AutoModelForSequenceClassification.from_pretrained(CROSS_ENCODER)
21
+ cross_encoder.eval()
22
+ cross_encoder_tokenizer = AutoTokenizer.from_pretrained(CROSS_ENCODER)
23
+
24
+
25
+ def rerank(query, documents, k):
26
+ """Use cross-encoder to rerank documents retrieved from the retriever."""
27
+ tokens = cross_encoder_tokenizer([query] * len(documents), documents, padding=True, truncation=True, return_tensors="pt")
28
+ with torch.no_grad():
29
+ logits = cross_encoder(**tokens).logits
30
+ scores = logits.reshape(-1).tolist()
31
+ documents = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
32
+ return [doc[0] for doc in documents[:k]]
33
+
34
+
35
+ # def retrieve(query, k):
36
+ # query_vec = retriever.encode(query)
37
+ # try:
38
+ # documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
39
+ # documents = [doc[TEXT_COLUMN] for doc in documents]
40
+ #
41
+ # return documents
42
+ #
43
+ # except Exception as e:
44
+ # raise gr.Error(str(e))
45
+
46
+
47
+ def retrieve(query, top_k_retriever=25, use_reranking=True, top_k_reranker=5):
48
+ query_vec = retriever.encode(query)
49
+ try:
50
+ documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(top_k_retriever).to_list()
51
+ documents = [doc[TEXT_COLUMN] for doc in documents]
52
+
53
+ if use_reranking:
54
+ documents = rerank(query, documents, top_k_reranker)
55
+
56
+ return documents
57
+
58
+ except Exception as e:
59
+ raise gr.Error(str(e))