Spaces:
Sleeping
Sleeping
Upload 8 files
Browse files- backend/.DS_Store +0 -0
- backend/__pycache__/query_llm.cpython-310.pyc +0 -0
- backend/__pycache__/query_llm.cpython-37.pyc +0 -0
- backend/__pycache__/query_llm.cpython-39.pyc +0 -0
- backend/__pycache__/semantic_search.cpython-310.pyc +0 -0
- backend/__pycache__/semantic_search.cpython-39.pyc +0 -0
- backend/query_llm.py +129 -0
- backend/semantic_search.py +59 -0
backend/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
backend/__pycache__/query_llm.cpython-310.pyc
ADDED
Binary file (4.03 kB). View file
|
|
backend/__pycache__/query_llm.cpython-37.pyc
ADDED
Binary file (4.01 kB). View file
|
|
backend/__pycache__/query_llm.cpython-39.pyc
ADDED
Binary file (4.04 kB). View file
|
|
backend/__pycache__/semantic_search.cpython-310.pyc
ADDED
Binary file (1.18 kB). View file
|
|
backend/__pycache__/semantic_search.cpython-39.pyc
ADDED
Binary file (2.14 kB). View file
|
|
backend/query_llm.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import openai
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
|
5 |
+
from typing import Any, Dict, Generator, List
|
6 |
+
|
7 |
+
from huggingface_hub import InferenceClient
|
8 |
+
from transformers import AutoTokenizer
|
9 |
+
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
load_dotenv()
|
12 |
+
OPENAI_KEY = os.getenv("OPENAI_API_KEY")
|
13 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
14 |
+
TOKENIZER = AutoTokenizer.from_pretrained(os.getenv("HF_MODEL"))
|
15 |
+
|
16 |
+
HF_CLIENT = InferenceClient(
|
17 |
+
os.getenv("HF_MODEL"),
|
18 |
+
token=HF_TOKEN
|
19 |
+
)
|
20 |
+
OAI_CLIENT = openai.Client(api_key=OPENAI_KEY)
|
21 |
+
|
22 |
+
HF_GENERATE_KWARGS = {
|
23 |
+
'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2),
|
24 |
+
'max_new_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)),
|
25 |
+
'top_p': float(os.getenv("TOP_P", 0.6)),
|
26 |
+
'repetition_penalty': float(os.getenv("REP_PENALTY", 1.2)),
|
27 |
+
'do_sample': bool(os.getenv("DO_SAMPLE", True))
|
28 |
+
}
|
29 |
+
|
30 |
+
OAI_GENERATE_KWARGS = {
|
31 |
+
'temperature': max(float(os.getenv("TEMPERATURE", 0.9)), 1e-2),
|
32 |
+
'max_tokens': int(os.getenv("MAX_NEW_TOKENS", 256)),
|
33 |
+
'top_p': float(os.getenv("TOP_P", 0.6)),
|
34 |
+
'frequency_penalty': max(-2, min(float(os.getenv("FREQ_PENALTY", 0)), 2))
|
35 |
+
}
|
36 |
+
|
37 |
+
|
38 |
+
def format_prompt(message: str, api_kind: str):
|
39 |
+
"""
|
40 |
+
Formats the given message using a chat template.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
message (str): The user message to be formatted.
|
44 |
+
api_kind (str): LLM API provider.
|
45 |
+
Returns:
|
46 |
+
str: Formatted message after applying the chat template.
|
47 |
+
"""
|
48 |
+
|
49 |
+
# Create a list of message dictionaries with role and content
|
50 |
+
messages: List[Dict[str, Any]] = [{'role': 'user', 'content': message}]
|
51 |
+
|
52 |
+
if api_kind == "openai":
|
53 |
+
return messages
|
54 |
+
elif api_kind == "hf":
|
55 |
+
return TOKENIZER.apply_chat_template(messages, tokenize=False)
|
56 |
+
elif api_kind:
|
57 |
+
raise ValueError("API is not supported")
|
58 |
+
|
59 |
+
|
60 |
+
def generate_hf(prompt: str, history: str) -> Generator[str, None, str]:
|
61 |
+
"""
|
62 |
+
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
prompt (str): The prompt for the text generation.
|
66 |
+
history (str): Context or history for the text generation.
|
67 |
+
Returns:
|
68 |
+
Generator[str, None, str]: A generator yielding chunks of generated text.
|
69 |
+
Returns a final string if an error occurs.
|
70 |
+
"""
|
71 |
+
|
72 |
+
formatted_prompt = format_prompt(prompt, "hf")
|
73 |
+
formatted_prompt = formatted_prompt.encode("utf-8").decode("utf-8")
|
74 |
+
|
75 |
+
try:
|
76 |
+
stream = HF_CLIENT.text_generation(
|
77 |
+
formatted_prompt,
|
78 |
+
**HF_GENERATE_KWARGS,
|
79 |
+
stream=True,
|
80 |
+
details=True,
|
81 |
+
return_full_text=False
|
82 |
+
)
|
83 |
+
output = ""
|
84 |
+
for response in stream:
|
85 |
+
output += response.token.text
|
86 |
+
yield output
|
87 |
+
|
88 |
+
except Exception as e:
|
89 |
+
if "Too Many Requests" in str(e):
|
90 |
+
raise gr.Error(f"Too many requests: {str(e)}")
|
91 |
+
elif "Authorization header is invalid" in str(e):
|
92 |
+
raise gr.Error("Authentication error: HF token was either not provided or incorrect")
|
93 |
+
else:
|
94 |
+
raise gr.Error(f"Unhandled Exception: {str(e)}")
|
95 |
+
|
96 |
+
|
97 |
+
def generate_openai(prompt: str, history: str) -> Generator[str, None, str]:
|
98 |
+
"""
|
99 |
+
Generate a sequence of tokens based on a given prompt and history using Mistral client.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
prompt (str): The initial prompt for the text generation.
|
103 |
+
history (str): Context or history for the text generation.
|
104 |
+
Returns:
|
105 |
+
Generator[str, None, str]: A generator yielding chunks of generated text.
|
106 |
+
Returns a final string if an error occurs.
|
107 |
+
"""
|
108 |
+
formatted_prompt = format_prompt(prompt, "openai")
|
109 |
+
|
110 |
+
try:
|
111 |
+
stream = OAI_CLIENT.chat.completions.create(
|
112 |
+
model=os.getenv("OPENAI_MODEL"),
|
113 |
+
messages=formatted_prompt,
|
114 |
+
**OAI_GENERATE_KWARGS,
|
115 |
+
stream=True
|
116 |
+
)
|
117 |
+
output = ""
|
118 |
+
for chunk in stream:
|
119 |
+
if chunk.choices[0].delta.content:
|
120 |
+
output += chunk.choices[0].delta.content
|
121 |
+
yield output
|
122 |
+
|
123 |
+
except Exception as e:
|
124 |
+
if "Too Many Requests" in str(e):
|
125 |
+
raise gr.Error("ERROR: Too many requests on OpenAI client")
|
126 |
+
elif "You didn't provide an API key" in str(e):
|
127 |
+
raise gr.Error("Authentication error: OpenAI key was either not provided or incorrect")
|
128 |
+
else:
|
129 |
+
raise gr.Error(f"Unhandled Exception: {str(e)}")
|
backend/semantic_search.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lancedb
|
2 |
+
import os
|
3 |
+
import gradio as gr
|
4 |
+
from sentence_transformers import SentenceTransformer
|
5 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
6 |
+
import torch
|
7 |
+
import time
|
8 |
+
import os
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
db = lancedb.connect(".lancedb")
|
12 |
+
|
13 |
+
TABLE = db.open_table(os.getenv("TABLE_NAME"))
|
14 |
+
VECTOR_COLUMN = os.getenv("VECTOR_COLUMN", "vector")
|
15 |
+
TEXT_COLUMN = os.getenv("TEXT_COLUMN", "text")
|
16 |
+
BATCH_SIZE = int(os.getenv("BATCH_SIZE", 32))
|
17 |
+
CROSS_ENCODER = os.getenv("CROSS_ENCODER")
|
18 |
+
|
19 |
+
retriever = SentenceTransformer(os.getenv("EMB_MODEL"))
|
20 |
+
cross_encoder = AutoModelForSequenceClassification.from_pretrained(CROSS_ENCODER)
|
21 |
+
cross_encoder.eval()
|
22 |
+
cross_encoder_tokenizer = AutoTokenizer.from_pretrained(CROSS_ENCODER)
|
23 |
+
|
24 |
+
|
25 |
+
def rerank(query, documents, k):
|
26 |
+
"""Use cross-encoder to rerank documents retrieved from the retriever."""
|
27 |
+
tokens = cross_encoder_tokenizer([query] * len(documents), documents, padding=True, truncation=True, return_tensors="pt")
|
28 |
+
with torch.no_grad():
|
29 |
+
logits = cross_encoder(**tokens).logits
|
30 |
+
scores = logits.reshape(-1).tolist()
|
31 |
+
documents = sorted(zip(documents, scores), key=lambda x: x[1], reverse=True)
|
32 |
+
return [doc[0] for doc in documents[:k]]
|
33 |
+
|
34 |
+
|
35 |
+
# def retrieve(query, k):
|
36 |
+
# query_vec = retriever.encode(query)
|
37 |
+
# try:
|
38 |
+
# documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(k).to_list()
|
39 |
+
# documents = [doc[TEXT_COLUMN] for doc in documents]
|
40 |
+
#
|
41 |
+
# return documents
|
42 |
+
#
|
43 |
+
# except Exception as e:
|
44 |
+
# raise gr.Error(str(e))
|
45 |
+
|
46 |
+
|
47 |
+
def retrieve(query, top_k_retriever=25, use_reranking=True, top_k_reranker=5):
|
48 |
+
query_vec = retriever.encode(query)
|
49 |
+
try:
|
50 |
+
documents = TABLE.search(query_vec, vector_column_name=VECTOR_COLUMN).limit(top_k_retriever).to_list()
|
51 |
+
documents = [doc[TEXT_COLUMN] for doc in documents]
|
52 |
+
|
53 |
+
if use_reranking:
|
54 |
+
documents = rerank(query, documents, top_k_reranker)
|
55 |
+
|
56 |
+
return documents
|
57 |
+
|
58 |
+
except Exception as e:
|
59 |
+
raise gr.Error(str(e))
|