Spaces:
Sleeping
Sleeping
import streamlit as st | |
import time | |
from langchain.agents import initialize_agent, Tool | |
from langchain.agents import AgentType | |
from langchain.memory import ConversationBufferWindowMemory | |
from langchain.prompts import PromptTemplate | |
import requests | |
class LLM: | |
def __init__(self, tools, model, search): | |
if not isinstance(tools, list) or not all(isinstance(t, Tool) for t in tools): | |
raise ValueError("Tools must be a list of Tool objects") | |
web_search_available = any( | |
tool.name == "Web Search" for tool in tools | |
) | |
if not web_search_available: | |
raise ValueError("Web Search tool must be included in tools list") | |
self.tools = tools | |
self.model = model | |
self.search = search | |
self.agent = initialize_agent( | |
self.tools, | |
self.model, | |
agent_type=AgentType.SELF_ASK_WITH_SEARCH, | |
verbose=True, | |
max_iterations=5, | |
handle_parsing_errors=True, | |
early_stopping_method="generate", | |
memory=ConversationBufferWindowMemory(k=1) | |
) | |
def perform_web_search(self, query, max_retries=5, delay=1, timeout=8): | |
retries = 0 | |
while retries < max_retries: | |
try: | |
search_results = self.search.run(query, timeout=timeout) | |
if search_results: | |
return search_results[:1500] | |
except requests.exceptions.Timeout: | |
retries += 1 | |
st.warning(f"Web search timed out. Retrying ({retries}/{max_retries})...") | |
time.sleep(delay) | |
except Exception as e: | |
retries += 1 | |
st.warning(f"Web search failed. Retrying ({retries}/{max_retries})... Error: {e}") | |
time.sleep(delay) | |
return "NaN" | |
def get_llm_response(self, entity, query_type, web_results): | |
prompt = PromptTemplate( | |
template=""" | |
You are a highly skilled information extractor. Your job is to extract the most relevant {query_type} from the following Web Search Results. | |
Provide the exact value requested and return only that value—no explanations, context, or irrelevant information. | |
Entity: {entity} | |
Information to Extract: {query_type} | |
Web Search Results: | |
{web_results} | |
If you cannot find relevant information, return "NaN". Do not return anything else. | |
Your extracted response: | |
""", | |
input_variables=["entity", "query_type", "web_results"] | |
) | |
try: | |
response = self.agent.invoke({ | |
"input": prompt.format( | |
query_type=query_type, | |
entity=entity, | |
web_results=web_results, | |
) | |
}) | |
raw_response = response.get("output", "").strip() | |
if raw_response: | |
return raw_response | |
else: | |
return "NaN" | |
except Exception as e: | |
st.error(f"Error processing response: {str(e)}") | |
return "NaN" | |
def refine_answer_with_searches(self, entity, query_type, max_retries=2): | |
search_query = f"{entity} current {query_type}" | |
search_results = self.perform_web_search(search_query) | |
extracted_answer = self.get_llm_response(entity, query_type, search_results) | |
if extracted_answer == "NaN" and max_retries > 0: | |
alternative_query = f"{entity} {query_type} detailed information" | |
search_results = self.perform_web_search(alternative_query) | |
extracted_answer = self.get_llm_response(entity, query_type, search_results) | |
return extracted_answer, search_results | |