Spaces:
Runtime error
Runtime error
import time | |
from dataclasses import dataclass | |
from datetime import datetime | |
from functools import reduce | |
import json | |
import os | |
from pathlib import Path | |
import re | |
import requests | |
from requests.models import MissingSchema | |
import sys | |
from typing import List, Optional, Tuple, Dict, Callable, Any | |
from bs4 import BeautifulSoup | |
import docx | |
from html2text import html2text | |
import langchain | |
from langchain.callbacks import get_openai_callback | |
from langchain.cache import SQLiteCache | |
from langchain.chains import LLMChain | |
from langchain.chains.chat_vector_db.prompts import CONDENSE_QUESTION_PROMPT | |
from langchain.chat_models import ChatOpenAI | |
from langchain.chat_models.base import BaseChatModel | |
from langchain.document_loaders import PyPDFLoader, PyMuPDFLoader | |
from langchain.embeddings.base import Embeddings | |
from langchain.embeddings.openai import OpenAIEmbeddings | |
from langchain.llms import OpenAI | |
from langchain.llms.base import LLM, BaseLLM | |
from langchain.prompts.chat import AIMessagePromptTemplate | |
from langchain.text_splitter import TokenTextSplitter, RecursiveCharacterTextSplitter | |
from langchain.vectorstores import Pinecone as OriginalPinecone | |
import numpy as np | |
import openai | |
import pinecone | |
from pptx import Presentation | |
from pypdf import PdfReader | |
import trafilatura | |
from streamlit_langchain_chat.constants import * | |
from streamlit_langchain_chat.customized_langchain.vectorstores import FAISS | |
from streamlit_langchain_chat.customized_langchain.vectorstores import Pinecone | |
from streamlit_langchain_chat.utils import maybe_is_text, maybe_is_truncated | |
from streamlit_langchain_chat.prompts import * | |
if REUSE_ANSWERS: | |
CACHE_PATH = TEMP_DIR / "llm_cache.db" | |
os.makedirs(os.path.dirname(CACHE_PATH), exist_ok=True) | |
langchain.llm_cache = SQLiteCache(str(CACHE_PATH)) | |
# option 1 | |
TextSplitter = TokenTextSplitter | |
# option 2 | |
# TextSplitter = RecursiveCharacterTextSplitter # usado por gpt4_pdf_chatbot_langchain (aka GPCL) | |
class Answer: | |
"""A class to hold the answer to a question.""" | |
question: str = "" | |
answer: str = "" | |
context: str = "" | |
chunks: str = "" | |
packages: List[Any] = None | |
references: str = "" | |
cost_str: str = "" | |
passages: Dict[str, str] = None | |
tokens: List[Dict] = None | |
def __post_init__(self): | |
"""Initialize the answer.""" | |
if self.packages is None: | |
self.packages = [] | |
if self.passages is None: | |
self.passages = {} | |
def __str__(self) -> str: | |
"""Return the answer as a string.""" | |
return self.answer | |
def parse_docx(path, citation, key, chunk_chars=2000, overlap=50): | |
try: | |
document = docx.Document(path) | |
fullText = [] | |
for paragraph in document.paragraphs: | |
fullText.append(paragraph.text) | |
doc = '\n'.join(fullText) + '\n' | |
except Exception as e: | |
print(f"code_error: {e}") | |
sys.exit(1) | |
if doc: | |
text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap) | |
texts = text_splitter.split_text(doc) | |
return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts) | |
else: | |
return [], [] | |
# TODO: si pones un conector con el formato loader = ... ; data = loader.load(); | |
# podrás poner todos los conectores de langchain | |
# https://langchain.readthedocs.io/en/stable/modules/document_loaders/examples/pdf.html | |
def parse_pdf(path, citation, key, chunk_chars=2000, overlap=50): | |
pdfFileObj = open(path, "rb") | |
pdfReader = PdfReader(pdfFileObj) | |
splits = [] | |
split = "" | |
pages = [] | |
metadatas = [] | |
for i, page in enumerate(pdfReader.pages): | |
split += page.extract_text() | |
pages.append(str(i + 1)) | |
# split could be so long it needs to be split | |
# into multiple chunks. Or it could be so short | |
# that it needs to be combined with the next chunk. | |
while len(split) > chunk_chars: | |
splits.append(split[:chunk_chars]) | |
# pretty formatting of pages (e.g. 1-3, 4, 5-7) | |
pg = "-".join([pages[0], pages[-1]]) | |
metadatas.append( | |
dict( | |
citation=citation, | |
dockey=key, | |
key=f"{key} pages {pg}", | |
) | |
) | |
split = split[chunk_chars - overlap:] | |
pages = [str(i + 1)] | |
if len(split) > overlap: | |
splits.append(split[:chunk_chars]) | |
pg = "-".join([pages[0], pages[-1]]) | |
metadatas.append( | |
dict( | |
citation=citation, | |
dockey=key, | |
key=f"{key} pages {pg}", | |
) | |
) | |
pdfFileObj.close() | |
# # ### option 2. PyPDFLoader | |
# loader = PyPDFLoader(path) | |
# data = loader.load_and_split() | |
# # ### option 2.1. PyPDFLoader usado por GPCL, aunque luego usa el | |
# loader = PyPDFLoader(path) | |
# rawDocs = loader.load() | |
# text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap) | |
# texts = text_splitter.split_documents(rawDocs) | |
# # ### option 3. PDFMiner. Este parece la mejor opcion | |
# loader = PyMuPDFLoader(path) | |
# data = loader.load() | |
return splits, metadatas | |
def parse_pptx(path, citation, key, chunk_chars=2000, overlap=50): | |
try: | |
presentation = Presentation(path) | |
fullText = [] | |
for slide in presentation.slides: | |
for shape in slide.shapes: | |
if hasattr(shape, "text"): | |
fullText.append(shape.text) | |
doc = ''.join(fullText) | |
if doc: | |
text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap) | |
texts = text_splitter.split_text(doc) | |
return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts) | |
else: | |
return [], [] | |
except Exception as e: | |
print(f"code_error: {e}") | |
sys.exit(1) | |
def parse_txt(path, citation, key, chunk_chars=2000, overlap=50, html=False): | |
try: | |
with open(path) as f: | |
doc = f.read() | |
except UnicodeDecodeError as e: | |
with open(path, encoding="utf-8", errors="ignore") as f: | |
doc = f.read() | |
if html: | |
doc = html2text(doc) | |
# yo, no idea why but the texts are not split correctly | |
text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap) | |
texts = text_splitter.split_text(doc) | |
return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts) | |
def parse_url(url: str, citation, key, chunk_chars=2000, overlap=50): | |
def beautifulsoup_extract_text_fallback(response_content): | |
""" | |
This is a fallback function, so that we can always return a value for text content. | |
Even for when both Trafilatura and BeautifulSoup are unable to extract the text from a | |
single URL. | |
""" | |
# Create the beautifulsoup object: | |
soup = BeautifulSoup(response_content, 'html.parser') | |
# Finding the text: | |
text = soup.find_all(text=True) | |
# Remove unwanted tag elements: | |
cleaned_text = '' | |
blacklist = [ | |
'[document]', | |
'noscript', | |
'header', | |
'html', | |
'meta', | |
'head', | |
'input', | |
'script', | |
'style', ] | |
# Then we will loop over every item in the extract text and make sure that the beautifulsoup4 tag | |
# is NOT in the blacklist | |
for item in text: | |
if item.parent.name not in blacklist: | |
cleaned_text += f'{item} ' # cleaned_text += '{} '.format(item) | |
# Remove any tab separation and strip the text: | |
cleaned_text = cleaned_text.replace('\t', '') | |
return cleaned_text.strip() | |
def extract_text_from_single_web_page(url): | |
print(f"\n===========\n{url=}\n===========\n") | |
downloaded_url = trafilatura.fetch_url(url) | |
a = None | |
try: | |
a = trafilatura.extract(downloaded_url, | |
output_format='json', | |
with_metadata=True, | |
include_comments=False, | |
date_extraction_params={'extensive_search': True, | |
'original_date': True}) | |
except AttributeError: | |
a = trafilatura.extract(downloaded_url, | |
output_format='json', | |
with_metadata=True, | |
date_extraction_params={'extensive_search': True, | |
'original_date': True}) | |
except Exception as e: | |
print(f"code_error: {e}") | |
if a: | |
json_output = json.loads(a) | |
return json_output['text'] | |
else: | |
try: | |
headers = {'User-Agent': 'Chrome/83.0.4103.106'} | |
resp = requests.get(url, headers=headers) | |
print(f"{resp=}\n") | |
# We will only extract the text from successful requests: | |
if resp.status_code == 200: | |
return beautifulsoup_extract_text_fallback(resp.content) | |
else: | |
# This line will handle for any failures in both the Trafilature and BeautifulSoup4 functions: | |
return np.nan | |
# Handling for any URLs that don't have the correct protocol | |
except MissingSchema: | |
return np.nan | |
text_to_split = extract_text_from_single_web_page(url) | |
text_splitter = TextSplitter(chunk_size=chunk_chars, chunk_overlap=overlap) | |
texts = text_splitter.split_text(text_to_split) | |
return texts, [dict(citation=citation, dockey=key, key=key)] * len(texts) | |
def read_source(path: str = None, | |
citation: str = None, | |
key: str = None, | |
chunk_chars: int = 3000, | |
overlap: int = 100, | |
disable_check: bool = False): | |
if path.endswith(".pdf"): | |
return parse_pdf(path, citation, key, chunk_chars, overlap) | |
elif path.endswith(".txt"): | |
return parse_txt(path, citation, key, chunk_chars, overlap) | |
elif path.endswith(".html"): | |
return parse_txt(path, citation, key, chunk_chars, overlap, html=True) | |
elif path.endswith(".docx"): | |
return parse_docx(path, citation, key, chunk_chars, overlap) | |
elif path.endswith(".pptx"): | |
return parse_pptx(path, citation, key, chunk_chars, overlap) | |
elif path.startswith("http://") or path.startswith("https://"): | |
return parse_url(path, citation, key, chunk_chars, overlap) | |
# TODO: poner mas conectores | |
# else: | |
# return parse_code_txt(path, citation, key, chunk_chars, overlap) | |
else: | |
raise "unknown extension" | |
class Dataset: | |
"""A collection of documents to be used for answering questions.""" | |
def __init__( | |
self, | |
chunk_size_limit: int = 3000, | |
llm: Optional[BaseLLM] | Optional[BaseChatModel] = None, | |
summary_llm: Optional[BaseLLM] = None, | |
name: str = "default", | |
index_path: Optional[Path] = None, | |
) -> None: | |
"""Initialize the collection of documents. | |
Args: | |
chunk_size_limit: The maximum number of characters to use for a single chunk of text. | |
llm: The language model to use for answering questions. Default - OpenAI chat-gpt-turbo | |
summary_llm: The language model to use for summarizing documents. If None, llm is used. | |
name: The name of the collection. | |
index_path: The path to the index file IF pickled. If None, defaults to using name in $HOME/.paperqa/name | |
""" | |
self.docs = dict() | |
self.keys = set() | |
self.chunk_size_limit = chunk_size_limit | |
self.index_docstore = None | |
if llm is None: | |
llm = ChatOpenAI(temperature=0.1, max_tokens=512) | |
if summary_llm is None: | |
summary_llm = llm | |
self.update_llm(llm, summary_llm) | |
if index_path is None: | |
index_path = TEMP_DIR / name | |
self.index_path = index_path | |
self.name = name | |
def update_llm(self, llm: BaseLLM | ChatOpenAI, summary_llm: Optional[BaseLLM] = None) -> None: | |
"""Update the LLM for answering questions.""" | |
self.llm = llm | |
if summary_llm is None: | |
summary_llm = llm | |
self.summary_llm = summary_llm | |
self.summary_chain = LLMChain(prompt=chat_summary_prompt, llm=summary_llm) | |
self.search_chain = LLMChain(prompt=search_prompt, llm=llm) | |
self.cite_chain = LLMChain(prompt=citation_prompt, llm=llm) | |
def add( | |
self, | |
path: str, | |
citation: Optional[str] = None, | |
key: Optional[str] = None, | |
disable_check: bool = False, | |
chunk_chars: Optional[int] = 3000, | |
) -> None: | |
"""Add a document to the collection.""" | |
if path in self.docs: | |
print(f"Document {path} already in collection.") | |
return None | |
if citation is None: | |
# peak first chunk | |
texts, _ = read_source(path, "", "", chunk_chars=chunk_chars) | |
with get_openai_callback() as cb: | |
citation = self.cite_chain.run(texts[0]) | |
if len(citation) < 3 or "Unknown" in citation or "insufficient" in citation: | |
citation = f"Unknown, {os.path.basename(path)}, {datetime.now().year}" | |
if key is None: | |
# get first name and year from citation | |
try: | |
author = re.search(r"([A-Z][a-z]+)", citation).group(1) | |
except AttributeError: | |
# panicking - no word?? | |
raise ValueError( | |
f"Could not parse key from citation {citation}. Consider just passing key explicitly - e.g. docs.py (path, citation, key='mykey')" | |
) | |
try: | |
year = re.search(r"(\d{4})", citation).group(1) | |
except AttributeError: | |
year = "" | |
key = f"{author}{year}" | |
suffix = "" | |
while key + suffix in self.keys: | |
# move suffix to next letter | |
if suffix == "": | |
suffix = "a" | |
else: | |
suffix = chr(ord(suffix) + 1) | |
key += suffix | |
self.keys.add(key) | |
texts, metadata = read_source(path, citation, key, chunk_chars=chunk_chars) | |
# loose check to see if document was loaded | |
# | |
if len("".join(texts)) < 10 or ( | |
not disable_check and not maybe_is_text("".join(texts)) | |
): | |
raise ValueError( | |
f"This does not look like a text document: {path}. Path disable_check to ignore this error." | |
) | |
self.docs[path] = dict(texts=texts, metadata=metadata, key=key) | |
if self.index_docstore is not None: | |
self.index_docstore.add_texts(texts, metadatas=metadata) | |
def clear(self) -> None: | |
"""Clear the collection of documents.""" | |
self.docs = dict() | |
self.keys = set() | |
self.index_docstore = None | |
# delete index file | |
pkl = self.index_path / "index.pkl" | |
if pkl.exists(): | |
pkl.unlink() | |
fs = self.index_path / "index.faiss" | |
if fs.exists(): | |
fs.unlink() | |
def doc_previews(self) -> List[Tuple[int, str, str]]: | |
"""Return a list of tuples of (key, citation) for each document.""" | |
return [ | |
( | |
len(doc["texts"]), | |
doc["metadata"][0]["dockey"], | |
doc["metadata"][0]["citation"], | |
) | |
for doc in self.docs.values() | |
] | |
# to pickle, we have to save the index as a file | |
def __getstate__(self, embedding: Embeddings): | |
if embedding is None: | |
embedding = OpenAIEmbeddings() | |
if self.index_docstore is None and len(self.docs) > 0: | |
self._build_faiss_index(embedding) | |
state = self.__dict__.copy() | |
if self.index_docstore is not None: | |
state["_index"].save_local(self.index_path) | |
del state["_index"] | |
# remove LLMs (they can have callbacks, which can't be pickled) | |
del state["summary_chain"] | |
del state["qa_chain"] | |
del state["cite_chain"] | |
del state["search_chain"] | |
return state | |
def __setstate__(self, state): | |
self.__dict__.update(state) | |
try: | |
self.index_docstore = FAISS.load_local(self.index_path, OpenAIEmbeddings()) | |
except: | |
# they use some special exception type, but I don't want to import it | |
self.index_docstore = None | |
self.update_llm( | |
ChatOpenAI(temperature=0.1, max_tokens=512) | |
) | |
def _build_faiss_index(self, embedding: Embeddings = None): | |
if embedding is None: | |
embedding = OpenAIEmbeddings() | |
if self.index_docstore is None: | |
texts = reduce( | |
lambda x, y: x + y, [doc["texts"] for doc in self.docs.values()], [] | |
) | |
metadatas = reduce( | |
lambda x, y: x + y, [doc["metadata"] for doc in self.docs.values()], [] | |
) | |
# if the index exists, load it | |
if LOAD_INDEX_LOCALLY and (self.index_path / "index.faiss").exists(): | |
self.index_docstore = FAISS.load_local(self.index_path, embedding) | |
# search if the text and metadata already existed in the index | |
for i in reversed(range(len(texts))): | |
text = texts[i] | |
metadata = metadatas[i] | |
for key, value in self.index_docstore.docstore.dict_.items(): | |
if value.page_content == text: | |
if value.metadata.get('citation').split(os.sep)[-1] != metadata.get('citation').split(os.sep)[-1]: | |
self.index_docstore.docstore.dict_[key].metadata['citation'] = metadata.get('citation').split(os.sep)[-1] | |
self.index_docstore.docstore.dict_[key].metadata['dockey'] = metadata.get('citation').split(os.sep)[-1] | |
self.index_docstore.docstore.dict_[key].metadata['key'] = metadata.get('citation').split(os.sep)[-1] | |
texts.pop(i) | |
metadatas.pop(i) | |
# add remaining texts | |
if texts: | |
self.index_docstore.add_texts(texts=texts, metadatas=metadatas) | |
else: | |
# crete new index | |
self.index_docstore = FAISS.from_texts(texts, embedding, metadatas=metadatas) | |
# | |
if SAVE_INDEX_LOCALLY: | |
# save index. | |
self.index_docstore.save_local(self.index_path) | |
def _build_pinecone_index(self, embedding: Embeddings = None): | |
if embedding is None: | |
embedding = OpenAIEmbeddings() | |
if self.index_docstore is None: | |
pinecone.init( | |
api_key=os.environ['PINECONE_API_KEY'], # find at app.pinecone.io | |
environment=os.environ['PINECONE_ENVIRONMENT'] # next to api key in console | |
) | |
texts = reduce( | |
lambda x, y: x + y, [doc["texts"] for doc in self.docs.values()], [] | |
) | |
metadatas = reduce( | |
lambda x, y: x + y, [doc["metadata"] for doc in self.docs.values()], [] | |
) | |
# TODO: que cuando exista que no lo borre, sino que lo actualice | |
# index_name = "langchain-demo1" | |
# if index_name in pinecone.list_indexes(): | |
# self.index_docstore = pinecone.Index(index_name) | |
# vectors = [] | |
# for text, metadata in zip(texts, metadatas): | |
# # embed = <faltaria saber con que embedding se hizo el index que ya existia> | |
# self.index_docstore.upsert(vectors=vectors) | |
# else: | |
# if openai.api_type == 'azure': | |
# self.index_docstore = Pinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name) | |
# else: | |
# self.index_docstore = OriginalPinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name) | |
index_name = "langchain-demo1" | |
# if the index exists, delete it | |
if index_name in pinecone.list_indexes(): | |
pinecone.delete_index(index_name) | |
# create new index | |
if openai.api_type == 'azure': | |
self.index_docstore = Pinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name) | |
else: | |
self.index_docstore = OriginalPinecone.from_texts(texts, embedding, metadatas=metadatas, index_name=index_name) | |
def get_evidence( | |
self, | |
answer: Answer, | |
embedding: Embeddings, | |
k: int = 3, | |
max_sources: int = 5, | |
marginal_relevance: bool = True, | |
) -> str: | |
if self.index_docstore is None: | |
self._build_faiss_index(embedding) | |
init_search_time = time.time() | |
# want to work through indices but less k | |
if marginal_relevance: | |
docs = self.index_docstore.max_marginal_relevance_search( | |
answer.question, k=k, fetch_k=5 * k | |
) | |
else: | |
docs = self.index_docstore.similarity_search( | |
answer.question, k=k, fetch_k=5 * k | |
) | |
if OPERATING_MODE == "debug": | |
print(f"time to search docs to build context: {time.time() - init_search_time:.2f} [s]") | |
init_summary_time = time.time() | |
partial_summary_time = "" | |
for i, doc in enumerate(docs): | |
with get_openai_callback() as cb: | |
init__partial_summary_time = time.time() | |
summary_of_chunked_text = self.summary_chain.run( | |
question=answer.question, context_str=doc.page_content | |
) | |
if OPERATING_MODE == "debug": | |
partial_summary_time += f"- time to make relevant summary of doc '{i}': {time.time() - init__partial_summary_time:.2f} [s]\n" | |
engine = self.summary_chain.llm.model_kwargs.get('deployment_id') or self.summary_chain.llm.model_name | |
if not answer.tokens: | |
answer.tokens = [{ | |
'engine': engine, | |
'total_tokens': cb.total_tokens}] | |
else: | |
answer.tokens.append({ | |
'engine': engine, | |
'total_tokens': cb.total_tokens | |
}) | |
summarized_package = ( | |
doc.metadata["key"], | |
doc.metadata["citation"], | |
summary_of_chunked_text, | |
doc.page_content, | |
) | |
if "Not applicable" not in summary_of_chunked_text and summarized_package not in answer.packages: | |
answer.packages.append(summarized_package) | |
yield answer | |
if len(answer.packages) == max_sources: | |
break | |
if OPERATING_MODE == "debug": | |
print(f"time to make all relevant summaries: {time.time() - init_summary_time:.2f} [s]") | |
# no se printea el ultimo caracter porque es un \n | |
print(partial_summary_time[:-1]) | |
context_str = "\n\n".join( | |
[f"{citation}: {summary_of_chunked_text}" | |
for key, citation, summary_of_chunked_text, chunked_text in answer.packages | |
if "Not applicable" not in summary_of_chunked_text] | |
) | |
chunks_str = "\n\n".join( | |
[f"{citation}: {chunked_text}" | |
for key, citation, summary_of_chunked_text, chunked_text in answer.packages | |
if "Not applicable" not in summary_of_chunked_text] | |
) | |
valid_keys = [key | |
for key, citation, summary_of_chunked_text, chunked_textin in answer.packages | |
if "Not applicable" not in summary_of_chunked_text] | |
if len(valid_keys) > 0: | |
context_str += "\n\nValid keys: " + ", ".join(valid_keys) | |
chunks_str += "\n\nValid keys: " + ", ".join(valid_keys) | |
answer.context = context_str | |
answer.chunks = chunks_str | |
yield answer | |
def query( | |
self, | |
query: str, | |
embedding: Embeddings, | |
chat_history: list[tuple[str, str]], | |
k: int = 10, | |
max_sources: int = 5, | |
length_prompt: str = "about 100 words", | |
marginal_relevance: bool = True, | |
): | |
for answer in self._query( | |
query, | |
embedding, | |
chat_history, | |
k=k, | |
max_sources=max_sources, | |
length_prompt=length_prompt, | |
marginal_relevance=marginal_relevance, | |
): | |
pass | |
return answer | |
def _query( | |
self, | |
query: str, | |
embedding: Embeddings, | |
chat_history: list[tuple[str, str]], | |
k: int, | |
max_sources: int, | |
length_prompt: str, | |
marginal_relevance: bool, | |
): | |
if k < max_sources: | |
k = max_sources + 1 | |
answer = Answer(question=query) | |
messages_qa = [system_message_prompt] | |
if len(chat_history) != 0: | |
for conversation in chat_history: | |
messages_qa.append(HumanMessagePromptTemplate.from_template(conversation[0])) | |
messages_qa.append(AIMessagePromptTemplate.from_template(conversation[1])) | |
messages_qa.append(human_qa_message_prompt) | |
chat_qa_prompt = ChatPromptTemplate.from_messages(messages_qa) | |
self.qa_chain = LLMChain(prompt=chat_qa_prompt, llm=self.llm) | |
for answer in self.get_evidence( | |
answer, | |
embedding, | |
k=k, | |
max_sources=max_sources, | |
marginal_relevance=marginal_relevance, | |
): | |
yield answer | |
references_dict = dict() | |
passages = dict() | |
if len(answer.context) < 10: | |
answer_text = "I cannot answer this question due to insufficient information." | |
else: | |
with get_openai_callback() as cb: | |
init_qa_time = time.time() | |
answer_text = self.qa_chain.run( | |
question=answer.question, context_str=answer.context, length=length_prompt | |
) | |
if OPERATING_MODE == "debug": | |
print(f"time to make the Q&A answer: {time.time() - init_qa_time:.2f} [s]") | |
engine = self.qa_chain.llm.model_kwargs.get('deployment_id') or self.qa_chain.llm.model_name | |
if not answer.tokens: | |
answer.tokens = [{ | |
'engine': engine, | |
'total_tokens': cb.total_tokens}] | |
else: | |
answer.tokens.append({ | |
'engine': engine, | |
'total_tokens': cb.total_tokens | |
}) | |
# it still happens lol | |
if "(Foo2012)" in answer_text: | |
answer_text = answer_text.replace("(Foo2012)", "") | |
for key, citation, summary, text in answer.packages: | |
# do check for whole key (so we don't catch Callahan2019a with Callahan2019) | |
skey = key.split(" ")[0] | |
if skey + " " in answer_text or skey + ")" in answer_text: | |
references_dict[skey] = citation | |
passages[key] = text | |
references_str = "\n\n".join( | |
[f"{i+1}. ({k}): {c}" for i, (k, c) in enumerate(references_dict.items())] | |
) | |
# cost_str = f"{answer_text}\n\n" | |
cost_str = "" | |
itemized_cost = "" | |
total_amount = 0 | |
for d in answer.tokens: | |
total_tokens = d.get('total_tokens') | |
if total_tokens: | |
engine = d.get('engine') | |
key_price = None | |
for key in PRICES.keys(): | |
if re.match(f"{key}", engine): | |
key_price = key | |
break | |
if PRICES.get(key_price): | |
partial_amount = total_tokens / 1000 * PRICES.get(key_price) | |
total_amount += partial_amount | |
itemized_cost += f"- {engine}: {total_tokens} tokens\t ---> ${partial_amount:.4f},\n" | |
else: | |
itemized_cost += f"- {engine}: {total_tokens} tokens,\n" | |
# delete ,\n | |
itemized_cost = itemized_cost[:-2] | |
# add tokens to formatted answer | |
cost_str += f"Total cost: ${total_amount:.4f}\nItemized cost:\n{itemized_cost}" | |
answer.answer = answer_text | |
answer.cost_str = cost_str | |
answer.references = references_str | |
answer.passages = passages | |
yield answer | |