Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from typing_extensions import TypedDict, List | |
from IPython.display import Image, display | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from langchain.schema import Document | |
from langgraph.graph import START, END, StateGraph | |
from langchain.prompts import PromptTemplate | |
import uuid | |
from langchain_groq import ChatGroq | |
from langchain_community.utilities import GoogleSerperAPIWrapper | |
from langchain_chroma import Chroma | |
from langchain_community.document_loaders import NewsURLLoader | |
from langchain_community.retrievers.wikipedia import WikipediaRetriever | |
from sentence_transformers import SentenceTransformer | |
from langchain.vectorstores import Chroma | |
from langchain_community.document_loaders import UnstructuredURLLoader, NewsURLLoader | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import WebBaseLoader | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.output_parsers import JsonOutputParser | |
from langchain_community.vectorstores.utils import filter_complex_metadata | |
from langchain.schema import Document | |
from langchain_community.document_loaders.directory import DirectoryLoader | |
from langchain.document_loaders import TextLoader | |
from langgraph.graph import START, END, StateGraph | |
from langchain.retrievers import WebResearchRetriever | |
from langchain.callbacks.manager import CallbackManager | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
from exa_py import Exa | |
os.environ["LANGCHAIN_TRACING_V2"]="true" | |
os.environ["LANGCHAIN_ENDPOINT"]= "https://api.smith.langchain.com" | |
os.environ["LANGCHAIN_PROJECT"] = "Civilinės_teises_Asistente_V1_Embed" | |
lang_api_key = os.getenv("LANGCHAIN_API_KEY") | |
SERPER_API_KEY = os.getenv("SERPER_API_KEY") | |
groq_api_key = os.getenv("GROQ_API_KEY") | |
exa_api_key = os.getenv("exa_api_key") | |
exa = Exa(api_key="exa_api_key") | |
def create_retriever_from_chroma(vectorstore_path="docs/chroma/", search_type='mmr', k=7, chunk_size=300, chunk_overlap=30): | |
model_name = "Alibaba-NLP/gte-multilingual-base" | |
model_kwargs = {'device': 'cpu', | |
"trust_remote_code" : 'False'} | |
encode_kwargs = {'normalize_embeddings': True} | |
embeddings = HuggingFaceEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs | |
) | |
if os.path.exists(vectorstore_path) and os.listdir(vectorstore_path): | |
vectorstore = Chroma(persist_directory=vectorstore_path,embedding_function=embeddings) | |
else: | |
st.write("Vector store doesnt exist and will be created now") | |
loader = DirectoryLoader('./data/', glob="./*.txt", loader_cls=TextLoader) | |
docs = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
chunk_size=chunk_size, chunk_overlap=chunk_overlap, | |
separators=["\n\n \n\n","\n\n\n", "\n\n", r"In \[[0-9]+\]", r"\n+", r"\s+"], | |
is_separator_regex = True | |
) | |
split_docs = text_splitter.split_documents(docs) | |
vectorstore = Chroma.from_documents( | |
documents=split_docs, embedding=embeddings, persist_directory=vectorstore_path, | |
) | |
retriever=vectorstore.as_retriever(search_type = search_type, search_kwargs={"k": k}) | |
return retriever | |
async def handle_userinput(user_question, custom_graph): | |
# Add the user's question to the chat history and display it in the UI | |
st.session_state.messages.append({"role": "user", "content": user_question}) | |
st.chat_message("user").write(user_question) | |
# Config setup (if required for the graph) | |
config = {"configurable": {"thread_id": str(uuid.uuid4())}} | |
try: | |
# Await the asynchronous invocation of the custom graph | |
state_dict = await custom_graph.ainvoke({"question": user_question, "steps": []}, config) | |
# Extract documents from the state dictionary | |
docs = state_dict.get("documents", []) | |
with st.sidebar: | |
st.subheader("Dokumentai, kuriuos Birutė gavo kaip kontekstą") | |
with st.spinner("Kraunama..."): | |
for doc in docs: | |
# Display each document | |
st.write(f"Dokumentas: {doc}") | |
# Check for and display the assistant's response | |
response = state_dict.get("generation") | |
if response: | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
st.chat_message("assistant").write(response) | |
except Exception as e: | |
# Display an error message in case of failure | |
st.chat_message("assistant").write("Klaida: Arba per didelis kontekstas suteiktas modeliui, arba užklausų serveryje yra per daug") | |
from typing import Annotated | |
def create_workflow(retriever): | |
class GraphState(TypedDict): | |
""" | |
Represents the state of our graph. | |
Attributes: | |
question: question | |
generation: LLM generation | |
search: whether to add search | |
documents: list of documents | |
generations_count : generations count | |
""" | |
question: Annotated[str, "Single"] # Ensuring only one value per step | |
generation: str | |
search: str | |
documents: List[str] | |
steps: List[str] | |
generation_count: int | |
llm = ChatGroq( | |
model="llama-3.3-70b-versatile", | |
temperature=0.2, | |
max_tokens=600, | |
max_retries=3, | |
) | |
llm_checker = ChatGroq( | |
model="llama3-groq-70b-8192-tool-use-preview", | |
temperature=0.1, | |
max_tokens=400, | |
max_retries=3, | |
) | |
workflow = StateGraph(GraphState) | |
# Define the nodes | |
workflow.add_node("ask_question", lambda state: ask_question(state)) | |
workflow.add_node("retrieve", lambda state: retrieve(state, retriever)) | |
workflow.add_node("grade_documents", lambda state: grade_documents(state, retrieval_grader_grader(llm_checker))) | |
workflow.add_node("generate", lambda state: generate(state, QA_chain(llm))) | |
workflow.add_node("web_search", web_search) | |
#workflow.add_node("transform_query", lambda state: transform_query(state, create_question_rewriter(llm))) | |
# Build graph | |
workflow.set_entry_point("ask_question") | |
workflow.add_edge("ask_question", "retrieve") | |
workflow.add_edge("retrieve", "grade_documents") | |
#workflow.add_edge("retrieve", "generate") | |
workflow.add_conditional_edges( | |
"grade_documents", | |
decide_to_generate, | |
{ | |
"search": "web_search", | |
"generate": "generate", | |
}, | |
) | |
workflow.add_edge("web_search", "generate") | |
workflow.add_edge("generate", END) | |
custom_graph = workflow.compile() | |
return custom_graph | |
def retrieval_grader_grader(llm): | |
""" | |
Function to create a grader object using a passed LLM model. | |
Args: | |
llm: The language model to be used for grading. | |
Returns: | |
Callable: A pipeline function that grades relevance based on the LLM. | |
""" | |
class GradeDocuments(BaseModel): | |
"""Ar faktas gali būti, nors truputi, naudingas atsakant į klausimą.""" | |
binary_score: str = Field( | |
description="Documentai yra aktualūs klausimui, 'yes' arba 'no'" | |
) | |
# Create the structured LLM grader using the passed LLM | |
structured_llm_grader = llm.with_structured_output(GradeDocuments) | |
# Define the prompt template | |
prompt = PromptTemplate( | |
template="""Jūs esate mokytojas, vertinantis viktoriną. Jums bus suteikta: | |
1/ KLAUSIMAS {question} | |
2/ Studento pateiktas FAKTAS {documents} | |
Jūs vertinate RELEVANCE RECALL: | |
yes reiškia, kad FAKTAS yra susijęs su KLAUSIMU. | |
no reiškia, kad FAKTAS nesusijęs su KLAUSIMU. | |
yes yra aukščiausias (geriausias) balas. no yra žemiausias balas, kurį galite duoti. | |
Jeigu galima iš Studento pateiktas FAKTAS gauti bet kokių įžvalgu susijusiu su KLAUSIMAS, duok įvertinimą yes. | |
Žingsnis po žingsnio paaiškinkite savo samprotavimus. Įsitikinkite, kad jūsų samprotavimai ir išvados yra teisingi. | |
Iš pradžių venkite tiesiog nurodyti teisingą atsakymą. | |
Klausimas: {question} \n | |
FAKTAS: \n\n {documents} \n\n | |
Suteikite dvejetainį balą „yes“ arba „no“, kad nurodytumėte, ar dokumentas yra susijęs su klausimu. \n | |
Pateikite dvejetainį balą kaip JSON su vienu raktu „balu“ ir be įžangos ar paaiškinimo. | |
""", | |
input_variables=['documents', 'question'], | |
) | |
# Combine the prompt with the structured LLM grader | |
retrieval_grader = prompt | structured_llm_grader | |
# Return the grader object | |
return retrieval_grader | |
def create_toxicity_checker(llm): | |
""" | |
Function to create a toxicity checker object using a passed LLM model. | |
Args: | |
llm: The language model to be used for toxicity checking. | |
Returns: | |
Callable: A pipeline function that checks for toxicity based on the LLM. | |
""" | |
# Define the class for toxicity check inside the function | |
class PromtGuard(BaseModel): | |
"""Binary score for toxicity check on question.""" | |
score: str = Field( | |
description="Ar klausimams yra nesaugus, 'taip' arba 'ne'" | |
) | |
# Create the structured LLM toxicity checker using the passed LLM | |
# Create the structured LLM grader using the passed LLM | |
structured_llm_toxicity_checker = llm.with_structured_output(PromtGuard) | |
# Define the prompt template | |
prompt = PromptTemplate( | |
template="""Jūs esate klasifikatorius, kuris įvertins tekstą pagal pateiktus kriterijus. | |
Čia pateikiami taškų kriterijai: | |
(1) įžeidinėjimai | |
(2) grasinimai | |
(3) labai neigiami komentarai | |
(4) bet kokią asmenį identifikuojančią informaciją, pvz., API raktus | |
Rezultatas: | |
„Taip“ reiškia, kad tekstas atitinka BET KOKĮ kriterijų. | |
Įvertinimas „ne“ reiškia, kad tekstas neatitinka VISŲ kriterijų. | |
Štai klausimas: {question} | |
Suteikite dvejetainį balą „taip“ arba „ne“, kad nurodytumėte, ar atsakymas yra naudingas norint išspręsti klausimą. | |
Jei atsakyme yra pasikartojančių frazių, kartojimas, tada grąžinkite „ne“\n | |
Pateikite dvejetainį balą kaip JSON su vienu raktu „balu“ ir be įžangos ar paaiškinimo.""", | |
input_variables=["question"], | |
) | |
# Combine the prompt with the structured LLM toxicity checker | |
toxicity_grader = prompt | structured_llm_toxicity_checker | |
# Return the toxicity checker object | |
return toxicity_grader | |
def grade_question_toxicity(state, toxicity_grader): | |
""" | |
Grades the question for toxicity. | |
Args: | |
state (dict): The current graph state. | |
Returns: | |
str: 'good' if the question passes the toxicity check, 'bad' otherwise. | |
""" | |
steps = state["steps"] | |
steps.append("promt guard") | |
score = toxicity_grader.invoke({"question": state["question"]}) | |
grade = getattr(score, 'score', None) | |
if grade == "yes": | |
return "bad" | |
else: | |
return "good" | |
def create_helpfulness_checker(llm): | |
""" | |
Function to create a helpfulness checker object using a passed LLM model. | |
Args: | |
llm: The language model to be used for checking the helpfulness of answers. | |
Returns: | |
Callable: A pipeline function that checks if the student's answer is helpful. | |
""" | |
class helpfulness_checker(BaseModel): | |
"""Binary score for toxicity check on question.""" | |
score: str = Field( | |
description="Ar atsakymas yra naudingas?, 'taip' arba 'ne'" | |
) | |
# Create the structured LLM toxicity checker using the passed LLM | |
structured_llm_helpfulness_checker = llm.with_structured_output(helpfulness_checker) | |
# Create the structured LLM helpfulness checker using the passed LLM | |
# Define the prompt template | |
prompt = PromptTemplate( | |
template="""Jums bus pateiktas KLAUSIMAS {question} ir ATSAKYMAS {generation}. | |
Įvertinkite ATSAKYMĄ pagal šiuos kriterijus: | |
Aktualumas: ATSAKYMAS turi būti tiesiogiai susijęs su KLAUSIMU ir konkrečiai į jį atsakyti. | |
Pakankamas: ATSAKYME turi būti pakankamai informacijos, kad būtų galima visapusiškai atsakyti į KLAUSIMĄ. Jei ATSAKYME vartojamos tokios frazės kaip „nežinau“, „neturiu pakankamai informacijos“, „pateiktuose dokumentuose apie tai neužsimenama“ ar panašių posakių, kuriuose vengiama tiesiogiai atsakyti į KLAUSIMĄ, įvertinkite „ne“. | |
Aiškumas ir glaustumas: ATSAKYMAS turi būti aiškus, be jokių nereikalingų frazių ar pasikartojimų. Jei jame yra perteklinė arba netiesioginė informacija, o ne tiesioginis atsakymas, įvertinkite „ne“. | |
Balų skaičiavimo instrukcijos: | |
„Taip“ reiškia, kad ATSAKYMAS atitinka visus šiuos kriterijus ir tiesiogiai susijęs su KLAUSIMU. | |
Įvertinimas „ne“ reiškia, kad ATSAKYMAS neatitinka visų šių kriterijų. | |
Jei randate tokio žodžio tekstą, kaip aš nežinau, nepakanka informacijos arba panašaus į šį, balas yra ne. | |
Pateikite balą kaip JSON su vienu raktu "balas" ir be papildomo teksto""", | |
input_variables=["generation", "question"] | |
) | |
# Combine the prompt with the structured LLM helpfulness checker | |
helpfulness_grader = prompt | structured_llm_helpfulness_checker | |
# Return the helpfulness checker object | |
return helpfulness_grader | |
def create_hallucination_checker(llm): | |
""" | |
Function to create a hallucination checker object using a passed LLM model. | |
Args: | |
llm: The language model to be used for checking hallucinations in the student's answer. | |
Returns: | |
Callable: A pipeline function that checks if the student's answer contains hallucinations. | |
""" | |
class hallucination_checker(BaseModel): | |
"""Binary score for toxicity check on question.""" | |
score: str = Field( | |
description="Ar dokumentas yra susijes su atsakymu?, 'taip' arba 'ne'" | |
) | |
# Create the structured LLM toxicity checker using the passed LLM | |
structured_llm_hallucination_checker = llm.with_structured_output(hallucination_checker) | |
# Define the prompt template | |
prompt = PromptTemplate( | |
template="""Jūs esate mokytojas, vertinantis viktoriną. | |
Jums bus pateikti FAKTAI ir MOKINIO ATSAKYMAS. | |
Jūs vertinate MOKINIO ATSAKYMĄ iš šaltinio FAKTAI. Sutelkite dėmesį į MOKINIO ATSAKYMO teisingumą ir bet kokių haliucinacijų aptikimą. | |
Įsitikinkite, kad MOKINIO ATSAKYMAS atitinka šiuos kriterijus: | |
(1) jame nėra informacijos, nesusijusios su FAKTAIS | |
(2) STUDENTŲ ATSAKYMAS turėtų būti visiškai pagrįstas ir pagrįstas pirminiuose dokumentuose pateikta informacija | |
Rezultatas: | |
„Taip“ reiškia, kad studento atsakymas atitinka visus kriterijus. Tai aukščiausias (geriausias) balas. | |
Balas „ne“ reiškia, kad studento atsakymas neatitinka visų kriterijų. Tai yra žemiausias galimas balas, kurį galite duoti. | |
Žingsnis po žingsnio paaiškinkite savo samprotavimus, kad įsitikintumėte, jog argumentai ir išvados yra teisingi. | |
Iš pradžių venkite tiesiog nurodyti teisingą atsakymą. | |
MOKINIO ATSAKYMAS: {generation} \n | |
FAKTAI: \n\n {documents} \n\n | |
Suteikite dvejetainį balą „taip“ arba „ne“, kad nurodytumėte, ar dokumentas yra susijęs su klausimu. \n | |
Pateikite dvejetainį balą kaip JSON su vienu raktu „balu“ ir be įžangos ar paaiškinimo.""", | |
input_variables=["generation", "documents"], | |
) | |
# Combine the prompt with the structured LLM hallucination checker | |
hallucination_grader = prompt | structured_llm_haliucinations_checker | |
# Return the hallucination checker object | |
return hallucination_grader | |
def create_question_rewriter(llm): | |
""" | |
Function to create a question rewriter object using a passed LLM model. | |
Args: | |
llm: The language model to be used for rewriting questions. | |
Returns: | |
Callable: A pipeline function that rewrites questions for optimized vector store retrieval. | |
""" | |
# Define the prompt template for question rewriting | |
re_write_prompt = PromptTemplate( | |
template="""Esate klausimų perrašytojas, kurio specializacija yra Lietuvos teisė, tobulinanti klausimus, kad būtų galima optimizuoti jų paiešką iš teisinių dokumentų. Jūsų tikslas – išaiškinti teisinę intenciją, pašalinti dviprasmiškumą ir pakoreguoti formuluotes taip, kad jos atspindėtų teisinę kalbą, daugiausia dėmesio skiriant atitinkamiems raktiniams žodžiams, siekiant užtikrinti tikslų informacijos gavimą iš Lietuvos teisės šaltinių. | |
Man nereikia paaiškinimų, tik perrašyto klausimo. | |
Štai pradinis klausimas: \n\n {question}. Patobulintas klausimas be paaiškinimų : \n""", | |
input_variables=["question"], | |
) | |
# Combine the prompt with the LLM and output parser | |
question_rewriter = re_write_prompt | llm | StrOutputParser() | |
# Return the question rewriter object | |
return question_rewriter | |
def transform_query(state, question_rewriter): | |
""" | |
Transform the query to produce a better question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Updates question key with a re-phrased question | |
""" | |
print("---TRANSFORM QUERY---") | |
question = state["question"] | |
documents = state["documents"] | |
steps = state["steps"] | |
steps.append("question_transformation") | |
# Re-write question | |
better_question = question_rewriter.invoke({"question": question}) | |
print(f" Transformed question: {better_question}") | |
return {"documents": documents, "question": better_question} | |
def format_google_results_search(google_results): | |
formatted_documents = [] | |
# Extract data from answerBox | |
answer_box = google_results.get("answerBox", {}) | |
answer_box_title = answer_box.get("title", "No title") | |
answer_box_answer = answer_box.get("answer", "No text") | |
# Extract and add organic results as separate Documents | |
for result in google_results.get("organic", []): | |
title = result.get("title", "No title") | |
link = result.get("link", "Nėra svetainės adreso") | |
snippet = result.get("snippet", "No snippet available") | |
document = Document( | |
metadata={ | |
"Organinio rezultato pavadinimas": title, | |
}, | |
page_content=( | |
f"Pavadinimas: {title} " | |
f"Straipsnio ištrauka: {snippet} " | |
f"Nuoroda: {link} " | |
) | |
) | |
formatted_documents.append(document) | |
return formatted_documents | |
def format_google_results_news(google_results): | |
formatted_documents = [] | |
# Loop through each organic result and create a Document for it | |
for result in google_results['organic']: | |
title = result.get('title', 'No title') | |
link = result.get('link', 'No link') | |
descripsion = result.get('description', 'No link') | |
snippet = result.get('snippet', 'No summary available') | |
text = result.get('text' , 'no text') | |
# Create a Document object with similar metadata structure to WikipediaRetriever | |
document = Document( | |
metadata={ | |
'Title': title, | |
'Description': descripsion, | |
'Text' : text, | |
'Snippet': snippet, | |
'Source': link | |
}, | |
page_content=snippet # Using the snippet as the page content | |
) | |
formatted_documents.append(document) | |
return formatted_documents | |
def QA_chain(llm): | |
""" | |
Creates a question-answering chain using the provided language model. | |
Args: | |
llm: The language model to use for generating answers. | |
Returns: | |
An LLMChain configured with the question-answering prompt and the provided model. | |
""" | |
# Define the prompt template | |
prompt = PromptTemplate( | |
template="""Esi teisės asistentas, kurio užduotis yra atsakyti konkrečiai, informatyviai ir glaustai , pagrindžiant savo atsakymą į klausima pagal pateiktus dokumentus. | |
Atsakymas turi būti lietuvių kalba. Nesikartok. | |
Jei negali atsakyti į klausimą, pasakyk, Atsiprašau, nežinau atsakymo į jūsų klausimą. | |
Neužduok papildomų klausimų. | |
Klausimas: {question} | |
Dokumentai: {documents} | |
Atsakymas: | |
""", | |
input_variables=["question", "documents"], | |
) | |
rag_chain = prompt | llm | StrOutputParser() | |
return rag_chain | |
def grade_generation_v_documents_and_question(state,hallucination_grader,answer_grader ): | |
""" | |
Determines whether the generation is grounded in the document and answers the question. | |
""" | |
print("---CHECK HALLUCINATIONS---") | |
question = state["question"] | |
documents = state["documents"] | |
generation = state["generation"] | |
generation_count = state.get("generation_count") # Use state.get to avoid KeyError | |
print(f" generation number: {generation_count}") | |
# Grading hallucinations | |
score = hallucination_grader.invoke( | |
{"documents": documents, "generation": generation} | |
) | |
grade = getattr(score, 'score', None) | |
# Check hallucination | |
if grade == "yes": | |
print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---") | |
# Check question-answering | |
print("---GRADE GENERATION vs QUESTION---") | |
score = answer_grader.invoke({"question": question, "generation": generation}) | |
grade = getattr(score, 'score', None) | |
if grade == "yes": | |
print("---DECISION: GENERATION ADDRESSES QUESTION---") | |
return "useful" | |
else: | |
print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---") | |
return "not useful" | |
else: | |
if generation_count > 1: | |
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, TRANSFORM QUERY---") | |
# Reset count if it exceeds limit | |
return "not useful" | |
else: | |
print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---") | |
# Increment correctly here | |
print(f" generation number after increment: {state['generation_count']}") | |
return "not supported" | |
def ask_question(state): | |
""" | |
Initialize question | |
Args: | |
state (dict): The current graph state | |
Returns: | |
state (dict): Question | |
""" | |
steps = state["steps"] | |
question = state["question"] | |
generations_count = state.get("generations_count", 0) | |
steps.append("question_asked") | |
return {"question": question, "steps": steps,"generation_count": generations_count} | |
def retrieve(state , retriever): | |
""" | |
Retrieve documents | |
Args: | |
state (dict): The current graph state | |
retriever: The retriever object | |
Returns: | |
state (dict): New key added to state, documents, that contains retrieved documents | |
""" | |
steps = state["steps"] | |
question = state["question"] | |
documents = retriever.invoke(question) | |
steps.append("retrieve_documents") | |
return {"documents": documents, "question": question, "steps": steps} | |
def generate(state,QA_chain): | |
""" | |
Generate answer | |
""" | |
question = state["question"] | |
documents = state["documents"] | |
generation = QA_chain.stream({"documents": documents, "question": question}) | |
steps = state["steps"] | |
steps.append("generate_answer") | |
generation_count = state["generation_count"] | |
generation_count += 1 | |
return { | |
"documents": documents, | |
"question": question, | |
"generation": generation, | |
"steps": steps, | |
"generation_count": generation_count # Include generation_count in return | |
} | |
def grade_documents(state, retrieval_grader): | |
question = state["question"] | |
documents = state["documents"] | |
steps = state["steps"] | |
steps.append("grade_document_retrieval") | |
filtered_docs = [] | |
web_results_list = [] | |
search = "No" | |
for d in documents: | |
# Call the grading function | |
score = retrieval_grader.invoke({"question": question, "documents": d}) | |
print(f"Grader output for document: {score}") # Detailed debugging output | |
# Extract the grade | |
grade = getattr(score, 'binary_score', None) | |
if grade and grade.lower() in ["yes", "true", "1",'taip']: | |
filtered_docs.append(d) | |
elif len(filtered_docs) < 4: | |
search = "Yes" | |
# Check the decision-making process | |
print(f"Final decision - Perform web search: {search}") | |
print(f"Filtered documents count: {len(filtered_docs)}") | |
return { | |
"documents": filtered_docs, | |
"question": question, | |
"search": search, | |
"steps": steps, | |
} | |
def clean_exa_document(doc): | |
""" | |
Extracts and retains only the title, url, text, and summary from the exa result document. | |
""" | |
return { | |
" Pavadinimas: ": doc.title, | |
" Apibendrinimas: ": doc.summary, | |
" Straipnsio internetinis adresas: ": doc.url, | |
" Tekstas: ": doc.text | |
} | |
def web_search(state): | |
question = state["question"] | |
documents = state.get("documents", []) | |
steps = state["steps"] | |
steps.append("web_search") | |
k = 8 - len(documents) | |
web_results_list = [] | |
# Fetch results from exa | |
exa_results_raw = exa.search_and_contents( | |
query=question, | |
start_published_date="2018-01-01T22:00:01.000Z", | |
type="keyword", | |
num_results=2, | |
text={"max_characters": 7000}, | |
summary={ | |
"query": "Tell in summary a meaning about what is article written. This summary has to be written in a way to be related to {question} Provide facts, be concise. Do it in Lithuanian language." | |
}, | |
include_domains=[ "infolex.lt", "vmi.lt", "lrs.lt", "e-seimas.lrs.lt", "teise.pro",'lt.wikipedia.org', 'teismai.lt' ], | |
) | |
# Extract results | |
exa_results = exa_results_raw.results if hasattr(exa_results_raw, "results") else [] | |
cleaned_exa_results = [clean_exa_document(doc) for doc in exa_results] | |
if len(cleaned_exa_results) <1: | |
web_results = GoogleSerperAPIWrapper(k=2, gl="lt", hl="lt", type="search").results(question) | |
formatted_documents = format_google_results_search(web_results) | |
web_results_list.extend(formatted_documents if isinstance(formatted_documents, list) else [formatted_documents]) | |
combined_documents = documents + cleaned_exa_results +web_results_list | |
else: | |
combined_documents = documents + cleaned_exa_results | |
return {"documents": combined_documents, "question": question, "steps": steps} | |
def decide_to_generate(state): | |
""" | |
Determines whether to generate an answer, or re-generate a question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Binary decision for next node to call | |
""" | |
search = state["search"] | |
if search == "Yes": | |
return "search" | |
else: | |
return "generate" | |
def decide_to_generate2(state): | |
""" | |
Determines whether to generate an answer, or re-generate a question. | |
Args: | |
state (dict): The current graph state | |
Returns: | |
str: Binary decision for next node to call | |
""" | |
search = state["search"] | |
if search == "Yes": | |
return "search" | |
else: | |
return "generate" |