|
import io |
|
import json |
|
|
|
import numpy as np |
|
import requests |
|
|
|
|
|
def import_talk_info() -> list[dict]: |
|
""" |
|
Import talk info from file. |
|
|
|
Returns: |
|
list[dict]: A list of talk info. |
|
""" |
|
|
|
target_file_url = "https://raw.githubusercontent.com/AlanFeder/rgov-2024/main/data/rgov_talks.json" |
|
|
|
response = requests.get(target_file_url) |
|
response.raise_for_status() |
|
return response.json() |
|
|
|
|
|
def import_embeds() -> np.ndarray: |
|
""" |
|
Import embeddings from file. |
|
|
|
Returns: |
|
np.ndarray: The embeddings. |
|
""" |
|
|
|
target_file_url = ( |
|
"https://raw.githubusercontent.com/AlanFeder/rgov-2024/main/data/embeds.csv" |
|
) |
|
|
|
response = requests.get(target_file_url) |
|
response.raise_for_status() |
|
|
|
|
|
data = np.genfromtxt( |
|
io.StringIO(response.text), delimiter="," |
|
) |
|
|
|
return data |
|
|
|
|
|
def import_data() -> tuple[list[dict], np.ndarray]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
talk_info = import_talk_info() |
|
embeds = import_embeds() |
|
|
|
return talk_info, embeds |
|
|
|
|
|
def do_1_embed(lt: str, oai_api_key: str) -> np.ndarray: |
|
""" |
|
Generate embeddings using the OpenAI API for a single text. |
|
|
|
Args: |
|
lt (str): A text to generate embeddings for. |
|
emb_client (OpenAI): The embedding API client (OpenAI). |
|
|
|
Returns: |
|
np.ndarray: The generated embeddings. |
|
""" |
|
|
|
url = "https://api.openai.com/v1/embeddings" |
|
|
|
|
|
headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {oai_api_key}", |
|
} |
|
|
|
|
|
payload = {"input": lt, "model": "text-embedding-3-small"} |
|
|
|
|
|
response = requests.post(url, headers=headers, data=json.dumps(payload)) |
|
|
|
|
|
if response.status_code == 200: |
|
|
|
embed_response = response.json() |
|
|
|
|
|
here_embed = np.array(embed_response["data"][0]["embedding"]) |
|
|
|
return here_embed |
|
else: |
|
print(f"Error: {response.status_code}") |
|
print(response.text) |
|
|
|
|
|
def do_sort( |
|
embed_q: np.ndarray, embed_talks: np.ndarray, list_talk_ids: list[str] |
|
) -> list[dict[str, str | float]]: |
|
""" |
|
Sort documents based on their cosine similarity to the query embedding. |
|
|
|
Args: |
|
embed_dict (dict[str, np.ndarray]): Dictionary containing document embeddings. |
|
arr_q (np.ndarray): Query embedding. |
|
|
|
Returns: |
|
pd.DataFrame: Sorted dataframe containing document IDs and similarity scores. |
|
""" |
|
|
|
|
|
cos_sims = np.dot(embed_talks, embed_q) |
|
|
|
|
|
best_match_video_ids = np.argsort(-cos_sims) |
|
|
|
|
|
sorted_vids = [ |
|
{"id0": list_talk_ids[i], "score": -cs} |
|
for i, cs in zip(best_match_video_ids, np.sort(-cos_sims)) |
|
] |
|
|
|
return sorted_vids |
|
|
|
|
|
def limit_docs( |
|
sorted_vids: list[dict], |
|
talk_info: dict, |
|
n_results: int, |
|
) -> list[dict]: |
|
""" |
|
Limit the retrieved documents based on a score threshold and return the top documents. |
|
|
|
Args: |
|
df_sorted (pd.DataFrame): Sorted dataframe containing document IDs and similarity scores. |
|
df_talks (pd.DataFrame): Dataframe containing talk information. |
|
n_results (int): Number of top documents to retrieve. |
|
transcript_dicts (dict[str, dict]): Dictionary containing transcript text for each document ID. |
|
|
|
Returns: |
|
dict[str, dict]: Dictionary containing the top documents with their IDs, scores, and text. |
|
""" |
|
|
|
|
|
top_vids = sorted_vids[:n_results] |
|
|
|
|
|
top_score = top_vids[0]["score"] |
|
score_thresh = max(min(0.6, top_score - 0.2), 0.2) |
|
|
|
|
|
keep_texts = [] |
|
for my_vid in top_vids: |
|
if my_vid["score"] >= score_thresh: |
|
vid_data = talk_info[my_vid["id0"]] |
|
vid_data = {**vid_data, **my_vid} |
|
keep_texts.append(vid_data) |
|
|
|
return keep_texts |
|
|
|
|
|
def do_retrieval( |
|
query0: str, |
|
n_results: int, |
|
oai_api_key: str, |
|
embeds: np.ndarray, |
|
talk_info: dict[str, str | int], |
|
) -> list[dict]: |
|
""" |
|
Retrieve relevant documents based on the user's query. |
|
|
|
Args: |
|
query0 (str): The user's query. |
|
n_results (int): The number of documents to retrieve. |
|
api_client (OpenAI): The API client (OpenAI) for generating embeddings. |
|
|
|
Returns: |
|
dict[str, dict]: The retrieved documents. |
|
""" |
|
try: |
|
|
|
arr_q = do_1_embed(query0, oai_api_key=oai_api_key) |
|
|
|
|
|
talk_ids = [ti["id0"] for ti in talk_info] |
|
talk_info = {ti["id0"]: ti for ti in talk_info} |
|
|
|
|
|
sorted_vids = do_sort(embed_q=arr_q, embed_talks=embeds, list_talk_ids=talk_ids) |
|
|
|
|
|
keep_texts = limit_docs( |
|
sorted_vids=sorted_vids, talk_info=talk_info, n_results=n_results |
|
) |
|
|
|
return keep_texts |
|
except Exception as e: |
|
raise e |
|
|
|
|
|
SYSTEM_PROMPT = """ |
|
You are an AI assistant that helps answer questions by searching through video transcripts. |
|
I have retrieved the transcripts most likely to answer the user's question. |
|
Carefully read through the transcripts to find information that helps answer the question. |
|
Be brief - your response should not be more than two paragraphs. |
|
Only use information directly stated in the provided transcripts to answer the question. |
|
Do not add any information or make any claims that are not explicitly supported by the transcripts. |
|
If the transcripts do not contain enough information to answer the question, state that you do not have enough information to provide a complete answer. |
|
Format the response clearly. If only one of the transcripts answers the question, don't reference the other and don't explain why its content is irrelevant. |
|
Do not speak in the first person. DO NOT write a letter, make an introduction, or salutation. |
|
Reference the speaker's name when you say what they said. |
|
""" |
|
|
|
|
|
def set_messages(system_prompt: str, user_prompt: str) -> list[dict[str, str]]: |
|
""" |
|
Set the messages for the chat completion. |
|
|
|
Args: |
|
system_prompt (str): The system prompt. |
|
user_prompt (str): The user prompt. |
|
|
|
Returns: |
|
tuple[list[dict[str, str]], int]: A tuple containing the messages and the total number of input tokens. |
|
""" |
|
messages1 = [ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_prompt}, |
|
] |
|
|
|
return messages1 |
|
|
|
|
|
def make_user_prompt(question: str, keep_texts: list[dict]) -> str: |
|
""" |
|
Create the user prompt based on the question and the retrieved transcripts. |
|
|
|
Args: |
|
question (str): The user's question. |
|
keep_texts (dict[str, dict[str, str]]): The retrieved transcripts. |
|
|
|
Returns: |
|
str: The user prompt. |
|
""" |
|
user_prompt = f""" |
|
Question: {question} |
|
============================== |
|
""" |
|
if len(keep_texts) > 0: |
|
list_strs = [] |
|
for i, tx_val in enumerate(keep_texts): |
|
text0 = tx_val["transcript"] |
|
speaker_name = tx_val["Speaker"] |
|
list_strs.append( |
|
f"Video Transcript {i+1}\nSpeaker: {speaker_name}\n{text0}" |
|
) |
|
user_prompt += "\n-------\n".join(list_strs) |
|
user_prompt += """ |
|
============================== |
|
After analyzing the above video transcripts, please provide a helpful answer to my question. Remember to stay within two paragraphs |
|
Address the response to me directly. Do not use any information not explicitly supported by the transcripts. Remember to reference the speaker's name.""" |
|
else: |
|
|
|
user_prompt += "No relevant video transcripts were found. Please just return a result that says something like 'I'm sorry, but the answer to {Question} was not found in the transcripts from the R/Gov Conference'" |
|
|
|
return user_prompt |
|
|
|
|
|
def parse_1_query_stream(response): |
|
|
|
if response.status_code == 200: |
|
for line in response.iter_lines(): |
|
if line: |
|
line = line.decode("utf-8") |
|
if line.startswith("data: "): |
|
data = line[6:] |
|
if data != "[DONE]": |
|
try: |
|
chunk = json.loads(data) |
|
content = chunk["choices"][0]["delta"].get("content", "") |
|
if content: |
|
yield content |
|
except json.JSONDecodeError: |
|
yield f"Error decoding JSON: {data}" |
|
else: |
|
yield f"Error: {response.status_code}\n{response.text}" |
|
|
|
|
|
def parse_1_query_no_stream(response): |
|
if response.status_code == 200: |
|
try: |
|
response1 = response.json() |
|
completion = response1["choices"][0]["message"]["content"] |
|
return completion |
|
except json.JSONDecodeError: |
|
return f"Error decoding JSON: {response.text}" |
|
else: |
|
return f"Error: {response.status_code}\n{response.text}" |
|
|
|
|
|
def do_1_query( |
|
messages1: list[dict[str, str]], oai_api_key: str, stream: bool, model_name: str |
|
): |
|
""" |
|
Generate a response using the specified chat completion model. |
|
|
|
Args: |
|
messages1 (list[dict[str, str]]): The messages for the chat completion. |
|
gen_client (OpenAI): The generation client (OpenAI). |
|
""" |
|
|
|
|
|
url = "https://api.openai.com/v1/chat/completions" |
|
|
|
|
|
|
|
headers = { |
|
"Content-Type": "application/json", |
|
"Authorization": f"Bearer {oai_api_key}", |
|
} |
|
if stream: |
|
headers["Accept"] = "text/event-stream" |
|
|
|
|
|
model1 = model_name |
|
|
|
|
|
payload = { |
|
"model": model1, |
|
"messages": messages1, |
|
"seed": 18, |
|
"temperature": 0, |
|
"stream": stream, |
|
} |
|
|
|
|
|
response = requests.post( |
|
url, headers=headers, data=json.dumps(payload), stream=stream |
|
) |
|
|
|
if stream: |
|
response1 = parse_1_query_stream(response) |
|
else: |
|
|
|
response1 = parse_1_query_no_stream(response) |
|
|
|
return response1 |
|
|
|
|
|
def do_generation( |
|
query1: str, keep_texts: list[dict], oai_api_key: str, stream: bool, model_name: str |
|
): |
|
""" |
|
Generate the chatbot response using the specified generation client. |
|
|
|
Args: |
|
query1 (str): The user's query. |
|
keep_texts (dict[str, dict[str, str]]): The retrieved relevant texts. |
|
gen_client (OpenAI): The generation client (OpenAI). |
|
|
|
Returns: |
|
tuple[Stream, int]: A tuple containing the generated response stream and the number of prompt tokens. |
|
""" |
|
user_prompt = make_user_prompt(query1, keep_texts=keep_texts) |
|
messages1 = set_messages(SYSTEM_PROMPT, user_prompt) |
|
response = do_1_query( |
|
messages1, oai_api_key=oai_api_key, stream=stream, model_name=model_name |
|
) |
|
|
|
return response |
|
|
|
|
|
def calc_cost( |
|
prompt_tokens: int, completion_tokens: int, embedding_tokens: int |
|
) -> float: |
|
""" |
|
Calculate the cost in cents based on the number of prompt, completion, and embedding tokens. |
|
|
|
Args: |
|
prompt_tokens (int): The number of tokens in the prompt. |
|
completion_tokens (int): The number of tokens in the completion. |
|
embedding_tokens (int): The number of tokens in the embedding. |
|
|
|
Returns: |
|
float: The cost in cents. |
|
""" |
|
prompt_cost = prompt_tokens / 2000 |
|
completion_cost = 3 * completion_tokens / 2000 |
|
embedding_cost = embedding_tokens / 500000 |
|
|
|
cost_cents = prompt_cost + completion_cost + embedding_cost |
|
|
|
return cost_cents |
|
|
|
|
|
def do_rag( |
|
user_input: str, |
|
oai_api_key: str, |
|
model_name: str, |
|
stream: bool = False, |
|
n_results: int = 3, |
|
): |
|
|
|
talk_info, embeds = import_data() |
|
|
|
|
|
retrieved_docs = do_retrieval( |
|
query0=user_input, |
|
n_results=n_results, |
|
oai_api_key=oai_api_key, |
|
embeds=embeds, |
|
talk_info=talk_info, |
|
) |
|
|
|
response = do_generation( |
|
query1=user_input, |
|
keep_texts=retrieved_docs, |
|
model_name=model_name, |
|
oai_api_key=oai_api_key, |
|
stream=stream, |
|
) |
|
|
|
return response, retrieved_docs |
|
|