Spaces:
Runtime error
Runtime error
add EKR files
Browse files- app.py +128 -0
- config.yaml +29 -0
- prompts/final_chain_prompt.yaml +17 -0
- prompts/llama7b-knowledge_retriever-custom_qa_prompt.yaml +19 -0
- prompts/qa_prompt.yaml +21 -0
- requirements.txt +30 -0
- src/bulkQA.py +132 -0
- src/document_retrieval.py +311 -0
- utils/model_wrappers/api_gateway.py +260 -0
- utils/model_wrappers/langchain_chat_models.py +465 -0
- utils/model_wrappers/langchain_embeddings.py +309 -0
- utils/model_wrappers/langchain_llms.py +770 -0
- utils/model_wrappers/usage.ipynb +878 -0
- utils/parsing/README.md +285 -0
- utils/parsing/config.yaml +69 -0
- utils/parsing/docker-compose.yaml +30 -0
- utils/parsing/parse_usage.ipynb +228 -0
- utils/parsing/requirements.txt +6 -0
- utils/parsing/sambaparse.py +525 -0
- utils/vectordb/create_vector_db.py +141 -0
- utils/vectordb/vector_db.py +353 -0
- utils/visual/env_utils.py +95 -0
app.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import logging
|
4 |
+
import yaml
|
5 |
+
import gradio as gr
|
6 |
+
import time
|
7 |
+
|
8 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
9 |
+
print(current_dir)
|
10 |
+
|
11 |
+
from src.document_retrieval import DocumentRetrieval
|
12 |
+
from utils.visual.env_utils import env_input_fields, initialize_env_variables, are_credentials_set, save_credentials
|
13 |
+
from utils.parsing.sambaparse import parse_doc_universal # added Petro
|
14 |
+
from utils.vectordb.vector_db import VectorDb
|
15 |
+
|
16 |
+
CONFIG_PATH = os.path.join(current_dir,'config.yaml')
|
17 |
+
PERSIST_DIRECTORY = os.path.join(current_dir,f"data/my-vector-db") # changed to current_dir
|
18 |
+
|
19 |
+
logging.basicConfig(level=logging.INFO)
|
20 |
+
logging.info("Gradio app is running")
|
21 |
+
|
22 |
+
class ChatState:
|
23 |
+
def __init__(self):
|
24 |
+
self.conversation = None
|
25 |
+
self.chat_history = []
|
26 |
+
self.show_sources = True
|
27 |
+
self.sources_history = []
|
28 |
+
self.vectorstore = None
|
29 |
+
self.input_disabled = True
|
30 |
+
self.document_retrieval = None
|
31 |
+
|
32 |
+
chat_state = ChatState()
|
33 |
+
|
34 |
+
chat_state.document_retrieval = DocumentRetrieval()
|
35 |
+
|
36 |
+
def handle_userinput(user_question):
|
37 |
+
if user_question:
|
38 |
+
try:
|
39 |
+
response_time = time.time()
|
40 |
+
response = chat_state.conversation.invoke({"question": user_question})
|
41 |
+
response_time = time.time() - response_time
|
42 |
+
chat_state.chat_history.append((user_question, response["answer"]))
|
43 |
+
|
44 |
+
#sources = set([f'{sd.metadata["filename"]}' for sd in response["source_documents"]])
|
45 |
+
#sources_text = "\n".join([f"{i+1}. {source}" for i, source in enumerate(sources)])
|
46 |
+
#state.sources_history.append(sources_text)
|
47 |
+
|
48 |
+
return chat_state.chat_history, "" #, state.sources_history
|
49 |
+
except Exception as e:
|
50 |
+
return f"An error occurred: {str(e)}", "" #, state.sources_history
|
51 |
+
return chat_state.chat_history, "" #, state.sources_history
|
52 |
+
|
53 |
+
def process_documents(files, save_location=None):
|
54 |
+
try:
|
55 |
+
#for doc in files:
|
56 |
+
_, _, text_chunks = parse_doc_universal(doc=files)
|
57 |
+
print(text_chunks)
|
58 |
+
#text_chunks = chat_state.document_retrieval.parse_doc(files)
|
59 |
+
embeddings = chat_state.document_retrieval.load_embedding_model()
|
60 |
+
collection_name = 'ekr_default_collection' if not config['prod_mode'] else None
|
61 |
+
vectorstore = chat_state.document_retrieval.create_vector_store(text_chunks, embeddings, output_db=save_location, collection_name=collection_name)
|
62 |
+
chat_state.vectorstore = vectorstore
|
63 |
+
chat_state.document_retrieval.init_retriever(vectorstore)
|
64 |
+
chat_state.conversation = chat_state.document_retrieval.get_qa_retrieval_chain()
|
65 |
+
chat_state.input_disabled = False
|
66 |
+
return "Documents processed successfully. You can now ask questions."
|
67 |
+
except Exception as e:
|
68 |
+
return f"An error occurred while processing: {str(e)}"
|
69 |
+
|
70 |
+
def reset_conversation():
|
71 |
+
chat_state.chat_history = []
|
72 |
+
#chat_state.sources_history = []
|
73 |
+
return chat_state.chat_history, ""
|
74 |
+
|
75 |
+
def show_selection(model):
|
76 |
+
return f"You selected: {model}"
|
77 |
+
|
78 |
+
# Read config file
|
79 |
+
with open(CONFIG_PATH, 'r') as yaml_file:
|
80 |
+
config = yaml.safe_load(yaml_file)
|
81 |
+
|
82 |
+
prod_mode = config.get('prod_mode', False)
|
83 |
+
default_collection = 'ekr_default_collection'
|
84 |
+
|
85 |
+
# Load env variables
|
86 |
+
initialize_env_variables(prod_mode)
|
87 |
+
|
88 |
+
caution_text = """⚠️ Note: depending on the size of your document, this could take several minutes.
|
89 |
+
"""
|
90 |
+
|
91 |
+
with gr.Blocks() as demo:
|
92 |
+
#gr.Markdown("# SambaNova Analyst Assistant") # title
|
93 |
+
gr.Markdown("# 🟠 SambaNova Analyst Assistant",
|
94 |
+
elem_id="title")
|
95 |
+
|
96 |
+
gr.Markdown("Powered by SambaNova Cloud. Get your API key [here](https://cloud.sambanova.ai/apis).")
|
97 |
+
|
98 |
+
api_key = gr.Textbox(label="API Key", type="password", placeholder="(Optional) Enter your API key here for more availability")
|
99 |
+
|
100 |
+
# Step 1: Add PDF file
|
101 |
+
gr.Markdown("## 1️⃣ Pick a datasource")
|
102 |
+
docs = gr.File(label="Add PDF file", file_types=["pdf"], file_count="single")
|
103 |
+
|
104 |
+
# Step 2: Process PDF file
|
105 |
+
gr.Markdown(("## 2️⃣ Process your documents and create vector store"))
|
106 |
+
process_btn = gr.Button("🔄 Process")
|
107 |
+
gr.Markdown(caution_text)
|
108 |
+
setup_output = gr.Textbox(label="Setup Output", visible=True)
|
109 |
+
|
110 |
+
process_btn.click(process_documents, inputs=[docs], outputs=setup_output, concurrency_limit=10)
|
111 |
+
#process_save_btn.click(process_documents, inputs=[file_upload, save_location], outputs=setup_output)
|
112 |
+
#load_db_btn.click(load_existing_db, inputs=[db_path], outputs=setup_output)
|
113 |
+
|
114 |
+
# Step 3: Chat with your data
|
115 |
+
gr.Markdown("## 3️⃣ Chat")
|
116 |
+
chatbot = gr.Chatbot(label="Chatbot", show_label=True, show_share_button=False, show_copy_button=True, likeable=True)
|
117 |
+
msg = gr.Textbox(label="Ask questions about your data", placeholder="Enter your message...")
|
118 |
+
clear = gr.Button("Clear chat")
|
119 |
+
#show_sources = gr.Checkbox(label="Show sources", value=True)
|
120 |
+
sources_output = gr.Textbox(label="Sources", visible=False)
|
121 |
+
|
122 |
+
#msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, sources_output])
|
123 |
+
msg.submit(handle_userinput, inputs=[msg], outputs=[chatbot, msg])
|
124 |
+
clear.click(reset_conversation, outputs=[chatbot,msg])
|
125 |
+
#show_sources.change(lambda x: gr.update(visible=x), show_sources, sources_output)
|
126 |
+
|
127 |
+
if __name__ == "__main__":
|
128 |
+
demo.launch()
|
config.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
api: "sncloud" # set either sambastudio or sncloud
|
2 |
+
|
3 |
+
embedding_model:
|
4 |
+
"type": "cpu" # set either sambastudio or cpu
|
5 |
+
"batch_size": 1 #set depending of your endpoint configuration (1 if CoE embedding expert)
|
6 |
+
"coe": True #set true if using Sambastudio embeddings in a CoE endpoint
|
7 |
+
"select_expert": "e5-mistral-7b-instruct" #set if using SambaStudio CoE embedding expert
|
8 |
+
|
9 |
+
llm:
|
10 |
+
"temperature": 0.0
|
11 |
+
"do_sample": False
|
12 |
+
"max_tokens_to_generate": 1200
|
13 |
+
"coe": True #set as true if using Sambastudio CoE endpoint
|
14 |
+
"select_expert": "llama3-8b" #set if using sncloud, SambaStudio CoE llm expert
|
15 |
+
#sncloud CoE expert name -> "llama3-8b"
|
16 |
+
|
17 |
+
retrieval:
|
18 |
+
"k_retrieved_documents": 15 #set if rerank enabled
|
19 |
+
"score_threshold": 0.2
|
20 |
+
"rerank": False # set if you want to rerank retriever results
|
21 |
+
"reranker": 'BAAI/bge-reranker-large' # set if you rerank enabled
|
22 |
+
"final_k_retrieved_documents": 5
|
23 |
+
|
24 |
+
pdf_only_mode: True # Set to true for PDF-only mode, false for all file types
|
25 |
+
prod_mode: False
|
26 |
+
|
27 |
+
prompts:
|
28 |
+
"qa_prompt": "prompts/qa_prompt.yaml"
|
29 |
+
"final_chain_prompt": "prompts/final_chain_prompt.yaml"
|
prompts/final_chain_prompt.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_type: prompt
|
2 |
+
input_types: {}
|
3 |
+
input_variables:
|
4 |
+
- question
|
5 |
+
- answers
|
6 |
+
name: null
|
7 |
+
output_parser: null
|
8 |
+
partial_variables: {}
|
9 |
+
template: |
|
10 |
+
<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are an assistant for question-answering tasks.
|
11 |
+
Use the following intermediate answers, provide a final answer to the original question. If you cannot answer based on the intermediate answers provided to you, say "Whoops! I don't know!". <|eot_id|><|start_header_id|>user<|end_header_id|>
|
12 |
+
Original Question: {question}
|
13 |
+
Intermediate Answers: {answers}
|
14 |
+
\n ------- \n
|
15 |
+
Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
16 |
+
template_format: f-string
|
17 |
+
validate_template: false
|
prompts/llama7b-knowledge_retriever-custom_qa_prompt.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_type: prompt
|
2 |
+
input_types: {}
|
3 |
+
input_variables:
|
4 |
+
- context
|
5 |
+
- question
|
6 |
+
name: null
|
7 |
+
output_parser: null
|
8 |
+
partial_variables: {}
|
9 |
+
template: "[INST]<<SYS>> You are a helpful assistant for question-answering tasks.\
|
10 |
+
\ Use the following pieces of retrieved context to answer the question.\n \
|
11 |
+
\ each piece of context includes the Source for reference\n if the question \
|
12 |
+
\ references a specific source then filter out that source and give a response based on that source\n If\
|
13 |
+
\ the answer is not in the context, say that you don't know. Cross check if the\
|
14 |
+
\ answer is contained in provided context. If not than say \"I do not have information\
|
15 |
+
\ regarding this.\n Do not use images or emojis in your answer. Keep the answer\
|
16 |
+
\ conversational and professional.<</SYS>>\n\n {context} \n \n Question:\
|
17 |
+
\ {question} \n Helpful answer: [/INST]"
|
18 |
+
template_format: f-string
|
19 |
+
validate_template: false
|
prompts/qa_prompt.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_type: prompt
|
2 |
+
input_types: {}
|
3 |
+
input_variables:
|
4 |
+
- context
|
5 |
+
- question
|
6 |
+
name: null
|
7 |
+
output_parser: null
|
8 |
+
partial_variables: {}
|
9 |
+
template: |
|
10 |
+
<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a knowledge base assistant chatbot powered by Sambanova's AI chip accelerator, designed to answer questions based on user-uploaded documents.
|
11 |
+
Use the following pieces of retrieved context to answer the question. Each piece of context includes the Source for reference. If the question references a specific source, then filter out that source and give a response based on that source.
|
12 |
+
If the answer is not in the context, say: "This information isn't in my current knowledge base." Then, suggest a related topic you can discuss based on the available context.
|
13 |
+
Maintain a professional yet conversational tone. Do not use images or emojis in your answer.
|
14 |
+
Prioritize accuracy and only provide information directly supported by the context. <|eot_id|><|start_header_id|>user<|end_header_id|>
|
15 |
+
Question: {question}
|
16 |
+
Context: {context}
|
17 |
+
\n ------- \n
|
18 |
+
Answer: <|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
19 |
+
|
20 |
+
template_format: f-string
|
21 |
+
validate_template: false
|
requirements.txt
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
streamlit==1.36.0
|
2 |
+
pydantic==2.7.0
|
3 |
+
pydantic_core==2.18.1
|
4 |
+
|
5 |
+
langchain==0.2.16
|
6 |
+
langchain-core==0.2.38
|
7 |
+
langchain-community==0.2.16
|
8 |
+
|
9 |
+
sentence_transformers==2.2.2
|
10 |
+
instructorembedding==1.0.1
|
11 |
+
faiss-cpu==1.7.4
|
12 |
+
python-dotenv==1.0.0
|
13 |
+
streamlit-extras==0.4.3
|
14 |
+
pillow==10.4.0
|
15 |
+
sseclient-py==1.8.0
|
16 |
+
# unstructured==0.14.9
|
17 |
+
# unstructured_inference==0.7.36
|
18 |
+
# unstructured_pytesseract==0.3.12
|
19 |
+
# pytesseract==0.3.10
|
20 |
+
chromadb==0.5.3
|
21 |
+
langgraph==0.0.55
|
22 |
+
openpyxl==3.1.4
|
23 |
+
psutil==6.0.0
|
24 |
+
pillow_heif==0.16.0
|
25 |
+
ipython==8.26.0
|
26 |
+
PyMuPDF==1.23.4
|
27 |
+
PyMuPDFb==1.23.3
|
28 |
+
|
29 |
+
#LLM Eval
|
30 |
+
weave==0.51.1
|
src/bulkQA.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import argparse
|
4 |
+
import pandas as pd
|
5 |
+
import time
|
6 |
+
from typing import Any, Dict, Optional
|
7 |
+
from langchain_core.callbacks import CallbackManagerForChainRun
|
8 |
+
from langchain.prompts import load_prompt
|
9 |
+
from langchain_core.output_parsers import StrOutputParser
|
10 |
+
from transformers import AutoTokenizer
|
11 |
+
|
12 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
13 |
+
kit_dir = os.path.abspath(os.path.join(current_dir, ".."))
|
14 |
+
repo_dir = os.path.abspath(os.path.join(kit_dir, ".."))
|
15 |
+
|
16 |
+
sys.path.append(kit_dir)
|
17 |
+
sys.path.append(repo_dir)
|
18 |
+
|
19 |
+
from enterprise_knowledge_retriever.src.document_retrieval import DocumentRetrieval, RetrievalQAChain
|
20 |
+
|
21 |
+
class TimedRetrievalQAChain(RetrievalQAChain):
|
22 |
+
#override call method to return times
|
23 |
+
def _call(self,
|
24 |
+
inputs: Dict[str, Any],
|
25 |
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
26 |
+
) -> Dict[str, Any]:
|
27 |
+
qa_chain = self.qa_prompt | self.llm | StrOutputParser()
|
28 |
+
response = {}
|
29 |
+
start_time = time.time()
|
30 |
+
documents = self.retriever.invoke(inputs["question"])
|
31 |
+
if self.rerank:
|
32 |
+
documents = self.rerank_docs(inputs["question"], documents, self.final_k_retrieved_documents)
|
33 |
+
docs = self._format_docs(documents)
|
34 |
+
end_preprocessing_time=time.time()
|
35 |
+
response["answer"] = qa_chain.invoke({"question": inputs["question"], "context": docs})
|
36 |
+
end_llm_time=time.time()
|
37 |
+
response["source_documents"] = documents
|
38 |
+
response["start_time"] = start_time
|
39 |
+
response["end_preprocessing_time"] = end_preprocessing_time
|
40 |
+
response["end_llm_time"] = end_llm_time
|
41 |
+
return response
|
42 |
+
|
43 |
+
def analyze_times(answer, start_time, end_preprocessing_time, end_llm_time, tokenizer):
|
44 |
+
preprocessing_time=end_preprocessing_time-start_time
|
45 |
+
llm_time=end_llm_time-end_preprocessing_time
|
46 |
+
token_count=len(tokenizer.encode(answer))
|
47 |
+
tokens_per_second = token_count / llm_time
|
48 |
+
perf = {"preprocessing_time": preprocessing_time,
|
49 |
+
"llm_time": llm_time,
|
50 |
+
"token_count": token_count,
|
51 |
+
"tokens_per_second": tokens_per_second}
|
52 |
+
return perf
|
53 |
+
|
54 |
+
def generate(qa_chain, question, tokenizer):
|
55 |
+
response = qa_chain.invoke({"question": question})
|
56 |
+
answer = response.get('answer')
|
57 |
+
sources = set([
|
58 |
+
f'{sd.metadata["filename"]}'
|
59 |
+
for sd in response["source_documents"]
|
60 |
+
])
|
61 |
+
times = analyze_times(
|
62 |
+
answer,
|
63 |
+
response.get("start_time"),
|
64 |
+
response.get("end_preprocessing_time"),
|
65 |
+
response.get("end_llm_time"),
|
66 |
+
tokenizer
|
67 |
+
)
|
68 |
+
return answer, sources, times
|
69 |
+
|
70 |
+
def process_bulk_QA(vectordb_path, questions_file_path):
|
71 |
+
documentRetrieval = DocumentRetrieval()
|
72 |
+
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf")
|
73 |
+
if os.path.exists(vectordb_path):
|
74 |
+
# load the vectorstore
|
75 |
+
embeddings = documentRetrieval.load_embedding_model()
|
76 |
+
vectorstore = documentRetrieval.load_vdb(vectordb_path, embeddings)
|
77 |
+
print("Database loaded")
|
78 |
+
documentRetrieval.init_retriever(vectorstore)
|
79 |
+
print("retriever initialized")
|
80 |
+
#get qa chain
|
81 |
+
qa_chain = TimedRetrievalQAChain(
|
82 |
+
retriever=documentRetrieval.retriever,
|
83 |
+
llm=documentRetrieval.llm,
|
84 |
+
qa_prompt = load_prompt(os.path.join(kit_dir, documentRetrieval.prompts["qa_prompt"])),
|
85 |
+
rerank = documentRetrieval.retrieval_info["rerank"],
|
86 |
+
final_k_retrieved_documents = documentRetrieval.retrieval_info["final_k_retrieved_documents"]
|
87 |
+
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
raise f"vector db path {vectordb_path} does not exist"
|
91 |
+
if os.path.exists(questions_file_path):
|
92 |
+
df = pd.read_excel(questions_file_path)
|
93 |
+
print(df)
|
94 |
+
output_file_path = questions_file_path.replace('.xlsx', '_output.xlsx')
|
95 |
+
if 'Answer' not in df.columns:
|
96 |
+
df['Answer'] = ''
|
97 |
+
df['Sources'] = ''
|
98 |
+
df['preprocessing_time'] = ''
|
99 |
+
df['llm_time'] = ''
|
100 |
+
df['token_count'] = ''
|
101 |
+
df['tokens_per_second'] = ''
|
102 |
+
for index, row in df.iterrows():
|
103 |
+
if row['Answer'].strip()=='': # Only process if 'Answer' is empty
|
104 |
+
try:
|
105 |
+
# Generate the answer
|
106 |
+
print(f"Generating answer for row {index}")
|
107 |
+
answer, sources, times = generate(qa_chain, row['Questions'], tokenizer)
|
108 |
+
df.at[index, 'Answer'] = answer
|
109 |
+
df.at[index, 'Sources'] = sources
|
110 |
+
df.at[index, 'preprocessing_time'] = times.get("preprocessing_time")
|
111 |
+
df.at[index, 'llm_time'] = times.get("llm_time")
|
112 |
+
df.at[index, 'token_count'] = times.get("token_count")
|
113 |
+
df.at[index, 'tokens_per_second'] = times.get("tokens_per_second")
|
114 |
+
except Exception as e:
|
115 |
+
print(f"Error processing row {index}: {e}")
|
116 |
+
# Save the file after each iteration to avoid data loss
|
117 |
+
df.to_excel(output_file_path, index=False)
|
118 |
+
else:
|
119 |
+
print(f"Skipping row {index} because 'Answer' is already in the document")
|
120 |
+
return output_file_path
|
121 |
+
else:
|
122 |
+
raise f"questions file path {questions_file_path} does not exist"
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
# Parse the arguments
|
126 |
+
parser = argparse.ArgumentParser(description='use a vectordb and an excel file with questions in the first column and generate answers for all the questions')
|
127 |
+
parser.add_argument('vectordb_path', type=str, help='vector db path with stored documents for RAG')
|
128 |
+
parser.add_argument('questions_path', type=str, help='xlsx file containing questions in a column named Questions')
|
129 |
+
args = parser.parse_args()
|
130 |
+
# process in bulk
|
131 |
+
out_file = process_bulk_QA(args.vectordb_path, args.questions_path)
|
132 |
+
print(f"Finished, responses in: {out_file}")
|
src/document_retrieval.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import sys
|
4 |
+
from typing import Any, Dict, List, Optional
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import yaml
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
from langchain.chains.base import Chain
|
10 |
+
from langchain.docstore.document import Document
|
11 |
+
from langchain.prompts import BasePromptTemplate, load_prompt
|
12 |
+
from langchain_core.callbacks import CallbackManagerForChainRun
|
13 |
+
from langchain_core.language_models import BaseLanguageModel
|
14 |
+
from langchain_core.output_parsers import StrOutputParser
|
15 |
+
from langchain_core.retrievers import BaseRetriever
|
16 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
17 |
+
|
18 |
+
current_dir = os.path.dirname(os.path.abspath(__file__)) # src/ directory
|
19 |
+
kit_dir = os.path.abspath(os.path.join(current_dir, '..')) # EKR/ directory
|
20 |
+
repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))
|
21 |
+
sys.path.append(kit_dir)
|
22 |
+
sys.path.append(repo_dir)
|
23 |
+
|
24 |
+
import streamlit as st
|
25 |
+
|
26 |
+
from utils.model_wrappers.api_gateway import APIGateway
|
27 |
+
from utils.vectordb.vector_db import VectorDb
|
28 |
+
from utils.visual.env_utils import get_wandb_key
|
29 |
+
|
30 |
+
CONFIG_PATH = os.path.join(kit_dir, 'config.yaml')
|
31 |
+
PERSIST_DIRECTORY = os.path.join(kit_dir, 'data/my-vector-db')
|
32 |
+
|
33 |
+
load_dotenv(os.path.join(kit_dir, '.env'))
|
34 |
+
|
35 |
+
|
36 |
+
from utils.parsing.sambaparse import parse_doc_universal
|
37 |
+
|
38 |
+
# Handle the WANDB_API_KEY resolution before importing weave
|
39 |
+
#wandb_api_key = get_wandb_key()
|
40 |
+
|
41 |
+
# If WANDB_API_KEY is set, proceed with weave initialization
|
42 |
+
#if wandb_api_key:
|
43 |
+
# import weave
|
44 |
+
|
45 |
+
# Initialize Weave with your project name
|
46 |
+
# weave.init('sambanova_ekr')
|
47 |
+
#else:
|
48 |
+
# print('WANDB_API_KEY is not set. Weave initialization skipped.')
|
49 |
+
|
50 |
+
|
51 |
+
class RetrievalQAChain(Chain):
|
52 |
+
"""class for question-answering."""
|
53 |
+
|
54 |
+
retriever: BaseRetriever
|
55 |
+
rerank: bool = True
|
56 |
+
llm: BaseLanguageModel
|
57 |
+
qa_prompt: BasePromptTemplate
|
58 |
+
final_k_retrieved_documents: int = 3
|
59 |
+
|
60 |
+
@property
|
61 |
+
def input_keys(self) -> List[str]:
|
62 |
+
"""Input keys.
|
63 |
+
:meta private:
|
64 |
+
"""
|
65 |
+
return ['question']
|
66 |
+
|
67 |
+
@property
|
68 |
+
def output_keys(self) -> List[str]:
|
69 |
+
"""Output keys.
|
70 |
+
:meta private:
|
71 |
+
"""
|
72 |
+
return ['answer', 'source_documents']
|
73 |
+
|
74 |
+
def _format_docs(self, docs):
|
75 |
+
return '\n\n'.join(doc.page_content for doc in docs)
|
76 |
+
|
77 |
+
def rerank_docs(self, query, docs, final_k):
|
78 |
+
# Lazy hardcoding for now
|
79 |
+
tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-reranker-large')
|
80 |
+
reranker = AutoModelForSequenceClassification.from_pretrained('BAAI/bge-reranker-large')
|
81 |
+
pairs = []
|
82 |
+
for d in docs:
|
83 |
+
pairs.append([query, d.page_content])
|
84 |
+
|
85 |
+
with torch.no_grad():
|
86 |
+
inputs = tokenizer(
|
87 |
+
pairs,
|
88 |
+
padding=True,
|
89 |
+
truncation=True,
|
90 |
+
return_tensors='pt',
|
91 |
+
max_length=512,
|
92 |
+
)
|
93 |
+
scores = (
|
94 |
+
reranker(**inputs, return_dict=True)
|
95 |
+
.logits.view(
|
96 |
+
-1,
|
97 |
+
)
|
98 |
+
.float()
|
99 |
+
)
|
100 |
+
|
101 |
+
scores_list = scores.tolist()
|
102 |
+
scores_sorted_idx = sorted(range(len(scores_list)), key=lambda k: scores_list[k], reverse=True)
|
103 |
+
|
104 |
+
docs_sorted = [docs[k] for k in scores_sorted_idx]
|
105 |
+
# docs_sorted = [docs[k] for k in scores_sorted_idx if scores_list[k]>0]
|
106 |
+
docs_sorted = docs_sorted[:final_k]
|
107 |
+
|
108 |
+
return docs_sorted
|
109 |
+
|
110 |
+
def _call(
|
111 |
+
self,
|
112 |
+
inputs: Dict[str, Any],
|
113 |
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
114 |
+
) -> Dict[str, Any]:
|
115 |
+
qa_chain = self.qa_prompt | self.llm | StrOutputParser()
|
116 |
+
response = {}
|
117 |
+
documents = self.retriever.invoke(inputs['question'])
|
118 |
+
if self.rerank:
|
119 |
+
documents = self.rerank_docs(inputs['question'], documents, self.final_k_retrieved_documents)
|
120 |
+
docs = self._format_docs(documents)
|
121 |
+
response['answer'] = qa_chain.invoke({'question': inputs['question'], 'context': docs})
|
122 |
+
response['source_documents'] = documents
|
123 |
+
return response
|
124 |
+
|
125 |
+
|
126 |
+
class DocumentRetrieval:
|
127 |
+
def __init__(self):
|
128 |
+
self.vectordb = VectorDb()
|
129 |
+
config_info = self.get_config_info()
|
130 |
+
self.api_info = config_info[0]
|
131 |
+
self.llm_info = config_info[1]
|
132 |
+
self.embedding_model_info = config_info[2]
|
133 |
+
self.retrieval_info = config_info[3]
|
134 |
+
self.prompts = config_info[4]
|
135 |
+
self.prod_mode = config_info[5]
|
136 |
+
self.retriever = None
|
137 |
+
self.llm = self.set_llm()
|
138 |
+
|
139 |
+
def get_config_info(self):
|
140 |
+
"""
|
141 |
+
Loads json config file
|
142 |
+
"""
|
143 |
+
# Read config file
|
144 |
+
with open(CONFIG_PATH, 'r') as yaml_file:
|
145 |
+
config = yaml.safe_load(yaml_file)
|
146 |
+
api_info = config['api']
|
147 |
+
llm_info = config['llm']
|
148 |
+
embedding_model_info = config['embedding_model']
|
149 |
+
retrieval_info = config['retrieval']
|
150 |
+
prompts = config['prompts']
|
151 |
+
prod_mode = config['prod_mode']
|
152 |
+
|
153 |
+
return api_info, llm_info, embedding_model_info, retrieval_info, prompts, prod_mode
|
154 |
+
|
155 |
+
def set_llm(self):
|
156 |
+
if self.prod_mode:
|
157 |
+
sambanova_api_key = st.session_state.SAMBANOVA_API_KEY
|
158 |
+
else:
|
159 |
+
if 'SAMBANOVA_API_KEY' in st.session_state:
|
160 |
+
sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY') or st.session_state.SAMBANOVA_API_KEY
|
161 |
+
else:
|
162 |
+
sambanova_api_key = os.environ.get('SAMBANOVA_API_KEY')
|
163 |
+
|
164 |
+
llm = APIGateway.load_llm(
|
165 |
+
type=self.api_info,
|
166 |
+
streaming=True,
|
167 |
+
coe=self.llm_info['coe'],
|
168 |
+
do_sample=self.llm_info['do_sample'],
|
169 |
+
max_tokens_to_generate=self.llm_info['max_tokens_to_generate'],
|
170 |
+
temperature=self.llm_info['temperature'],
|
171 |
+
select_expert=self.llm_info['select_expert'],
|
172 |
+
process_prompt=False,
|
173 |
+
sambanova_api_key=sambanova_api_key,
|
174 |
+
)
|
175 |
+
return llm
|
176 |
+
|
177 |
+
def parse_doc(self, docs: List, additional_metadata: Optional[Dict] = None) -> List[Document]:
|
178 |
+
"""
|
179 |
+
Parse the uploaded documents and return a list of LangChain documents.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
docs (List[UploadFile]): A list of uploaded files.
|
183 |
+
additional_metadata (Optional[Dict], optional): Additional metadata to include in the processed documents.
|
184 |
+
Defaults to an empty dictionary.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
List[Document]: A list of LangChain documents.
|
188 |
+
"""
|
189 |
+
if additional_metadata is None:
|
190 |
+
additional_metadata = {}
|
191 |
+
|
192 |
+
# Create the data/tmp folder if it doesn't exist
|
193 |
+
temp_folder = os.path.join(kit_dir, 'data/tmp')
|
194 |
+
if not os.path.exists(temp_folder):
|
195 |
+
os.makedirs(temp_folder)
|
196 |
+
else:
|
197 |
+
# If there are already files there, delete them
|
198 |
+
for filename in os.listdir(temp_folder):
|
199 |
+
file_path = os.path.join(temp_folder, filename)
|
200 |
+
try:
|
201 |
+
if os.path.isfile(file_path) or os.path.islink(file_path):
|
202 |
+
os.unlink(file_path)
|
203 |
+
elif os.path.isdir(file_path):
|
204 |
+
shutil.rmtree(file_path)
|
205 |
+
except Exception as e:
|
206 |
+
print(f'Failed to delete {file_path}. Reason: {e}')
|
207 |
+
|
208 |
+
# Save all selected files to the tmp dir with their file names
|
209 |
+
#for doc in docs:
|
210 |
+
# temp_file = os.path.join(temp_folder, doc.name)
|
211 |
+
# with open(temp_file, 'wb') as f:
|
212 |
+
# f.write(doc.getvalue())
|
213 |
+
|
214 |
+
for doc_info in docs:
|
215 |
+
file_name, file_obj = doc_info
|
216 |
+
temp_file = os.path.join(temp_folder, file_name)
|
217 |
+
with open(temp_file, 'wb') as f:
|
218 |
+
f.write(file_obj.read())
|
219 |
+
|
220 |
+
# Pass in the temp folder for processing into the parse_doc_universal function
|
221 |
+
_, _, langchain_docs = parse_doc_universal(doc=temp_folder, additional_metadata=additional_metadata)
|
222 |
+
return langchain_docs
|
223 |
+
|
224 |
+
def load_embedding_model(self):
|
225 |
+
embeddings = APIGateway.load_embedding_model(
|
226 |
+
type=self.embedding_model_info['type'],
|
227 |
+
batch_size=self.embedding_model_info['batch_size'],
|
228 |
+
coe=self.embedding_model_info['coe'],
|
229 |
+
select_expert=self.embedding_model_info['select_expert'],
|
230 |
+
)
|
231 |
+
return embeddings
|
232 |
+
|
233 |
+
def create_vector_store(self, text_chunks, embeddings, output_db=None, collection_name=None):
|
234 |
+
print(f'Collection name is {collection_name}')
|
235 |
+
vectorstore = self.vectordb.create_vector_store(
|
236 |
+
text_chunks, embeddings, output_db=output_db, collection_name=collection_name, db_type='chroma'
|
237 |
+
)
|
238 |
+
return vectorstore
|
239 |
+
|
240 |
+
def load_vdb(self, db_path, embeddings, collection_name=None):
|
241 |
+
print(f'Loading collection name is {collection_name}')
|
242 |
+
vectorstore = self.vectordb.load_vdb(db_path, embeddings, db_type='chroma', collection_name=collection_name)
|
243 |
+
return vectorstore
|
244 |
+
|
245 |
+
def init_retriever(self, vectorstore):
|
246 |
+
if self.retrieval_info['rerank']:
|
247 |
+
self.retriever = vectorstore.as_retriever(
|
248 |
+
search_type='similarity_score_threshold',
|
249 |
+
search_kwargs={
|
250 |
+
'score_threshold': self.retrieval_info['score_threshold'],
|
251 |
+
'k': self.retrieval_info['k_retrieved_documents'],
|
252 |
+
},
|
253 |
+
)
|
254 |
+
else:
|
255 |
+
self.retriever = vectorstore.as_retriever(
|
256 |
+
search_type='similarity_score_threshold',
|
257 |
+
search_kwargs={
|
258 |
+
'score_threshold': self.retrieval_info['score_threshold'],
|
259 |
+
'k': self.retrieval_info['final_k_retrieved_documents'],
|
260 |
+
},
|
261 |
+
)
|
262 |
+
|
263 |
+
def get_qa_retrieval_chain(self):
|
264 |
+
"""
|
265 |
+
Generate a qa_retrieval chain using a language model.
|
266 |
+
|
267 |
+
This function uses a language model, specifically a SambaNova LLM, to generate a qa_retrieval chain
|
268 |
+
based on the input vector store of text chunks.
|
269 |
+
|
270 |
+
Parameters:
|
271 |
+
vectorstore (Chroma): A Vector Store containing embeddings of text chunks used as context
|
272 |
+
for generating the conversation chain.
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
RetrievalQA: A chain ready for QA without memory
|
276 |
+
"""
|
277 |
+
# customprompt = load_prompt(os.path.join(kit_dir, self.prompts["qa_prompt"]))
|
278 |
+
# qa_chain = customprompt | self.llm | StrOutputParser()
|
279 |
+
|
280 |
+
# response = {}
|
281 |
+
# documents = self.retriever.invoke(question)
|
282 |
+
# if self.retrieval_info["rerank"]:
|
283 |
+
# documents = self.rerank_docs(question, documents, self.retrieval_info["final_k_retrieved_documents"])
|
284 |
+
# docs = self._format_docs(documents)
|
285 |
+
|
286 |
+
# response["answer"] = qa_chain.invoke({"question": question, "context": docs})
|
287 |
+
# response["source_documents"] = documents
|
288 |
+
|
289 |
+
retrievalQAChain = RetrievalQAChain(
|
290 |
+
retriever=self.retriever,
|
291 |
+
llm=self.llm,
|
292 |
+
qa_prompt=load_prompt(os.path.join(kit_dir, self.prompts['qa_prompt'])),
|
293 |
+
rerank=self.retrieval_info['rerank'],
|
294 |
+
final_k_retrieved_documents=self.retrieval_info['final_k_retrieved_documents'],
|
295 |
+
)
|
296 |
+
return retrievalQAChain
|
297 |
+
|
298 |
+
def get_conversational_qa_retrieval_chain(self):
|
299 |
+
"""
|
300 |
+
Generate a conversational retrieval qa chain using a language model.
|
301 |
+
|
302 |
+
This function uses a language model, specifically a SambaNova LLM, to generate a conversational_qa_retrieval chain
|
303 |
+
based on the chat history and the relevant retrieved content from the input vector store of text chunks.
|
304 |
+
|
305 |
+
Parameters:
|
306 |
+
vectorstore (Chroma): A Vector Store containing embeddings of text chunks used as context
|
307 |
+
for generating the conversation chain.
|
308 |
+
|
309 |
+
Returns:
|
310 |
+
RetrievalQA: A chain ready for QA with memory
|
311 |
+
"""
|
utils/model_wrappers/api_gateway.py
ADDED
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
from typing import Optional, Dict
|
5 |
+
|
6 |
+
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
|
7 |
+
from langchain_core.embeddings import Embeddings
|
8 |
+
from langchain_core.language_models.llms import LLM
|
9 |
+
from langchain_core.language_models.chat_models import BaseChatModel
|
10 |
+
|
11 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
12 |
+
utils_dir = os.path.abspath(os.path.join(current_dir, '..'))
|
13 |
+
repo_dir = os.path.abspath(os.path.join(utils_dir, '..'))
|
14 |
+
sys.path.append(utils_dir)
|
15 |
+
sys.path.append(repo_dir)
|
16 |
+
|
17 |
+
from utils.model_wrappers.langchain_embeddings import SambaStudioEmbeddings
|
18 |
+
from utils.model_wrappers.langchain_llms import SambaStudio
|
19 |
+
from utils.model_wrappers.langchain_llms import SambaNovaCloud
|
20 |
+
from utils.model_wrappers.langchain_chat_models import ChatSambaNovaCloud
|
21 |
+
|
22 |
+
EMBEDDING_MODEL = 'intfloat/e5-large-v2'
|
23 |
+
NORMALIZE_EMBEDDINGS = True
|
24 |
+
|
25 |
+
# Configure the logger
|
26 |
+
logging.basicConfig(
|
27 |
+
level=logging.INFO,
|
28 |
+
format='%(asctime)s [%(levelname)s] - %(message)s',
|
29 |
+
handlers=[
|
30 |
+
logging.StreamHandler(),
|
31 |
+
],
|
32 |
+
)
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
|
36 |
+
class APIGateway:
|
37 |
+
@staticmethod
|
38 |
+
def load_embedding_model(
|
39 |
+
type: str = 'cpu',
|
40 |
+
batch_size: Optional[int] = None,
|
41 |
+
coe: bool = False,
|
42 |
+
select_expert: Optional[str] = None,
|
43 |
+
sambastudio_embeddings_base_url: Optional[str] = None,
|
44 |
+
sambastudio_embeddings_base_uri: Optional[str] = None,
|
45 |
+
sambastudio_embeddings_project_id: Optional[str] = None,
|
46 |
+
sambastudio_embeddings_endpoint_id: Optional[str] = None,
|
47 |
+
sambastudio_embeddings_api_key: Optional[str] = None,
|
48 |
+
) -> Embeddings:
|
49 |
+
"""Loads a langchain embedding model given a type and parameters
|
50 |
+
Args:
|
51 |
+
type (str): wether to use sambastudio embedding model or in local cpu model
|
52 |
+
batch_size (int, optional): batch size for sambastudio model. Defaults to None.
|
53 |
+
coe (bool, optional): whether to use coe model. Defaults to False. only for sambastudio models
|
54 |
+
select_expert (str, optional): expert model to be used when coe selected. Defaults to None.
|
55 |
+
only for sambastudio models.
|
56 |
+
sambastudio_embeddings_base_url (str, optional): base url for sambastudio model. Defaults to None.
|
57 |
+
sambastudio_embeddings_base_uri (str, optional): endpoint base uri for sambastudio model. Defaults to None.
|
58 |
+
sambastudio_embeddings_project_id (str, optional): project id for sambastudio model. Defaults to None.
|
59 |
+
sambastudio_embeddings_endpoint_id (str, optional): endpoint id for sambastudio model. Defaults to None.
|
60 |
+
sambastudio_embeddings_api_key (str, optional): api key for sambastudio model. Defaults to None.
|
61 |
+
Returns:
|
62 |
+
langchain embedding model
|
63 |
+
"""
|
64 |
+
|
65 |
+
if type == 'sambastudio':
|
66 |
+
envs = {
|
67 |
+
'sambastudio_embeddings_base_url': sambastudio_embeddings_base_url,
|
68 |
+
'sambastudio_embeddings_base_uri': sambastudio_embeddings_base_uri,
|
69 |
+
'sambastudio_embeddings_project_id': sambastudio_embeddings_project_id,
|
70 |
+
'sambastudio_embeddings_endpoint_id': sambastudio_embeddings_endpoint_id,
|
71 |
+
'sambastudio_embeddings_api_key': sambastudio_embeddings_api_key,
|
72 |
+
}
|
73 |
+
envs = {k: v for k, v in envs.items() if v is not None}
|
74 |
+
|
75 |
+
if coe:
|
76 |
+
if batch_size is None:
|
77 |
+
batch_size = 1
|
78 |
+
embeddings = SambaStudioEmbeddings(
|
79 |
+
**envs, batch_size=batch_size, model_kwargs={'select_expert': select_expert}
|
80 |
+
)
|
81 |
+
else:
|
82 |
+
if batch_size is None:
|
83 |
+
batch_size = 32
|
84 |
+
embeddings = SambaStudioEmbeddings(**envs, batch_size=batch_size)
|
85 |
+
elif type == 'cpu':
|
86 |
+
encode_kwargs = {'normalize_embeddings': NORMALIZE_EMBEDDINGS}
|
87 |
+
embedding_model = EMBEDDING_MODEL
|
88 |
+
embeddings = HuggingFaceInstructEmbeddings(
|
89 |
+
model_name=embedding_model,
|
90 |
+
embed_instruction='', # no instruction is needed for candidate passages
|
91 |
+
query_instruction='Represent this sentence for searching relevant passages: ',
|
92 |
+
encode_kwargs=encode_kwargs,
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
raise ValueError(f'{type} is not a valid embedding model type')
|
96 |
+
|
97 |
+
return embeddings
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def load_llm(
|
101 |
+
type: str,
|
102 |
+
streaming: bool = False,
|
103 |
+
coe: bool = False,
|
104 |
+
do_sample: Optional[bool] = None,
|
105 |
+
max_tokens_to_generate: Optional[int] = None,
|
106 |
+
temperature: Optional[float] = None,
|
107 |
+
select_expert: Optional[str] = None,
|
108 |
+
top_p: Optional[float] = None,
|
109 |
+
top_k: Optional[int] = None,
|
110 |
+
repetition_penalty: Optional[float] = None,
|
111 |
+
stop_sequences: Optional[str] = None,
|
112 |
+
process_prompt: Optional[bool] = False,
|
113 |
+
sambastudio_base_url: Optional[str] = None,
|
114 |
+
sambastudio_base_uri: Optional[str] = None,
|
115 |
+
sambastudio_project_id: Optional[str] = None,
|
116 |
+
sambastudio_endpoint_id: Optional[str] = None,
|
117 |
+
sambastudio_api_key: Optional[str] = None,
|
118 |
+
sambanova_url: Optional[str] = None,
|
119 |
+
sambanova_api_key: Optional[str] = None,
|
120 |
+
) -> LLM:
|
121 |
+
"""Loads a langchain Sambanova llm model given a type and parameters
|
122 |
+
Args:
|
123 |
+
type (str): wether to use sambastudio, or SambaNova Cloud model "sncloud"
|
124 |
+
streaming (bool): wether to use streaming method. Defaults to False.
|
125 |
+
coe (bool): whether to use coe model. Defaults to False.
|
126 |
+
|
127 |
+
do_sample (bool) : Optional wether to do sample.
|
128 |
+
max_tokens_to_generate (int) : Optional max number of tokens to generate.
|
129 |
+
temperature (float) : Optional model temperature.
|
130 |
+
select_expert (str) : Optional expert to use when using CoE models.
|
131 |
+
top_p (float) : Optional model top_p.
|
132 |
+
top_k (int) : Optional model top_k.
|
133 |
+
repetition_penalty (float) : Optional model repetition penalty.
|
134 |
+
stop_sequences (str) : Optional model stop sequences.
|
135 |
+
process_prompt (bool) : Optional default to false.
|
136 |
+
|
137 |
+
sambastudio_base_url (str): Optional SambaStudio environment URL".
|
138 |
+
sambastudio_base_uri (str): Optional SambaStudio-base-URI".
|
139 |
+
sambastudio_project_id (str): Optional SambaStudio project ID.
|
140 |
+
sambastudio_endpoint_id (str): Optional SambaStudio endpoint ID.
|
141 |
+
sambastudio_api_token (str): Optional SambaStudio endpoint API key.
|
142 |
+
|
143 |
+
sambanova_url (str): Optional SambaNova Cloud URL",
|
144 |
+
sambanova_api_key (str): Optional SambaNovaCloud API key.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
langchain llm model
|
148 |
+
"""
|
149 |
+
|
150 |
+
if type == 'sambastudio':
|
151 |
+
envs = {
|
152 |
+
'sambastudio_base_url': sambastudio_base_url,
|
153 |
+
'sambastudio_base_uri': sambastudio_base_uri,
|
154 |
+
'sambastudio_project_id': sambastudio_project_id,
|
155 |
+
'sambastudio_endpoint_id': sambastudio_endpoint_id,
|
156 |
+
'sambastudio_api_key': sambastudio_api_key,
|
157 |
+
}
|
158 |
+
envs = {k: v for k, v in envs.items() if v is not None}
|
159 |
+
if coe:
|
160 |
+
model_kwargs = {
|
161 |
+
'do_sample': do_sample,
|
162 |
+
'max_tokens_to_generate': max_tokens_to_generate,
|
163 |
+
'temperature': temperature,
|
164 |
+
'select_expert': select_expert,
|
165 |
+
'top_p': top_p,
|
166 |
+
'top_k': top_k,
|
167 |
+
'repetition_penalty': repetition_penalty,
|
168 |
+
'stop_sequences': stop_sequences,
|
169 |
+
'process_prompt': process_prompt,
|
170 |
+
}
|
171 |
+
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
|
172 |
+
|
173 |
+
llm = SambaStudio(
|
174 |
+
**envs,
|
175 |
+
streaming=streaming,
|
176 |
+
model_kwargs=model_kwargs,
|
177 |
+
)
|
178 |
+
else:
|
179 |
+
model_kwargs = {
|
180 |
+
'do_sample': do_sample,
|
181 |
+
'max_tokens_to_generate': max_tokens_to_generate,
|
182 |
+
'temperature': temperature,
|
183 |
+
'top_p': top_p,
|
184 |
+
'top_k': top_k,
|
185 |
+
'repetition_penalty': repetition_penalty,
|
186 |
+
'stop_sequences': stop_sequences,
|
187 |
+
}
|
188 |
+
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
|
189 |
+
llm = SambaStudio(
|
190 |
+
**envs,
|
191 |
+
streaming=streaming,
|
192 |
+
model_kwargs=model_kwargs,
|
193 |
+
)
|
194 |
+
|
195 |
+
elif type == 'sncloud':
|
196 |
+
envs = {
|
197 |
+
'sambanova_url': sambanova_url,
|
198 |
+
'sambanova_api_key': sambanova_api_key,
|
199 |
+
}
|
200 |
+
envs = {k: v for k, v in envs.items() if v is not None}
|
201 |
+
llm = SambaNovaCloud(
|
202 |
+
**envs,
|
203 |
+
max_tokens=max_tokens_to_generate,
|
204 |
+
model=select_expert,
|
205 |
+
temperature=temperature,
|
206 |
+
top_k=top_k,
|
207 |
+
top_p=top_p,
|
208 |
+
)
|
209 |
+
|
210 |
+
else:
|
211 |
+
raise ValueError(f"Invalid LLM API: {type}, only 'sncloud' and 'sambastudio' are supported.")
|
212 |
+
|
213 |
+
return llm
|
214 |
+
|
215 |
+
@staticmethod
|
216 |
+
def load_chat(
|
217 |
+
model: str,
|
218 |
+
streaming: bool = False,
|
219 |
+
max_tokens: int = 1024,
|
220 |
+
temperature: Optional[float] = 0.0,
|
221 |
+
top_p: Optional[float] = None,
|
222 |
+
top_k: Optional[int] = None,
|
223 |
+
stream_options: Optional[Dict[str, bool]] = {"include_usage": True},
|
224 |
+
sambanova_url: Optional[str] = None,
|
225 |
+
sambanova_api_key: Optional[str] = None,
|
226 |
+
) -> BaseChatModel:
|
227 |
+
"""
|
228 |
+
Loads a langchain SambanovaCloud chat model given some parameters
|
229 |
+
Args:
|
230 |
+
model (str): The name of the model to use, e.g., llama3-8b.
|
231 |
+
streaming (bool): whether to use streaming method. Defaults to False.
|
232 |
+
max_tokens (int) : Optional max number of tokens to generate.
|
233 |
+
temperature (float) : Optional model temperature.
|
234 |
+
top_p (float) : Optional model top_p.
|
235 |
+
top_k (int) : Optional model top_k.
|
236 |
+
stream_options (dict) : stream options, include usage to get generation metrics
|
237 |
+
|
238 |
+
sambanova_url (str): Optional SambaNova Cloud URL",
|
239 |
+
sambanova_api_key (str): Optional SambaNovaCloud API key.
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
langchain BaseChatModel
|
243 |
+
"""
|
244 |
+
|
245 |
+
envs = {
|
246 |
+
'sambanova_url': sambanova_url,
|
247 |
+
'sambanova_api_key': sambanova_api_key,
|
248 |
+
}
|
249 |
+
envs = {k: v for k, v in envs.items() if v is not None}
|
250 |
+
model = ChatSambaNovaCloud(
|
251 |
+
**envs,
|
252 |
+
model= model,
|
253 |
+
streaming=streaming,
|
254 |
+
max_tokens=max_tokens,
|
255 |
+
temperature=temperature,
|
256 |
+
top_k=top_k,
|
257 |
+
top_p=top_p,
|
258 |
+
stream_options=stream_options
|
259 |
+
)
|
260 |
+
return model
|
utils/model_wrappers/langchain_chat_models.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import Any, Dict, Iterator, List, Optional
|
3 |
+
|
4 |
+
import requests
|
5 |
+
from langchain_core.callbacks import (
|
6 |
+
CallbackManagerForLLMRun,
|
7 |
+
)
|
8 |
+
from langchain_core.language_models.chat_models import (
|
9 |
+
BaseChatModel,
|
10 |
+
generate_from_stream,
|
11 |
+
)
|
12 |
+
from langchain_core.messages import (
|
13 |
+
AIMessage,
|
14 |
+
AIMessageChunk,
|
15 |
+
BaseMessage,
|
16 |
+
ChatMessage,
|
17 |
+
HumanMessage,
|
18 |
+
SystemMessage,
|
19 |
+
ToolMessage,
|
20 |
+
)
|
21 |
+
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
22 |
+
from langchain_core.pydantic_v1 import Field, SecretStr
|
23 |
+
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
24 |
+
|
25 |
+
|
26 |
+
class ChatSambaNovaCloud(BaseChatModel):
|
27 |
+
"""
|
28 |
+
SambaNova Cloud chat model.
|
29 |
+
|
30 |
+
Setup:
|
31 |
+
To use, you should have the environment variables
|
32 |
+
``SAMBANOVA_URL`` set with your SambaNova Cloud URL.
|
33 |
+
``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key.
|
34 |
+
http://cloud.sambanova.ai/
|
35 |
+
Example:
|
36 |
+
.. code-block:: python
|
37 |
+
ChatSambaNovaCloud(
|
38 |
+
sambanova_url = SambaNova cloud endpoint URL,
|
39 |
+
sambanova_api_key = set with your SambaNova cloud API key,
|
40 |
+
model = model name,
|
41 |
+
streaming = set True for use streaming API
|
42 |
+
max_tokens = max number of tokens to generate,
|
43 |
+
temperature = model temperature,
|
44 |
+
top_p = model top p,
|
45 |
+
top_k = model top k,
|
46 |
+
stream_options = include usage to get generation metrics
|
47 |
+
)
|
48 |
+
|
49 |
+
Key init args — completion params:
|
50 |
+
model: str
|
51 |
+
The name of the model to use, e.g., llama3-8b.
|
52 |
+
streaming: bool
|
53 |
+
Whether to use streaming or not
|
54 |
+
max_tokens: int
|
55 |
+
max tokens to generate
|
56 |
+
temperature: float
|
57 |
+
model temperature
|
58 |
+
top_p: float
|
59 |
+
model top p
|
60 |
+
top_k: int
|
61 |
+
model top k
|
62 |
+
stream_options: dict
|
63 |
+
stream options, include usage to get generation metrics
|
64 |
+
|
65 |
+
Key init args — client params:
|
66 |
+
sambanova_url: str
|
67 |
+
SambaNova Cloud Url
|
68 |
+
sambanova_api_key: str
|
69 |
+
SambaNova Cloud api key
|
70 |
+
|
71 |
+
Instantiate:
|
72 |
+
.. code-block:: python
|
73 |
+
|
74 |
+
from langchain_community.chat_models import ChatSambaNovaCloud
|
75 |
+
|
76 |
+
chat = ChatSambaNovaCloud(
|
77 |
+
sambanova_url = SambaNova cloud endpoint URL,
|
78 |
+
sambanova_api_key = set with your SambaNova cloud API key,
|
79 |
+
model = model name,
|
80 |
+
streaming = set True for streaming
|
81 |
+
max_tokens = max number of tokens to generate,
|
82 |
+
temperature = model temperature,
|
83 |
+
top_p = model top p,
|
84 |
+
top_k = model top k,
|
85 |
+
stream_options = include usage to get generation metrics
|
86 |
+
)
|
87 |
+
Invoke:
|
88 |
+
.. code-block:: python
|
89 |
+
messages = [
|
90 |
+
SystemMessage(content="your are an AI assistant."),
|
91 |
+
HumanMessage(content="tell me a joke."),
|
92 |
+
]
|
93 |
+
response = chat.invoke(messages)
|
94 |
+
|
95 |
+
Stream:
|
96 |
+
.. code-block:: python
|
97 |
+
|
98 |
+
for chunk in chat.stream(messages):
|
99 |
+
print(chunk.content, end="", flush=True)
|
100 |
+
|
101 |
+
Async:
|
102 |
+
.. code-block:: python
|
103 |
+
|
104 |
+
response = chat.ainvoke(messages)
|
105 |
+
await response
|
106 |
+
|
107 |
+
Token usage:
|
108 |
+
.. code-block:: python
|
109 |
+
response = chat.invoke(messages)
|
110 |
+
print(response.response_metadata["usage"]["prompt_tokens"]
|
111 |
+
print(response.response_metadata["usage"]["total_tokens"]
|
112 |
+
|
113 |
+
Response metadata
|
114 |
+
.. code-block:: python
|
115 |
+
|
116 |
+
response = chat.invoke(messages)
|
117 |
+
print(response.response_metadata)
|
118 |
+
"""
|
119 |
+
|
120 |
+
sambanova_url: str = Field(default="")
|
121 |
+
"""SambaNova Cloud Url"""
|
122 |
+
|
123 |
+
sambanova_api_key: SecretStr = Field(default="")
|
124 |
+
"""SambaNova Cloud api key"""
|
125 |
+
|
126 |
+
model: str = Field(default="llama3-8b")
|
127 |
+
"""The name of the model"""
|
128 |
+
|
129 |
+
streaming: bool = Field(default=False)
|
130 |
+
"""Whether to use streaming or not"""
|
131 |
+
|
132 |
+
max_tokens: int = Field(default=1024)
|
133 |
+
"""max tokens to generate"""
|
134 |
+
|
135 |
+
temperature: float = Field(default=0.7)
|
136 |
+
"""model temperature"""
|
137 |
+
|
138 |
+
top_p: float = Field(default=0.0)
|
139 |
+
"""model top p"""
|
140 |
+
|
141 |
+
top_k: int = Field(default=1)
|
142 |
+
"""model top k"""
|
143 |
+
|
144 |
+
stream_options: dict = Field(default={"include_usage": True})
|
145 |
+
"""stream options, include usage to get generation metrics"""
|
146 |
+
|
147 |
+
class Config:
|
148 |
+
allow_population_by_field_name = True
|
149 |
+
|
150 |
+
@classmethod
|
151 |
+
def is_lc_serializable(cls) -> bool:
|
152 |
+
"""Return whether this model can be serialized by Langchain."""
|
153 |
+
return False
|
154 |
+
|
155 |
+
@property
|
156 |
+
def lc_secrets(self) -> Dict[str, str]:
|
157 |
+
return {"sambanova_api_key": "sambanova_api_key"}
|
158 |
+
|
159 |
+
@property
|
160 |
+
def _identifying_params(self) -> Dict[str, Any]:
|
161 |
+
"""Return a dictionary of identifying parameters.
|
162 |
+
|
163 |
+
This information is used by the LangChain callback system, which
|
164 |
+
is used for tracing purposes make it possible to monitor LLMs.
|
165 |
+
"""
|
166 |
+
return {
|
167 |
+
"model": self.model,
|
168 |
+
"streaming": self.streaming,
|
169 |
+
"max_tokens": self.max_tokens,
|
170 |
+
"temperature": self.temperature,
|
171 |
+
"top_p": self.top_p,
|
172 |
+
"top_k": self.top_k,
|
173 |
+
"stream_options": self.stream_options,
|
174 |
+
}
|
175 |
+
|
176 |
+
@property
|
177 |
+
def _llm_type(self) -> str:
|
178 |
+
"""Get the type of language model used by this chat model."""
|
179 |
+
return "sambanovacloud-chatmodel"
|
180 |
+
|
181 |
+
def __init__(self, **kwargs: Any) -> None:
|
182 |
+
"""init and validate environment variables"""
|
183 |
+
kwargs["sambanova_url"] = get_from_dict_or_env(
|
184 |
+
kwargs,
|
185 |
+
"sambanova_url",
|
186 |
+
"SAMBANOVA_URL",
|
187 |
+
default="https://api.sambanova.ai/v1/chat/completions",
|
188 |
+
)
|
189 |
+
kwargs["sambanova_api_key"] = convert_to_secret_str(
|
190 |
+
get_from_dict_or_env(kwargs, "sambanova_api_key", "SAMBANOVA_API_KEY")
|
191 |
+
)
|
192 |
+
super().__init__(**kwargs)
|
193 |
+
|
194 |
+
def _handle_request(
|
195 |
+
self, messages_dicts: List[Dict], stop: Optional[List[str]] = None
|
196 |
+
) -> Dict[str, Any]:
|
197 |
+
"""
|
198 |
+
Performs a post request to the LLM API.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
messages_dicts: List of role / content dicts to use as input.
|
202 |
+
stop: list of stop tokens
|
203 |
+
|
204 |
+
Returns:
|
205 |
+
An iterator of response dicts.
|
206 |
+
"""
|
207 |
+
data = {
|
208 |
+
"messages": messages_dicts,
|
209 |
+
"max_tokens": self.max_tokens,
|
210 |
+
"stop": stop,
|
211 |
+
"model": self.model,
|
212 |
+
"temperature": self.temperature,
|
213 |
+
"top_p": self.top_p,
|
214 |
+
"top_k": self.top_k,
|
215 |
+
}
|
216 |
+
http_session = requests.Session()
|
217 |
+
response = http_session.post(
|
218 |
+
self.sambanova_url,
|
219 |
+
headers={
|
220 |
+
"Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}",
|
221 |
+
"Content-Type": "application/json",
|
222 |
+
},
|
223 |
+
json=data,
|
224 |
+
)
|
225 |
+
if response.status_code != 200:
|
226 |
+
raise RuntimeError(
|
227 |
+
f"Sambanova /complete call failed with status code "
|
228 |
+
f"{response.status_code}."
|
229 |
+
f"{response.text}."
|
230 |
+
)
|
231 |
+
response_dict = response.json()
|
232 |
+
if response_dict.get("error"):
|
233 |
+
raise RuntimeError(
|
234 |
+
f"Sambanova /complete call failed with status code "
|
235 |
+
f"{response.status_code}."
|
236 |
+
f"{response_dict}."
|
237 |
+
)
|
238 |
+
return response_dict
|
239 |
+
|
240 |
+
def _handle_streaming_request(
|
241 |
+
self, messages_dicts: List[Dict], stop: Optional[List[str]] = None
|
242 |
+
) -> Iterator[Dict]:
|
243 |
+
"""
|
244 |
+
Performs an streaming post request to the LLM API.
|
245 |
+
|
246 |
+
Args:
|
247 |
+
messages_dicts: List of role / content dicts to use as input.
|
248 |
+
stop: list of stop tokens
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
An iterator of response dicts.
|
252 |
+
"""
|
253 |
+
try:
|
254 |
+
import sseclient
|
255 |
+
except ImportError:
|
256 |
+
raise ImportError(
|
257 |
+
"could not import sseclient library"
|
258 |
+
"Please install it with `pip install sseclient-py`."
|
259 |
+
)
|
260 |
+
data = {
|
261 |
+
"messages": messages_dicts,
|
262 |
+
"max_tokens": self.max_tokens,
|
263 |
+
"stop": stop,
|
264 |
+
"model": self.model,
|
265 |
+
"temperature": self.temperature,
|
266 |
+
"top_p": self.top_p,
|
267 |
+
"top_k": self.top_k,
|
268 |
+
"stream": True,
|
269 |
+
"stream_options": self.stream_options,
|
270 |
+
}
|
271 |
+
http_session = requests.Session()
|
272 |
+
response = http_session.post(
|
273 |
+
self.sambanova_url,
|
274 |
+
headers={
|
275 |
+
"Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}",
|
276 |
+
"Content-Type": "application/json",
|
277 |
+
},
|
278 |
+
json=data,
|
279 |
+
stream=True,
|
280 |
+
)
|
281 |
+
|
282 |
+
client = sseclient.SSEClient(response)
|
283 |
+
|
284 |
+
if response.status_code != 200:
|
285 |
+
raise RuntimeError(
|
286 |
+
f"Sambanova /complete call failed with status code "
|
287 |
+
f"{response.status_code}."
|
288 |
+
f"{response.text}."
|
289 |
+
)
|
290 |
+
|
291 |
+
for event in client.events():
|
292 |
+
chunk = {
|
293 |
+
"event": event.event,
|
294 |
+
"data": event.data,
|
295 |
+
"status_code": response.status_code,
|
296 |
+
}
|
297 |
+
|
298 |
+
if chunk["event"] == "error_event" or chunk["status_code"] != 200:
|
299 |
+
raise RuntimeError(
|
300 |
+
f"Sambanova /complete call failed with status code "
|
301 |
+
f"{chunk['status_code']}."
|
302 |
+
f"{chunk}."
|
303 |
+
)
|
304 |
+
|
305 |
+
try:
|
306 |
+
# check if the response is a final event
|
307 |
+
# in that case event data response is '[DONE]'
|
308 |
+
if chunk["data"] != "[DONE]":
|
309 |
+
if isinstance(chunk["data"], str):
|
310 |
+
data = json.loads(chunk["data"])
|
311 |
+
else:
|
312 |
+
raise RuntimeError(
|
313 |
+
f"Sambanova /complete call failed with status code "
|
314 |
+
f"{chunk['status_code']}."
|
315 |
+
f"{chunk}."
|
316 |
+
)
|
317 |
+
if data.get("error"):
|
318 |
+
raise RuntimeError(
|
319 |
+
f"Sambanova /complete call failed with status code "
|
320 |
+
f"{chunk['status_code']}."
|
321 |
+
f"{chunk}."
|
322 |
+
)
|
323 |
+
yield data
|
324 |
+
except Exception:
|
325 |
+
raise Exception(
|
326 |
+
f"Error getting content chunk raw streamed response: {chunk}"
|
327 |
+
)
|
328 |
+
|
329 |
+
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
|
330 |
+
"""
|
331 |
+
convert a BaseMessage to a dictionary with Role / content
|
332 |
+
|
333 |
+
Args:
|
334 |
+
message: BaseMessage
|
335 |
+
|
336 |
+
Returns:
|
337 |
+
messages_dict: role / content dict
|
338 |
+
"""
|
339 |
+
if isinstance(message, ChatMessage):
|
340 |
+
message_dict = {"role": message.role, "content": message.content}
|
341 |
+
elif isinstance(message, SystemMessage):
|
342 |
+
message_dict = {"role": "system", "content": message.content}
|
343 |
+
elif isinstance(message, HumanMessage):
|
344 |
+
message_dict = {"role": "user", "content": message.content}
|
345 |
+
elif isinstance(message, AIMessage):
|
346 |
+
message_dict = {"role": "assistant", "content": message.content}
|
347 |
+
elif isinstance(message, ToolMessage):
|
348 |
+
message_dict = {"role": "tool", "content": message.content}
|
349 |
+
else:
|
350 |
+
raise TypeError(f"Got unknown type {message}")
|
351 |
+
return message_dict
|
352 |
+
|
353 |
+
def _create_message_dicts(
|
354 |
+
self, messages: List[BaseMessage]
|
355 |
+
) -> List[Dict[str, Any]]:
|
356 |
+
"""
|
357 |
+
convert a lit of BaseMessages to a list of dictionaries with Role / content
|
358 |
+
|
359 |
+
Args:
|
360 |
+
messages: list of BaseMessages
|
361 |
+
|
362 |
+
Returns:
|
363 |
+
messages_dicts: list of role / content dicts
|
364 |
+
"""
|
365 |
+
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
366 |
+
return message_dicts
|
367 |
+
|
368 |
+
def _generate(
|
369 |
+
self,
|
370 |
+
messages: List[BaseMessage],
|
371 |
+
stop: Optional[List[str]] = None,
|
372 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
373 |
+
**kwargs: Any,
|
374 |
+
) -> ChatResult:
|
375 |
+
"""
|
376 |
+
SambaNovaCloud chat model logic.
|
377 |
+
|
378 |
+
Call SambaNovaCloud API.
|
379 |
+
|
380 |
+
Args:
|
381 |
+
messages: the prompt composed of a list of messages.
|
382 |
+
stop: a list of strings on which the model should stop generating.
|
383 |
+
If generation stops due to a stop token, the stop token itself
|
384 |
+
SHOULD BE INCLUDED as part of the output. This is not enforced
|
385 |
+
across models right now, but it's a good practice to follow since
|
386 |
+
it makes it much easier to parse the output of the model
|
387 |
+
downstream and understand why generation stopped.
|
388 |
+
run_manager: A run manager with callbacks for the LLM.
|
389 |
+
"""
|
390 |
+
if self.streaming:
|
391 |
+
stream_iter = self._stream(
|
392 |
+
messages, stop=stop, run_manager=run_manager, **kwargs
|
393 |
+
)
|
394 |
+
if stream_iter:
|
395 |
+
return generate_from_stream(stream_iter)
|
396 |
+
messages_dicts = self._create_message_dicts(messages)
|
397 |
+
response = self._handle_request(messages_dicts, stop)
|
398 |
+
message = AIMessage(
|
399 |
+
content=response["choices"][0]["message"]["content"],
|
400 |
+
additional_kwargs={},
|
401 |
+
response_metadata={
|
402 |
+
"finish_reason": response["choices"][0]["finish_reason"],
|
403 |
+
"usage": response.get("usage"),
|
404 |
+
"model_name": response["model"],
|
405 |
+
"system_fingerprint": response["system_fingerprint"],
|
406 |
+
"created": response["created"],
|
407 |
+
},
|
408 |
+
id=response["id"],
|
409 |
+
)
|
410 |
+
|
411 |
+
generation = ChatGeneration(message=message)
|
412 |
+
return ChatResult(generations=[generation])
|
413 |
+
|
414 |
+
def _stream(
|
415 |
+
self,
|
416 |
+
messages: List[BaseMessage],
|
417 |
+
stop: Optional[List[str]] = None,
|
418 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
419 |
+
**kwargs: Any,
|
420 |
+
) -> Iterator[ChatGenerationChunk]:
|
421 |
+
"""
|
422 |
+
Stream the output of the SambaNovaCloud chat model.
|
423 |
+
|
424 |
+
Args:
|
425 |
+
messages: the prompt composed of a list of messages.
|
426 |
+
stop: a list of strings on which the model should stop generating.
|
427 |
+
If generation stops due to a stop token, the stop token itself
|
428 |
+
SHOULD BE INCLUDED as part of the output. This is not enforced
|
429 |
+
across models right now, but it's a good practice to follow since
|
430 |
+
it makes it much easier to parse the output of the model
|
431 |
+
downstream and understand why generation stopped.
|
432 |
+
run_manager: A run manager with callbacks for the LLM.
|
433 |
+
"""
|
434 |
+
messages_dicts = self._create_message_dicts(messages)
|
435 |
+
finish_reason = None
|
436 |
+
for partial_response in self._handle_streaming_request(messages_dicts, stop):
|
437 |
+
if len(partial_response["choices"]) > 0:
|
438 |
+
finish_reason = partial_response["choices"][0].get("finish_reason")
|
439 |
+
content = partial_response["choices"][0]["delta"]["content"]
|
440 |
+
id = partial_response["id"]
|
441 |
+
chunk = ChatGenerationChunk(
|
442 |
+
message=AIMessageChunk(content=content, id=id, additional_kwargs={})
|
443 |
+
)
|
444 |
+
else:
|
445 |
+
content = ""
|
446 |
+
id = partial_response["id"]
|
447 |
+
metadata = {
|
448 |
+
"finish_reason": finish_reason,
|
449 |
+
"usage": partial_response.get("usage"),
|
450 |
+
"model_name": partial_response["model"],
|
451 |
+
"system_fingerprint": partial_response["system_fingerprint"],
|
452 |
+
"created": partial_response["created"],
|
453 |
+
}
|
454 |
+
chunk = ChatGenerationChunk(
|
455 |
+
message=AIMessageChunk(
|
456 |
+
content=content,
|
457 |
+
id=id,
|
458 |
+
response_metadata=metadata,
|
459 |
+
additional_kwargs={},
|
460 |
+
)
|
461 |
+
)
|
462 |
+
|
463 |
+
if run_manager:
|
464 |
+
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
465 |
+
yield chunk
|
utils/model_wrappers/langchain_embeddings.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Langchain Wrapper around Sambanova embedding APIs."""
|
2 |
+
|
3 |
+
import json
|
4 |
+
from typing import Dict, Generator, List, Optional
|
5 |
+
|
6 |
+
import requests
|
7 |
+
from langchain_core.embeddings import Embeddings
|
8 |
+
from langchain_core.pydantic_v1 import BaseModel
|
9 |
+
from langchain_core.utils import get_from_dict_or_env, pre_init
|
10 |
+
|
11 |
+
|
12 |
+
class SambaStudioEmbeddings(BaseModel, Embeddings):
|
13 |
+
"""SambaNova embedding models.
|
14 |
+
|
15 |
+
To use, you should have the environment variables
|
16 |
+
``SAMBASTUDIO_EMBEDDINGS_BASE_URL``, ``SAMBASTUDIO_EMBEDDINGS_BASE_URI``
|
17 |
+
``SAMBASTUDIO_EMBEDDINGS_PROJECT_ID``, ``SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID``,
|
18 |
+
``SAMBASTUDIO_EMBEDDINGS_API_KEY``
|
19 |
+
set with your personal sambastudio variable or pass it as a named parameter
|
20 |
+
to the constructor.
|
21 |
+
|
22 |
+
Example:
|
23 |
+
.. code-block:: python
|
24 |
+
|
25 |
+
from langchain_community.embeddings import SambaStudioEmbeddings
|
26 |
+
|
27 |
+
embeddings = SambaStudioEmbeddings(sambastudio_embeddings_base_url=base_url,
|
28 |
+
sambastudio_embeddings_base_uri=base_uri,
|
29 |
+
sambastudio_embeddings_project_id=project_id,
|
30 |
+
sambastudio_embeddings_endpoint_id=endpoint_id,
|
31 |
+
sambastudio_embeddings_api_key=api_key,
|
32 |
+
batch_size=32)
|
33 |
+
(or)
|
34 |
+
|
35 |
+
embeddings = SambaStudioEmbeddings(batch_size=32)
|
36 |
+
|
37 |
+
(or)
|
38 |
+
|
39 |
+
# CoE example
|
40 |
+
embeddings = SambaStudioEmbeddings(
|
41 |
+
batch_size=1,
|
42 |
+
model_kwargs={
|
43 |
+
'select_expert':'e5-mistral-7b-instruct'
|
44 |
+
}
|
45 |
+
)
|
46 |
+
"""
|
47 |
+
|
48 |
+
sambastudio_embeddings_base_url: str = ''
|
49 |
+
"""Base url to use"""
|
50 |
+
|
51 |
+
sambastudio_embeddings_base_uri: str = ''
|
52 |
+
"""endpoint base uri"""
|
53 |
+
|
54 |
+
sambastudio_embeddings_project_id: str = ''
|
55 |
+
"""Project id on sambastudio for model"""
|
56 |
+
|
57 |
+
sambastudio_embeddings_endpoint_id: str = ''
|
58 |
+
"""endpoint id on sambastudio for model"""
|
59 |
+
|
60 |
+
sambastudio_embeddings_api_key: str = ''
|
61 |
+
"""sambastudio api key"""
|
62 |
+
|
63 |
+
model_kwargs: dict = {}
|
64 |
+
"""Key word arguments to pass to the model."""
|
65 |
+
|
66 |
+
batch_size: int = 32
|
67 |
+
"""Batch size for the embedding models"""
|
68 |
+
|
69 |
+
@pre_init
|
70 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
71 |
+
"""Validate that api key and python package exists in environment."""
|
72 |
+
values['sambastudio_embeddings_base_url'] = get_from_dict_or_env(
|
73 |
+
values, 'sambastudio_embeddings_base_url', 'SAMBASTUDIO_EMBEDDINGS_BASE_URL'
|
74 |
+
)
|
75 |
+
values['sambastudio_embeddings_base_uri'] = get_from_dict_or_env(
|
76 |
+
values,
|
77 |
+
'sambastudio_embeddings_base_uri',
|
78 |
+
'SAMBASTUDIO_EMBEDDINGS_BASE_URI',
|
79 |
+
default='api/predict/generic',
|
80 |
+
)
|
81 |
+
values['sambastudio_embeddings_project_id'] = get_from_dict_or_env(
|
82 |
+
values,
|
83 |
+
'sambastudio_embeddings_project_id',
|
84 |
+
'SAMBASTUDIO_EMBEDDINGS_PROJECT_ID',
|
85 |
+
)
|
86 |
+
values['sambastudio_embeddings_endpoint_id'] = get_from_dict_or_env(
|
87 |
+
values,
|
88 |
+
'sambastudio_embeddings_endpoint_id',
|
89 |
+
'SAMBASTUDIO_EMBEDDINGS_ENDPOINT_ID',
|
90 |
+
)
|
91 |
+
values['sambastudio_embeddings_api_key'] = get_from_dict_or_env(
|
92 |
+
values, 'sambastudio_embeddings_api_key', 'SAMBASTUDIO_EMBEDDINGS_API_KEY'
|
93 |
+
)
|
94 |
+
return values
|
95 |
+
|
96 |
+
def _get_tuning_params(self) -> str:
|
97 |
+
"""
|
98 |
+
Get the tuning parameters to use when calling the model
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
The tuning parameters as a JSON string.
|
102 |
+
"""
|
103 |
+
if 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri:
|
104 |
+
tuning_params_dict = self.model_kwargs
|
105 |
+
else:
|
106 |
+
tuning_params_dict = {
|
107 |
+
k: {'type': type(v).__name__, 'value': str(v)} for k, v in (self.model_kwargs.items())
|
108 |
+
}
|
109 |
+
tuning_params = json.dumps(tuning_params_dict)
|
110 |
+
return tuning_params
|
111 |
+
|
112 |
+
def _get_full_url(self, path: str) -> str:
|
113 |
+
"""
|
114 |
+
Return the full API URL for a given path.
|
115 |
+
|
116 |
+
:param str path: the sub-path
|
117 |
+
:returns: the full API URL for the sub-path
|
118 |
+
:rtype: str
|
119 |
+
"""
|
120 |
+
return f'{self.sambastudio_embeddings_base_url}/{self.sambastudio_embeddings_base_uri}/{path}' # noqa: E501
|
121 |
+
|
122 |
+
def _iterate_over_batches(self, texts: List[str], batch_size: int) -> Generator:
|
123 |
+
"""Generator for creating batches in the embed documents method
|
124 |
+
Args:
|
125 |
+
texts (List[str]): list of strings to embed
|
126 |
+
batch_size (int, optional): batch size to be used for the embedding model.
|
127 |
+
Will depend on the RDU endpoint used.
|
128 |
+
Yields:
|
129 |
+
List[str]: list (batch) of strings of size batch size
|
130 |
+
"""
|
131 |
+
for i in range(0, len(texts), batch_size):
|
132 |
+
yield texts[i : i + batch_size]
|
133 |
+
|
134 |
+
def embed_documents(self, texts: List[str], batch_size: Optional[int] = None) -> List[List[float]]:
|
135 |
+
"""Returns a list of embeddings for the given sentences.
|
136 |
+
Args:
|
137 |
+
texts (`List[str]`): List of texts to encode
|
138 |
+
batch_size (`int`): Batch size for the encoding
|
139 |
+
|
140 |
+
Returns:
|
141 |
+
`List[np.ndarray]` or `List[tensor]`: List of embeddings
|
142 |
+
for the given sentences
|
143 |
+
"""
|
144 |
+
if batch_size is None:
|
145 |
+
batch_size = self.batch_size
|
146 |
+
http_session = requests.Session()
|
147 |
+
url = self._get_full_url(f'{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}')
|
148 |
+
params = json.loads(self._get_tuning_params())
|
149 |
+
embeddings = []
|
150 |
+
|
151 |
+
if 'api/predict/nlp' in self.sambastudio_embeddings_base_uri:
|
152 |
+
for batch in self._iterate_over_batches(texts, batch_size):
|
153 |
+
data = {'inputs': batch, 'params': params}
|
154 |
+
response = http_session.post(
|
155 |
+
url,
|
156 |
+
headers={'key': self.sambastudio_embeddings_api_key},
|
157 |
+
json=data,
|
158 |
+
)
|
159 |
+
if response.status_code != 200:
|
160 |
+
raise RuntimeError(
|
161 |
+
f'Sambanova /complete call failed with status code '
|
162 |
+
f'{response.status_code}.\n Details: {response.text}'
|
163 |
+
)
|
164 |
+
try:
|
165 |
+
embedding = response.json()['data']
|
166 |
+
embeddings.extend(embedding)
|
167 |
+
except KeyError:
|
168 |
+
raise KeyError(
|
169 |
+
"'data' not found in endpoint response",
|
170 |
+
response.json(),
|
171 |
+
)
|
172 |
+
|
173 |
+
elif 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri:
|
174 |
+
for batch in self._iterate_over_batches(texts, batch_size):
|
175 |
+
items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(batch)]
|
176 |
+
data = {'items': items, 'params': params}
|
177 |
+
response = http_session.post(
|
178 |
+
url,
|
179 |
+
headers={'key': self.sambastudio_embeddings_api_key},
|
180 |
+
json=data,
|
181 |
+
)
|
182 |
+
if response.status_code != 200:
|
183 |
+
raise RuntimeError(
|
184 |
+
f'Sambanova /complete call failed with status code '
|
185 |
+
f'{response.status_code}.\n Details: {response.text}'
|
186 |
+
)
|
187 |
+
try:
|
188 |
+
embedding = [item['value'] for item in response.json()['items']]
|
189 |
+
embeddings.extend(embedding)
|
190 |
+
except KeyError:
|
191 |
+
raise KeyError(
|
192 |
+
"'items' not found in endpoint response",
|
193 |
+
response.json(),
|
194 |
+
)
|
195 |
+
|
196 |
+
elif 'api/predict/generic' in self.sambastudio_embeddings_base_uri:
|
197 |
+
for batch in self._iterate_over_batches(texts, batch_size):
|
198 |
+
data = {'instances': batch, 'params': params}
|
199 |
+
response = http_session.post(
|
200 |
+
url,
|
201 |
+
headers={'key': self.sambastudio_embeddings_api_key},
|
202 |
+
json=data,
|
203 |
+
)
|
204 |
+
if response.status_code != 200:
|
205 |
+
raise RuntimeError(
|
206 |
+
f'Sambanova /complete call failed with status code '
|
207 |
+
f'{response.status_code}.\n Details: {response.text}'
|
208 |
+
)
|
209 |
+
try:
|
210 |
+
if params.get('select_expert'):
|
211 |
+
embedding = response.json()['predictions']
|
212 |
+
else:
|
213 |
+
embedding = response.json()['predictions']
|
214 |
+
embeddings.extend(embedding)
|
215 |
+
except KeyError:
|
216 |
+
raise KeyError(
|
217 |
+
"'predictions' not found in endpoint response",
|
218 |
+
response.json(),
|
219 |
+
)
|
220 |
+
|
221 |
+
else:
|
222 |
+
raise ValueError(
|
223 |
+
f'handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented' # noqa: E501
|
224 |
+
)
|
225 |
+
|
226 |
+
return embeddings
|
227 |
+
|
228 |
+
def embed_query(self, text: str) -> List[float]:
|
229 |
+
"""Returns a list of embeddings for the given sentences.
|
230 |
+
Args:
|
231 |
+
sentences (`List[str]`): List of sentences to encode
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
`List[np.ndarray]` or `List[tensor]`: List of embeddings
|
235 |
+
for the given sentences
|
236 |
+
"""
|
237 |
+
http_session = requests.Session()
|
238 |
+
url = self._get_full_url(f'{self.sambastudio_embeddings_project_id}/{self.sambastudio_embeddings_endpoint_id}')
|
239 |
+
params = json.loads(self._get_tuning_params())
|
240 |
+
|
241 |
+
if 'api/predict/nlp' in self.sambastudio_embeddings_base_uri:
|
242 |
+
data = {'inputs': [text], 'params': params}
|
243 |
+
response = http_session.post(
|
244 |
+
url,
|
245 |
+
headers={'key': self.sambastudio_embeddings_api_key},
|
246 |
+
json=data,
|
247 |
+
)
|
248 |
+
if response.status_code != 200:
|
249 |
+
raise RuntimeError(
|
250 |
+
f'Sambanova /complete call failed with status code '
|
251 |
+
f'{response.status_code}.\n Details: {response.text}'
|
252 |
+
)
|
253 |
+
try:
|
254 |
+
embedding = response.json()['data'][0]
|
255 |
+
except KeyError:
|
256 |
+
raise KeyError(
|
257 |
+
"'data' not found in endpoint response",
|
258 |
+
response.json(),
|
259 |
+
)
|
260 |
+
|
261 |
+
elif 'api/v2/predict/generic' in self.sambastudio_embeddings_base_uri:
|
262 |
+
data = {'items': [{'id': 'item0', 'value': text}], 'params': params}
|
263 |
+
response = http_session.post(
|
264 |
+
url,
|
265 |
+
headers={'key': self.sambastudio_embeddings_api_key},
|
266 |
+
json=data,
|
267 |
+
)
|
268 |
+
if response.status_code != 200:
|
269 |
+
raise RuntimeError(
|
270 |
+
f'Sambanova /complete call failed with status code '
|
271 |
+
f'{response.status_code}.\n Details: {response.text}'
|
272 |
+
)
|
273 |
+
try:
|
274 |
+
embedding = response.json()['items'][0]['value']
|
275 |
+
except KeyError:
|
276 |
+
raise KeyError(
|
277 |
+
"'items' not found in endpoint response",
|
278 |
+
response.json(),
|
279 |
+
)
|
280 |
+
|
281 |
+
elif 'api/predict/generic' in self.sambastudio_embeddings_base_uri:
|
282 |
+
data = {'instances': [text], 'params': params}
|
283 |
+
response = http_session.post(
|
284 |
+
url,
|
285 |
+
headers={'key': self.sambastudio_embeddings_api_key},
|
286 |
+
json=data,
|
287 |
+
)
|
288 |
+
if response.status_code != 200:
|
289 |
+
raise RuntimeError(
|
290 |
+
f'Sambanova /complete call failed with status code '
|
291 |
+
f'{response.status_code}.\n Details: {response.text}'
|
292 |
+
)
|
293 |
+
try:
|
294 |
+
if params.get('select_expert'):
|
295 |
+
embedding = response.json()['predictions'][0]
|
296 |
+
else:
|
297 |
+
embedding = response.json()['predictions'][0]
|
298 |
+
except KeyError:
|
299 |
+
raise KeyError(
|
300 |
+
"'predictions' not found in endpoint response",
|
301 |
+
response.json(),
|
302 |
+
)
|
303 |
+
|
304 |
+
else:
|
305 |
+
raise ValueError(
|
306 |
+
f'handling of endpoint uri: {self.sambastudio_embeddings_base_uri} not implemented' # noqa: E501
|
307 |
+
)
|
308 |
+
|
309 |
+
return embedding
|
utils/model_wrappers/langchain_llms.py
ADDED
@@ -0,0 +1,770 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Langchain Wrapper around Sambanova LLM APIs."""
|
2 |
+
|
3 |
+
import json
|
4 |
+
from typing import Any, Dict, Generator, Iterator, List, Optional, Union
|
5 |
+
|
6 |
+
import requests
|
7 |
+
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
8 |
+
from langchain_core.language_models.llms import LLM
|
9 |
+
from langchain_core.outputs import GenerationChunk
|
10 |
+
from langchain_core.pydantic_v1 import Extra
|
11 |
+
from langchain_core.utils import get_from_dict_or_env, pre_init
|
12 |
+
|
13 |
+
|
14 |
+
class SSEndpointHandler:
|
15 |
+
"""
|
16 |
+
SambaNova Systems Interface for SambaStudio model endpoints.
|
17 |
+
|
18 |
+
:param str host_url: Base URL of the DaaS API service
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, host_url: str, api_base_uri: str):
|
22 |
+
"""
|
23 |
+
Initialize the SSEndpointHandler.
|
24 |
+
|
25 |
+
:param str host_url: Base URL of the DaaS API service
|
26 |
+
:param str api_base_uri: Base URI of the DaaS API service
|
27 |
+
"""
|
28 |
+
self.host_url = host_url
|
29 |
+
self.api_base_uri = api_base_uri
|
30 |
+
self.http_session = requests.Session()
|
31 |
+
|
32 |
+
def _process_response(self, response: requests.Response) -> Dict:
|
33 |
+
"""
|
34 |
+
Processes the API response and returns the resulting dict.
|
35 |
+
|
36 |
+
All resulting dicts, regardless of success or failure, will contain the
|
37 |
+
`status_code` key with the API response status code.
|
38 |
+
|
39 |
+
If the API returned an error, the resulting dict will contain the key
|
40 |
+
`detail` with the error message.
|
41 |
+
|
42 |
+
If the API call was successful, the resulting dict will contain the key
|
43 |
+
`data` with the response data.
|
44 |
+
|
45 |
+
:param requests.Response response: the response object to process
|
46 |
+
:return: the response dict
|
47 |
+
:type: dict
|
48 |
+
"""
|
49 |
+
result: Dict[str, Any] = {}
|
50 |
+
try:
|
51 |
+
result = response.json()
|
52 |
+
except Exception as e:
|
53 |
+
result['detail'] = str(e)
|
54 |
+
if 'status_code' not in result:
|
55 |
+
result['status_code'] = response.status_code
|
56 |
+
return result
|
57 |
+
|
58 |
+
def _process_streaming_response(
|
59 |
+
self,
|
60 |
+
response: requests.Response,
|
61 |
+
) -> Generator[Dict, None, None]:
|
62 |
+
"""Process the streaming response"""
|
63 |
+
if 'api/predict/nlp' in self.api_base_uri:
|
64 |
+
try:
|
65 |
+
import sseclient
|
66 |
+
except ImportError:
|
67 |
+
raise ImportError(
|
68 |
+
'could not import sseclient library' 'Please install it with `pip install sseclient-py`.'
|
69 |
+
)
|
70 |
+
client = sseclient.SSEClient(response)
|
71 |
+
close_conn = False
|
72 |
+
for event in client.events():
|
73 |
+
if event.event == 'error_event':
|
74 |
+
close_conn = True
|
75 |
+
chunk = {
|
76 |
+
'event': event.event,
|
77 |
+
'data': event.data,
|
78 |
+
'status_code': response.status_code,
|
79 |
+
}
|
80 |
+
yield chunk
|
81 |
+
if close_conn:
|
82 |
+
client.close()
|
83 |
+
elif 'api/v2/predict/generic' in self.api_base_uri or 'api/predict/generic' in self.api_base_uri:
|
84 |
+
try:
|
85 |
+
for line in response.iter_lines():
|
86 |
+
chunk = json.loads(line)
|
87 |
+
if 'status_code' not in chunk:
|
88 |
+
chunk['status_code'] = response.status_code
|
89 |
+
yield chunk
|
90 |
+
except Exception as e:
|
91 |
+
raise RuntimeError(f'Error processing streaming response: {e}')
|
92 |
+
else:
|
93 |
+
raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented')
|
94 |
+
|
95 |
+
def _get_full_url(self, path: str) -> str:
|
96 |
+
"""
|
97 |
+
Return the full API URL for a given path.
|
98 |
+
|
99 |
+
:param str path: the sub-path
|
100 |
+
:returns: the full API URL for the sub-path
|
101 |
+
:type: str
|
102 |
+
"""
|
103 |
+
return f'{self.host_url}/{self.api_base_uri}/{path}'
|
104 |
+
|
105 |
+
def nlp_predict(
|
106 |
+
self,
|
107 |
+
project: str,
|
108 |
+
endpoint: str,
|
109 |
+
key: str,
|
110 |
+
input: Union[List[str], str],
|
111 |
+
params: Optional[str] = '',
|
112 |
+
stream: bool = False,
|
113 |
+
) -> Dict:
|
114 |
+
"""
|
115 |
+
NLP predict using inline input string.
|
116 |
+
|
117 |
+
:param str project: Project ID in which the endpoint exists
|
118 |
+
:param str endpoint: Endpoint ID
|
119 |
+
:param str key: API Key
|
120 |
+
:param str input_str: Input string
|
121 |
+
:param str params: Input params string
|
122 |
+
:returns: Prediction results
|
123 |
+
:type: dict
|
124 |
+
"""
|
125 |
+
if isinstance(input, str):
|
126 |
+
input = [input]
|
127 |
+
if 'api/predict/nlp' in self.api_base_uri:
|
128 |
+
if params:
|
129 |
+
data = {'inputs': input, 'params': json.loads(params)}
|
130 |
+
else:
|
131 |
+
data = {'inputs': input}
|
132 |
+
elif 'api/v2/predict/generic' in self.api_base_uri:
|
133 |
+
items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)]
|
134 |
+
if params:
|
135 |
+
data = {'items': items, 'params': json.loads(params)}
|
136 |
+
else:
|
137 |
+
data = {'items': items}
|
138 |
+
elif 'api/predict/generic' in self.api_base_uri:
|
139 |
+
if params:
|
140 |
+
data = {'instances': input, 'params': json.loads(params)}
|
141 |
+
else:
|
142 |
+
data = {'instances': input}
|
143 |
+
else:
|
144 |
+
raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented')
|
145 |
+
response = self.http_session.post(
|
146 |
+
self._get_full_url(f'{project}/{endpoint}'),
|
147 |
+
headers={'key': key},
|
148 |
+
json=data,
|
149 |
+
)
|
150 |
+
return self._process_response(response)
|
151 |
+
|
152 |
+
def nlp_predict_stream(
|
153 |
+
self,
|
154 |
+
project: str,
|
155 |
+
endpoint: str,
|
156 |
+
key: str,
|
157 |
+
input: Union[List[str], str],
|
158 |
+
params: Optional[str] = '',
|
159 |
+
) -> Iterator[Dict]:
|
160 |
+
"""
|
161 |
+
NLP predict using inline input string.
|
162 |
+
|
163 |
+
:param str project: Project ID in which the endpoint exists
|
164 |
+
:param str endpoint: Endpoint ID
|
165 |
+
:param str key: API Key
|
166 |
+
:param str input_str: Input string
|
167 |
+
:param str params: Input params string
|
168 |
+
:returns: Prediction results
|
169 |
+
:type: dict
|
170 |
+
"""
|
171 |
+
if 'api/predict/nlp' in self.api_base_uri:
|
172 |
+
if isinstance(input, str):
|
173 |
+
input = [input]
|
174 |
+
if params:
|
175 |
+
data = {'inputs': input, 'params': json.loads(params)}
|
176 |
+
else:
|
177 |
+
data = {'inputs': input}
|
178 |
+
elif 'api/v2/predict/generic' in self.api_base_uri:
|
179 |
+
if isinstance(input, str):
|
180 |
+
input = [input]
|
181 |
+
items = [{'id': f'item{i}', 'value': item} for i, item in enumerate(input)]
|
182 |
+
if params:
|
183 |
+
data = {'items': items, 'params': json.loads(params)}
|
184 |
+
else:
|
185 |
+
data = {'items': items}
|
186 |
+
elif 'api/predict/generic' in self.api_base_uri:
|
187 |
+
if isinstance(input, list):
|
188 |
+
input = input[0]
|
189 |
+
if params:
|
190 |
+
data = {'instance': input, 'params': json.loads(params)}
|
191 |
+
else:
|
192 |
+
data = {'instance': input}
|
193 |
+
else:
|
194 |
+
raise ValueError(f'handling of endpoint uri: {self.api_base_uri} not implemented')
|
195 |
+
# Streaming output
|
196 |
+
response = self.http_session.post(
|
197 |
+
self._get_full_url(f'stream/{project}/{endpoint}'),
|
198 |
+
headers={'key': key},
|
199 |
+
json=data,
|
200 |
+
stream=True,
|
201 |
+
)
|
202 |
+
for chunk in self._process_streaming_response(response):
|
203 |
+
yield chunk
|
204 |
+
|
205 |
+
|
206 |
+
class SambaStudio(LLM):
|
207 |
+
"""
|
208 |
+
SambaStudio large language models.
|
209 |
+
|
210 |
+
To use, you should have the environment variables
|
211 |
+
``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL.
|
212 |
+
``SAMBASTUDIO_BASE_URI`` set with your SambaStudio api base URI.
|
213 |
+
``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID.
|
214 |
+
``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID.
|
215 |
+
``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key.
|
216 |
+
|
217 |
+
https://sambanova.ai/products/enterprise-ai-platform-sambanova-suite
|
218 |
+
|
219 |
+
read extra documentation in https://docs.sambanova.ai/sambastudio/latest/index.html
|
220 |
+
|
221 |
+
Example:
|
222 |
+
.. code-block:: python
|
223 |
+
|
224 |
+
from langchain_community.llms.sambanova import SambaStudio
|
225 |
+
SambaStudio(
|
226 |
+
sambastudio_base_url="your-SambaStudio-environment-URL",
|
227 |
+
sambastudio_base_uri="your-SambaStudio-base-URI",
|
228 |
+
sambastudio_project_id="your-SambaStudio-project-ID",
|
229 |
+
sambastudio_endpoint_id="your-SambaStudio-endpoint-ID",
|
230 |
+
sambastudio_api_key="your-SambaStudio-endpoint-API-key,
|
231 |
+
streaming=False
|
232 |
+
model_kwargs={
|
233 |
+
"do_sample": False,
|
234 |
+
"max_tokens_to_generate": 1000,
|
235 |
+
"temperature": 0.7,
|
236 |
+
"top_p": 1.0,
|
237 |
+
"repetition_penalty": 1,
|
238 |
+
"top_k": 50,
|
239 |
+
#"process_prompt": False,
|
240 |
+
#"select_expert": "Meta-Llama-3-8B-Instruct"
|
241 |
+
},
|
242 |
+
)
|
243 |
+
"""
|
244 |
+
|
245 |
+
sambastudio_base_url: str = ''
|
246 |
+
"""Base url to use"""
|
247 |
+
|
248 |
+
sambastudio_base_uri: str = ''
|
249 |
+
"""endpoint base uri"""
|
250 |
+
|
251 |
+
sambastudio_project_id: str = ''
|
252 |
+
"""Project id on sambastudio for model"""
|
253 |
+
|
254 |
+
sambastudio_endpoint_id: str = ''
|
255 |
+
"""endpoint id on sambastudio for model"""
|
256 |
+
|
257 |
+
sambastudio_api_key: str = ''
|
258 |
+
"""sambastudio api key"""
|
259 |
+
|
260 |
+
model_kwargs: Optional[dict] = None
|
261 |
+
"""Key word arguments to pass to the model."""
|
262 |
+
|
263 |
+
streaming: Optional[bool] = False
|
264 |
+
"""Streaming flag to get streamed response."""
|
265 |
+
|
266 |
+
class Config:
|
267 |
+
"""Configuration for this pydantic object."""
|
268 |
+
|
269 |
+
extra = Extra.forbid
|
270 |
+
|
271 |
+
@classmethod
|
272 |
+
def is_lc_serializable(cls) -> bool:
|
273 |
+
return True
|
274 |
+
|
275 |
+
@property
|
276 |
+
def _identifying_params(self) -> Dict[str, Any]:
|
277 |
+
"""Get the identifying parameters."""
|
278 |
+
return {**{'model_kwargs': self.model_kwargs}}
|
279 |
+
|
280 |
+
@property
|
281 |
+
def _llm_type(self) -> str:
|
282 |
+
"""Return type of llm."""
|
283 |
+
return 'Sambastudio LLM'
|
284 |
+
|
285 |
+
@pre_init
|
286 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
287 |
+
"""Validate that api key and python package exists in environment."""
|
288 |
+
values['sambastudio_base_url'] = get_from_dict_or_env(values, 'sambastudio_base_url', 'SAMBASTUDIO_BASE_URL')
|
289 |
+
values['sambastudio_base_uri'] = get_from_dict_or_env(
|
290 |
+
values,
|
291 |
+
'sambastudio_base_uri',
|
292 |
+
'SAMBASTUDIO_BASE_URI',
|
293 |
+
default='api/predict/generic',
|
294 |
+
)
|
295 |
+
values['sambastudio_project_id'] = get_from_dict_or_env(
|
296 |
+
values, 'sambastudio_project_id', 'SAMBASTUDIO_PROJECT_ID'
|
297 |
+
)
|
298 |
+
values['sambastudio_endpoint_id'] = get_from_dict_or_env(
|
299 |
+
values, 'sambastudio_endpoint_id', 'SAMBASTUDIO_ENDPOINT_ID'
|
300 |
+
)
|
301 |
+
values['sambastudio_api_key'] = get_from_dict_or_env(values, 'sambastudio_api_key', 'SAMBASTUDIO_API_KEY')
|
302 |
+
return values
|
303 |
+
|
304 |
+
def _get_tuning_params(self, stop: Optional[List[str]]) -> str:
|
305 |
+
"""
|
306 |
+
Get the tuning parameters to use when calling the LLM.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
310 |
+
first occurrence of any of the stop substrings.
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
The tuning parameters as a JSON string.
|
314 |
+
"""
|
315 |
+
_model_kwargs = self.model_kwargs or {}
|
316 |
+
_kwarg_stop_sequences = _model_kwargs.get('stop_sequences', [])
|
317 |
+
_stop_sequences = stop or _kwarg_stop_sequences
|
318 |
+
# if not _kwarg_stop_sequences:
|
319 |
+
# _model_kwargs["stop_sequences"] = ",".join(
|
320 |
+
# f'"{x}"' for x in _stop_sequences
|
321 |
+
# )
|
322 |
+
if 'api/v2/predict/generic' in self.sambastudio_base_uri:
|
323 |
+
tuning_params_dict = _model_kwargs
|
324 |
+
else:
|
325 |
+
tuning_params_dict = {k: {'type': type(v).__name__, 'value': str(v)} for k, v in (_model_kwargs.items())}
|
326 |
+
# _model_kwargs["stop_sequences"] = _kwarg_stop_sequences
|
327 |
+
tuning_params = json.dumps(tuning_params_dict)
|
328 |
+
return tuning_params
|
329 |
+
|
330 |
+
def _handle_nlp_predict(self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str) -> str:
|
331 |
+
"""
|
332 |
+
Perform an NLP prediction using the SambaStudio endpoint handler.
|
333 |
+
|
334 |
+
Args:
|
335 |
+
sdk: The SSEndpointHandler to use for the prediction.
|
336 |
+
prompt: The prompt to use for the prediction.
|
337 |
+
tuning_params: The tuning parameters to use for the prediction.
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
The prediction result.
|
341 |
+
|
342 |
+
Raises:
|
343 |
+
ValueError: If the prediction fails.
|
344 |
+
"""
|
345 |
+
response = sdk.nlp_predict(
|
346 |
+
self.sambastudio_project_id,
|
347 |
+
self.sambastudio_endpoint_id,
|
348 |
+
self.sambastudio_api_key,
|
349 |
+
prompt,
|
350 |
+
tuning_params,
|
351 |
+
)
|
352 |
+
if response['status_code'] != 200:
|
353 |
+
optional_detail = response.get('detail')
|
354 |
+
if optional_detail:
|
355 |
+
raise RuntimeError(
|
356 |
+
f"Sambanova /complete call failed with status code "
|
357 |
+
f"{response['status_code']}.\n Details: {optional_detail}"
|
358 |
+
)
|
359 |
+
else:
|
360 |
+
raise RuntimeError(
|
361 |
+
f"Sambanova /complete call failed with status code "
|
362 |
+
f"{response['status_code']}.\n response {response}"
|
363 |
+
)
|
364 |
+
if 'api/predict/nlp' in self.sambastudio_base_uri:
|
365 |
+
return response['data'][0]['completion']
|
366 |
+
elif 'api/v2/predict/generic' in self.sambastudio_base_uri:
|
367 |
+
return response['items'][0]['value']['completion']
|
368 |
+
elif 'api/predict/generic' in self.sambastudio_base_uri:
|
369 |
+
return response['predictions'][0]['completion']
|
370 |
+
else:
|
371 |
+
raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri} not implemented')
|
372 |
+
|
373 |
+
def _handle_completion_requests(self, prompt: Union[List[str], str], stop: Optional[List[str]]) -> str:
|
374 |
+
"""
|
375 |
+
Perform a prediction using the SambaStudio endpoint handler.
|
376 |
+
|
377 |
+
Args:
|
378 |
+
prompt: The prompt to use for the prediction.
|
379 |
+
stop: stop sequences.
|
380 |
+
|
381 |
+
Returns:
|
382 |
+
The prediction result.
|
383 |
+
|
384 |
+
Raises:
|
385 |
+
ValueError: If the prediction fails.
|
386 |
+
"""
|
387 |
+
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri)
|
388 |
+
tuning_params = self._get_tuning_params(stop)
|
389 |
+
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
|
390 |
+
|
391 |
+
def _handle_nlp_predict_stream(
|
392 |
+
self, sdk: SSEndpointHandler, prompt: Union[List[str], str], tuning_params: str
|
393 |
+
) -> Iterator[GenerationChunk]:
|
394 |
+
"""
|
395 |
+
Perform a streaming request to the LLM.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
sdk: The SVEndpointHandler to use for the prediction.
|
399 |
+
prompt: The prompt to use for the prediction.
|
400 |
+
tuning_params: The tuning parameters to use for the prediction.
|
401 |
+
|
402 |
+
Returns:
|
403 |
+
An iterator of GenerationChunks.
|
404 |
+
"""
|
405 |
+
for chunk in sdk.nlp_predict_stream(
|
406 |
+
self.sambastudio_project_id,
|
407 |
+
self.sambastudio_endpoint_id,
|
408 |
+
self.sambastudio_api_key,
|
409 |
+
prompt,
|
410 |
+
tuning_params,
|
411 |
+
):
|
412 |
+
if chunk['status_code'] != 200:
|
413 |
+
error = chunk.get('error')
|
414 |
+
if error:
|
415 |
+
optional_code = error.get('code')
|
416 |
+
optional_details = error.get('details')
|
417 |
+
optional_message = error.get('message')
|
418 |
+
raise ValueError(
|
419 |
+
f"Sambanova /complete call failed with status code "
|
420 |
+
f"{chunk['status_code']}.\n"
|
421 |
+
f"Message: {optional_message}\n"
|
422 |
+
f"Details: {optional_details}\n"
|
423 |
+
f"Code: {optional_code}\n"
|
424 |
+
)
|
425 |
+
else:
|
426 |
+
raise RuntimeError(
|
427 |
+
f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}."
|
428 |
+
)
|
429 |
+
if 'api/predict/nlp' in self.sambastudio_base_uri:
|
430 |
+
text = json.loads(chunk['data'])['stream_token']
|
431 |
+
elif 'api/v2/predict/generic' in self.sambastudio_base_uri:
|
432 |
+
text = chunk['result']['items'][0]['value']['stream_token']
|
433 |
+
elif 'api/predict/generic' in self.sambastudio_base_uri:
|
434 |
+
if len(chunk['result']['responses']) > 0:
|
435 |
+
text = chunk['result']['responses'][0]['stream_token']
|
436 |
+
else:
|
437 |
+
text = ''
|
438 |
+
else:
|
439 |
+
raise ValueError(f'handling of endpoint uri: {self.sambastudio_base_uri}' f'not implemented')
|
440 |
+
generated_chunk = GenerationChunk(text=text)
|
441 |
+
yield generated_chunk
|
442 |
+
|
443 |
+
def _stream(
|
444 |
+
self,
|
445 |
+
prompt: Union[List[str], str],
|
446 |
+
stop: Optional[List[str]] = None,
|
447 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
448 |
+
**kwargs: Any,
|
449 |
+
) -> Iterator[GenerationChunk]:
|
450 |
+
"""Call out to Sambanova's complete endpoint.
|
451 |
+
|
452 |
+
Args:
|
453 |
+
prompt: The prompt to pass into the model.
|
454 |
+
stop: Optional list of stop words to use when generating.
|
455 |
+
|
456 |
+
Returns:
|
457 |
+
The string generated by the model.
|
458 |
+
"""
|
459 |
+
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url, self.sambastudio_base_uri)
|
460 |
+
tuning_params = self._get_tuning_params(stop)
|
461 |
+
try:
|
462 |
+
if self.streaming:
|
463 |
+
for chunk in self._handle_nlp_predict_stream(ss_endpoint, prompt, tuning_params):
|
464 |
+
if run_manager:
|
465 |
+
run_manager.on_llm_new_token(chunk.text)
|
466 |
+
yield chunk
|
467 |
+
else:
|
468 |
+
return
|
469 |
+
except Exception as e:
|
470 |
+
# Handle any errors raised by the inference endpoint
|
471 |
+
raise ValueError(f'Error raised by the inference endpoint: {e}') from e
|
472 |
+
|
473 |
+
def _handle_stream_request(
|
474 |
+
self,
|
475 |
+
prompt: Union[List[str], str],
|
476 |
+
stop: Optional[List[str]],
|
477 |
+
run_manager: Optional[CallbackManagerForLLMRun],
|
478 |
+
kwargs: Dict[str, Any],
|
479 |
+
) -> str:
|
480 |
+
"""
|
481 |
+
Perform a streaming request to the LLM.
|
482 |
+
|
483 |
+
Args:
|
484 |
+
prompt: The prompt to generate from.
|
485 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
486 |
+
first occurrence of any of the stop substrings.
|
487 |
+
run_manager: Callback manager for the run.
|
488 |
+
**kwargs: Additional keyword arguments. directly passed
|
489 |
+
to the sambastudio model in API call.
|
490 |
+
|
491 |
+
Returns:
|
492 |
+
The model output as a string.
|
493 |
+
"""
|
494 |
+
completion = ''
|
495 |
+
for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs):
|
496 |
+
completion += chunk.text
|
497 |
+
return completion
|
498 |
+
|
499 |
+
def _call(
|
500 |
+
self,
|
501 |
+
prompt: Union[List[str], str],
|
502 |
+
stop: Optional[List[str]] = None,
|
503 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
504 |
+
**kwargs: Any,
|
505 |
+
) -> str:
|
506 |
+
"""Call out to Sambanova's complete endpoint.
|
507 |
+
|
508 |
+
Args:
|
509 |
+
prompt: The prompt to pass into the model.
|
510 |
+
stop: Optional list of stop words to use when generating.
|
511 |
+
|
512 |
+
Returns:
|
513 |
+
The string generated by the model.
|
514 |
+
"""
|
515 |
+
if stop is not None:
|
516 |
+
raise Exception('stop not implemented')
|
517 |
+
try:
|
518 |
+
if self.streaming:
|
519 |
+
return self._handle_stream_request(prompt, stop, run_manager, kwargs)
|
520 |
+
return self._handle_completion_requests(prompt, stop)
|
521 |
+
except Exception as e:
|
522 |
+
# Handle any errors raised by the inference endpoint
|
523 |
+
raise ValueError(f'Error raised by the inference endpoint: {e}') from e
|
524 |
+
|
525 |
+
|
526 |
+
class SambaNovaCloud(LLM):
|
527 |
+
"""
|
528 |
+
SambaNova Cloud large language models.
|
529 |
+
|
530 |
+
To use, you should have the environment variables
|
531 |
+
``SAMBANOVA_URL`` set with your SambaNova Cloud URL.
|
532 |
+
``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key.
|
533 |
+
|
534 |
+
http://cloud.sambanova.ai/
|
535 |
+
|
536 |
+
Example:
|
537 |
+
.. code-block:: python
|
538 |
+
|
539 |
+
SambaNovaCloud(
|
540 |
+
sambanova_url = SambaNova cloud endpoint URL,
|
541 |
+
sambanova_api_key = set with your SambaNova cloud API key,
|
542 |
+
max_tokens = mas number of tokens to generate
|
543 |
+
stop_tokens = list of stop tokens
|
544 |
+
model = model name
|
545 |
+
)
|
546 |
+
"""
|
547 |
+
|
548 |
+
sambanova_url: str = ''
|
549 |
+
"""SambaNova Cloud Url"""
|
550 |
+
|
551 |
+
sambanova_api_key: str = ''
|
552 |
+
"""SambaNova Cloud api key"""
|
553 |
+
|
554 |
+
max_tokens: int = 1024
|
555 |
+
"""max tokens to generate"""
|
556 |
+
|
557 |
+
stop_tokens: list = ['<|eot_id|>']
|
558 |
+
"""Stop tokens"""
|
559 |
+
|
560 |
+
model: str = 'llama3-8b'
|
561 |
+
"""LLM model expert to use"""
|
562 |
+
|
563 |
+
temperature: float = 0.0
|
564 |
+
"""model temperature"""
|
565 |
+
|
566 |
+
top_p: float = 0.0
|
567 |
+
"""model top p"""
|
568 |
+
|
569 |
+
top_k: int = 1
|
570 |
+
"""model top k"""
|
571 |
+
|
572 |
+
stream_api: bool = True
|
573 |
+
"""use stream api"""
|
574 |
+
|
575 |
+
stream_options: dict = {'include_usage': True}
|
576 |
+
"""stream options, include usage to get generation metrics"""
|
577 |
+
|
578 |
+
class Config:
|
579 |
+
"""Configuration for this pydantic object."""
|
580 |
+
|
581 |
+
extra = Extra.forbid
|
582 |
+
|
583 |
+
@classmethod
|
584 |
+
def is_lc_serializable(cls) -> bool:
|
585 |
+
return True
|
586 |
+
|
587 |
+
@property
|
588 |
+
def _identifying_params(self) -> Dict[str, Any]:
|
589 |
+
"""Get the identifying parameters."""
|
590 |
+
return {
|
591 |
+
'model': self.model,
|
592 |
+
'max_tokens': self.max_tokens,
|
593 |
+
'stop': self.stop_tokens,
|
594 |
+
'temperature': self.temperature,
|
595 |
+
'top_p': self.top_p,
|
596 |
+
'top_k': self.top_k,
|
597 |
+
}
|
598 |
+
|
599 |
+
@property
|
600 |
+
def _llm_type(self) -> str:
|
601 |
+
"""Return type of llm."""
|
602 |
+
return 'SambaNova Cloud'
|
603 |
+
|
604 |
+
@pre_init
|
605 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
606 |
+
"""Validate that api key and python package exists in environment."""
|
607 |
+
values['sambanova_url'] = get_from_dict_or_env(
|
608 |
+
values, 'sambanova_url', 'SAMBANOVA_URL', default='https://api.sambanova.ai/v1/chat/completions'
|
609 |
+
)
|
610 |
+
values['sambanova_api_key'] = get_from_dict_or_env(values, 'sambanova_api_key', 'SAMBANOVA_API_KEY')
|
611 |
+
return values
|
612 |
+
|
613 |
+
def _handle_nlp_predict_stream(
|
614 |
+
self,
|
615 |
+
prompt: Union[List[str], str],
|
616 |
+
stop: List[str],
|
617 |
+
) -> Iterator[GenerationChunk]:
|
618 |
+
"""
|
619 |
+
Perform a streaming request to the LLM.
|
620 |
+
|
621 |
+
Args:
|
622 |
+
prompt: The prompt to use for the prediction.
|
623 |
+
stop: list of stop tokens
|
624 |
+
|
625 |
+
Returns:
|
626 |
+
An iterator of GenerationChunks.
|
627 |
+
"""
|
628 |
+
try:
|
629 |
+
import sseclient
|
630 |
+
except ImportError:
|
631 |
+
raise ImportError('could not import sseclient library' 'Please install it with `pip install sseclient-py`.')
|
632 |
+
try:
|
633 |
+
formatted_prompt = json.loads(prompt)
|
634 |
+
except:
|
635 |
+
formatted_prompt = [{'role': 'user', 'content': prompt}]
|
636 |
+
|
637 |
+
http_session = requests.Session()
|
638 |
+
if not stop:
|
639 |
+
stop = self.stop_tokens
|
640 |
+
data = {
|
641 |
+
'messages': formatted_prompt,
|
642 |
+
'max_tokens': self.max_tokens,
|
643 |
+
'stop': stop,
|
644 |
+
'model': self.model,
|
645 |
+
'temperature': self.temperature,
|
646 |
+
'top_p': self.top_p,
|
647 |
+
'top_k': self.top_k,
|
648 |
+
'stream': self.stream_api,
|
649 |
+
'stream_options': self.stream_options,
|
650 |
+
}
|
651 |
+
# Streaming output
|
652 |
+
response = http_session.post(
|
653 |
+
self.sambanova_url,
|
654 |
+
headers={'Authorization': f'Bearer {self.sambanova_api_key}', 'Content-Type': 'application/json'},
|
655 |
+
json=data,
|
656 |
+
stream=True,
|
657 |
+
)
|
658 |
+
|
659 |
+
client = sseclient.SSEClient(response)
|
660 |
+
close_conn = False
|
661 |
+
|
662 |
+
if response.status_code != 200:
|
663 |
+
raise RuntimeError(
|
664 |
+
f'Sambanova /complete call failed with status code ' f'{response.status_code}.' f'{response.text}.'
|
665 |
+
)
|
666 |
+
|
667 |
+
for event in client.events():
|
668 |
+
if event.event == 'error_event':
|
669 |
+
close_conn = True
|
670 |
+
chunk = {
|
671 |
+
'event': event.event,
|
672 |
+
'data': event.data,
|
673 |
+
'status_code': response.status_code,
|
674 |
+
}
|
675 |
+
|
676 |
+
if chunk.get('error'):
|
677 |
+
raise RuntimeError(
|
678 |
+
f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}."
|
679 |
+
)
|
680 |
+
|
681 |
+
try:
|
682 |
+
# check if the response is a final event in that case event data response is '[DONE]'
|
683 |
+
if chunk['data'] != '[DONE]':
|
684 |
+
data = json.loads(chunk['data'])
|
685 |
+
if data.get('error'):
|
686 |
+
raise RuntimeError(
|
687 |
+
f"Sambanova /complete call failed with status code " f"{chunk['status_code']}." f"{chunk}."
|
688 |
+
)
|
689 |
+
# check if the response is a final response with usage stats (not includes content)
|
690 |
+
if data.get('usage') is None:
|
691 |
+
# check is not "end of text" response
|
692 |
+
if data['choices'][0]['finish_reason'] is None:
|
693 |
+
text = data['choices'][0]['delta']['content']
|
694 |
+
generated_chunk = GenerationChunk(text=text)
|
695 |
+
yield generated_chunk
|
696 |
+
except Exception as e:
|
697 |
+
raise Exception(f'Error getting content chunk raw streamed response: {chunk}')
|
698 |
+
|
699 |
+
def _stream(
|
700 |
+
self,
|
701 |
+
prompt: Union[List[str], str],
|
702 |
+
stop: Optional[List[str]] = None,
|
703 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
704 |
+
**kwargs: Any,
|
705 |
+
) -> Iterator[GenerationChunk]:
|
706 |
+
"""Call out to Sambanova's complete endpoint.
|
707 |
+
|
708 |
+
Args:
|
709 |
+
prompt: The prompt to pass into the model.
|
710 |
+
stop: Optional list of stop words to use when generating.
|
711 |
+
|
712 |
+
Returns:
|
713 |
+
The string generated by the model.
|
714 |
+
"""
|
715 |
+
try:
|
716 |
+
for chunk in self._handle_nlp_predict_stream(prompt, stop):
|
717 |
+
if run_manager:
|
718 |
+
run_manager.on_llm_new_token(chunk.text)
|
719 |
+
yield chunk
|
720 |
+
except Exception as e:
|
721 |
+
# Handle any errors raised by the inference endpoint
|
722 |
+
raise ValueError(f'Error raised by the inference endpoint: {e}') from e
|
723 |
+
|
724 |
+
def _handle_stream_request(
|
725 |
+
self,
|
726 |
+
prompt: Union[List[str], str],
|
727 |
+
stop: Optional[List[str]],
|
728 |
+
run_manager: Optional[CallbackManagerForLLMRun],
|
729 |
+
kwargs: Dict[str, Any],
|
730 |
+
) -> str:
|
731 |
+
"""
|
732 |
+
Perform a streaming request to the LLM.
|
733 |
+
|
734 |
+
Args:
|
735 |
+
prompt: The prompt to generate from.
|
736 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
737 |
+
first occurrence of any of the stop substrings.
|
738 |
+
run_manager: Callback manager for the run.
|
739 |
+
**kwargs: Additional keyword arguments. directly passed
|
740 |
+
to the Sambanova Cloud model in API call.
|
741 |
+
|
742 |
+
Returns:
|
743 |
+
The model output as a string.
|
744 |
+
"""
|
745 |
+
completion = ''
|
746 |
+
for chunk in self._stream(prompt=prompt, stop=stop, run_manager=run_manager, **kwargs):
|
747 |
+
completion += chunk.text
|
748 |
+
return completion
|
749 |
+
|
750 |
+
def _call(
|
751 |
+
self,
|
752 |
+
prompt: Union[List[str], str],
|
753 |
+
stop: Optional[List[str]] = None,
|
754 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
755 |
+
**kwargs: Any,
|
756 |
+
) -> str:
|
757 |
+
"""Call out to Sambanova's complete endpoint.
|
758 |
+
|
759 |
+
Args:
|
760 |
+
prompt: The prompt to pass into the model.
|
761 |
+
stop: Optional list of stop words to use when generating.
|
762 |
+
|
763 |
+
Returns:
|
764 |
+
The string generated by the model.
|
765 |
+
"""
|
766 |
+
try:
|
767 |
+
return self._handle_stream_request(prompt, stop, run_manager, kwargs)
|
768 |
+
except Exception as e:
|
769 |
+
# Handle any errors raised by the inference endpoint
|
770 |
+
raise ValueError(f'Error raised by the inference endpoint: {e}') from e
|
utils/model_wrappers/usage.ipynb
ADDED
@@ -0,0 +1,878 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# SambanNova Langchain Wrappers Usage"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "code",
|
12 |
+
"execution_count": 2,
|
13 |
+
"metadata": {},
|
14 |
+
"outputs": [
|
15 |
+
{
|
16 |
+
"data": {
|
17 |
+
"text/plain": [
|
18 |
+
"True"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
"execution_count": 2,
|
22 |
+
"metadata": {},
|
23 |
+
"output_type": "execute_result"
|
24 |
+
}
|
25 |
+
],
|
26 |
+
"source": [
|
27 |
+
"import os\n",
|
28 |
+
"\n",
|
29 |
+
"from dotenv import load_dotenv\n",
|
30 |
+
"from langchain_embeddings import SambaStudioEmbeddings\n",
|
31 |
+
"from langchain_llms import SambaStudio, SambaNovaCloud\n",
|
32 |
+
"from langchain_chat_models import ChatSambaNovaCloud\n",
|
33 |
+
"from langchain_core.messages import SystemMessage, HumanMessage\n",
|
34 |
+
"\n",
|
35 |
+
"current_dir = os.getcwd()\n",
|
36 |
+
"utils_dir = os.path.abspath(os.path.join(current_dir, '..'))\n",
|
37 |
+
"repo_dir = os.path.abspath(os.path.join(utils_dir, '..'))\n",
|
38 |
+
"\n",
|
39 |
+
"load_dotenv(os.path.join(repo_dir, '.env'), override=True)"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "markdown",
|
44 |
+
"metadata": {},
|
45 |
+
"source": [
|
46 |
+
"# SambaStudio LLM"
|
47 |
+
]
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "markdown",
|
51 |
+
"metadata": {},
|
52 |
+
"source": [
|
53 |
+
"## Non streaming"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": 9,
|
59 |
+
"metadata": {},
|
60 |
+
"outputs": [],
|
61 |
+
"source": [
|
62 |
+
"llm = SambaStudio(\n",
|
63 |
+
" streaming=False,\n",
|
64 |
+
" # base_uri=\"api/predict/generic\",\n",
|
65 |
+
" model_kwargs={\n",
|
66 |
+
" 'do_sample': False,\n",
|
67 |
+
" 'temperature': 0.01,\n",
|
68 |
+
" 'max_tokens_to_generate': 256,\n",
|
69 |
+
" 'process_prompt': False,\n",
|
70 |
+
" 'select_expert': 'Meta-Llama-3-70B-Instruct-4096',\n",
|
71 |
+
" },\n",
|
72 |
+
")"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "code",
|
77 |
+
"execution_count": 11,
|
78 |
+
"metadata": {},
|
79 |
+
"outputs": [
|
80 |
+
{
|
81 |
+
"data": {
|
82 |
+
"text/plain": [
|
83 |
+
"' of a brave knight\\nSir Valoric, the fearless knight, charged into the dark forest, his armor shining like the sun. He battled the dragon, its fiery breath singeing his beard, but he stood tall, his sword flashing in the moonlight, until the beast lay defeated at his feet, its treasure his noble reward.'"
|
84 |
+
]
|
85 |
+
},
|
86 |
+
"execution_count": 11,
|
87 |
+
"metadata": {},
|
88 |
+
"output_type": "execute_result"
|
89 |
+
}
|
90 |
+
],
|
91 |
+
"source": [
|
92 |
+
"llm.invoke('tell me a 50 word tale')"
|
93 |
+
]
|
94 |
+
},
|
95 |
+
{
|
96 |
+
"cell_type": "markdown",
|
97 |
+
"metadata": {},
|
98 |
+
"source": [
|
99 |
+
"## Streaming"
|
100 |
+
]
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"cell_type": "code",
|
104 |
+
"execution_count": null,
|
105 |
+
"metadata": {},
|
106 |
+
"outputs": [],
|
107 |
+
"source": [
|
108 |
+
"llm = SambaStudio(\n",
|
109 |
+
" streaming=True,\n",
|
110 |
+
" model_kwargs={\n",
|
111 |
+
" 'do_sample': False,\n",
|
112 |
+
" 'max_tokens_to_generate': 256,\n",
|
113 |
+
" 'temperature': 0.01,\n",
|
114 |
+
" 'process_prompt': False,\n",
|
115 |
+
" 'select_expert': 'Meta-Llama-3-70B-Instruct-4096',\n",
|
116 |
+
" },\n",
|
117 |
+
")"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "code",
|
122 |
+
"execution_count": null,
|
123 |
+
"metadata": {},
|
124 |
+
"outputs": [
|
125 |
+
{
|
126 |
+
"name": "stdout",
|
127 |
+
"output_type": "stream",
|
128 |
+
"text": [
|
129 |
+
" of a character who is a master of disguise\n",
|
130 |
+
"\n",
|
131 |
+
"Sure! Here is a 50-word tale of a character who is a master of disguise:\n",
|
132 |
+
"\n",
|
133 |
+
"\"Araxys, the skilled disguise artist, transformed into a stunning mermaid to infiltrate a pirate's lair. With a flick of her tail, she charmed the pirates and stole their treasure.\""
|
134 |
+
]
|
135 |
+
}
|
136 |
+
],
|
137 |
+
"source": [
|
138 |
+
"for chunk in llm.stream('tell me a 50 word tale'):\n",
|
139 |
+
" print(chunk, end='', flush=True)"
|
140 |
+
]
|
141 |
+
},
|
142 |
+
{
|
143 |
+
"cell_type": "markdown",
|
144 |
+
"metadata": {},
|
145 |
+
"source": [
|
146 |
+
"# SambaNovaCloud LLM"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"cell_type": "markdown",
|
151 |
+
"metadata": {},
|
152 |
+
"source": [
|
153 |
+
"## Non Streaming"
|
154 |
+
]
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "code",
|
158 |
+
"execution_count": 4,
|
159 |
+
"metadata": {},
|
160 |
+
"outputs": [],
|
161 |
+
"source": [
|
162 |
+
"llm = SambaNovaCloud(model='llama3-70b')"
|
163 |
+
]
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"cell_type": "code",
|
167 |
+
"execution_count": 5,
|
168 |
+
"metadata": {},
|
169 |
+
"outputs": [
|
170 |
+
{
|
171 |
+
"data": {
|
172 |
+
"text/plain": [
|
173 |
+
"'Hello. How can I assist you today?'"
|
174 |
+
]
|
175 |
+
},
|
176 |
+
"execution_count": 5,
|
177 |
+
"metadata": {},
|
178 |
+
"output_type": "execute_result"
|
179 |
+
}
|
180 |
+
],
|
181 |
+
"source": [
|
182 |
+
"import json\n",
|
183 |
+
"\n",
|
184 |
+
"llm.invoke(json.dumps([{'role': 'user', 'content': 'hello'}]))"
|
185 |
+
]
|
186 |
+
},
|
187 |
+
{
|
188 |
+
"cell_type": "code",
|
189 |
+
"execution_count": 6,
|
190 |
+
"metadata": {},
|
191 |
+
"outputs": [
|
192 |
+
{
|
193 |
+
"data": {
|
194 |
+
"text/plain": [
|
195 |
+
"'Hello. How can I assist you today?'"
|
196 |
+
]
|
197 |
+
},
|
198 |
+
"execution_count": 6,
|
199 |
+
"metadata": {},
|
200 |
+
"output_type": "execute_result"
|
201 |
+
}
|
202 |
+
],
|
203 |
+
"source": [
|
204 |
+
"llm.invoke('hello')"
|
205 |
+
]
|
206 |
+
},
|
207 |
+
{
|
208 |
+
"cell_type": "markdown",
|
209 |
+
"metadata": {},
|
210 |
+
"source": [
|
211 |
+
"## Streaming"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": 7,
|
217 |
+
"metadata": {},
|
218 |
+
"outputs": [
|
219 |
+
{
|
220 |
+
"name": "stdout",
|
221 |
+
"output_type": "stream",
|
222 |
+
"text": [
|
223 |
+
"\n",
|
224 |
+
"Here's a long story \n",
|
225 |
+
"for you:\n",
|
226 |
+
"\n",
|
227 |
+
"Once upon \n",
|
228 |
+
"a time, in a small village \n",
|
229 |
+
"nestled in the rolling hills of \n",
|
230 |
+
"rural France, there lived a \n",
|
231 |
+
"young girl named Sophie. Sophie \n",
|
232 |
+
"was a curious and adventurous \n",
|
233 |
+
"child, with a mop of curly \n",
|
234 |
+
"brown hair and a smile that \n",
|
235 |
+
"could light up the darkest \n",
|
236 |
+
"of rooms. She lived with \n",
|
237 |
+
"her parents, Pierre and \n",
|
238 |
+
"Colette, in a small stone cottage \n",
|
239 |
+
"on the outskirts of \n",
|
240 |
+
"the village.\n",
|
241 |
+
"\n",
|
242 |
+
"Sophie's village was \n",
|
243 |
+
"a charming \n",
|
244 |
+
"place, filled with narrow \n",
|
245 |
+
"cobblestone streets, quaint shops, \n",
|
246 |
+
"and \n",
|
247 |
+
"bustling cafes. The villagers \n",
|
248 |
+
"were a tight-knit \n",
|
249 |
+
"community, and everyone knew each \n",
|
250 |
+
"other's names and stories. Sophie \n",
|
251 |
+
"loved listening to the villagers' \n",
|
252 |
+
"tales of \n",
|
253 |
+
"old, which \n",
|
254 |
+
"often featured brave knights, \n",
|
255 |
+
"beautiful princesses, and \n",
|
256 |
+
"magical creatures.\n",
|
257 |
+
"\n",
|
258 |
+
"One day, while exploring \n",
|
259 |
+
"the village, Sophie stumbled upon \n",
|
260 |
+
"a small, mysterious shop tucked \n",
|
261 |
+
"away on a quiet street. \n",
|
262 |
+
"The sign above the door \n",
|
263 |
+
"read \"Curios \n",
|
264 |
+
"and Wonders,\" and the \n",
|
265 |
+
"windows were filled \n",
|
266 |
+
"with a dazzling array of strange \n",
|
267 |
+
"and exotic objects. Sophie's \n",
|
268 |
+
"curiosity was piqued, \n",
|
269 |
+
"and she pushed open the door \n",
|
270 |
+
"to venture inside.\n",
|
271 |
+
"\n",
|
272 |
+
"The shop \n",
|
273 |
+
"was dimly lit, and \n",
|
274 |
+
"the air was thick with the \n",
|
275 |
+
"scent of old books and \n",
|
276 |
+
"dust. Sophie's eyes \n",
|
277 |
+
"adjusted slowly, and she \n",
|
278 |
+
"saw that the shop was filled \n",
|
279 |
+
"with all manner of curious \n",
|
280 |
+
"objects: vintage \n",
|
281 |
+
"clocks, rare coins, \n",
|
282 |
+
"and even a \n",
|
283 |
+
"taxidermied owl perched on \n",
|
284 |
+
"a shelf. Behind the counter stood \n",
|
285 |
+
"an old man with a kind \n",
|
286 |
+
"face \n",
|
287 |
+
"and a twinkle in his eye.\n",
|
288 |
+
"\n",
|
289 |
+
"\n",
|
290 |
+
"\n",
|
291 |
+
"\"Bonjour, mademoiselle,\" he \n",
|
292 |
+
"said, his voice low and \n",
|
293 |
+
"soothing. \"Welcome to Curios \n",
|
294 |
+
"and Wonders. I \n",
|
295 |
+
"am Monsieur LaFleur, \n",
|
296 |
+
"the proprietor. How may I \n",
|
297 |
+
"assist you \n",
|
298 |
+
"today?\"\n",
|
299 |
+
"\n",
|
300 |
+
"Sophie wandered the aisles, \n",
|
301 |
+
"running her fingers over \n",
|
302 |
+
"the strange objects on \n",
|
303 |
+
"display. She picked up \n",
|
304 |
+
"a small, delicate music \n",
|
305 |
+
"box and wound \n",
|
306 |
+
"it up, listening \n",
|
307 |
+
"as it played \n",
|
308 |
+
"a soft, melancholy \n",
|
309 |
+
"tune. Monsieur LaFleur \n",
|
310 |
+
"smiled and nodded \n",
|
311 |
+
"in approval.\n",
|
312 |
+
"\n",
|
313 |
+
"\"Ah, you have a \n",
|
314 |
+
"good ear for \n",
|
315 |
+
"music, mademoiselle,\" he \n",
|
316 |
+
"said. \"That music box \n",
|
317 |
+
"is a \n",
|
318 |
+
"rare and precious item. It \n",
|
319 |
+
"was crafted by a skilled artisan \n",
|
320 |
+
"in the 18th century.\"\n",
|
321 |
+
"\n",
|
322 |
+
"\n",
|
323 |
+
"As Sophie continued to \n",
|
324 |
+
"explore the shop, \n",
|
325 |
+
"she stumbled upon \n",
|
326 |
+
"a large, leather-bound book \n",
|
327 |
+
"with strange symbols etched into \n",
|
328 |
+
"the cover. \n",
|
329 |
+
"Monsieur LaFleur noticed her interest and \n",
|
330 |
+
"approached \n",
|
331 |
+
"her.\n",
|
332 |
+
"\n",
|
333 |
+
"\"Ah, you've found \n",
|
334 |
+
"the infamous 'Livre \n",
|
335 |
+
"\n",
|
336 |
+
"des Secrets,'\" \n",
|
337 |
+
"he said, his \n",
|
338 |
+
"voice low and mysterious. \n",
|
339 |
+
"\"That book is said to contain \n",
|
340 |
+
"the secrets of the universe, \n",
|
341 |
+
"hidden within its pages. But \n",
|
342 |
+
"be \n",
|
343 |
+
"warned, mademoiselle, \n",
|
344 |
+
"the book is said to \n",
|
345 |
+
"be cursed. Many have attempted \n",
|
346 |
+
"to unlock its secrets, but \n",
|
347 |
+
"none have \n",
|
348 |
+
"succeeded.\"\n",
|
349 |
+
"\n",
|
350 |
+
"Sophie's eyes widened with \n",
|
351 |
+
"excitement as she carefully opened \n",
|
352 |
+
"the book. The pages \n",
|
353 |
+
"were yellowed and \n",
|
354 |
+
"crackling, and \n",
|
355 |
+
"the text was written in a \n",
|
356 |
+
"language she couldn't understand. \n",
|
357 |
+
"But as she turned the \n",
|
358 |
+
"pages, \n",
|
359 |
+
"she felt a strange sensation, \n",
|
360 |
+
"as if the book \n",
|
361 |
+
"was calling \n",
|
362 |
+
"to her.\n",
|
363 |
+
"\n",
|
364 |
+
"Monsieur \n",
|
365 |
+
"LaFleur smiled \n",
|
366 |
+
"and \n",
|
367 |
+
"nodded. \"I see you have a \n",
|
368 |
+
"connection to the \n",
|
369 |
+
"book, mademoiselle. Perhaps you \n",
|
370 |
+
"are the one who can unlock \n",
|
371 |
+
"its secrets.\"\n",
|
372 |
+
"\n",
|
373 |
+
"Over the next \n",
|
374 |
+
"few weeks, Sophie returned to \n",
|
375 |
+
"the shop again and again, \n",
|
376 |
+
"pouring over \n",
|
377 |
+
"the pages of the Livre \n",
|
378 |
+
"des Secrets. She spent hours \n",
|
379 |
+
"studying \n",
|
380 |
+
"the symbols and trying to decipher \n",
|
381 |
+
"the text. \n",
|
382 |
+
"Monsieur \n",
|
383 |
+
"LaFleur watched her with a \n",
|
384 |
+
"keen eye, offering guidance and encouragement \n",
|
385 |
+
"whenever she needed it.\n",
|
386 |
+
"\n",
|
387 |
+
"As \n",
|
388 |
+
"the days turned into weeks, \n",
|
389 |
+
"Sophie began to notice strange occurrences \n",
|
390 |
+
"happening around her. She would \n",
|
391 |
+
"find objects moved from their \n",
|
392 |
+
"usual places, and she would hear \n",
|
393 |
+
"whispers in the night. She \n",
|
394 |
+
"began \n",
|
395 |
+
"to feel as though the book \n",
|
396 |
+
"was exerting some kind of \n",
|
397 |
+
"influence over her, drawing her \n",
|
398 |
+
"deeper into \n",
|
399 |
+
"its secrets.\n",
|
400 |
+
"\n",
|
401 |
+
"One \n",
|
402 |
+
"night, Sophie had a vivid dream \n",
|
403 |
+
"in which \n",
|
404 |
+
"she saw herself standing in \n",
|
405 |
+
"a \n",
|
406 |
+
"grand, \n",
|
407 |
+
"candlelit hall. \n",
|
408 |
+
"The walls were lined with \n",
|
409 |
+
"ancient tapestries, and the \n",
|
410 |
+
"air was thick with the scent \n",
|
411 |
+
"of \n",
|
412 |
+
"incense. At the far end of \n",
|
413 |
+
"the hall, she saw a \n",
|
414 |
+
"figure cloaked in shadows.\n",
|
415 |
+
"\n",
|
416 |
+
"\n",
|
417 |
+
"As she approached \n",
|
418 |
+
"the figure, it stepped forward, \n",
|
419 |
+
"revealing a woman \n",
|
420 |
+
"with long, flowing hair and \n",
|
421 |
+
"piercing green eyes. The woman \n",
|
422 |
+
"spoke in a voice that was \n",
|
423 |
+
"both familiar and yet \n",
|
424 |
+
"completely alien.\n",
|
425 |
+
"\n",
|
426 |
+
"\"Sophie, you \n",
|
427 |
+
"have been chosen to unlock the \n",
|
428 |
+
"secrets of the Livre \n",
|
429 |
+
"des Secrets,\" she \n",
|
430 |
+
"said. \"But be warned, \n",
|
431 |
+
"the \n",
|
432 |
+
"journey will \n",
|
433 |
+
"be difficult, and the cost \n",
|
434 |
+
"will be high. Are you \n",
|
435 |
+
"prepared to pay \n",
|
436 |
+
"the price?\"\n",
|
437 |
+
"\n",
|
438 |
+
"Sophie woke up with \n",
|
439 |
+
"a start, her heart racing and \n",
|
440 |
+
"her mind reeling. She \n",
|
441 |
+
"knew that she had \n",
|
442 |
+
"to return to the shop and \n",
|
443 |
+
"confront Monsieur LaFleur \n",
|
444 |
+
"about the \n",
|
445 |
+
"strange \n",
|
446 |
+
"occurrences. But when she \n",
|
447 |
+
"arrived at the shop, she \n",
|
448 |
+
"found that it \n",
|
449 |
+
"was closed, \n",
|
450 |
+
"and \n",
|
451 |
+
"a sign on the door \n",
|
452 |
+
"read \"Gone on \n",
|
453 |
+
"a \n",
|
454 |
+
"journey. Will return \n",
|
455 |
+
"soon.\"\n",
|
456 |
+
"\n",
|
457 |
+
"Sophie \n",
|
458 |
+
"was devastated. \n",
|
459 |
+
"She felt as though she had \n",
|
460 |
+
"been abandoned, left \n",
|
461 |
+
"to navigate the mysteries of \n",
|
462 |
+
"the Livre des Secrets on \n",
|
463 |
+
"her own. But as \n",
|
464 |
+
"she turned to leave, she \n",
|
465 |
+
"noticed a\n"
|
466 |
+
]
|
467 |
+
}
|
468 |
+
],
|
469 |
+
"source": [
|
470 |
+
"for i in llm.stream('hello tell me a long story'):\n",
|
471 |
+
" print(i)"
|
472 |
+
]
|
473 |
+
},
|
474 |
+
{
|
475 |
+
"cell_type": "markdown",
|
476 |
+
"metadata": {},
|
477 |
+
"source": [
|
478 |
+
"# SambaNova Cloud Chat Model"
|
479 |
+
]
|
480 |
+
},
|
481 |
+
{
|
482 |
+
"cell_type": "markdown",
|
483 |
+
"metadata": {},
|
484 |
+
"source": [
|
485 |
+
"## Non Streaming"
|
486 |
+
]
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"cell_type": "code",
|
490 |
+
"execution_count": 4,
|
491 |
+
"metadata": {},
|
492 |
+
"outputs": [],
|
493 |
+
"source": [
|
494 |
+
"llm = ChatSambaNovaCloud(\n",
|
495 |
+
" model= \"llama3-405b\",\n",
|
496 |
+
" max_tokens=1024,\n",
|
497 |
+
" temperature=0.7,\n",
|
498 |
+
" top_k=1,\n",
|
499 |
+
" top_p=0.01,\n",
|
500 |
+
" stream_options={'include_usage':True}\n",
|
501 |
+
" )"
|
502 |
+
]
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"cell_type": "code",
|
506 |
+
"execution_count": 5,
|
507 |
+
"metadata": {},
|
508 |
+
"outputs": [
|
509 |
+
{
|
510 |
+
"data": {
|
511 |
+
"text/plain": [
|
512 |
+
"AIMessage(content='A man walked into a library and asked the librarian, \"Do you have any books on Pavlov\\'s dogs and Schrödinger\\'s cat?\"\\n\\nThe librarian replied, \"It rings a bell, but I\\'m not sure if it\\'s here or not.\"', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 6.875, 'completion_tokens': 54, 'completion_tokens_after_first_per_sec': 146.48573712341215, 'completion_tokens_after_first_per_sec_first_ten': 172.9005798161617, 'completion_tokens_per_sec': 81.99632208428116, 'end_time': 1726178488.071125, 'is_last_response': True, 'prompt_tokens': 40, 'start_time': 1726178487.3630672, 'time_to_first_token': 0.34624791145324707, 'total_latency': 0.658566123789007, 'total_tokens': 94, 'total_tokens_per_sec': 142.73433844300794}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726178487}, id='a5590b89-4853-4bd9-9fd8-83276b369278')"
|
513 |
+
]
|
514 |
+
},
|
515 |
+
"execution_count": 5,
|
516 |
+
"metadata": {},
|
517 |
+
"output_type": "execute_result"
|
518 |
+
}
|
519 |
+
],
|
520 |
+
"source": [
|
521 |
+
"llm.invoke(\"tell me a joke\")"
|
522 |
+
]
|
523 |
+
},
|
524 |
+
{
|
525 |
+
"cell_type": "code",
|
526 |
+
"execution_count": 7,
|
527 |
+
"metadata": {},
|
528 |
+
"outputs": [
|
529 |
+
{
|
530 |
+
"data": {
|
531 |
+
"text/plain": [
|
532 |
+
"AIMessage(content=\"Yer lookin' fer a joke, eh? Alright then, matey! Here be one fer ye:\\n\\nWhy did the pirate quit his job?\\n\\n(pause fer dramatic effect)\\n\\nBecause he was sick o' all the arrrr-guments!\\n\\nYarrr, hope that made ye laugh, me hearty!\", response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 5.583333333333333, 'completion_tokens': 64, 'completion_tokens_after_first_per_sec': 120.91573778458478, 'completion_tokens_after_first_per_sec_first_ten': 140.3985499426452, 'completion_tokens_per_sec': 79.98855768735817, 'end_time': 1726065701.9732044, 'is_last_response': True, 'prompt_tokens': 48, 'start_time': 1726065701.107911, 'time_to_first_token': 0.3442692756652832, 'total_latency': 0.8001144394945743, 'total_tokens': 112, 'total_tokens_per_sec': 139.9799759528768}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726065701}, id='7b0748bb-c5f7-4696-ae56-03b734b60fb9')"
|
533 |
+
]
|
534 |
+
},
|
535 |
+
"execution_count": 7,
|
536 |
+
"metadata": {},
|
537 |
+
"output_type": "execute_result"
|
538 |
+
}
|
539 |
+
],
|
540 |
+
"source": [
|
541 |
+
"messages = [\n",
|
542 |
+
" SystemMessage(content=\"You are a helpful assistant with pirate accent\"),\n",
|
543 |
+
" HumanMessage(content=\"tell me a joke\")\n",
|
544 |
+
" ]\n",
|
545 |
+
"llm.invoke(messages)"
|
546 |
+
]
|
547 |
+
},
|
548 |
+
{
|
549 |
+
"cell_type": "code",
|
550 |
+
"execution_count": 8,
|
551 |
+
"metadata": {},
|
552 |
+
"outputs": [
|
553 |
+
{
|
554 |
+
"data": {
|
555 |
+
"text/plain": [
|
556 |
+
"AIMessage(content='A man walked into a library and asked the librarian, \"Do you have any books on Pavlov\\'s dogs and Schrödinger\\'s cat?\"\\n\\nThe librarian replied, \"It rings a bell, but I\\'m not sure if it\\'s here or not.\"', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 6.875, 'completion_tokens': 54, 'completion_tokens_after_first_per_sec': 146.72813415408498, 'completion_tokens_after_first_per_sec_first_ten': 172.71830994351703, 'completion_tokens_per_sec': 82.34884281970663, 'end_time': 1726065746.6364844, 'is_last_response': True, 'prompt_tokens': 40, 'start_time': 1726065745.932173, 'time_to_first_token': 0.34309911727905273, 'total_latency': 0.6557469194585627, 'total_tokens': 94, 'total_tokens_per_sec': 143.34798564911895}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726065745}, id='27e7d4fe-8e24-419a-b75b-51ea2519781b')"
|
557 |
+
]
|
558 |
+
},
|
559 |
+
"execution_count": 8,
|
560 |
+
"metadata": {},
|
561 |
+
"output_type": "execute_result"
|
562 |
+
}
|
563 |
+
],
|
564 |
+
"source": [
|
565 |
+
"future_response = llm.ainvoke(\"tell me a joke\")\n",
|
566 |
+
"await(future_response) "
|
567 |
+
]
|
568 |
+
},
|
569 |
+
{
|
570 |
+
"cell_type": "markdown",
|
571 |
+
"metadata": {},
|
572 |
+
"source": [
|
573 |
+
"## Batching"
|
574 |
+
]
|
575 |
+
},
|
576 |
+
{
|
577 |
+
"cell_type": "code",
|
578 |
+
"execution_count": 9,
|
579 |
+
"metadata": {},
|
580 |
+
"outputs": [],
|
581 |
+
"source": [
|
582 |
+
"llm = ChatSambaNovaCloud(\n",
|
583 |
+
" model= \"llama3-405b\",\n",
|
584 |
+
" streaming=False,\n",
|
585 |
+
" max_tokens=1024,\n",
|
586 |
+
" temperature=0.7,\n",
|
587 |
+
" top_k=1,\n",
|
588 |
+
" top_p=0.01,\n",
|
589 |
+
" stream_options={'include_usage':True}\n",
|
590 |
+
" )"
|
591 |
+
]
|
592 |
+
},
|
593 |
+
{
|
594 |
+
"cell_type": "code",
|
595 |
+
"execution_count": 11,
|
596 |
+
"metadata": {},
|
597 |
+
"outputs": [
|
598 |
+
{
|
599 |
+
"data": {
|
600 |
+
"text/plain": [
|
601 |
+
"[AIMessage(content='A man walked into a library and asked the librarian, \"Do you have any books on Pavlov\\'s dogs and Schrödinger\\'s cat?\"\\n\\nThe librarian replied, \"It rings a bell, but I\\'m not sure if it\\'s here or not.\"', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 6.875, 'completion_tokens': 54, 'completion_tokens_after_first_per_sec': 146.72232349940003, 'completion_tokens_after_first_per_sec_first_ten': 173.01988455676758, 'completion_tokens_per_sec': 82.21649876350362, 'end_time': 1726065879.4066722, 'is_last_response': True, 'prompt_tokens': 40, 'start_time': 1726065878.700746, 'time_to_first_token': 0.3446996212005615, 'total_latency': 0.656802476536144, 'total_tokens': 94, 'total_tokens_per_sec': 143.1176089586915}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726065878}, id='28d3a38b-5dae-4d62-bf6c-cface081df34'),\n",
|
602 |
+
" AIMessage(content='The capital of the United Kingdom is London.', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 13, 'completion_tokens': 10, 'completion_tokens_after_first_per_sec': 110.21174794386165, 'completion_tokens_after_first_per_sec_first_ten': 327.0275172132524, 'completion_tokens_per_sec': 26.88555788272027, 'end_time': 1726065879.138034, 'is_last_response': True, 'prompt_tokens': 43, 'start_time': 1726065878.7150047, 'time_to_first_token': 0.3413684368133545, 'total_latency': 0.37194690337547887, 'total_tokens': 53, 'total_tokens_per_sec': 142.49345677841742}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726065878}, id='859a9e45-c0a5-44ec-bd53-686877c2cf89')]"
|
603 |
+
]
|
604 |
+
},
|
605 |
+
"execution_count": 11,
|
606 |
+
"metadata": {},
|
607 |
+
"output_type": "execute_result"
|
608 |
+
}
|
609 |
+
],
|
610 |
+
"source": [
|
611 |
+
"llm.batch([\"tell me a joke\",\"which is the capital of UK?\"])"
|
612 |
+
]
|
613 |
+
},
|
614 |
+
{
|
615 |
+
"cell_type": "code",
|
616 |
+
"execution_count": 13,
|
617 |
+
"metadata": {},
|
618 |
+
"outputs": [
|
619 |
+
{
|
620 |
+
"name": "stderr",
|
621 |
+
"output_type": "stream",
|
622 |
+
"text": [
|
623 |
+
"/var/folders/p4/y0q2kh796nx_k_yzfhxs57f00000gp/T/ipykernel_33601/1543848179.py:1: RuntimeWarning: coroutine 'Runnable.abatch' was never awaited\n",
|
624 |
+
" future_responses = llm.abatch([\"tell me a joke\",\"which is the capital of UK?\"])\n",
|
625 |
+
"RuntimeWarning: Enable tracemalloc to get the object allocation traceback\n"
|
626 |
+
]
|
627 |
+
},
|
628 |
+
{
|
629 |
+
"data": {
|
630 |
+
"text/plain": [
|
631 |
+
"[AIMessage(content='A man walked into a library and asked the librarian, \"Do you have any books on Pavlov\\'s dogs and Schrödinger\\'s cat?\"\\n\\nThe librarian replied, \"It rings a bell, but I\\'m not sure if it\\'s here or not.\"', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 6.875, 'completion_tokens': 54, 'completion_tokens_after_first_per_sec': 120.34699641554552, 'completion_tokens_after_first_per_sec_first_ten': 141.51170437257693, 'completion_tokens_per_sec': 36.223157123884754, 'end_time': 1726065914.8678048, 'is_last_response': True, 'prompt_tokens': 40, 'start_time': 1726065913.3182464, 'time_to_first_token': 1.1091651916503906, 'total_latency': 1.4907590692693538, 'total_tokens': 94, 'total_tokens_per_sec': 63.05512536379939}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726065913}, id='f279d0fb-70b5-428c-9283-457b9831b559'),\n",
|
632 |
+
" AIMessage(content='The capital of the United Kingdom is London.', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 9.5, 'completion_tokens': 10, 'completion_tokens_after_first_per_sec': 60.73429985889864, 'completion_tokens_after_first_per_sec_first_ten': 195.5434460421063, 'completion_tokens_per_sec': 8.61842566880045, 'end_time': 1726065914.575598, 'is_last_response': True, 'prompt_tokens': 43, 'start_time': 1726065913.3182464, 'time_to_first_token': 1.1091651916503906, 'total_latency': 1.160304722033049, 'total_tokens': 53, 'total_tokens_per_sec': 45.67765604464238}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726065913}, id='f279d0fb-70b5-428c-9283-457b9831b559')]"
|
633 |
+
]
|
634 |
+
},
|
635 |
+
"execution_count": 13,
|
636 |
+
"metadata": {},
|
637 |
+
"output_type": "execute_result"
|
638 |
+
}
|
639 |
+
],
|
640 |
+
"source": [
|
641 |
+
"future_responses = llm.abatch([\"tell me a joke\",\"which is the capital of UK?\"])\n",
|
642 |
+
"await(future_responses)"
|
643 |
+
]
|
644 |
+
},
|
645 |
+
{
|
646 |
+
"cell_type": "markdown",
|
647 |
+
"metadata": {},
|
648 |
+
"source": [
|
649 |
+
"## Streaming"
|
650 |
+
]
|
651 |
+
},
|
652 |
+
{
|
653 |
+
"cell_type": "code",
|
654 |
+
"execution_count": 14,
|
655 |
+
"metadata": {},
|
656 |
+
"outputs": [],
|
657 |
+
"source": [
|
658 |
+
"llm = ChatSambaNovaCloud(\n",
|
659 |
+
" model= \"llama3-405b\",\n",
|
660 |
+
" streaming=True,\n",
|
661 |
+
" max_tokens=1024,\n",
|
662 |
+
" temperature=0.7,\n",
|
663 |
+
" top_k=1,\n",
|
664 |
+
" top_p=0.01,\n",
|
665 |
+
" stream_options={'include_usage':True}\n",
|
666 |
+
" )"
|
667 |
+
]
|
668 |
+
},
|
669 |
+
{
|
670 |
+
"cell_type": "code",
|
671 |
+
"execution_count": 15,
|
672 |
+
"metadata": {},
|
673 |
+
"outputs": [
|
674 |
+
{
|
675 |
+
"name": "stdout",
|
676 |
+
"output_type": "stream",
|
677 |
+
"text": [
|
678 |
+
"\n",
|
679 |
+
"A man walked into a \n",
|
680 |
+
"library and asked the \n",
|
681 |
+
"librarian, \"Do you have any books \n",
|
682 |
+
"on Pavlov's dogs \n",
|
683 |
+
"and Schrödinger's cat?\"\n",
|
684 |
+
"\n",
|
685 |
+
"\n",
|
686 |
+
"The librarian \n",
|
687 |
+
"replied, \"It rings a bell, \n",
|
688 |
+
"but I'm not sure \n",
|
689 |
+
"if it's here \n",
|
690 |
+
"or not.\"\n",
|
691 |
+
"\n",
|
692 |
+
"\n",
|
693 |
+
"\n"
|
694 |
+
]
|
695 |
+
}
|
696 |
+
],
|
697 |
+
"source": [
|
698 |
+
"for chunk in llm.stream(\"tell me a joke\"):\n",
|
699 |
+
" print(chunk.content)"
|
700 |
+
]
|
701 |
+
},
|
702 |
+
{
|
703 |
+
"cell_type": "code",
|
704 |
+
"execution_count": 16,
|
705 |
+
"metadata": {},
|
706 |
+
"outputs": [
|
707 |
+
{
|
708 |
+
"name": "stdout",
|
709 |
+
"output_type": "stream",
|
710 |
+
"text": [
|
711 |
+
"\n",
|
712 |
+
"Yer lookin' \n",
|
713 |
+
"fer a joke, eh? \n",
|
714 |
+
"Alright then, matey! \n",
|
715 |
+
"Here be one fer \n",
|
716 |
+
"ye:\n",
|
717 |
+
"\n",
|
718 |
+
"Why did the pirate quit his job?\n",
|
719 |
+
"\n",
|
720 |
+
"\n",
|
721 |
+
"\n",
|
722 |
+
"(pause fer \n",
|
723 |
+
"dramatic effect)\n",
|
724 |
+
"\n",
|
725 |
+
"Because he was sick \n",
|
726 |
+
"o' all the arrrr-guments!\n",
|
727 |
+
"\n",
|
728 |
+
"\n",
|
729 |
+
"\n",
|
730 |
+
"\n",
|
731 |
+
"Yarrr, hope that made ye \n",
|
732 |
+
"laugh, \n",
|
733 |
+
"me hearty!\n",
|
734 |
+
"\n",
|
735 |
+
"\n",
|
736 |
+
"\n"
|
737 |
+
]
|
738 |
+
}
|
739 |
+
],
|
740 |
+
"source": [
|
741 |
+
"messages = [\n",
|
742 |
+
" SystemMessage(content=\"You are a helpful assistant with pirate accent\"),\n",
|
743 |
+
" HumanMessage(content=\"tell me a joke\")\n",
|
744 |
+
" ]\n",
|
745 |
+
"for chunk in llm.stream(messages):\n",
|
746 |
+
" print(chunk.content)"
|
747 |
+
]
|
748 |
+
},
|
749 |
+
{
|
750 |
+
"cell_type": "code",
|
751 |
+
"execution_count": 17,
|
752 |
+
"metadata": {},
|
753 |
+
"outputs": [
|
754 |
+
{
|
755 |
+
"name": "stdout",
|
756 |
+
"output_type": "stream",
|
757 |
+
"text": [
|
758 |
+
"\n",
|
759 |
+
"A man walked into a \n",
|
760 |
+
"library and asked the \n",
|
761 |
+
"librarian, \"Do you have any books \n",
|
762 |
+
"on Pavlov's dogs \n",
|
763 |
+
"and Schrödinger's cat?\"\n",
|
764 |
+
"\n",
|
765 |
+
"\n",
|
766 |
+
"The librarian \n",
|
767 |
+
"replied, \"It rings a bell, \n",
|
768 |
+
"but I'm not sure \n",
|
769 |
+
"if it's here \n",
|
770 |
+
"or not.\"\n",
|
771 |
+
"\n",
|
772 |
+
"\n",
|
773 |
+
"\n"
|
774 |
+
]
|
775 |
+
}
|
776 |
+
],
|
777 |
+
"source": [
|
778 |
+
"async for chunk in llm.astream(\"tell me a joke\"):\n",
|
779 |
+
" print(chunk.content)"
|
780 |
+
]
|
781 |
+
},
|
782 |
+
{
|
783 |
+
"cell_type": "markdown",
|
784 |
+
"metadata": {},
|
785 |
+
"source": [
|
786 |
+
"# Sambastudio Embeddings"
|
787 |
+
]
|
788 |
+
},
|
789 |
+
{
|
790 |
+
"cell_type": "code",
|
791 |
+
"execution_count": null,
|
792 |
+
"metadata": {},
|
793 |
+
"outputs": [],
|
794 |
+
"source": [
|
795 |
+
"embedding = SambaStudioEmbeddings(batch_size=1, model_kwargs={'select_expert': 'e5-mistral-7b-instruct'})\n",
|
796 |
+
"embedding.embed_documents(['tell me a 50 word tale', 'tell me a joke'])\n",
|
797 |
+
"embedding.embed_query('tell me a 50 word tale')"
|
798 |
+
]
|
799 |
+
},
|
800 |
+
{
|
801 |
+
"cell_type": "code",
|
802 |
+
"execution_count": 13,
|
803 |
+
"metadata": {},
|
804 |
+
"outputs": [
|
805 |
+
{
|
806 |
+
"name": "stderr",
|
807 |
+
"output_type": "stream",
|
808 |
+
"text": [
|
809 |
+
"/Users/jorgep/Documents/ask_public_own/finetuning_env/lib/python3.11/site-packages/langchain_core/_api/deprecation.py:139: LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 0.3.0. Use invoke instead.\n",
|
810 |
+
" warn_deprecated(\n"
|
811 |
+
]
|
812 |
+
},
|
813 |
+
{
|
814 |
+
"data": {
|
815 |
+
"text/plain": [
|
816 |
+
"[Document(page_content='tell me a 50 word tale'),\n",
|
817 |
+
" Document(page_content='tell me a joke'),\n",
|
818 |
+
" Document(page_content='give me 3 party activities'),\n",
|
819 |
+
" Document(page_content='give me three healty dishes')]"
|
820 |
+
]
|
821 |
+
},
|
822 |
+
"execution_count": 13,
|
823 |
+
"metadata": {},
|
824 |
+
"output_type": "execute_result"
|
825 |
+
}
|
826 |
+
],
|
827 |
+
"source": [
|
828 |
+
"from langchain.schema import Document\n",
|
829 |
+
"from langchain.vectorstores import Chroma\n",
|
830 |
+
"\n",
|
831 |
+
"docs = [\n",
|
832 |
+
" 'tell me a 50 word tale',\n",
|
833 |
+
" 'tell me a joke',\n",
|
834 |
+
" 'when was America discoverd?',\n",
|
835 |
+
" 'how to build an engine?',\n",
|
836 |
+
" 'give me 3 party activities',\n",
|
837 |
+
" 'give me three healty dishes',\n",
|
838 |
+
"]\n",
|
839 |
+
"docs = [Document(doc) for doc in docs]\n",
|
840 |
+
"\n",
|
841 |
+
"query = 'prompt for generating something fun'\n",
|
842 |
+
"\n",
|
843 |
+
"vectordb = Chroma.from_documents(docs, embedding)\n",
|
844 |
+
"retriever = vectordb.as_retriever()\n",
|
845 |
+
"\n",
|
846 |
+
"retriever.get_relevant_documents(query)"
|
847 |
+
]
|
848 |
+
},
|
849 |
+
{
|
850 |
+
"cell_type": "code",
|
851 |
+
"execution_count": null,
|
852 |
+
"metadata": {},
|
853 |
+
"outputs": [],
|
854 |
+
"source": []
|
855 |
+
}
|
856 |
+
],
|
857 |
+
"metadata": {
|
858 |
+
"kernelspec": {
|
859 |
+
"display_name": "peenv",
|
860 |
+
"language": "python",
|
861 |
+
"name": "python3"
|
862 |
+
},
|
863 |
+
"language_info": {
|
864 |
+
"codemirror_mode": {
|
865 |
+
"name": "ipython",
|
866 |
+
"version": 3
|
867 |
+
},
|
868 |
+
"file_extension": ".py",
|
869 |
+
"mimetype": "text/x-python",
|
870 |
+
"name": "python",
|
871 |
+
"nbconvert_exporter": "python",
|
872 |
+
"pygments_lexer": "ipython3",
|
873 |
+
"version": "3.10.11"
|
874 |
+
}
|
875 |
+
},
|
876 |
+
"nbformat": 4,
|
877 |
+
"nbformat_minor": 2
|
878 |
+
}
|
utils/parsing/README.md
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# SambaParse
|
2 |
+
|
3 |
+
SambaParse is a Python library that simplifies the process of extracting and processing unstructured data using the Unstructured.io API. It provides a convenient wrapper around the Unstructured.io CLI tool, allowing you to ingest data from various sources, perform partitioning, chunking, embedding, and load the processed data into a vector database. It's designed to be used within AI Starter kits and SN Apps, unifying our data ingestion and document intelligence platform. This allows us to keep our code base centralized for data ingestion kits.
|
4 |
+
|
5 |
+
## Prerequisites
|
6 |
+
|
7 |
+
Before using SambaParse, make sure you have the following:
|
8 |
+
|
9 |
+
- Docker installed on your machine (or access to another API server)
|
10 |
+
- An Unstructured.io API key
|
11 |
+
|
12 |
+
Before using SambaParse, make sure you have the following:
|
13 |
+
|
14 |
+
- Create a `.env` file in the ai-starter-kit root directory (not in the parsing folder root):
|
15 |
+
|
16 |
+
```bash
|
17 |
+
UNSTRUCTURED_API_KEY=your_api_key_here
|
18 |
+
```
|
19 |
+
|
20 |
+
## Setup
|
21 |
+
|
22 |
+
### Pre Reqs
|
23 |
+
|
24 |
+
Using pyenv to manage virtualenv's is recommended
|
25 |
+
Mac install instructions. See pyenv-virtualenv repo for more detailed instructions.
|
26 |
+
|
27 |
+
```bash
|
28 |
+
brew install pyenv-virtualenv
|
29 |
+
```
|
30 |
+
|
31 |
+
- Create a python venv using python version 3.10.12
|
32 |
+
|
33 |
+
```bash
|
34 |
+
pyenv install 3.10.12
|
35 |
+
pyenv virtualenv 3.10.12 sambaparse
|
36 |
+
pyenv activate sambaparse
|
37 |
+
```
|
38 |
+
|
39 |
+
- Clone the ai-starter-kit repo and cd:
|
40 |
+
|
41 |
+
```bash
|
42 |
+
git clone https://github.com/sambanova/ai-starter-kit
|
43 |
+
```
|
44 |
+
|
45 |
+
- cd into utils/parsing and pip install the requirements
|
46 |
+
|
47 |
+
```bash
|
48 |
+
pip install -r requirements.txt
|
49 |
+
```
|
50 |
+
|
51 |
+
- cd into the unstructured-api foder and Install the unstructured-api make-file:
|
52 |
+
|
53 |
+
```bash
|
54 |
+
cd unstructured-api
|
55 |
+
```
|
56 |
+
|
57 |
+
- Run
|
58 |
+
|
59 |
+
```bash
|
60 |
+
make install
|
61 |
+
```
|
62 |
+
|
63 |
+
- Run The Web Server:
|
64 |
+
|
65 |
+
```bash
|
66 |
+
make run-web-app
|
67 |
+
```
|
68 |
+
|
69 |
+
This script will start the Unstructured API server using the specified API key and expose it on port 8005.
|
70 |
+
|
71 |
+
- Alternatively, if you have another Unstructured API server running on a different instance, make sure to update the `partition_endpoint` and `unstructured_port` values in the YAML configuration file accordingly.
|
72 |
+
|
73 |
+
## Usage
|
74 |
+
|
75 |
+
1. Import the `SambaParse` class from the `ai-starter-kit` library:
|
76 |
+
|
77 |
+
```python
|
78 |
+
from utils.parsing.sambaparse import SambaParse
|
79 |
+
```
|
80 |
+
|
81 |
+
2. Create a YAML configuration file (e.g., `config.yaml`) to specify the desired settings for the ingestion process. Here's the configuration for use cases 1 and 2 ie local files and folders:
|
82 |
+
|
83 |
+
```yaml
|
84 |
+
processor:
|
85 |
+
verbose: True
|
86 |
+
output_dir: './output'
|
87 |
+
num_processes: 2
|
88 |
+
|
89 |
+
sources:
|
90 |
+
local:
|
91 |
+
recursive: True
|
92 |
+
confluence:
|
93 |
+
api_token: 'your_confluence_api_token'
|
94 |
+
user_email: 'your_email@example.com'
|
95 |
+
url: 'https://your-confluence-url.atlassian.net'
|
96 |
+
github:
|
97 |
+
url: 'owner/repo'
|
98 |
+
branch: 'main'
|
99 |
+
google_drive:
|
100 |
+
service_account_key: 'path/to/service_account_key.json'
|
101 |
+
recursive: True
|
102 |
+
drive_id: 'your_drive_id'
|
103 |
+
|
104 |
+
partitioning:
|
105 |
+
pdf_infer_table_structure: True
|
106 |
+
skip_infer_table_types: []
|
107 |
+
strategy: 'auto'
|
108 |
+
hi_res_model_name: 'yolox'
|
109 |
+
ocr_languages: ['eng']
|
110 |
+
encoding: 'utf-8'
|
111 |
+
fields_include: ['element_id', 'text', 'type', 'metadata', 'embeddings']
|
112 |
+
flatten_metadata: False
|
113 |
+
metadata_exclude: []
|
114 |
+
metadata_include: []
|
115 |
+
partition_endpoint: 'http://localhost'
|
116 |
+
unstructured_port: 8005
|
117 |
+
partition_by_api: True
|
118 |
+
|
119 |
+
chunking:
|
120 |
+
enabled: True
|
121 |
+
strategy: 'basic'
|
122 |
+
chunk_max_characters: 1500
|
123 |
+
chunk_overlap: 300
|
124 |
+
|
125 |
+
embedding:
|
126 |
+
enabled: False
|
127 |
+
provider: 'langchain-huggingface'
|
128 |
+
model_name: 'intfloat/e5-large-v2'
|
129 |
+
|
130 |
+
destination_connectors:
|
131 |
+
enabled: False
|
132 |
+
type: 'chroma'
|
133 |
+
batch_size: 80
|
134 |
+
chroma:
|
135 |
+
host: 'localhost'
|
136 |
+
port: 8004
|
137 |
+
collection_name: 'snconf'
|
138 |
+
tenant: 'default_tenant'
|
139 |
+
database: 'default_database'
|
140 |
+
qdrant:
|
141 |
+
location: 'http://localhost:6333'
|
142 |
+
collection_name: 'test'
|
143 |
+
|
144 |
+
additional_processing:
|
145 |
+
enabled: True
|
146 |
+
extend_metadata: True
|
147 |
+
replace_table_text: True
|
148 |
+
table_text_key: 'text_as_html'
|
149 |
+
return_langchain_docs: True
|
150 |
+
convert_metadata_keys_to_string: True
|
151 |
+
```
|
152 |
+
|
153 |
+
Make sure to place the `config.yaml` file in the desired folder.
|
154 |
+
|
155 |
+
3. Create an instance of the `SambaParse` class, passing the path to the YAML configuration file:
|
156 |
+
|
157 |
+
```python
|
158 |
+
sambaparse = SambaParse('path/to/config.yaml')
|
159 |
+
```
|
160 |
+
|
161 |
+
4. Use the `run_ingest` method to process your data:
|
162 |
+
|
163 |
+
- For a single file:
|
164 |
+
|
165 |
+
```python
|
166 |
+
source_type = 'local'
|
167 |
+
input_path = 'path/to/your/file.pdf'
|
168 |
+
additional_metadata = {'key': 'value'}
|
169 |
+
texts, metadata_list, langchain_docs = sambaparse.run_ingest(source_type, input_path=input_path, additional_metadata=additional_metadata)
|
170 |
+
```
|
171 |
+
|
172 |
+
- For a folder:
|
173 |
+
|
174 |
+
```python
|
175 |
+
source_type = 'local'
|
176 |
+
input_path = 'path/to/your/file.pdf'
|
177 |
+
additional_metadata = {'key': 'value'}
|
178 |
+
texts, metadata_list, langchain_docs = sambaparse.run_ingest(source_type, input_path=input_path, additional_metadata=additional_metadata)
|
179 |
+
```
|
180 |
+
|
181 |
+
- For Confluence:
|
182 |
+
|
183 |
+
```python
|
184 |
+
source_type = 'confluence'
|
185 |
+
additional_metadata = {'key': 'value'}
|
186 |
+
texts, metadata_list, langchain_docs = sambaparse.run_ingest(source_type, additional_metadata=additional_metadata)
|
187 |
+
```
|
188 |
+
|
189 |
+
Note that for conflence you must enable embedding and destinatation connectors automatically ie Chroma and turn off additional processing (ie langchain), an example yaml to do that is below
|
190 |
+
|
191 |
+
```yaml
|
192 |
+
processor:
|
193 |
+
verbose: True
|
194 |
+
output_dir: './output'
|
195 |
+
num_processes: 2
|
196 |
+
|
197 |
+
sources:
|
198 |
+
local:
|
199 |
+
recursive: True
|
200 |
+
confluence:
|
201 |
+
api_token: 'your_confluence_api_token'
|
202 |
+
user_email: 'your_email@example.com'
|
203 |
+
url: 'https://your-confluence-url.atlassian.net'
|
204 |
+
github:
|
205 |
+
url: 'owner/repo'
|
206 |
+
branch: 'main'
|
207 |
+
google_drive:
|
208 |
+
service_account_key: 'path/to/service_account_key.json'
|
209 |
+
recursive: True
|
210 |
+
drive_id: 'your_drive_id'
|
211 |
+
|
212 |
+
partitioning:
|
213 |
+
pdf_infer_table_structure: True
|
214 |
+
skip_infer_table_types: []
|
215 |
+
strategy: 'auto'
|
216 |
+
hi_res_model_name: 'yolox'
|
217 |
+
ocr_languages: ['eng']
|
218 |
+
encoding: 'utf-8'
|
219 |
+
fields_include: ['element_id', 'text', 'type', 'metadata', 'embeddings']
|
220 |
+
flatten_metadata: False
|
221 |
+
metadata_exclude: []
|
222 |
+
metadata_include: []
|
223 |
+
partition_endpoint: 'http://localhost'
|
224 |
+
unstructured_port: 8005
|
225 |
+
partition_by_api: True
|
226 |
+
|
227 |
+
chunking:
|
228 |
+
enabled: True
|
229 |
+
strategy: 'basic'
|
230 |
+
chunk_max_characters: 1500
|
231 |
+
chunk_overlap: 300
|
232 |
+
|
233 |
+
embedding:
|
234 |
+
enabled: True
|
235 |
+
provider: 'langchain-huggingface'
|
236 |
+
model_name: 'intfloat/e5-large-v2'
|
237 |
+
|
238 |
+
destination_connectors:
|
239 |
+
enabled: True
|
240 |
+
type: 'chroma'
|
241 |
+
batch_size: 80
|
242 |
+
chroma:
|
243 |
+
host: 'localhost'
|
244 |
+
port: 8004
|
245 |
+
collection_name: 'snconf'
|
246 |
+
tenant: 'default_tenant'
|
247 |
+
database: 'default_database'
|
248 |
+
qdrant:
|
249 |
+
location: 'http://localhost:6333'
|
250 |
+
collection_name: 'test'
|
251 |
+
|
252 |
+
additional_processing:
|
253 |
+
enabled: False
|
254 |
+
extend_metadata: True
|
255 |
+
replace_table_text: True
|
256 |
+
table_text_key: 'text_as_html'
|
257 |
+
return_langchain_docs: True
|
258 |
+
convert_metadata_keys_to_string: True
|
259 |
+
```
|
260 |
+
|
261 |
+
In addition for confluence you will need to have a Chroma Server running on port 8004, you can do this by running the docker command below
|
262 |
+
|
263 |
+
```bash
|
264 |
+
docker run -d --rm --name chromadb -v ./chroma:/chroma/chroma -e IS_PERSISTENT=TRUE -e ANONYMIZED_TELEMETRY=TRUE -p 8004:8000 chromadb/chroma:latest
|
265 |
+
```
|
266 |
+
|
267 |
+
The `run_ingest` method returns a tuple containing the extracted texts, metadata, and LangChain documents (if `return_langchain_docs` is set to `True` in the configuration).
|
268 |
+
|
269 |
+
5. Process the returned data as needed:
|
270 |
+
- `texts`: A list of extracted text elements from the documents.
|
271 |
+
- `metadata_list`: A list of metadata dictionaries for each text element.
|
272 |
+
- `langchain_docs`: A list of LangChain `Document` objects, which combine the text and metadata.
|
273 |
+
|
274 |
+
#### Configuration Options
|
275 |
+
|
276 |
+
The YAML configuration file allows you to customize various aspects of the ingestion process. Here are some of the key options:
|
277 |
+
|
278 |
+
- `processor`: Settings related to the processing of documents, such as the output directory and the number of processes to use.
|
279 |
+
- `sources`: Configuration for different data sources, including local files, Confluence, GitHub, and Google Drive.
|
280 |
+
- `partitioning`: Options for partitioning the documents, including the strategy, OCR languages, and API settings.
|
281 |
+
- `chunking`: Settings for chunking the documents, such as enabling chunking, specifying the chunking strategy, and setting the maximum chunk size and overlap.
|
282 |
+
- `embedding`: Options for embedding the documents, including enabling embedding, specifying the embedding provider, and setting the model name.
|
283 |
+
- `additional_processing`: Configuration for additional processing steps, such as extending metadata, replacing table text, and returning LangChain documents.
|
284 |
+
|
285 |
+
Make sure to review and modify the configuration file according to your specific requirements.
|
utils/parsing/config.yaml
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
processor:
|
2 |
+
verbose: True
|
3 |
+
output_dir: './output'
|
4 |
+
num_processes: 2
|
5 |
+
reprocess: False
|
6 |
+
|
7 |
+
sources:
|
8 |
+
local:
|
9 |
+
recursive: True
|
10 |
+
confluence:
|
11 |
+
api_token: 'your_confluence_api_token'
|
12 |
+
user_email: 'your_email@example.com'
|
13 |
+
url: 'https://your-confluence-url.atlassian.net'
|
14 |
+
github:
|
15 |
+
url: 'owner/repo'
|
16 |
+
branch: 'main'
|
17 |
+
google_drive:
|
18 |
+
service_account_key: 'path/to/service_account_key.json'
|
19 |
+
recursive: True
|
20 |
+
drive_id: 'your_drive_id'
|
21 |
+
|
22 |
+
partitioning:
|
23 |
+
skip_infer_table_types: []
|
24 |
+
strategy: 'auto'
|
25 |
+
hi_res_model_name: 'yolox'
|
26 |
+
ocr_languages: ['eng']
|
27 |
+
encoding: 'utf-8'
|
28 |
+
fields_include: ['element_id', 'text', 'type', 'metadata', 'embeddings']
|
29 |
+
flatten_metadata: False
|
30 |
+
metadata_exclude: []
|
31 |
+
metadata_include: []
|
32 |
+
partition_endpoint: 'http://localhost'
|
33 |
+
unstructured_port: 8005
|
34 |
+
partition_by_api: False # set as true if using API server
|
35 |
+
default_unstructured_api_key: 123456789abcde
|
36 |
+
|
37 |
+
chunking:
|
38 |
+
enabled: True
|
39 |
+
strategy: 'by_title'
|
40 |
+
chunk_max_characters: 1500
|
41 |
+
chunk_overlap: 300
|
42 |
+
combine_under_n_chars: 1500
|
43 |
+
|
44 |
+
embedding:
|
45 |
+
enabled: False
|
46 |
+
provider: 'langchain-huggingface'
|
47 |
+
model_name: 'intfloat/e5-large-v2'
|
48 |
+
|
49 |
+
destination_connectors:
|
50 |
+
enabled: False
|
51 |
+
type: 'chroma'
|
52 |
+
batch_size: 80
|
53 |
+
chroma:
|
54 |
+
host: 'localhost'
|
55 |
+
port: 8004
|
56 |
+
collection_name: 'snconf'
|
57 |
+
tenant: 'default_tenant'
|
58 |
+
database: 'default_database'
|
59 |
+
qdrant:
|
60 |
+
location: 'http://localhost:6333'
|
61 |
+
collection_name: 'test'
|
62 |
+
|
63 |
+
additional_processing:
|
64 |
+
enabled: True
|
65 |
+
extend_metadata: True
|
66 |
+
replace_table_text: True
|
67 |
+
table_text_key: 'text_as_html'
|
68 |
+
return_langchain_docs: True
|
69 |
+
convert_metadata_keys_to_string: True
|
utils/parsing/docker-compose.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3.9'
|
2 |
+
|
3 |
+
networks:
|
4 |
+
net:
|
5 |
+
driver: bridge
|
6 |
+
|
7 |
+
services:
|
8 |
+
unstructured-api:
|
9 |
+
image: downloads.unstructured.io/unstructured-io/unstructured-api:latest
|
10 |
+
command: --port 8000 --host 0.0.0.0
|
11 |
+
ports:
|
12 |
+
- "${UNSTRUCTURED_PORT:-8005}:8000"
|
13 |
+
env_file:
|
14 |
+
- ../../.env
|
15 |
+
|
16 |
+
networks:
|
17 |
+
- net
|
18 |
+
|
19 |
+
chromadb:
|
20 |
+
image: chromadb/chroma:latest
|
21 |
+
volumes:
|
22 |
+
- ./chromadb:/chroma/chroma
|
23 |
+
environment:
|
24 |
+
- IS_PERSISTENT=TRUE
|
25 |
+
- PERSIST_DIRECTORY=/chroma/chroma
|
26 |
+
- ANONYMIZED_TELEMETRY=${ANONYMIZED_TELEMETRY:-TRUE}
|
27 |
+
ports:
|
28 |
+
- "${CHROMA_PORT:-8004}:8000"
|
29 |
+
networks:
|
30 |
+
- net
|
utils/parsing/parse_usage.ipynb
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 15,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"This is the repo dir /Users/kwasia/Documents/Projects/ai-starter-kit\n"
|
13 |
+
]
|
14 |
+
}
|
15 |
+
],
|
16 |
+
"source": [
|
17 |
+
"import os\n",
|
18 |
+
"import sys\n",
|
19 |
+
"\n",
|
20 |
+
"current_dir = os.getcwd()\n",
|
21 |
+
"kit_dir = os.path.abspath(os.path.join(current_dir, '..'))\n",
|
22 |
+
"repo_dir = os.path.abspath(os.path.join(kit_dir, '..'))\n",
|
23 |
+
"\n",
|
24 |
+
"sys.path.append(kit_dir)\n",
|
25 |
+
"sys.path.append(repo_dir)\n",
|
26 |
+
"\n",
|
27 |
+
"print(f'This is the repo dir {repo_dir}')"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "code",
|
32 |
+
"execution_count": 16,
|
33 |
+
"metadata": {},
|
34 |
+
"outputs": [
|
35 |
+
{
|
36 |
+
"data": {
|
37 |
+
"text/plain": [
|
38 |
+
"True"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
"execution_count": 16,
|
42 |
+
"metadata": {},
|
43 |
+
"output_type": "execute_result"
|
44 |
+
}
|
45 |
+
],
|
46 |
+
"source": [
|
47 |
+
"# Load DotEnv\n",
|
48 |
+
"\n",
|
49 |
+
"from dotenv import load_dotenv\n",
|
50 |
+
"\n",
|
51 |
+
"load_dotenv('../../.env')"
|
52 |
+
]
|
53 |
+
},
|
54 |
+
{
|
55 |
+
"cell_type": "code",
|
56 |
+
"execution_count": 17,
|
57 |
+
"metadata": {},
|
58 |
+
"outputs": [],
|
59 |
+
"source": [
|
60 |
+
"from utils.parsing.sambaparse import SambaParse"
|
61 |
+
]
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"cell_type": "markdown",
|
65 |
+
"metadata": {},
|
66 |
+
"source": [
|
67 |
+
"# Use Case 1 - Process a Single File"
|
68 |
+
]
|
69 |
+
},
|
70 |
+
{
|
71 |
+
"cell_type": "code",
|
72 |
+
"execution_count": 19,
|
73 |
+
"metadata": {},
|
74 |
+
"outputs": [
|
75 |
+
{
|
76 |
+
"name": "stderr",
|
77 |
+
"output_type": "stream",
|
78 |
+
"text": [
|
79 |
+
"2024-06-20 16:15:20,971 - INFO - Deleting contents of output directory: ./output\n"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"name": "stderr",
|
84 |
+
"output_type": "stream",
|
85 |
+
"text": [
|
86 |
+
"2024-06-20 16:15:20,995 - INFO - Running command: unstructured-ingest local --output-dir ./output --num-processes 2 --strategy auto --ocr-languages eng --encoding utf-8 --fields-include element_id,text,type,metadata,embeddings --metadata-exclude --metadata-include --pdf-infer-table-structure --input-path \"./test_docs/samba_turbo.pdf\" --recursive --verbose --partition-by-api --api-key EA6ZX3037WEZUV8THwco --partition-endpoint http://localhost:8005 --pdf-infer-table-structure --chunking-strategy basic --chunk-max-characters 1500 --chunk-overlap 300\n",
|
87 |
+
"2024-06-20 16:15:20,996 - INFO - This may take some time depending on the size of your data. Please be patient...\n",
|
88 |
+
"2024-06-20 16:15:20,996 - INFO - This may take some time depending on the size of your data. Please be patient...\n",
|
89 |
+
"/Users/kwasia/.pyenv/versions/sambaparse/lib/python3.10/site-packages/dataclasses_json/core.py:201: RuntimeWarning: 'NoneType' object value of non-optional type additional_partition_args detected when decoding CliPartitionConfig.\n",
|
90 |
+
" warnings.warn(\n",
|
91 |
+
"2024-06-20 16:15:22,908 MainProcess INFO running pipeline: DocFactory -> Reader -> Partitioner -> Chunker -> Copier with config: {\"reprocess\": false, \"verbose\": true, \"work_dir\": \"/Users/kwasia/.cache/unstructured/ingest/pipeline\", \"output_dir\": \"./output\", \"num_processes\": 2, \"raise_on_error\": false}\n",
|
92 |
+
"2024-06-20 16:15:24,658 MainProcess INFO Running doc factory to generate ingest docs. Source connector: {\"processor_config\": {\"reprocess\": false, \"verbose\": true, \"work_dir\": \"/Users/kwasia/.cache/unstructured/ingest/pipeline\", \"output_dir\": \"./output\", \"num_processes\": 2, \"raise_on_error\": false}, \"read_config\": {\"download_dir\": null, \"re_download\": false, \"preserve_downloads\": false, \"download_only\": false, \"max_docs\": null}, \"connector_config\": {\"input_path\": \"./test_docs/samba_turbo.pdf\", \"recursive\": true, \"file_glob\": null}}\n",
|
93 |
+
"2024-06-20 16:15:24,661 MainProcess INFO processing 1 docs via 2 processes\n",
|
94 |
+
"2024-06-20 16:15:24,661 MainProcess INFO Calling Reader with 1 docs\n",
|
95 |
+
"2024-06-20 16:15:24,661 MainProcess INFO Running source node to download data associated with ingest docs\n",
|
96 |
+
"2024-06-20 16:15:26,511 SpawnPoolWorker-3 INFO File exists: test_docs/samba_turbo.pdf, skipping download\n",
|
97 |
+
"2024-06-20 16:15:26,522 MainProcess INFO Calling Partitioner with 1 docs\n",
|
98 |
+
"2024-06-20 16:15:26,523 MainProcess INFO Running partition node to extract content from json files. Config: {\"pdf_infer_table_structure\": true, \"strategy\": \"auto\", \"ocr_languages\": [\"eng\"], \"encoding\": \"utf-8\", \"additional_partition_args\": null, \"skip_infer_table_types\": null, \"fields_include\": [\"element_id\", \"text\", \"type\", \"metadata\", \"embeddings\"], \"flatten_metadata\": false, \"metadata_exclude\": [\"--metadata-include\"], \"metadata_include\": [], \"partition_endpoint\": \"http://localhost:8005\", \"partition_by_api\": true, \"api_key\": \"*******\", \"hi_res_model_name\": null}, partition kwargs: {}]\n",
|
99 |
+
"2024-06-20 16:15:26,523 MainProcess INFO Creating /Users/kwasia/.cache/unstructured/ingest/pipeline/partitioned\n",
|
100 |
+
"2024-06-20 16:15:28,387 SpawnPoolWorker-4 INFO Processing test_docs/samba_turbo.pdf\n",
|
101 |
+
"2024-06-20 16:15:29,836 SpawnPoolWorker-4 DEBUG Using remote partition (http://localhost:8005)\n",
|
102 |
+
"2024-06-20 16:15:40,244 SpawnPoolWorker-4 INFO writing partitioned content to /Users/kwasia/.cache/unstructured/ingest/pipeline/partitioned/eb87c25354d57b8c7434994ca9c3f796.json\n",
|
103 |
+
"2024-06-20 16:15:40,254 MainProcess INFO Calling Chunker with 1 docs\n",
|
104 |
+
"2024-06-20 16:15:40,255 MainProcess INFO Running chunking node. Chunking config: {\"chunking_strategy\": \"basic\", \"combine_text_under_n_chars\": null, \"include_orig_elements\": true, \"max_characters\": 1500, \"multipage_sections\": true, \"new_after_n_chars\": null, \"overlap\": 300, \"overlap_all\": false}]\n",
|
105 |
+
"2024-06-20 16:15:40,255 MainProcess INFO Creating /Users/kwasia/.cache/unstructured/ingest/pipeline/chunked\n",
|
106 |
+
"2024-06-20 16:15:42,318 SpawnPoolWorker-6 INFO writing chunking content to /Users/kwasia/.cache/unstructured/ingest/pipeline/chunked/df2636b5a36c11e91958dfd7ae81ddb1.json\n",
|
107 |
+
"2024-06-20 16:15:42,323 MainProcess INFO Calling Copier with 1 docs\n",
|
108 |
+
"2024-06-20 16:15:42,323 MainProcess INFO Running copy node to move content to desired output location\n",
|
109 |
+
"2024-06-20 16:15:44,114 SpawnPoolWorker-9 INFO Copying /Users/kwasia/.cache/unstructured/ingest/pipeline/chunked/df2636b5a36c11e91958dfd7ae81ddb1.json -> output/samba_turbo.pdf.json\n",
|
110 |
+
"2024-06-20 16:15:44,320 - INFO - Ingest process completed successfully!\n",
|
111 |
+
"2024-06-20 16:15:44,321 - INFO - Performing additional processing...\n",
|
112 |
+
"2024-06-20 16:15:44,324 - INFO - Additional processing completed.\n"
|
113 |
+
]
|
114 |
+
}
|
115 |
+
],
|
116 |
+
"source": [
|
117 |
+
"config_yaml = './config.yaml'\n",
|
118 |
+
"sambaparse = SambaParse(config_yaml)\n",
|
119 |
+
"\n",
|
120 |
+
"source_type = 'local'\n",
|
121 |
+
"input_path = './test_docs/samba_turbo.pdf'\n",
|
122 |
+
"additional_metadata = {'key': 'value'}\n",
|
123 |
+
"\n",
|
124 |
+
"texts, metadata_list, langchain_docs = sambaparse.run_ingest(\n",
|
125 |
+
" source_type, input_path=input_path, additional_metadata=additional_metadata\n",
|
126 |
+
")"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"cell_type": "code",
|
131 |
+
"execution_count": 20,
|
132 |
+
"metadata": {},
|
133 |
+
"outputs": [
|
134 |
+
{
|
135 |
+
"name": "stdout",
|
136 |
+
"output_type": "stream",
|
137 |
+
"text": [
|
138 |
+
"This is the length of the lanchain docs 5\n",
|
139 |
+
"This is an example langcahin doc \n",
|
140 |
+
"\n",
|
141 |
+
" page_content=\"6/20/24, 3:23 PM\\n\\nSambaNova has broken the 1000 t/s barrier: why it's a big deal for enterprise AI\\n\\nG\\\\SambaNovar\\n\\nEN\\n\\nBACK TO RESOURCES\\n\\n<\\n\\nPREVIOUS | NEXT\\n\\n>\\n\\nMay 29, 2024\\n\\njn\\n\\nNX\\n\\nfF\\n\\nBS\\n\\nSambaNova has broken the 1000 t/s barrier: why it's a big deal for enterprise AI\\n\\nSambaNova is the clear winner of the latest large language model LLM benchmark by Artificial Analysis. Topping the Leaderboad at over 1000 tokens per second (t/s), Samba-1 Turbo sets a new record for Llama 3 8B performance on a single SN40L node and with full precision.\\n\\nWith speeds like this, enterprises can expect to accelerate an array of use cases and will enable innovation around unblocking agentic workflow, copilot, and synthetic data, to name a few. This breakthrough in AI technology is possible because the purpose-built SambaNova SN40L Reconfigurable Dataflow Unit RDU can hold hundreds of models at the same time and can switch between them in microseconds.\\n\\nSpeed for today and tomorrow\" metadata={'filename': 'samba_turbo.pdf', 'filetype': 'application/pdf', 'languages': 'eng', 'page_number': '1', 'orig_elements': 'eJzVl21v2zYQx7/KwW+2AV7DJ1FUMQxI22wrlqZFHrYCbVHw4WhzkSVBkut63b77jvYejCJF7BdDkleCyBN597v/Hak3nyZY4wKb8X0Kk8cwKWzl0TimbZBSVEWprZZRSx20YbJwkylMFjjaYEdL9p8mMdXY2AXmjwe7cPb9uOxd+6gLMdvm6XHdbaZt19XJ2zG1zdHf07VtZks7w4Hm30ywmU3e0WhHI++b5cJhT+P8Txoa8eOY19BHgh0JNQX5WEh49SIv8s/6P6EN9AWZfx5VjIWpKvRWRFlZ6dEJxr3SElEFr8xdR3WRtzhrP1iY2wFc315jA+McgTPGYDyiMdv3CfvHsJqvIY1fDWDBpRkEtDXEtgcKFvuuTwPC8fP9qJQlE8YV0ToeXamKqiAuKI3RJTNFFe6CymakPyBzuxh/fPv2X5L9LoPLNNZ4EwIUpQqxRKYL4RwtaQxHdMYrFSup7lzuu0E8X5DdTUGEGENRVbIqfCU1s0wFGazTVMGFJnZ3HcTJ2X56VFgq6jWoKMOlM0FykmMMEjk6LoW4D3q8lfVu4E+On/4Mly/h/OTi5dX505OLvTQZjUWunZeclwKDFJ4JpqMzUflKlXgfMNyaql0M3+2GfdXQ9jhr+/Q7hstscQMCxlUUUikUSjkjnePCCGFtYJE77x+eEl6dn/zy/OXVBfwBZyevL/fSQSWo+3tqzRELaniyUpoTFKkq5JWLd9KbPodwa6J2IXx/sA5Ky6pYeMZZITPlUAmnOTFBI6lC2IND8MKuQVRTEEyoXRpndLyTKx/wSyQsL5UpiAcLwUWLRisTSsmCcVYUgd8HEgdVxG/NXkVQBCtj0JJjLA1TvhROBaF9QR3IBV48uLjPXu8VNzeVQDR05SF5Oead08JjENFLr6R4ePmOP+wVty+rUvsouLHcM7qHlbzAwAXngU4Zey8Ov8PuAPsd+nQniq5iOieawi4dD5F5Q1WvqkKZ+ODi/j9/ab58dQqs0oXWBT0tp0opQwx0I6GyQS+re3Fe3JrpmymmYUPP12h7WKWmwR7auBmr6SAdRnr0s/yy9QkWbcAaTk9fgMPGzxe2vwa3huN+TDH5RIyPG1uvhzQ8gsu261Iz26x2urmfu9YGsCO0H2ifbcpyBgfo6H1A3zYBvqYsfjOFjY/fcrjM5GhuzIlscAU9mfVhk8rT2i4sSDBP8gI0srCNR2gbMh1o5xrh4kyxU2jIa7C0+CqNc4jLuoaO1kkDYX900ImpSAUycG1N6Zig25NkvIraMvqLlE7ci4o6SAu/ZiJDhxgGqNM1UrbSMN2pkgG8bQA/duhHShdY7wkJUcpEIfNaZ8ksqZ68zeZbzoSY4naUAlIVSS1HRdbtkmaXjatbf521QZE0Y/KwavvrWLerKfi2S3U7TjfrDOuG1JMNMtBp3j/DpPxGXJHCyFfqAmivxzktPZvTZlTVMKKfN23dztZZ4V07DCl74uiPP/uZBdktexrHb90y1SP8VxFbwZxnLcY0W/abEJ7R5tk7uGrSCOfPrjZM5m0dYE4B9RkeMdgUx5AFnncYsqNjWmyVlz8YSH5+Tm6MK9z2rUV2eJF8327VPxymRsOYDFHooL1WwmjDCxc8liW6KnJu74MaD+vvWYib2h7bQMLK5MZ20fZ9u7qhV7/7C2wCbXA=', 'key': 'value', 'type': 'CompositeElement', 'element_id': '34922f62e3c3e7600d32eb0627b79202', 'page': '1'}\n"
|
142 |
+
]
|
143 |
+
}
|
144 |
+
],
|
145 |
+
"source": [
|
146 |
+
"# Inspect the Output\n",
|
147 |
+
"\n",
|
148 |
+
"# 1. Number of Chunks\n",
|
149 |
+
"print(f'This is the length of the lanchain docs {len(langchain_docs)}')\n",
|
150 |
+
"\n",
|
151 |
+
"# 2. Example Chunk\n",
|
152 |
+
"print(f'This is an example langcahin doc \\n\\n {langchain_docs[0]}')"
|
153 |
+
]
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "markdown",
|
157 |
+
"metadata": {},
|
158 |
+
"source": [
|
159 |
+
"# Use Case 2 - Process Whole Directory "
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "code",
|
164 |
+
"execution_count": null,
|
165 |
+
"metadata": {},
|
166 |
+
"outputs": [],
|
167 |
+
"source": [
|
168 |
+
"config_yaml = './config.yaml'\n",
|
169 |
+
"sambaparse = SambaParse(config_yaml)\n",
|
170 |
+
"\n",
|
171 |
+
"source_type = 'local'\n",
|
172 |
+
"input_path = './test_docs'\n",
|
173 |
+
"additional_metadata = {'key': 'value'}\n",
|
174 |
+
"\n",
|
175 |
+
"texts, metadata_list, langchain_docs = sambaparse.run_ingest(\n",
|
176 |
+
" source_type, input_path=input_path, additional_metadata=additional_metadata\n",
|
177 |
+
")"
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"cell_type": "code",
|
182 |
+
"execution_count": 22,
|
183 |
+
"metadata": {},
|
184 |
+
"outputs": [
|
185 |
+
{
|
186 |
+
"name": "stdout",
|
187 |
+
"output_type": "stream",
|
188 |
+
"text": [
|
189 |
+
"This is the length of the lanchain docs 44\n",
|
190 |
+
"This is an example langcahin doc \n",
|
191 |
+
"\n",
|
192 |
+
" page_content=\"6/20/24, 3:23 PM\\n\\nSambaNova has broken the 1000 t/s barrier: why it's a big deal for enterprise AI\\n\\nG\\\\SambaNovar\\n\\nEN\\n\\nBACK TO RESOURCES\\n\\n<\\n\\nPREVIOUS | NEXT\\n\\n>\\n\\nMay 29, 2024\\n\\njn\\n\\nNX\\n\\nfF\\n\\nBS\\n\\nSambaNova has broken the 1000 t/s barrier: why it's a big deal for enterprise AI\\n\\nSambaNova is the clear winner of the latest large language model LLM benchmark by Artificial Analysis. Topping the Leaderboad at over 1000 tokens per second (t/s), Samba-1 Turbo sets a new record for Llama 3 8B performance on a single SN40L node and with full precision.\\n\\nWith speeds like this, enterprises can expect to accelerate an array of use cases and will enable innovation around unblocking agentic workflow, copilot, and synthetic data, to name a few. This breakthrough in AI technology is possible because the purpose-built SambaNova SN40L Reconfigurable Dataflow Unit RDU can hold hundreds of models at the same time and can switch between them in microseconds.\\n\\nSpeed for today and tomorrow\" metadata={'filename': 'samba_turbo.pdf', 'filetype': 'application/pdf', 'languages': 'eng', 'page_number': '1', 'orig_elements': 'eJzVl21v2zYQx7/KwW+2AV7DJ1FUMQxI22wrlqZFHrYCbVHw4WhzkSVBkut63b77jvYejCJF7BdDkleCyBN597v/Hak3nyZY4wKb8X0Kk8cwKWzl0TimbZBSVEWprZZRSx20YbJwkylMFjjaYEdL9p8mMdXY2AXmjwe7cPb9uOxd+6gLMdvm6XHdbaZt19XJ2zG1zdHf07VtZks7w4Hm30ywmU3e0WhHI++b5cJhT+P8Txoa8eOY19BHgh0JNQX5WEh49SIv8s/6P6EN9AWZfx5VjIWpKvRWRFlZ6dEJxr3SElEFr8xdR3WRtzhrP1iY2wFc315jA+McgTPGYDyiMdv3CfvHsJqvIY1fDWDBpRkEtDXEtgcKFvuuTwPC8fP9qJQlE8YV0ToeXamKqiAuKI3RJTNFFe6CymakPyBzuxh/fPv2X5L9LoPLNNZ4EwIUpQqxRKYL4RwtaQxHdMYrFSup7lzuu0E8X5DdTUGEGENRVbIqfCU1s0wFGazTVMGFJnZ3HcTJ2X56VFgq6jWoKMOlM0FykmMMEjk6LoW4D3q8lfVu4E+On/4Mly/h/OTi5dX505OLvTQZjUWunZeclwKDFJ4JpqMzUflKlXgfMNyaql0M3+2GfdXQ9jhr+/Q7hstscQMCxlUUUikUSjkjnePCCGFtYJE77x+eEl6dn/zy/OXVBfwBZyevL/fSQSWo+3tqzRELaniyUpoTFKkq5JWLd9KbPodwa6J2IXx/sA5Ky6pYeMZZITPlUAmnOTFBI6lC2IND8MKuQVRTEEyoXRpndLyTKx/wSyQsL5UpiAcLwUWLRisTSsmCcVYUgd8HEgdVxG/NXkVQBCtj0JJjLA1TvhROBaF9QR3IBV48uLjPXu8VNzeVQDR05SF5Oead08JjENFLr6R4ePmOP+wVty+rUvsouLHcM7qHlbzAwAXngU4Zey8Ov8PuAPsd+nQniq5iOieawi4dD5F5Q1WvqkKZ+ODi/j9/ab58dQqs0oXWBT0tp0opQwx0I6GyQS+re3Fe3JrpmymmYUPP12h7WKWmwR7auBmr6SAdRnr0s/yy9QkWbcAaTk9fgMPGzxe2vwa3huN+TDH5RIyPG1uvhzQ8gsu261Iz26x2urmfu9YGsCO0H2ifbcpyBgfo6H1A3zYBvqYsfjOFjY/fcrjM5GhuzIlscAU9mfVhk8rT2i4sSDBP8gI0srCNR2gbMh1o5xrh4kyxU2jIa7C0+CqNc4jLuoaO1kkDYX900ImpSAUycG1N6Zig25NkvIraMvqLlE7ci4o6SAu/ZiJDhxgGqNM1UrbSMN2pkgG8bQA/duhHShdY7wkJUcpEIfNaZ8ksqZ68zeZbzoSY4naUAlIVSS1HRdbtkmaXjatbf521QZE0Y/KwavvrWLerKfi2S3U7TjfrDOuG1JMNMtBp3j/DpPxGXJHCyFfqAmivxzktPZvTZlTVMKKfN23dztZZ4V07DCl74uiPP/uZBdktexrHb90y1SP8VxFbwZxnLcY0W/abEJ7R5tk7uGrSCOfPrjZM5m0dYE4B9RkeMdgUx5AFnncYsqNjWmyVlz8YSH5+Tm6MK9z2rUV2eJF8327VPxymRsOYDFHooL1WwmjDCxc8liW6KnJu74MaD+vvWYib2h7bQMLK5MZ20fZ9u7qhV7/7C2wCbXA=', 'key': 'value', 'type': 'CompositeElement', 'element_id': '34922f62e3c3e7600d32eb0627b79202', 'page': '1'}\n"
|
193 |
+
]
|
194 |
+
}
|
195 |
+
],
|
196 |
+
"source": [
|
197 |
+
"# Inspect the Output\n",
|
198 |
+
"\n",
|
199 |
+
"# 1. Number of Chunks\n",
|
200 |
+
"print(f'This is the length of the lanchain docs {len(langchain_docs)}')\n",
|
201 |
+
"\n",
|
202 |
+
"# 2. Example Chunk\n",
|
203 |
+
"print(f'This is an example langcahin doc \\n\\n {langchain_docs[0]}')"
|
204 |
+
]
|
205 |
+
}
|
206 |
+
],
|
207 |
+
"metadata": {
|
208 |
+
"kernelspec": {
|
209 |
+
"display_name": "aisk-fine-tune-embeddings",
|
210 |
+
"language": "python",
|
211 |
+
"name": "python3"
|
212 |
+
},
|
213 |
+
"language_info": {
|
214 |
+
"codemirror_mode": {
|
215 |
+
"name": "ipython",
|
216 |
+
"version": 3
|
217 |
+
},
|
218 |
+
"file_extension": ".py",
|
219 |
+
"mimetype": "text/x-python",
|
220 |
+
"name": "python",
|
221 |
+
"nbconvert_exporter": "python",
|
222 |
+
"pygments_lexer": "ipython3",
|
223 |
+
"version": "3.10.12"
|
224 |
+
}
|
225 |
+
},
|
226 |
+
"nbformat": 4,
|
227 |
+
"nbformat_minor": 2
|
228 |
+
}
|
utils/parsing/requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
unstructured==0.13.6
|
2 |
+
unstructured-client==0.18.0
|
3 |
+
unstructured-inference==0.7.29
|
4 |
+
langchain==0.1.16
|
5 |
+
PyMuPDF==1.23.4
|
6 |
+
PyMuPDFb==1.23.3
|
utils/parsing/sambaparse.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import yaml
|
3 |
+
import subprocess
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
from typing import Dict, Optional, List, Tuple, Union, Any
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
from langchain.docstore.document import Document
|
9 |
+
import shutil
|
10 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
11 |
+
|
12 |
+
load_dotenv()
|
13 |
+
|
14 |
+
logging.basicConfig(
|
15 |
+
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
|
16 |
+
)
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class SambaParse:
|
21 |
+
def __init__(self, config_path: str):
|
22 |
+
with open(config_path, "r") as file:
|
23 |
+
self.config = yaml.safe_load(file)
|
24 |
+
|
25 |
+
# Set the default Unstructured API key as an environment variable if not already set
|
26 |
+
if "UNSTRUCTURED_API_KEY" not in os.environ:
|
27 |
+
default_api_key = self.config.get("partitioning", {}).get("default_unstructured_api_key")
|
28 |
+
if default_api_key:
|
29 |
+
os.environ["UNSTRUCTURED_API_KEY"] = default_api_key
|
30 |
+
|
31 |
+
|
32 |
+
def run_ingest(
|
33 |
+
self,
|
34 |
+
source_type: str,
|
35 |
+
input_path: Optional[str] = None,
|
36 |
+
additional_metadata: Optional[Dict] = None,
|
37 |
+
) -> Tuple[List[str], List[Dict], List[Document]]:
|
38 |
+
"""
|
39 |
+
Runs the ingest process for the specified source type and input path.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
source_type (str): The type of source to ingest (e.g., 'local', 'confluence', 'github', 'google-drive').
|
43 |
+
input_path (Optional[str]): The input path for the source (only required for 'local' source type).
|
44 |
+
additional_metadata (Optional[Dict]): Additional metadata to include in the processed documents.
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
Tuple[List[str], List[Dict], List[Document]]: A tuple containing the extracted texts, metadata, and LangChain documents.
|
48 |
+
"""
|
49 |
+
if not self.config["partitioning"]["partition_by_api"]:
|
50 |
+
return self._run_ingest_pymupdf(input_path, additional_metadata)
|
51 |
+
|
52 |
+
output_dir = self.config["processor"]["output_dir"]
|
53 |
+
|
54 |
+
# Create the output directory if it doesn't exist
|
55 |
+
os.makedirs(output_dir, exist_ok=True)
|
56 |
+
|
57 |
+
# Delete contents of the output directory using shell command
|
58 |
+
del_command = f"rm -rf {output_dir}/*"
|
59 |
+
logger.info(f"Deleting contents of output directory: {output_dir}")
|
60 |
+
subprocess.run(del_command, shell=True, check=True)
|
61 |
+
|
62 |
+
command = [
|
63 |
+
"unstructured-ingest",
|
64 |
+
source_type,
|
65 |
+
"--output-dir",
|
66 |
+
output_dir,
|
67 |
+
"--num-processes",
|
68 |
+
str(self.config["processor"]["num_processes"]),
|
69 |
+
]
|
70 |
+
|
71 |
+
if self.config["processor"]["reprocess"] == True:
|
72 |
+
command.extend(["--reprocess"])
|
73 |
+
|
74 |
+
# Add partition arguments
|
75 |
+
command.extend(
|
76 |
+
[
|
77 |
+
"--strategy",
|
78 |
+
self.config["partitioning"]["strategy"],
|
79 |
+
"--ocr-languages",
|
80 |
+
",".join(self.config["partitioning"]["ocr_languages"]),
|
81 |
+
"--encoding",
|
82 |
+
self.config["partitioning"]["encoding"],
|
83 |
+
"--fields-include",
|
84 |
+
",".join(self.config["partitioning"]["fields_include"]),
|
85 |
+
"--metadata-exclude",
|
86 |
+
",".join(self.config["partitioning"]["metadata_exclude"]),
|
87 |
+
"--metadata-include",
|
88 |
+
",".join(self.config["partitioning"]["metadata_include"]),
|
89 |
+
]
|
90 |
+
)
|
91 |
+
|
92 |
+
if self.config["partitioning"]["skip_infer_table_types"]:
|
93 |
+
command.extend(
|
94 |
+
[
|
95 |
+
"--skip-infer-table-types",
|
96 |
+
",".join(self.config["partitioning"]["skip_infer_table_types"]),
|
97 |
+
]
|
98 |
+
)
|
99 |
+
|
100 |
+
if self.config["partitioning"]["flatten_metadata"]:
|
101 |
+
command.append("--flatten-metadata")
|
102 |
+
|
103 |
+
if source_type == "local":
|
104 |
+
if input_path is None:
|
105 |
+
raise ValueError("Input path is required for local source type.")
|
106 |
+
command.extend(["--input-path", f'"{input_path}"'])
|
107 |
+
|
108 |
+
if self.config["sources"]["local"]["recursive"]:
|
109 |
+
command.append("--recursive")
|
110 |
+
elif source_type == "confluence":
|
111 |
+
command.extend(
|
112 |
+
[
|
113 |
+
"--url",
|
114 |
+
self.config["sources"]["confluence"]["url"],
|
115 |
+
"--user-email",
|
116 |
+
self.config["sources"]["confluence"]["user_email"],
|
117 |
+
"--api-token",
|
118 |
+
self.config["sources"]["confluence"]["api_token"],
|
119 |
+
]
|
120 |
+
)
|
121 |
+
elif source_type == "github":
|
122 |
+
command.extend(
|
123 |
+
[
|
124 |
+
"--url",
|
125 |
+
self.config["sources"]["github"]["url"],
|
126 |
+
"--git-branch",
|
127 |
+
self.config["sources"]["github"]["branch"],
|
128 |
+
]
|
129 |
+
)
|
130 |
+
elif source_type == "google-drive":
|
131 |
+
command.extend(
|
132 |
+
[
|
133 |
+
"--drive-id",
|
134 |
+
self.config["sources"]["google_drive"]["drive_id"],
|
135 |
+
"--service-account-key",
|
136 |
+
self.config["sources"]["google_drive"]["service_account_key"],
|
137 |
+
]
|
138 |
+
)
|
139 |
+
if self.config["sources"]["google_drive"]["recursive"]:
|
140 |
+
command.append("--recursive")
|
141 |
+
else:
|
142 |
+
raise ValueError(f"Unsupported source type: {source_type}")
|
143 |
+
|
144 |
+
if self.config["processor"]["verbose"]:
|
145 |
+
command.append("--verbose")
|
146 |
+
|
147 |
+
if self.config["partitioning"]["partition_by_api"]:
|
148 |
+
api_key = os.getenv("UNSTRUCTURED_API_KEY")
|
149 |
+
partition_endpoint_url = f"{self.config['partitioning']['partition_endpoint']}:{self.config['partitioning']['unstructured_port']}"
|
150 |
+
if api_key:
|
151 |
+
command.extend(["--partition-by-api", "--api-key", api_key])
|
152 |
+
command.extend(["--partition-endpoint", partition_endpoint_url])
|
153 |
+
else:
|
154 |
+
logger.warning("No Unstructured API key available. Partitioning by API will be skipped.")
|
155 |
+
|
156 |
+
if self.config["partitioning"]["strategy"] == "hi_res":
|
157 |
+
if (
|
158 |
+
"hi_res_model_name" in self.config["partitioning"]
|
159 |
+
and self.config["partitioning"]["hi_res_model_name"]
|
160 |
+
):
|
161 |
+
command.extend(
|
162 |
+
[
|
163 |
+
"--hi-res-model-name",
|
164 |
+
self.config["partitioning"]["hi_res_model_name"],
|
165 |
+
]
|
166 |
+
)
|
167 |
+
logger.warning(
|
168 |
+
"You've chosen the high-resolution partitioning strategy. Grab a cup of coffee or tea while you wait, as this may take some time due to OCR and table detection."
|
169 |
+
)
|
170 |
+
|
171 |
+
if self.config["chunking"]["enabled"]:
|
172 |
+
command.extend(
|
173 |
+
[
|
174 |
+
"--chunking-strategy",
|
175 |
+
self.config["chunking"]["strategy"],
|
176 |
+
"--chunk-max-characters",
|
177 |
+
str(self.config["chunking"]["chunk_max_characters"]),
|
178 |
+
"--chunk-overlap",
|
179 |
+
str(self.config["chunking"]["chunk_overlap"]),
|
180 |
+
]
|
181 |
+
)
|
182 |
+
|
183 |
+
if self.config["chunking"]["strategy"] == "by_title":
|
184 |
+
command.extend(
|
185 |
+
[
|
186 |
+
"--chunk-combine-text-under-n-chars",
|
187 |
+
str(self.config["chunking"]["combine_under_n_chars"]),
|
188 |
+
]
|
189 |
+
)
|
190 |
+
|
191 |
+
if self.config["embedding"]["enabled"]:
|
192 |
+
command.extend(
|
193 |
+
[
|
194 |
+
"--embedding-provider",
|
195 |
+
self.config["embedding"]["provider"],
|
196 |
+
"--embedding-model-name",
|
197 |
+
self.config["embedding"]["model_name"],
|
198 |
+
]
|
199 |
+
)
|
200 |
+
|
201 |
+
if self.config["destination_connectors"]["enabled"]:
|
202 |
+
destination_type = self.config["destination_connectors"]["type"]
|
203 |
+
if destination_type == "chroma":
|
204 |
+
command.extend(
|
205 |
+
[
|
206 |
+
"chroma",
|
207 |
+
"--host",
|
208 |
+
self.config["destination_connectors"]["chroma"]["host"],
|
209 |
+
"--port",
|
210 |
+
str(self.config["destination_connectors"]["chroma"]["port"]),
|
211 |
+
"--collection-name",
|
212 |
+
self.config["destination_connectors"]["chroma"][
|
213 |
+
"collection_name"
|
214 |
+
],
|
215 |
+
"--tenant",
|
216 |
+
self.config["destination_connectors"]["chroma"]["tenant"],
|
217 |
+
"--database",
|
218 |
+
self.config["destination_connectors"]["chroma"]["database"],
|
219 |
+
"--batch-size",
|
220 |
+
str(self.config["destination_connectors"]["batch_size"]),
|
221 |
+
]
|
222 |
+
)
|
223 |
+
elif destination_type == "qdrant":
|
224 |
+
command.extend(
|
225 |
+
[
|
226 |
+
"qdrant",
|
227 |
+
"--location",
|
228 |
+
self.config["destination_connectors"]["qdrant"]["location"],
|
229 |
+
"--collection-name",
|
230 |
+
self.config["destination_connectors"]["qdrant"][
|
231 |
+
"collection_name"
|
232 |
+
],
|
233 |
+
"--batch-size",
|
234 |
+
str(self.config["destination_connectors"]["batch_size"]),
|
235 |
+
]
|
236 |
+
)
|
237 |
+
else:
|
238 |
+
raise ValueError(
|
239 |
+
f"Unsupported destination connector type: {destination_type}"
|
240 |
+
)
|
241 |
+
|
242 |
+
command_str = " ".join(command)
|
243 |
+
logger.info(f"Running command: {command_str}")
|
244 |
+
logger.info(
|
245 |
+
"This may take some time depending on the size of your data. Please be patient..."
|
246 |
+
)
|
247 |
+
|
248 |
+
subprocess.run(command_str, shell=True, check=True)
|
249 |
+
|
250 |
+
logger.info("Ingest process completed successfully!")
|
251 |
+
|
252 |
+
# Call the additional processing function if enabled
|
253 |
+
if self.config["additional_processing"]["enabled"]:
|
254 |
+
logger.info("Performing additional processing...")
|
255 |
+
texts, metadata_list, langchain_docs = additional_processing(
|
256 |
+
directory=output_dir,
|
257 |
+
extend_metadata=self.config["additional_processing"]["extend_metadata"],
|
258 |
+
additional_metadata=additional_metadata,
|
259 |
+
replace_table_text=self.config["additional_processing"][
|
260 |
+
"replace_table_text"
|
261 |
+
],
|
262 |
+
table_text_key=self.config["additional_processing"]["table_text_key"],
|
263 |
+
return_langchain_docs=self.config["additional_processing"][
|
264 |
+
"return_langchain_docs"
|
265 |
+
],
|
266 |
+
convert_metadata_keys_to_string=self.config["additional_processing"][
|
267 |
+
"convert_metadata_keys_to_string"
|
268 |
+
],
|
269 |
+
)
|
270 |
+
logger.info("Additional processing completed.")
|
271 |
+
return texts, metadata_list, langchain_docs
|
272 |
+
|
273 |
+
def _run_ingest_pymupdf(
|
274 |
+
self, input_path: str, additional_metadata: Optional[Dict] = None
|
275 |
+
) -> Tuple[List[str], List[Dict], List[Document]]:
|
276 |
+
"""
|
277 |
+
Runs the ingest process using PyMuPDF via LangChain.
|
278 |
+
|
279 |
+
Args:
|
280 |
+
input_path (str): The input path for the source.
|
281 |
+
additional_metadata (Optional[Dict]): Additional metadata to include in the processed documents.
|
282 |
+
|
283 |
+
Returns:
|
284 |
+
Tuple[List[str], List[Dict], List[Document]]: A tuple containing the extracted texts, metadata, and LangChain documents.
|
285 |
+
"""
|
286 |
+
if not input_path:
|
287 |
+
raise ValueError("Input path is required for PyMuPDF processing.")
|
288 |
+
|
289 |
+
texts = []
|
290 |
+
metadata_list = []
|
291 |
+
langchain_docs = []
|
292 |
+
|
293 |
+
if os.path.isfile(input_path):
|
294 |
+
file_paths = [input_path]
|
295 |
+
else:
|
296 |
+
file_paths = [
|
297 |
+
os.path.join(input_path, f)
|
298 |
+
for f in os.listdir(input_path)
|
299 |
+
if f.lower().endswith('.pdf')
|
300 |
+
]
|
301 |
+
|
302 |
+
for file_path in file_paths:
|
303 |
+
loader = PyMuPDFLoader(file_path)
|
304 |
+
docs = loader.load()
|
305 |
+
|
306 |
+
for doc in docs:
|
307 |
+
text = doc.page_content
|
308 |
+
metadata = doc.metadata
|
309 |
+
|
310 |
+
# Add 'filename' key to metadata
|
311 |
+
metadata['filename'] = os.path.basename(metadata['source'])
|
312 |
+
|
313 |
+
if additional_metadata:
|
314 |
+
metadata.update(additional_metadata)
|
315 |
+
|
316 |
+
texts.append(text)
|
317 |
+
metadata_list.append(metadata)
|
318 |
+
langchain_docs.append(doc)
|
319 |
+
|
320 |
+
return texts, metadata_list, langchain_docs
|
321 |
+
|
322 |
+
|
323 |
+
def convert_to_string(value: Union[List, Tuple, Dict, Any]) -> str:
|
324 |
+
"""
|
325 |
+
Convert a value to its string representation.
|
326 |
+
|
327 |
+
Args:
|
328 |
+
value (Union[List, Tuple, Dict, Any]): The value to be converted to a string.
|
329 |
+
|
330 |
+
Returns:
|
331 |
+
str: The string representation of the value.
|
332 |
+
"""
|
333 |
+
if isinstance(value, (list, tuple)):
|
334 |
+
return ", ".join(map(str, value))
|
335 |
+
elif isinstance(value, dict):
|
336 |
+
return json.dumps(value)
|
337 |
+
else:
|
338 |
+
return str(value)
|
339 |
+
|
340 |
+
|
341 |
+
def additional_processing(
|
342 |
+
directory: str,
|
343 |
+
extend_metadata: bool,
|
344 |
+
additional_metadata: Optional[Dict],
|
345 |
+
replace_table_text: bool,
|
346 |
+
table_text_key: str,
|
347 |
+
return_langchain_docs: bool,
|
348 |
+
convert_metadata_keys_to_string: bool,
|
349 |
+
):
|
350 |
+
"""
|
351 |
+
Performs additional processing on the extracted documents.
|
352 |
+
|
353 |
+
Args:
|
354 |
+
directory (str): The directory containing the extracted JSON files.
|
355 |
+
extend_metadata (bool): Whether to extend the metadata with additional metadata.
|
356 |
+
additional_metadata (Optional[Dict]): Additional metadata to include in the processed documents.
|
357 |
+
replace_table_text (bool): Whether to replace table text with the specified table text key.
|
358 |
+
table_text_key (str): The key to use for replacing table text.
|
359 |
+
return_langchain_docs (bool): Whether to return LangChain documents.
|
360 |
+
convert_metadata_keys_to_string (bool): Whether to convert non-string metadata keys to string.
|
361 |
+
|
362 |
+
Returns:
|
363 |
+
Tuple[List[str], List[Dict], List[Document]]: A tuple containing the extracted texts, metadata, and LangChain documents.
|
364 |
+
"""
|
365 |
+
if os.path.isfile(directory):
|
366 |
+
file_paths = [directory]
|
367 |
+
else:
|
368 |
+
file_paths = [
|
369 |
+
os.path.join(directory, f)
|
370 |
+
for f in os.listdir(directory)
|
371 |
+
if f.endswith(".json")
|
372 |
+
]
|
373 |
+
|
374 |
+
texts = []
|
375 |
+
metadata_list = []
|
376 |
+
langchain_docs = []
|
377 |
+
|
378 |
+
for file_path in file_paths:
|
379 |
+
with open(file_path, "r") as file:
|
380 |
+
data = json.load(file)
|
381 |
+
|
382 |
+
for element in data:
|
383 |
+
if extend_metadata and additional_metadata:
|
384 |
+
element["metadata"].update(additional_metadata)
|
385 |
+
|
386 |
+
if replace_table_text and element["type"] == "Table":
|
387 |
+
element["text"] = element["metadata"][table_text_key]
|
388 |
+
|
389 |
+
metadata = element["metadata"].copy()
|
390 |
+
if convert_metadata_keys_to_string:
|
391 |
+
metadata = {
|
392 |
+
str(key): convert_to_string(value)
|
393 |
+
for key, value in metadata.items()
|
394 |
+
}
|
395 |
+
for key in element:
|
396 |
+
if key not in ["text", "metadata", "embeddings"]:
|
397 |
+
metadata[key] = element[key]
|
398 |
+
if "page_number" in metadata:
|
399 |
+
metadata["page"] = metadata["page_number"]
|
400 |
+
else:
|
401 |
+
metadata["page"] = 1
|
402 |
+
|
403 |
+
metadata_list.append(metadata)
|
404 |
+
texts.append(element["text"])
|
405 |
+
|
406 |
+
if return_langchain_docs:
|
407 |
+
langchain_docs.extend(get_langchain_docs(texts, metadata_list))
|
408 |
+
|
409 |
+
with open(file_path, "w") as file:
|
410 |
+
json.dump(data, file, indent=2)
|
411 |
+
|
412 |
+
return texts, metadata_list, langchain_docs
|
413 |
+
|
414 |
+
|
415 |
+
def get_langchain_docs(texts: List[str], metadata_list: List[Dict]) -> List[Document]:
|
416 |
+
"""
|
417 |
+
Creates LangChain documents from the extracted texts and metadata.
|
418 |
+
|
419 |
+
Args:
|
420 |
+
texts (List[str]): The extracted texts.
|
421 |
+
metadata_list (List[Dict]): The metadata associated with each text.
|
422 |
+
|
423 |
+
Returns:
|
424 |
+
List[Document]: A list of LangChain documents.
|
425 |
+
"""
|
426 |
+
return [
|
427 |
+
Document(page_content=content, metadata=metadata)
|
428 |
+
for content, metadata in zip(texts, metadata_list)
|
429 |
+
]
|
430 |
+
|
431 |
+
|
432 |
+
def parse_doc_universal(
|
433 |
+
doc: str, additional_metadata: Optional[Dict] = None, source_type: str = "local"
|
434 |
+
) -> Tuple[List[str], List[Dict], List[Document]]:
|
435 |
+
"""
|
436 |
+
Extract text, tables, images, and metadata from a document or a folder of documents.
|
437 |
+
|
438 |
+
Args:
|
439 |
+
doc (str): Path to the document or folder of documents.
|
440 |
+
additional_metadata (Optional[Dict], optional): Additional metadata to include in the processed documents.
|
441 |
+
Defaults to an empty dictionary.
|
442 |
+
source_type (str, optional): The type of source to ingest. Defaults to 'local'.
|
443 |
+
|
444 |
+
Returns:
|
445 |
+
Tuple[List[str], List[Dict], List[Document]]: A tuple containing:
|
446 |
+
- A list of extracted text per page.
|
447 |
+
- A list of extracted metadata per page.
|
448 |
+
- A list of LangChain documents.
|
449 |
+
"""
|
450 |
+
if additional_metadata is None:
|
451 |
+
additional_metadata = {}
|
452 |
+
|
453 |
+
# Get the directory of the current file
|
454 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
455 |
+
|
456 |
+
# Join the current directory with the relative path of the config file
|
457 |
+
config_path = os.path.join(current_dir, "config.yaml")
|
458 |
+
|
459 |
+
wrapper = SambaParse(config_path)
|
460 |
+
|
461 |
+
def process_file(file_path):
|
462 |
+
if file_path.lower().endswith('.pdf'):
|
463 |
+
return wrapper._run_ingest_pymupdf(file_path, additional_metadata)
|
464 |
+
else:
|
465 |
+
# Use the original method for non-PDF files
|
466 |
+
return wrapper.run_ingest(source_type, input_path=file_path, additional_metadata=additional_metadata)
|
467 |
+
|
468 |
+
if os.path.isfile(doc):
|
469 |
+
return process_file(doc)
|
470 |
+
else:
|
471 |
+
all_texts, all_metadata, all_docs = [], [], []
|
472 |
+
for root, _, files in os.walk(doc):
|
473 |
+
for file in files:
|
474 |
+
file_path = os.path.join(root, file)
|
475 |
+
texts, metadata_list, langchain_docs = process_file(file_path)
|
476 |
+
all_texts.extend(texts)
|
477 |
+
all_metadata.extend(metadata_list)
|
478 |
+
all_docs.extend(langchain_docs)
|
479 |
+
return all_texts, all_metadata, all_docs
|
480 |
+
|
481 |
+
|
482 |
+
def parse_doc_streamlit(docs: List,
|
483 |
+
kit_dir: str,
|
484 |
+
additional_metadata: Optional[Dict] = None,
|
485 |
+
) -> List[Document]:
|
486 |
+
"""
|
487 |
+
Parse the uploaded documents and return a list of LangChain documents.
|
488 |
+
|
489 |
+
Args:
|
490 |
+
docs (List[UploadFile]): A list of uploaded files.
|
491 |
+
kit_dir (str): The directory of the current kit.
|
492 |
+
additional_metadata (Optional[Dict], optional): Additional metadata to include in the processed documents.
|
493 |
+
Defaults to an empty dictionary.
|
494 |
+
|
495 |
+
Returns:
|
496 |
+
List[Document]: A list of LangChain documents.
|
497 |
+
"""
|
498 |
+
if additional_metadata is None:
|
499 |
+
additional_metadata = {}
|
500 |
+
|
501 |
+
# Create the data/tmp folder if it doesn't exist
|
502 |
+
temp_folder = os.path.join(kit_dir, "data/tmp")
|
503 |
+
if not os.path.exists(temp_folder):
|
504 |
+
os.makedirs(temp_folder)
|
505 |
+
else:
|
506 |
+
# If there are already files there, delete them
|
507 |
+
for filename in os.listdir(temp_folder):
|
508 |
+
file_path = os.path.join(temp_folder, filename)
|
509 |
+
try:
|
510 |
+
if os.path.isfile(file_path) or os.path.islink(file_path):
|
511 |
+
os.unlink(file_path)
|
512 |
+
elif os.path.isdir(file_path):
|
513 |
+
shutil.rmtree(file_path)
|
514 |
+
except Exception as e:
|
515 |
+
print(f'Failed to delete {file_path}. Reason: {e}')
|
516 |
+
|
517 |
+
# Save all selected files to the tmp dir with their file names
|
518 |
+
for doc in docs:
|
519 |
+
temp_file = os.path.join(temp_folder, doc.name)
|
520 |
+
with open(temp_file, "wb") as f:
|
521 |
+
f.write(doc.getvalue())
|
522 |
+
|
523 |
+
# Pass in the temp folder for processing into the parse_doc_universal function
|
524 |
+
_, _, langchain_docs = parse_doc_universal(doc=temp_folder, additional_metadata=additional_metadata)
|
525 |
+
return langchain_docs
|
utils/vectordb/create_vector_db.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Define the script's usage example
|
2 |
+
USAGE_EXAMPLE = """
|
3 |
+
Example usage:
|
4 |
+
|
5 |
+
To process input *.txt files at input_path and save the vector db output at output_db:
|
6 |
+
python create_vector_db.py input_path output_db --chunk_size 100 --chunk_overlap 10
|
7 |
+
|
8 |
+
Required arguments:
|
9 |
+
- input_path: Path to the input dir containing the .txt files
|
10 |
+
- output_path: Path to the output vector db.
|
11 |
+
|
12 |
+
Optional arguments:
|
13 |
+
- --chunk_size: Size of the chunks (default: None).
|
14 |
+
- --chunk_overlap: Overlap between chunks (default: None).
|
15 |
+
"""
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import logging
|
19 |
+
import os
|
20 |
+
|
21 |
+
from langchain.document_loaders import DirectoryLoader
|
22 |
+
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
23 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
24 |
+
from langchain.vectorstores import FAISS, Chroma, Qdrant
|
25 |
+
|
26 |
+
# Configure the logger
|
27 |
+
logging.basicConfig(
|
28 |
+
level=logging.INFO, # Set the logging level (e.g., INFO, DEBUG)
|
29 |
+
format="%(asctime)s [%(levelname)s] - %(message)s", # Define the log message format
|
30 |
+
handlers=[
|
31 |
+
logging.StreamHandler(), # Output logs to the console
|
32 |
+
logging.FileHandler("create_vector_db.log"),
|
33 |
+
],
|
34 |
+
)
|
35 |
+
|
36 |
+
# Create a logger object
|
37 |
+
logger = logging.getLogger(__name__)
|
38 |
+
|
39 |
+
|
40 |
+
# Parse the arguments
|
41 |
+
def parse_arguments():
|
42 |
+
parser = argparse.ArgumentParser(description="Process command line arguments.")
|
43 |
+
parser.add_argument("-input_path", type=dir_path, help="path to input directory")
|
44 |
+
parser.add_argument("--chunk_size", type=int, help="chunk size for splitting")
|
45 |
+
parser.add_argument("--chunk_overlap", type=int, help="chunk overlap for splitting")
|
46 |
+
parser.add_argument("-output_path", type=dir_path, help="path to input directory")
|
47 |
+
|
48 |
+
return parser.parse_args()
|
49 |
+
|
50 |
+
|
51 |
+
# Check valid path
|
52 |
+
def dir_path(path):
|
53 |
+
if os.path.isdir(path):
|
54 |
+
return path
|
55 |
+
else:
|
56 |
+
raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path")
|
57 |
+
|
58 |
+
|
59 |
+
def main(input_path, output_db, chunk_size, chunk_overlap, db_type):
|
60 |
+
# Load files from input_location
|
61 |
+
loader = DirectoryLoader(input_path, glob="*.txt")
|
62 |
+
docs = loader.load()
|
63 |
+
logger.info(f"Total {len(docs)} files loaded")
|
64 |
+
|
65 |
+
# get the text chunks
|
66 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
67 |
+
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
|
68 |
+
)
|
69 |
+
chunks = text_splitter.split_documents(docs)
|
70 |
+
logger.info(f"Total {len(chunks)} chunks created")
|
71 |
+
|
72 |
+
# create vector store
|
73 |
+
encode_kwargs = {"normalize_embeddings": True}
|
74 |
+
embedding_model = "BAAI/bge-large-en"
|
75 |
+
embeddings = HuggingFaceInstructEmbeddings(
|
76 |
+
model_name=embedding_model,
|
77 |
+
embed_instruction="", # no instruction is needed for candidate passages
|
78 |
+
query_instruction="Represent this sentence for searching relevant passages: ",
|
79 |
+
encode_kwargs=encode_kwargs,
|
80 |
+
)
|
81 |
+
logger.info(
|
82 |
+
f"Processing embeddings using {embedding_model}. This could take time depending on the number of chunks ..."
|
83 |
+
)
|
84 |
+
|
85 |
+
if db_type == "faiss":
|
86 |
+
vectorstore = FAISS.from_documents(documents=chunks, embedding=embeddings)
|
87 |
+
# save vectorstore
|
88 |
+
vectorstore.save_local(output_db)
|
89 |
+
elif db_type == "chromadb":
|
90 |
+
vectorstore = Chroma.from_documents(
|
91 |
+
documents=chunks, embedding=embeddings, persist_directory=output_db
|
92 |
+
)
|
93 |
+
elif db_type == "qdrant":
|
94 |
+
vectorstore = Qdrant.from_documents(
|
95 |
+
documents=chunks,
|
96 |
+
embedding=embeddings,
|
97 |
+
path=output_db,
|
98 |
+
collection_name="test_collection",
|
99 |
+
)
|
100 |
+
elif db_type == "qdrant-server":
|
101 |
+
url = "http://localhost:6333/"
|
102 |
+
vectorstore = Qdrant.from_documents(
|
103 |
+
documents=chunks,
|
104 |
+
embedding=embeddings,
|
105 |
+
url=url,
|
106 |
+
prefer_grpc=True,
|
107 |
+
collection_name="anaconda",
|
108 |
+
)
|
109 |
+
|
110 |
+
logger.info(f"Vector store saved to {output_db}")
|
111 |
+
|
112 |
+
|
113 |
+
if __name__ == "__main__":
|
114 |
+
parser = argparse.ArgumentParser(description="Process data with optional chunking")
|
115 |
+
|
116 |
+
# Required arguments
|
117 |
+
parser.add_argument("input_path", type=str, help="Path to the input directory")
|
118 |
+
parser.add_argument("output_db", type=str, help="Path to the output vectordb")
|
119 |
+
|
120 |
+
# Optional arguments
|
121 |
+
parser.add_argument(
|
122 |
+
"--chunk_size", type=int, default=1000, help="Chunk size (default: 1000)"
|
123 |
+
)
|
124 |
+
parser.add_argument(
|
125 |
+
"--chunk_overlap", type=int, default=200, help="Chunk overlap (default: 200)"
|
126 |
+
)
|
127 |
+
parser.add_argument(
|
128 |
+
"--db_type",
|
129 |
+
type=str,
|
130 |
+
default="faiss",
|
131 |
+
help="Type of vectorstore (default: faiss)",
|
132 |
+
)
|
133 |
+
|
134 |
+
args = parser.parse_args()
|
135 |
+
main(
|
136 |
+
args.input_path,
|
137 |
+
args.output_db,
|
138 |
+
args.chunk_size,
|
139 |
+
args.chunk_overlap,
|
140 |
+
args.db_type,
|
141 |
+
)
|
utils/vectordb/vector_db.py
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Define the script's usage example
|
2 |
+
USAGE_EXAMPLE = """
|
3 |
+
Example usage:
|
4 |
+
|
5 |
+
To process input *.txt files at input_path and save the vector db output at output_db:
|
6 |
+
python create_vector_db.py input_path output_db --chunk_size 100 --chunk_overlap 10
|
7 |
+
|
8 |
+
Required arguments:
|
9 |
+
- input_path: Path to the input dir containing the .txt files
|
10 |
+
- output_path: Path to the output vector db.
|
11 |
+
|
12 |
+
Optional arguments:
|
13 |
+
- --chunk_size: Size of the chunks (default: None).
|
14 |
+
- --chunk_overlap: Overlap between chunks (default: None).
|
15 |
+
"""
|
16 |
+
|
17 |
+
import os
|
18 |
+
import sys
|
19 |
+
import argparse
|
20 |
+
import logging
|
21 |
+
|
22 |
+
from langchain_community.document_loaders import DirectoryLoader, UnstructuredURLLoader
|
23 |
+
from langchain_community.embeddings import HuggingFaceInstructEmbeddings
|
24 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
|
25 |
+
from langchain_community.vectorstores import FAISS, Chroma, Qdrant
|
26 |
+
|
27 |
+
vectordb_dir = os.path.dirname(os.path.abspath(__file__))
|
28 |
+
utils_dir = os.path.abspath(os.path.join(vectordb_dir, ".."))
|
29 |
+
repo_dir = os.path.abspath(os.path.join(utils_dir, ".."))
|
30 |
+
|
31 |
+
sys.path.append(repo_dir)
|
32 |
+
sys.path.append(utils_dir)
|
33 |
+
|
34 |
+
from utils.model_wrappers.api_gateway import APIGateway
|
35 |
+
import uuid
|
36 |
+
import streamlit as st
|
37 |
+
|
38 |
+
EMBEDDING_MODEL = "intfloat/e5-large-v2"
|
39 |
+
NORMALIZE_EMBEDDINGS = True
|
40 |
+
VECTORDB_LOG_FILE_NAME = "vector_db.log"
|
41 |
+
|
42 |
+
# Configure the logger
|
43 |
+
logging.basicConfig(
|
44 |
+
level=logging.INFO, # Set the logging level (e.g., INFO, DEBUG)
|
45 |
+
format="%(asctime)s [%(levelname)s] - %(message)s", # Define the log message format
|
46 |
+
handlers=[
|
47 |
+
logging.StreamHandler(), # Output logs to the console
|
48 |
+
logging.FileHandler(VECTORDB_LOG_FILE_NAME),
|
49 |
+
],
|
50 |
+
)
|
51 |
+
|
52 |
+
# Create a logger object
|
53 |
+
logger = logging.getLogger(__name__)
|
54 |
+
|
55 |
+
|
56 |
+
class VectorDb():
|
57 |
+
"""
|
58 |
+
A class for creating, updating and loading FAISS or Chroma vector databases,
|
59 |
+
to use them with retrieval augmented generation tasks with langchain
|
60 |
+
|
61 |
+
Args:
|
62 |
+
None
|
63 |
+
|
64 |
+
Attributes:
|
65 |
+
None
|
66 |
+
|
67 |
+
Methods:
|
68 |
+
load_files: Load files from an input directory as langchain documents
|
69 |
+
get_text_chunks: Get text chunks from a list of documents
|
70 |
+
get_token_chunks: Get token chunks from a list of documents
|
71 |
+
create_vector_store: Create a vector store from chunks and an embedding model
|
72 |
+
load_vdb: load a previous stored vector database
|
73 |
+
update_vdb: Update an existing vector store with new chunks
|
74 |
+
create_vdb: Create a vector database from the raw files in a specific input directory
|
75 |
+
"""
|
76 |
+
def __init__(self) -> None:
|
77 |
+
self.collection_id = str(uuid.uuid4())
|
78 |
+
self.vector_collections = set()
|
79 |
+
|
80 |
+
def load_files(self, input_path, recursive=False, load_txt=True, load_pdf=False, urls = None) -> list:
|
81 |
+
"""Load files from input location
|
82 |
+
|
83 |
+
Args:
|
84 |
+
input_path : input location of files
|
85 |
+
recursive (bool, optional): flag to load files recursively. Defaults to False.
|
86 |
+
load_txt (bool, optional): flag to load txt files. Defaults to True.
|
87 |
+
load_pdf (bool, optional): flag to load pdf files. Defaults to False.
|
88 |
+
urls (list, optional): list of urls to load. Defaults to None.
|
89 |
+
|
90 |
+
Returns:
|
91 |
+
list: list of documents
|
92 |
+
"""
|
93 |
+
docs=[]
|
94 |
+
text_loader_kwargs={'autodetect_encoding': True}
|
95 |
+
if input_path is not None:
|
96 |
+
if load_txt:
|
97 |
+
loader = DirectoryLoader(input_path, glob="*.txt", recursive=recursive, show_progress=True, loader_kwargs=text_loader_kwargs)
|
98 |
+
docs.extend(loader.load())
|
99 |
+
if load_pdf:
|
100 |
+
loader = DirectoryLoader(input_path, glob="*.pdf", recursive=recursive, show_progress=True, loader_kwargs=text_loader_kwargs)
|
101 |
+
docs.extend(loader.load())
|
102 |
+
if urls:
|
103 |
+
loader = UnstructuredURLLoader(urls=urls)
|
104 |
+
docs.extend(loader.load())
|
105 |
+
|
106 |
+
logger.info(f"Total {len(docs)} files loaded")
|
107 |
+
|
108 |
+
return docs
|
109 |
+
|
110 |
+
def get_text_chunks(self, docs: list, chunk_size: int, chunk_overlap: int, meta_data: list = None) -> list:
|
111 |
+
"""Gets text chunks. If metadata is not None, it will create chunks with metadata elements.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
docs (list): list of documents or texts. If no metadata is passed, this parameter is a list of documents.
|
115 |
+
If metadata is passed, this parameter is a list of texts.
|
116 |
+
chunk_size (int): chunk size in number of characters
|
117 |
+
chunk_overlap (int): chunk overlap in number of characters
|
118 |
+
metadata (list, optional): list of metadata in dictionary format. Defaults to None.
|
119 |
+
|
120 |
+
Returns:
|
121 |
+
list: list of documents
|
122 |
+
"""
|
123 |
+
|
124 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
125 |
+
chunk_size=chunk_size, chunk_overlap=chunk_overlap, length_function=len
|
126 |
+
)
|
127 |
+
|
128 |
+
if meta_data is None:
|
129 |
+
logger.info(f"Splitter: splitting documents")
|
130 |
+
chunks = text_splitter.split_documents(docs)
|
131 |
+
else:
|
132 |
+
logger.info(f"Splitter: creating documents with metadata")
|
133 |
+
chunks = text_splitter.create_documents(docs, meta_data)
|
134 |
+
|
135 |
+
logger.info(f"Total {len(chunks)} chunks created")
|
136 |
+
|
137 |
+
return chunks
|
138 |
+
|
139 |
+
def get_token_chunks(self, docs: list, chunk_size: int, chunk_overlap: int, tokenizer) -> list:
|
140 |
+
"""Gets token chunks. If metadata is not None, it will create chunks with metadata elements.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
docs (list): list of documents or texts. If no metadata is passed, this parameter is a list of documents.
|
144 |
+
If metadata is passed, this parameter is a list of texts.
|
145 |
+
chunk_size (int): chunk size in number of tokens
|
146 |
+
chunk_overlap (int): chunk overlap in number of tokens
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
list: list of documents
|
150 |
+
"""
|
151 |
+
|
152 |
+
text_splitter = CharacterTextSplitter.from_huggingface_tokenizer(
|
153 |
+
tokenizer, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
154 |
+
)
|
155 |
+
|
156 |
+
logger.info(f"Splitter: splitting documents")
|
157 |
+
chunks = text_splitter.split_documents(docs)
|
158 |
+
|
159 |
+
logger.info(f"Total {len(chunks)} chunks created")
|
160 |
+
|
161 |
+
return chunks
|
162 |
+
|
163 |
+
|
164 |
+
|
165 |
+
def create_vector_store(self, chunks: list, embeddings: HuggingFaceInstructEmbeddings, db_type: str,
|
166 |
+
output_db: str = None, collection_name: str = None):
|
167 |
+
"""Creates a vector store
|
168 |
+
|
169 |
+
Args:
|
170 |
+
chunks (list): list of chunks
|
171 |
+
embeddings (HuggingFaceInstructEmbeddings): embedding model
|
172 |
+
db_type (str): vector db type
|
173 |
+
output_db (str, optional): output path to save the vector db. Defaults to None.
|
174 |
+
"""
|
175 |
+
if collection_name is None:
|
176 |
+
collection_name = f"collection_{self.collection_id}"
|
177 |
+
logger.info(f'This is the collection name: {collection_name}')
|
178 |
+
|
179 |
+
if db_type == "faiss":
|
180 |
+
vector_store = FAISS.from_documents(
|
181 |
+
documents=chunks,
|
182 |
+
embedding=embeddings
|
183 |
+
)
|
184 |
+
if output_db:
|
185 |
+
vector_store.save_local(output_db)
|
186 |
+
|
187 |
+
elif db_type == "chroma":
|
188 |
+
if output_db:
|
189 |
+
vector_store = Chroma()
|
190 |
+
vector_store.delete_collection()
|
191 |
+
vector_store = Chroma.from_documents(
|
192 |
+
documents=chunks,
|
193 |
+
embedding=embeddings,
|
194 |
+
persist_directory=output_db,
|
195 |
+
collection_name=collection_name
|
196 |
+
)
|
197 |
+
else:
|
198 |
+
vector_store = Chroma()
|
199 |
+
vector_store.delete_collection()
|
200 |
+
vector_store = Chroma.from_documents(
|
201 |
+
documents=chunks,
|
202 |
+
embedding=embeddings,
|
203 |
+
collection_name=collection_name
|
204 |
+
)
|
205 |
+
self.vector_collections.add(collection_name)
|
206 |
+
|
207 |
+
elif db_type == "qdrant":
|
208 |
+
if output_db:
|
209 |
+
vector_store = Qdrant.from_documents(
|
210 |
+
documents=chunks,
|
211 |
+
embedding=embeddings,
|
212 |
+
path=output_db,
|
213 |
+
collection_name="test_collection",
|
214 |
+
)
|
215 |
+
else:
|
216 |
+
vector_store = Qdrant.from_documents(
|
217 |
+
documents=chunks,
|
218 |
+
embedding=embeddings,
|
219 |
+
collection_name="test_collection",
|
220 |
+
)
|
221 |
+
|
222 |
+
logger.info(f"Vector store saved to {output_db}")
|
223 |
+
|
224 |
+
return vector_store
|
225 |
+
|
226 |
+
def load_vdb(self, persist_directory, embedding_model, db_type="chroma", collection_name=None):
|
227 |
+
if db_type == "faiss":
|
228 |
+
vector_store = FAISS.load_local(persist_directory, embedding_model, allow_dangerous_deserialization=True)
|
229 |
+
elif db_type == "chroma":
|
230 |
+
if collection_name:
|
231 |
+
vector_store = Chroma(
|
232 |
+
persist_directory=persist_directory,
|
233 |
+
embedding_function=embedding_model,
|
234 |
+
collection_name=collection_name
|
235 |
+
)
|
236 |
+
else:
|
237 |
+
vector_store = Chroma(
|
238 |
+
persist_directory=persist_directory,
|
239 |
+
embedding_function=embedding_model
|
240 |
+
)
|
241 |
+
elif db_type == "qdrant":
|
242 |
+
# TODO: Implement Qdrant loading
|
243 |
+
pass
|
244 |
+
else:
|
245 |
+
raise ValueError(f"Unsupported database type: {db_type}")
|
246 |
+
|
247 |
+
return vector_store
|
248 |
+
|
249 |
+
def update_vdb(self, chunks: list, embeddings, db_type: str, input_db: str = None,
|
250 |
+
output_db: str = None):
|
251 |
+
|
252 |
+
if db_type == "faiss":
|
253 |
+
vector_store = FAISS.load_local(input_db, embeddings, allow_dangerous_deserialization=True)
|
254 |
+
new_vector_store = self.create_vector_store(chunks, embeddings, db_type, None)
|
255 |
+
vector_store.merge_from(new_vector_store)
|
256 |
+
if output_db:
|
257 |
+
vector_store.save_local(output_db)
|
258 |
+
|
259 |
+
elif db_type == "chroma":
|
260 |
+
# TODO implement update method for chroma
|
261 |
+
pass
|
262 |
+
elif db_type == "qdrant":
|
263 |
+
# TODO implement update method for qdrant
|
264 |
+
pass
|
265 |
+
|
266 |
+
return vector_store
|
267 |
+
|
268 |
+
def create_vdb(
|
269 |
+
self,
|
270 |
+
input_path,
|
271 |
+
chunk_size,
|
272 |
+
chunk_overlap,
|
273 |
+
db_type,
|
274 |
+
output_db=None,
|
275 |
+
recursive=False,
|
276 |
+
tokenizer=None,
|
277 |
+
load_txt=True,
|
278 |
+
load_pdf=False,
|
279 |
+
urls=None,
|
280 |
+
embedding_type="cpu",
|
281 |
+
batch_size= None,
|
282 |
+
coe = None,
|
283 |
+
select_expert = None
|
284 |
+
):
|
285 |
+
|
286 |
+
docs = self.load_files(input_path, recursive=recursive, load_txt=load_txt, load_pdf=load_pdf, urls=urls)
|
287 |
+
|
288 |
+
if tokenizer is None:
|
289 |
+
chunks = self.get_text_chunks(docs, chunk_size, chunk_overlap)
|
290 |
+
else:
|
291 |
+
chunks = self.get_token_chunks(docs, chunk_size, chunk_overlap, tokenizer)
|
292 |
+
|
293 |
+
embeddings = APIGateway.load_embedding_model(
|
294 |
+
type=embedding_type,
|
295 |
+
batch_size=batch_size,
|
296 |
+
coe=coe,
|
297 |
+
select_expert=select_expert
|
298 |
+
)
|
299 |
+
|
300 |
+
vector_store = self.create_vector_store(chunks, embeddings, db_type, output_db)
|
301 |
+
|
302 |
+
return vector_store
|
303 |
+
|
304 |
+
|
305 |
+
def dir_path(path):
|
306 |
+
if os.path.isdir(path):
|
307 |
+
return path
|
308 |
+
else:
|
309 |
+
raise argparse.ArgumentTypeError(f"readable_dir:{path} is not a valid path")
|
310 |
+
|
311 |
+
|
312 |
+
# Parse the arguments
|
313 |
+
def parse_arguments():
|
314 |
+
parser = argparse.ArgumentParser(description="Process command line arguments.")
|
315 |
+
parser.add_argument("-input_path", type=dir_path, help="path to input directory")
|
316 |
+
parser.add_argument("--chunk_size", type=int, help="chunk size for splitting")
|
317 |
+
parser.add_argument("--chunk_overlap", type=int, help="chunk overlap for splitting")
|
318 |
+
parser.add_argument("-output_path", type=dir_path, help="path to input directory")
|
319 |
+
|
320 |
+
return parser.parse_args()
|
321 |
+
|
322 |
+
|
323 |
+
if __name__ == "__main__":
|
324 |
+
parser = argparse.ArgumentParser(description="Process data with optional chunking")
|
325 |
+
|
326 |
+
# Required arguments
|
327 |
+
parser.add_argument("--input_path", type=str, help="Path to the input directory")
|
328 |
+
parser.add_argument("--output_db", type=str, help="Path to the output vectordb")
|
329 |
+
|
330 |
+
# Optional arguments
|
331 |
+
parser.add_argument(
|
332 |
+
"--chunk_size", type=int, default=1000, help="Chunk size (default: 1000)"
|
333 |
+
)
|
334 |
+
parser.add_argument(
|
335 |
+
"--chunk_overlap", type=int, default=200, help="Chunk overlap (default: 200)"
|
336 |
+
)
|
337 |
+
parser.add_argument(
|
338 |
+
"--db_type",
|
339 |
+
type=str,
|
340 |
+
default="faiss",
|
341 |
+
help="Type of vector store (default: faiss)",
|
342 |
+
)
|
343 |
+
args = parser.parse_args()
|
344 |
+
|
345 |
+
vectordb = VectorDb()
|
346 |
+
|
347 |
+
vectordb.create_vdb(
|
348 |
+
args.input_path,
|
349 |
+
args.output_db,
|
350 |
+
args.chunk_size,
|
351 |
+
args.chunk_overlap,
|
352 |
+
args.db_type,
|
353 |
+
)
|
utils/visual/env_utils.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import netrc
|
2 |
+
import os
|
3 |
+
from typing import List, Optional, Tuple
|
4 |
+
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
|
8 |
+
def initialize_env_variables(prod_mode: bool = False, additional_env_vars: Optional[List[str]] = None) -> None:
|
9 |
+
if additional_env_vars is None:
|
10 |
+
additional_env_vars = []
|
11 |
+
|
12 |
+
if not prod_mode:
|
13 |
+
# In non-prod mode, prioritize environment variables
|
14 |
+
st.session_state.SAMBANOVA_API_KEY = os.environ.get(
|
15 |
+
'SAMBANOVA_API_KEY', st.session_state.get('SMABANOVA_API_KEY', '')
|
16 |
+
)
|
17 |
+
for var in additional_env_vars:
|
18 |
+
st.session_state[var] = os.environ.get(var, st.session_state.get(var, ''))
|
19 |
+
else:
|
20 |
+
# In prod mode, only use session state
|
21 |
+
if 'SAMBANOVA_API_KEY' not in st.session_state:
|
22 |
+
st.session_state.SAMBANOVA_API_KEY = ''
|
23 |
+
for var in additional_env_vars:
|
24 |
+
if var not in st.session_state:
|
25 |
+
st.session_state[var] = ''
|
26 |
+
|
27 |
+
|
28 |
+
def set_env_variables(api_key, additional_vars=None, prod_mode=False):
|
29 |
+
st.session_state.SAMBANOVA_API_KEY = api_key
|
30 |
+
if additional_vars:
|
31 |
+
for key, value in additional_vars.items():
|
32 |
+
st.session_state[key] = value
|
33 |
+
if not prod_mode:
|
34 |
+
# In non-prod mode, also set environment variables
|
35 |
+
os.environ['SAMBANOVA_API_KEY'] = api_key
|
36 |
+
if additional_vars:
|
37 |
+
for key, value in additional_vars.items():
|
38 |
+
os.environ[key] = value
|
39 |
+
|
40 |
+
|
41 |
+
def env_input_fields(additional_env_vars=None) -> Tuple[str, str]:
|
42 |
+
if additional_env_vars is None:
|
43 |
+
additional_env_vars = []
|
44 |
+
|
45 |
+
api_key = st.text_input('Sambanova API Key', value=st.session_state.SAMBANOVA_API_KEY, type='password')
|
46 |
+
|
47 |
+
additional_vars = {}
|
48 |
+
for var in additional_env_vars:
|
49 |
+
additional_vars[var] = st.text_input(f'{var}', value=st.session_state.get(var, ''), type='password')
|
50 |
+
|
51 |
+
return api_key, additional_vars
|
52 |
+
|
53 |
+
|
54 |
+
def are_credentials_set(additional_env_vars=None) -> bool:
|
55 |
+
if additional_env_vars is None:
|
56 |
+
additional_env_vars = []
|
57 |
+
|
58 |
+
base_creds_set = bool(st.session_state.SAMBANOVA_API_KEY)
|
59 |
+
additional_creds_set = all(bool(st.session_state.get(var, '')) for var in additional_env_vars)
|
60 |
+
|
61 |
+
return base_creds_set and additional_creds_set
|
62 |
+
|
63 |
+
|
64 |
+
def save_credentials(api_key, additional_vars=None, prod_mode=False) -> str:
|
65 |
+
set_env_variables(api_key, additional_vars, prod_mode)
|
66 |
+
return 'Credentials saved successfully!'
|
67 |
+
|
68 |
+
|
69 |
+
def get_wandb_key():
|
70 |
+
# Check for WANDB_API_KEY in environment variables
|
71 |
+
env_wandb_api_key = os.getenv('WANDB_API_KEY')
|
72 |
+
|
73 |
+
# Check for WANDB_API_KEY in ~/.netrc
|
74 |
+
try:
|
75 |
+
netrc_path = os.path.expanduser('~/.netrc')
|
76 |
+
netrc_data = netrc.netrc(netrc_path)
|
77 |
+
netrc_wandb_api_key = netrc_data.authenticators('api.wandb.ai')
|
78 |
+
except (FileNotFoundError, netrc.NetrcParseError):
|
79 |
+
netrc_wandb_api_key = None
|
80 |
+
|
81 |
+
# If both are set, handle the conflict
|
82 |
+
if env_wandb_api_key and netrc_wandb_api_key:
|
83 |
+
print('WANDB_API_KEY is set in both the environment and ~/.netrc. Prioritizing environment variable.')
|
84 |
+
# Optionally, you can choose to remove one of them, here we remove the env variable
|
85 |
+
del os.environ['WANDB_API_KEY'] # Remove from environment to prioritize ~/.netrc
|
86 |
+
return netrc_wandb_api_key[2] if netrc_wandb_api_key else None # Return the key from .netrc
|
87 |
+
|
88 |
+
# Return the key from environment if available, otherwise from .netrc
|
89 |
+
if env_wandb_api_key:
|
90 |
+
return env_wandb_api_key
|
91 |
+
elif netrc_wandb_api_key:
|
92 |
+
return netrc_wandb_api_key[2] if netrc_wandb_api_key else None
|
93 |
+
|
94 |
+
# If neither is set, return None
|
95 |
+
return None
|